Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph pruning #521

Merged
merged 2 commits into from
Aug 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions src/pathfinding_service/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from gevent.pywsgi import WSGIServer
from marshmallow import fields
from marshmallow_dataclass import add_schema
from networkx.exception import NetworkXNoPath, NodeNotFound
from web3 import Web3

import pathfinding_service.exceptions as exceptions
Expand Down Expand Up @@ -133,17 +132,17 @@ def post(self, token_network_address: str) -> Tuple[dict, int]:
value = getattr(path_req, arg)
if value is not None:
optional_args[arg] = value
try:
palango marked this conversation as resolved.
Show resolved Hide resolved
paths = token_network.get_paths(
source=path_req.from_,
target=path_req.to,
value=path_req.value,
address_to_reachability=self.pathfinding_service.address_to_reachability,
max_paths=path_req.max_paths,
**optional_args,
)
except (NetworkXNoPath, NodeNotFound):
# this is for assertion via the scenario player

paths = token_network.get_paths(
source=path_req.from_,
target=path_req.to,
value=path_req.value,
address_to_reachability=self.pathfinding_service.address_to_reachability,
max_paths=path_req.max_paths,
**optional_args,
)
# this is for assertion via the scenario player
if len(paths) == 0:
if self.debug_mode:
last_requests.append(
dict(
Expand All @@ -154,7 +153,9 @@ def post(self, token_network_address: str) -> Tuple[dict, int]:
)
)
raise exceptions.NoRouteFound(
from_=to_checksum_address(path_req.from_), to=to_checksum_address(path_req.to)
from_=to_checksum_address(path_req.from_),
to=to_checksum_address(path_req.to),
value=path_req.value,
)

# this is for assertion via the scenario player
Expand Down
61 changes: 45 additions & 16 deletions src/pathfinding_service/model/token_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import structlog
from eth_utils import to_checksum_address
from networkx import DiGraph
from networkx.exception import NetworkXNoPath, NodeNotFound

