diff --git a/src/pathfinding_service/api.py b/src/pathfinding_service/api.py index c9d73c4f..c7bbd679 100644 --- a/src/pathfinding_service/api.py +++ b/src/pathfinding_service/api.py @@ -18,7 +18,6 @@ from gevent.pywsgi import WSGIServer from marshmallow import fields from marshmallow_dataclass import add_schema -from networkx.exception import NetworkXNoPath, NodeNotFound from web3 import Web3 import pathfinding_service.exceptions as exceptions @@ -133,17 +132,17 @@ def post(self, token_network_address: str) -> Tuple[dict, int]: value = getattr(path_req, arg) if value is not None: optional_args[arg] = value - try: - paths = token_network.get_paths( - source=path_req.from_, - target=path_req.to, - value=path_req.value, - address_to_reachability=self.pathfinding_service.address_to_reachability, - max_paths=path_req.max_paths, - **optional_args, - ) - except (NetworkXNoPath, NodeNotFound): - # this is for assertion via the scenario player + + paths = token_network.get_paths( + source=path_req.from_, + target=path_req.to, + value=path_req.value, + address_to_reachability=self.pathfinding_service.address_to_reachability, + max_paths=path_req.max_paths, + **optional_args, + ) + # this is for assertion via the scenario player + if len(paths) == 0: if self.debug_mode: last_requests.append( dict( @@ -154,7 +153,9 @@ def post(self, token_network_address: str) -> Tuple[dict, int]: ) ) raise exceptions.NoRouteFound( - from_=to_checksum_address(path_req.from_), to=to_checksum_address(path_req.to) + from_=to_checksum_address(path_req.from_), + to=to_checksum_address(path_req.to), + value=path_req.value, ) # this is for assertion via the scenario player diff --git a/src/pathfinding_service/model/token_network.py b/src/pathfinding_service/model/token_network.py index 90195aae..e1823bf0 100644 --- a/src/pathfinding_service/model/token_network.py +++ b/src/pathfinding_service/model/token_network.py @@ -7,6 +7,7 @@ import structlog from eth_utils import to_checksum_address from networkx import DiGraph +from networkx.exception import NetworkXNoPath, NodeNotFound from pathfinding_service.constants import ( DEFAULT_SETTLE_TO_REVEAL_TIMEOUT_RATIO, @@ -30,6 +31,23 @@ log = structlog.get_logger(__name__) +def prune_graph( + graph: DiGraph, address_to_reachability: Dict[Address, AddressReachability] +) -> DiGraph: + """ Prunes the given `graph` of all channels where the participants are not reachable. """ + pruned_graph = DiGraph() + for p1, p2 in graph.edges: + nodes_online = ( + address_to_reachability[p1] == AddressReachability.REACHABLE + and address_to_reachability[p2] == AddressReachability.REACHABLE + ) + if nodes_online: + pruned_graph.add_edge(p1, p2, view=graph[p1][p2]["view"]) + pruned_graph.add_edge(p2, p1, view=graph[p2][p1]["view"]) + + return pruned_graph + + def window(seq: Sequence, n: int = 2) -> Iterable[tuple]: """Returns a sliding window (of width n) over data from the iterable s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... @@ -295,8 +313,9 @@ def edge_weight( no_refund_weight = 1 return 1 + diversity_weight + fee_weight + no_refund_weight - def _get_single_path( # pylint: disable=too-many-arguments + def _get_single_path( # pylint: disable=too-many-arguments, too-many-locals self, + graph: DiGraph, source: Address, target: Address, value: PaymentAmount, @@ -306,14 +325,14 @@ def _get_single_path( # pylint: disable=too-many-arguments fee_penalty: float, ) -> Optional[Path]: # update edge weights - for node1, node2 in self.G.edges(): - edge = self.G[node1][node2] - backwards_edge = self.G[node2][node1] + for node1, node2 in graph.edges(): + edge = graph[node1][node2] + backwards_edge = graph[node2][node1] edge["weight"] = self.edge_weight(visited, edge, backwards_edge, value, fee_penalty) # find next path all_paths: Iterable[List[Address]] = nx.shortest_simple_paths( - G=self.G, source=source, target=target, weight="weight" + G=graph, source=source, target=target, weight="weight" ) try: # skip duplicates and invalid paths @@ -349,8 +368,8 @@ def get_paths( # pylint: disable=too-many-arguments log.debug( "Finding paths for payment", - source=to_checksum_address(source), - target=to_checksum_address(target), + source=source, + target=target, value=value, max_paths=max_paths, diversity_penalty=diversity_penalty, @@ -358,16 +377,26 @@ def get_paths( # pylint: disable=too-many-arguments reachabilities=address_to_reachability, ) + # TODO: improve the pruning + # Currently we make a snapshot of the currently reachable nodes, so the serached graph + # becomes smaller + pruned_graph = prune_graph(graph=self.G, address_to_reachability=address_to_reachability) + while len(paths) < max_paths: - path = self._get_single_path( - source=source, - target=target, - value=value, - address_to_reachability=address_to_reachability, - visited=visited, - disallowed_paths=[p.nodes for p in paths], - fee_penalty=fee_penalty, - ) + try: + path = self._get_single_path( + graph=pruned_graph, + source=source, + target=target, + value=value, + address_to_reachability=address_to_reachability, + visited=visited, + disallowed_paths=[p.nodes for p in paths], + fee_penalty=fee_penalty, + ) + except (NetworkXNoPath, NodeNotFound): + return [] + if path is None: break paths.append(path) diff --git a/src/raiden_libs/utils.py b/src/raiden_libs/utils.py index 33e43224..c97651cb 100644 --- a/src/raiden_libs/utils.py +++ b/src/raiden_libs/utils.py @@ -1,7 +1,7 @@ from typing import Union from coincurve import PrivateKey, PublicKey -from eth_utils import keccak, to_bytes +from eth_utils import decode_hex, keccak from raiden.utils.typing import Address @@ -15,7 +15,7 @@ def public_key_to_address(public_key: PublicKey) -> Address: def private_key_to_address(private_key: Union[str, bytes]) -> Address: """ Converts a private key to an Ethereum address. """ if isinstance(private_key, str): - private_key = to_bytes(hexstr=private_key) + private_key = decode_hex(private_key) assert isinstance(private_key, bytes) privkey = PrivateKey(private_key) diff --git a/tests/pathfinding/test_graphs.py b/tests/pathfinding/test_graphs.py index 6048c9bb..e2d6be35 100644 --- a/tests/pathfinding/test_graphs.py +++ b/tests/pathfinding/test_graphs.py @@ -3,7 +3,6 @@ from copy import deepcopy from typing import Dict, List -import networkx import pytest from eth_utils import decode_hex, to_canonical_address, to_checksum_address @@ -120,14 +119,14 @@ def test_routing_simple( } # Not connected. - with pytest.raises(NetworkXNoPath): - token_network_model.get_paths( - source=addresses[0], - target=addresses[5], - value=PaymentAmount(10), - max_paths=1, - address_to_reachability=address_to_reachability, - ) + no_paths = token_network_model.get_paths( + source=addresses[0], + target=addresses[5], + value=PaymentAmount(10), + max_paths=1, + address_to_reachability=address_to_reachability, + ) + assert [] == no_paths @pytest.mark.usefixtures("populate_token_network_case_1") @@ -139,8 +138,8 @@ def test_capacity_check( """ Test that the mediation fees are included in the capacity check """ # First get a path without mediation fees. This must return the shortest path: 4->1->0 paths = token_network_model.get_paths( - addresses[4], - addresses[0], + source=addresses[4], + target=addresses[0], value=PaymentAmount(35), max_paths=1, address_to_reachability=address_to_reachability, @@ -156,8 +155,8 @@ def test_capacity_check( # value of 35. But 35 + 1 exceeds the capacity for channel 4->1, which is # 35. So we should now get the next best route instead. paths = model_with_fees.get_paths( - addresses[4], - addresses[0], + source=addresses[4], + target=addresses[0], value=PaymentAmount(35), max_paths=1, address_to_reachability=address_to_reachability, @@ -175,8 +174,8 @@ def test_routing_result_order( ): hex_addrs = [to_checksum_address(addr) for addr in addresses] paths = token_network_model.get_paths( - addresses[0], - addresses[2], + source=addresses[0], + target=addresses[2], value=PaymentAmount(10), max_paths=5, address_to_reachability=address_to_reachability, @@ -342,7 +341,7 @@ def test_reachability_target( ) -@pytest.mark.skip('Just run it locally for now') +@pytest.mark.skip("Just run it locally for now") @pytest.mark.usefixtures("populate_token_network_random") def test_routing_benchmark(token_network_model: TokenNetwork): # pylint: disable=too-many-locals value = PaymentAmount(100)