diff --git a/src/pathfinding_service/api.py b/src/pathfinding_service/api.py index c9d73c4f..c7bbd679 100644 --- a/src/pathfinding_service/api.py +++ b/src/pathfinding_service/api.py @@ -18,7 +18,6 @@ from gevent.pywsgi import WSGIServer from marshmallow import fields from marshmallow_dataclass import add_schema -from networkx.exception import NetworkXNoPath, NodeNotFound from web3 import Web3 import pathfinding_service.exceptions as exceptions @@ -133,17 +132,17 @@ def post(self, token_network_address: str) -> Tuple[dict, int]: value = getattr(path_req, arg) if value is not None: optional_args[arg] = value - try: - paths = token_network.get_paths( - source=path_req.from_, - target=path_req.to, - value=path_req.value, - address_to_reachability=self.pathfinding_service.address_to_reachability, - max_paths=path_req.max_paths, - **optional_args, - ) - except (NetworkXNoPath, NodeNotFound): - # this is for assertion via the scenario player + + paths = token_network.get_paths( + source=path_req.from_, + target=path_req.to, + value=path_req.value, + address_to_reachability=self.pathfinding_service.address_to_reachability, + max_paths=path_req.max_paths, + **optional_args, + ) + # this is for assertion via the scenario player + if len(paths) == 0: if self.debug_mode: last_requests.append( dict( @@ -154,7 +153,9 @@ def post(self, token_network_address: str) -> Tuple[dict, int]: ) ) raise exceptions.NoRouteFound( - from_=to_checksum_address(path_req.from_), to=to_checksum_address(path_req.to) + from_=to_checksum_address(path_req.from_), + to=to_checksum_address(path_req.to), + value=path_req.value, ) # this is for assertion via the scenario player diff --git a/src/pathfinding_service/model/token_network.py b/src/pathfinding_service/model/token_network.py index 90195aae..e1823bf0 100644 --- a/src/pathfinding_service/model/token_network.py +++ b/src/pathfinding_service/model/token_network.py @@ -7,6 +7,7 @@ import structlog from eth_utils import to_checksum_address from networkx import DiGraph +from networkx.exception import NetworkXNoPath, NodeNotFound from pathfinding_service.constants import ( DEFAULT_SETTLE_TO_REVEAL_TIMEOUT_RATIO, @@ -30,6 +31,23 @@ log = structlog.get_logger(__name__) +def prune_graph( + graph: DiGraph, address_to_reachability: Dict[Address, AddressReachability] +) -> DiGraph: + """ Prunes the given `graph` of all channels where the participants are not reachable. """ + pruned_graph = DiGraph() + for p1, p2 in graph.edges: + nodes_online = ( + address_to_reachability[p1] == AddressReachability.REACHABLE + and address_to_reachability[p2] == AddressReachability.REACHABLE + ) + if nodes_online: + pruned_graph.add_edge(p1, p2, view=graph[p1][p2]["view"]) + pruned_graph.add_edge(p2, p1, view=graph[p2][p1]["view"]) + + return pruned_graph + + def window(seq: Sequence, n: int = 2) -> Iterable[tuple]: """Returns a sliding window (of width n) over data from the iterable s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ... @@ -295,8 +313,9 @@ def edge_weight( no_refund_weight = 1 return 1 + diversity_weight + fee_weight + no_refund_weight - def _get_single_path( # pylint: disable=too-many-arguments + def _get_single_path( # pylint: disable=too-many-arguments, too-many-locals self, + graph: DiGraph, source: Address, target: Address, value: PaymentAmount, @@ -306,14 +325,14 @@ def _get_single_path( # pylint: disable=too-many-arguments fee_penalty: float, ) -> Optional[Path]: # update edge weights - for node1, node2 in self.G.edges(): - edge = self.G[node1][node2] - backwards_edge = self.G[node2][node1] + for node1, node2 in graph.edges(): + edge = graph[node1][node2] + backwards_edge = graph[node2][node1] edge["weight"] = self.edge_weight(visited, edge, backwards_edge, value, fee_penalty) # find next path all_paths: Iterable[List[Address]] = nx.shortest_simple_paths( - G=self.G, source=source, target=target, weight="weight" + G=graph, source=source, target=target, weight="weight" ) try: # skip duplicates and invalid paths @@ -349,8 +368,8 @@ def get_paths( # pylint: disable=too-many-arguments log.debug( "Finding paths for payment", - source=to_checksum_address(source), - target=to_checksum_address(target), + source=source, + target=target, value=value, max_paths=max_paths, diversity_penalty=diversity_penalty, @@ -358,16 +377,26 @@ def get_paths( # pylint: disable=too-many-arguments reachabilities=address_to_reachability, ) + # TODO: improve the pruning + # Currently we make a snapshot of the currently reachable nodes, so the serached graph + # becomes smaller + pruned_graph = prune_graph(graph=self.G, address_to_reachability=address_to_reachability) + while len(paths) < max_paths: - path = self._get_single_path( - source=source, - target=target, - value=value, - address_to_reachability=address_to_reachability, - visited=visited, - disallowed_paths=[p.nodes for p in paths], - fee_penalty=fee_penalty, - ) + try: + path = self._get_single_path( + graph=pruned_graph, + source=source, + target=target, + value=value, + address_to_reachability=address_to_reachability, + visited=visited, + disallowed_paths=[p.nodes for p in paths], + fee_penalty=fee_penalty, + ) + except (NetworkXNoPath, NodeNotFound): + return [] + if path is None: break paths.append(path) diff --git a/src/raiden_libs/utils.py b/src/raiden_libs/utils.py index 33e43224..c97651cb 100644 --- a/src/raiden_libs/utils.py +++ b/src/raiden_libs/utils.py @@ -1,7 +1,7 @@ from typing import Union from coincurve import PrivateKey, PublicKey -from eth_utils import keccak, to_bytes +from eth_utils import decode_hex, keccak from raiden.utils.typing import Address @@ -15,7 +15,7 @@ def public_key_to_address(public_key: PublicKey) -> Address: def private_key_to_address(private_key: Union[str, bytes]) -> Address: """ Converts a private key to an Ethereum address. """ if isinstance(private_key, str): - private_key = to_bytes(hexstr=private_key) + private_key = decode_hex(private_key) assert isinstance(private_key, bytes) privkey = PrivateKey(private_key) diff --git a/tests/pathfinding/fixtures/network_service.py b/tests/pathfinding/fixtures/network_service.py index 43849c47..d068b949 100644 --- a/tests/pathfinding/fixtures/network_service.py +++ b/tests/pathfinding/fixtures/network_service.py @@ -1,4 +1,5 @@ # pylint: disable=redefined-outer-name +import random from typing import Callable, Dict, Generator, List from unittest.mock import Mock, patch @@ -21,6 +22,7 @@ TokenNetworkAddress, ) from raiden_contracts.constants import CONTRACT_TOKEN_NETWORK_REGISTRY, CONTRACT_USER_DEPOSIT +from raiden_libs.utils import private_key_to_address @pytest.fixture(scope="session") @@ -327,3 +329,66 @@ def pathfinding_service_web3_mock( ) yield pathfinding_service + + +@pytest.fixture +def populate_token_network_random( + token_network_model: TokenNetwork, private_keys: List[str] +) -> None: + number_of_channels = 300 + # seed for pseudo-randomness from config constant, that changes from time to time + random.seed(number_of_channels) + + for channel_id_int in range(number_of_channels): + channel_id = ChannelID(channel_id_int) + + private_key1, private_key2 = random.sample(private_keys, 2) + address1 = private_key_to_address(private_key1) + address2 = private_key_to_address(private_key2) + settle_timeout = 15 + token_network_model.handle_channel_opened_event( + channel_id, address1, address2, settle_timeout + ) + + # deposit to channels + deposit1 = TokenAmount(random.randint(0, 1000)) + deposit2 = TokenAmount(random.randint(0, 1000)) + address1, address2 = token_network_model.channel_id_to_addresses[channel_id] + token_network_model.handle_channel_balance_update_message( + PFSCapacityUpdate( + canonical_identifier=CanonicalIdentifier( + chain_identifier=ChainID(1), + channel_identifier=channel_id, + token_network_address=TokenNetworkAddress(token_network_model.address), + ), + updating_participant=address1, + other_participant=address2, + updating_nonce=Nonce(1), + other_nonce=Nonce(1), + updating_capacity=deposit1, + other_capacity=deposit2, + reveal_timeout=2, + signature=EMPTY_SIGNATURE, + ), + updating_capacity_partner=TokenAmount(0), + other_capacity_partner=TokenAmount(0), + ) + token_network_model.handle_channel_balance_update_message( + PFSCapacityUpdate( + canonical_identifier=CanonicalIdentifier( + chain_identifier=ChainID(1), + channel_identifier=channel_id, + token_network_address=TokenNetworkAddress(token_network_model.address), + ), + updating_participant=address2, + other_participant=address1, + updating_nonce=Nonce(2), + other_nonce=Nonce(1), + updating_capacity=deposit2, + other_capacity=deposit1, + reveal_timeout=2, + signature=EMPTY_SIGNATURE, + ), + updating_capacity_partner=TokenAmount(deposit1), + other_capacity_partner=TokenAmount(deposit2), + ) diff --git a/tests/pathfinding/test_graphs.py b/tests/pathfinding/test_graphs.py index 5e9d8733..e2d6be35 100644 --- a/tests/pathfinding/test_graphs.py +++ b/tests/pathfinding/test_graphs.py @@ -1,9 +1,10 @@ +import random +import time from copy import deepcopy from typing import Dict, List import pytest -from eth_utils import decode_hex, to_checksum_address -from networkx import NetworkXNoPath +from eth_utils import decode_hex, to_canonical_address, to_checksum_address from pathfinding_service.constants import DIVERSITY_PEN_DEFAULT from pathfinding_service.model import ChannelView, TokenNetwork @@ -118,14 +119,14 @@ def test_routing_simple( } # Not connected. - with pytest.raises(NetworkXNoPath): - token_network_model.get_paths( - source=addresses[0], - target=addresses[5], - value=PaymentAmount(10), - max_paths=1, - address_to_reachability=address_to_reachability, - ) + no_paths = token_network_model.get_paths( + source=addresses[0], + target=addresses[5], + value=PaymentAmount(10), + max_paths=1, + address_to_reachability=address_to_reachability, + ) + assert [] == no_paths @pytest.mark.usefixtures("populate_token_network_case_1") @@ -137,8 +138,8 @@ def test_capacity_check( """ Test that the mediation fees are included in the capacity check """ # First get a path without mediation fees. This must return the shortest path: 4->1->0 paths = token_network_model.get_paths( - addresses[4], - addresses[0], + source=addresses[4], + target=addresses[0], value=PaymentAmount(35), max_paths=1, address_to_reachability=address_to_reachability, @@ -154,8 +155,8 @@ def test_capacity_check( # value of 35. But 35 + 1 exceeds the capacity for channel 4->1, which is # 35. So we should now get the next best route instead. paths = model_with_fees.get_paths( - addresses[4], - addresses[0], + source=addresses[4], + target=addresses[0], value=PaymentAmount(35), max_paths=1, address_to_reachability=address_to_reachability, @@ -173,8 +174,8 @@ def test_routing_result_order( ): hex_addrs = [to_checksum_address(addr) for addr in addresses] paths = token_network_model.get_paths( - addresses[0], - addresses[2], + source=addresses[0], + target=addresses[2], value=PaymentAmount(10), max_paths=5, address_to_reachability=address_to_reachability, @@ -338,3 +339,50 @@ def test_reachability_target( ) == [] ) + + +@pytest.mark.skip("Just run it locally for now") +@pytest.mark.usefixtures("populate_token_network_random") +def test_routing_benchmark(token_network_model: TokenNetwork): # pylint: disable=too-many-locals + value = PaymentAmount(100) + G = token_network_model.G + addresses_to_reachabilities = { + node: random.choice( + ( + AddressReachability.REACHABLE, + AddressReachability.UNKNOWN, + AddressReachability.UNREACHABLE, + ) + ) + for node in G.nodes + } + + times = [] + start = time.time() + for _ in range(100): + tic = time.time() + source, target = random.sample(G.nodes, 2) + paths = token_network_model.get_paths( + source=source, + target=target, + value=value, + max_paths=5, + address_to_reachability=addresses_to_reachabilities, + ) + + toc = time.time() + times.append(toc - tic) + end = time.time() + + for path_object in paths: + path = path_object["path"] + fees = path_object["estimated_fee"] + for node1, node2 in zip(path[:-1], path[1:]): + view: ChannelView = G[to_canonical_address(node1)][to_canonical_address(node2)]["view"] + print("capacity = ", view.capacity) + print("fee sum = ", fees) + print("Paths: ", paths) + print("Mean runtime: ", sum(times) / len(times)) + print("Min runtime: ", min(times)) + print("Max runtime: ", max(times)) + print("Total runtime: ", end - start)