from pathfinding_service.constants import (
DEFAULT_SETTLE_TO_REVEAL_TIMEOUT_RATIO,
Expand All @@ -30,6 +31,23 @@
log = structlog.get_logger(__name__)


def prune_graph(
graph: DiGraph, address_to_reachability: Dict[Address, AddressReachability]
) -> DiGraph:
""" Prunes the given `graph` of all channels where the participants are not reachable. """
pruned_graph = DiGraph()
for p1, p2 in graph.edges:
nodes_online = (
address_to_reachability[p1] == AddressReachability.REACHABLE
and address_to_reachability[p2] == AddressReachability.REACHABLE
)
if nodes_online:
pruned_graph.add_edge(p1, p2, view=graph[p1][p2]["view"])
pruned_graph.add_edge(p2, p1, view=graph[p2][p1]["view"])

return pruned_graph


def window(seq: Sequence, n: int = 2) -> Iterable[tuple]:
"""Returns a sliding window (of width n) over data from the iterable
s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...
Expand Down Expand Up @@ -295,8 +313,9 @@ def edge_weight(
no_refund_weight = 1
return 1 + diversity_weight + fee_weight + no_refund_weight

def _get_single_path( # pylint: disable=too-many-arguments
def _get_single_path( # pylint: disable=too-many-arguments, too-many-locals
self,
graph: DiGraph,
source: Address,
target: Address,
value: PaymentAmount,
Expand All @@ -306,14 +325,14 @@ def _get_single_path( # pylint: disable=too-many-arguments
fee_penalty: float,
) -> Optional[Path]:
# update edge weights
for node1, node2 in self.G.edges():
edge = self.G[node1][node2]
backwards_edge = self.G[node2][node1]
for node1, node2 in graph.edges():
edge = graph[node1][node2]
backwards_edge = graph[node2][node1]
edge["weight"] = self.edge_weight(visited, edge, backwards_edge, value, fee_penalty)

# find next path
all_paths: Iterable[List[Address]] = nx.shortest_simple_paths(
G=self.G, source=source, target=target, weight="weight"
G=graph, source=source, target=target, weight="weight"
)
try:
# skip duplicates and invalid paths
Expand Down Expand Up @@ -349,25 +368,35 @@ def get_paths( # pylint: disable=too-many-arguments

log.debug(
"Finding paths for payment",
source=to_checksum_address(source),
target=to_checksum_address(target),
source=source,
target=target,
value=value,
max_paths=max_paths,
diversity_penalty=diversity_penalty,
fee_penalty=fee_penalty,
reachabilities=address_to_reachability,
)

# TODO: improve the pruning
# Currently we make a snapshot of the currently reachable nodes, so the serached graph
# becomes smaller
pruned_graph = prune_graph(graph=self.G, address_to_reachability=address_to_reachability)

while len(paths) < max_paths:
path = self._get_single_path(
source=source,
target=target,
value=value,
address_to_reachability=address_to_reachability,
visited=visited,
disallowed_paths=[p.nodes for p in paths],
fee_penalty=fee_penalty,
)
try:
path = self._get_single_path(
graph=pruned_graph,
source=source,
target=target,
value=value,
address_to_reachability=address_to_reachability,
visited=visited,
disallowed_paths=[p.nodes for p in paths],
fee_penalty=fee_penalty,
)
except (NetworkXNoPath, NodeNotFound):
return []

if path is None:
break
paths.append(path)
Expand Down
4 changes: 2 additions & 2 deletions src/raiden_libs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union

from coincurve import PrivateKey, PublicKey
from eth_utils import keccak, to_bytes
from eth_utils import decode_hex, keccak

from raiden.utils.typing import Address

Expand All @@ -15,7 +15,7 @@ def public_key_to_address(public_key: PublicKey) -> Address:
def private_key_to_address(private_key: Union[str, bytes]) -> Address:
""" Converts a private key to an Ethereum address. """
if isinstance(private_key, str):
private_key = to_bytes(hexstr=private_key)
private_key = decode_hex(private_key)
Dominik1999 marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(private_key, bytes)
privkey = PrivateKey(private_key)
Expand Down
65 changes: 65 additions & 0 deletions tests/pathfinding/fixtures/network_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: disable=redefined-outer-name
import random
from typing import Callable, Dict, Generator, List
from unittest.mock import Mock, patch

Expand All @@ -21,6 +22,7 @@
TokenNetworkAddress,
)
from raiden_contracts.constants import CONTRACT_TOKEN_NETWORK_REGISTRY, CONTRACT_USER_DEPOSIT
from raiden_libs.utils import private_key_to_address


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -327,3 +329,66 @@ def pathfinding_service_web3_mock(
)

yield pathfinding_service


@pytest.fixture
def populate_token_network_random(
token_network_model: TokenNetwork, private_keys: List[str]
) -> None:
number_of_channels = 300
# seed for pseudo-randomness from config constant, that changes from time to time
random.seed(number_of_channels)

for channel_id_int in range(number_of_channels):
channel_id = ChannelID(channel_id_int)

private_key1, private_key2 = random.sample(private_keys, 2)
address1 = private_key_to_address(private_key1)
address2 = private_key_to_address(private_key2)
settle_timeout = 15
token_network_model.handle_channel_opened_event(
channel_id, address1, address2, settle_timeout
)

# deposit to channels
deposit1 = TokenAmount(random.randint(0, 1000))
deposit2 = TokenAmount(random.randint(0, 1000))
address1, address2 = token_network_model.channel_id_to_addresses[channel_id]
token_network_model.handle_channel_balance_update_message(
PFSCapacityUpdate(
canonical_identifier=CanonicalIdentifier(
chain_identifier=ChainID(1),
channel_identifier=channel_id,
token_network_address=TokenNetworkAddress(token_network_model.address),
),
updating_participant=address1,
other_participant=address2,
updating_nonce=Nonce(1),
other_nonce=Nonce(1),
updating_capacity=deposit1,
other_capacity=deposit2,
reveal_timeout=2,
signature=EMPTY_SIGNATURE,
),
updating_capacity_partner=TokenAmount(0),
other_capacity_partner=TokenAmount(0),
)
token_network_model.handle_channel_balance_update_message(
PFSCapacityUpdate(
canonical_identifier=CanonicalIdentifier(
chain_identifier=ChainID(1),
channel_identifier=channel_id,
token_network_address=TokenNetworkAddress(token_network_model.address),
),
updating_participant=address2,
other_participant=address1,
updating_nonce=Nonce(2),
other_nonce=Nonce(1),
updating_capacity=deposit2,
other_capacity=deposit1,
reveal_timeout=2,
signature=EMPTY_SIGNATURE,
),
updating_capacity_partner=TokenAmount(deposit1),
other_capacity_partner=TokenAmount(deposit2),
)
80 changes: 64 additions & 16 deletions tests/pathfinding/test_graphs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import random
import time
from copy import deepcopy
from typing import Dict, List

import pytest
from eth_utils import decode_hex, to_checksum_address
from networkx import NetworkXNoPath
from eth_utils import decode_hex, to_canonical_address, to_checksum_address

from pathfinding_service.constants import DIVERSITY_PEN_DEFAULT
from pathfinding_service.model import ChannelView, TokenNetwork
Expand Down Expand Up @@ -118,14 +119,14 @@ def test_routing_simple(
}

# Not connected.
with pytest.raises(NetworkXNoPath):
token_network_model.get_paths(
source=addresses[0],
target=addresses[5],
value=PaymentAmount(10),
max_paths=1,
address_to_reachability=address_to_reachability,
)
no_paths = token_network_model.get_paths(
source=addresses[0],
target=addresses[5],
value=PaymentAmount(10),
max_paths=1,
address_to_reachability=address_to_reachability,
)
assert [] == no_paths


@pytest.mark.usefixtures("populate_token_network_case_1")
Expand All @@ -137,8 +138,8 @@ def test_capacity_check(
""" Test that the mediation fees are included in the capacity check """
# First get a path without mediation fees. This must return the shortest path: 4->1->0
paths = token_network_model.get_paths(
addresses[4],
addresses[0],
source=addresses[4],
Dominik1999 marked this conversation as resolved.
Show resolved Hide resolved
target=addresses[0],
value=PaymentAmount(35),
max_paths=1,
address_to_reachability=address_to_reachability,
Expand All @@ -154,8 +155,8 @@ def test_capacity_check(
# value of 35. But 35 + 1 exceeds the capacity for channel 4->1, which is
# 35. So we should now get the next best route instead.
paths = model_with_fees.get_paths(
addresses[4],
addresses[0],
source=addresses[4],
target=addresses[0],
value=PaymentAmount(35),
max_paths=1,
address_to_reachability=address_to_reachability,
Expand All @@ -173,8 +174,8 @@ def test_routing_result_order(
):
hex_addrs = [to_checksum_address(addr) for addr in addresses]
paths = token_network_model.get_paths(
addresses[0],
addresses[2],
source=addresses[0],
target=addresses[2],
value=PaymentAmount(10),
max_paths=5,
address_to_reachability=address_to_reachability,
Expand Down Expand Up @@ -338,3 +339,50 @@ def test_reachability_target(
)
== []
)


@pytest.mark.skip("Just run it locally for now")
@pytest.mark.usefixtures("populate_token_network_random")
def test_routing_benchmark(token_network_model: TokenNetwork): # pylint: disable=too-many-locals
value = PaymentAmount(100)
G = token_network_model.G
addresses_to_reachabilities = {
node: random.choice(
(
AddressReachability.REACHABLE,
AddressReachability.UNKNOWN,
AddressReachability.UNREACHABLE,
)
)
for node in G.nodes
}

times = []
start = time.time()
for _ in range(100):
tic = time.time()
source, target = random.sample(G.nodes, 2)
paths = token_network_model.get_paths(
source=source,
target=target,
value=value,
max_paths=5,
address_to_reachability=addresses_to_reachabilities,
)

toc = time.time()
times.append(toc - tic)
end = time.time()

for path_object in paths:
path = path_object["path"]
fees = path_object["estimated_fee"]
for node1, node2 in zip(path[:-1], path[1:]):
view: ChannelView = G[to_canonical_address(node1)][to_canonical_address(node2)]["view"]
print("capacity = ", view.capacity)
print("fee sum = ", fees)
print("Paths: ", paths)
print("Mean runtime: ", sum(times) / len(times))
print("Min runtime: ", min(times))
print("Max runtime: ", max(times))
print("Total runtime: ", end - start)