Skip to content

Commit

Permalink
Implement basic graph pruning
Browse files Browse the repository at this point in the history
It showed in #520 that in bigger token networks the search was blocking
the PFS for longer intervals. The main reason for that was that the path
finding was trying many routes where participants were offline.

This can easily be improved by pruning the graph on unreachable nodes
before the routing, so that the graph on which the search is done
consists only of reachable nodes.

This can be improved further in the future by also pruning channels with
to little capacity, etc.
  • Loading branch information
palango committed Aug 20, 2019
1 parent cd2da93 commit ac83828
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 47 deletions.
27 changes: 14 additions & 13 deletions src/pathfinding_service/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from gevent.pywsgi import WSGIServer
from marshmallow import fields
from marshmallow_dataclass import add_schema
from networkx.exception import NetworkXNoPath, NodeNotFound
from web3 import Web3

import pathfinding_service.exceptions as exceptions
Expand Down Expand Up @@ -133,17 +132,17 @@ def post(self, token_network_address: str) -> Tuple[dict, int]:
value = getattr(path_req, arg)
if value is not None:
optional_args[arg] = value
try:
paths = token_network.get_paths(
source=path_req.from_,
target=path_req.to,
value=path_req.value,
address_to_reachability=self.pathfinding_service.address_to_reachability,
max_paths=path_req.max_paths,
**optional_args,
)
except (NetworkXNoPath, NodeNotFound):
# this is for assertion via the scenario player

paths = token_network.get_paths(
source=path_req.from_,
target=path_req.to,
value=path_req.value,
address_to_reachability=self.pathfinding_service.address_to_reachability,
max_paths=path_req.max_paths,
**optional_args,
)
# this is for assertion via the scenario player
if len(paths) == 0:
if self.debug_mode:
last_requests.append(
dict(
Expand All @@ -154,7 +153,9 @@ def post(self, token_network_address: str) -> Tuple[dict, int]:
)
)
raise exceptions.NoRouteFound(
from_=to_checksum_address(path_req.from_), to=to_checksum_address(path_req.to)
from_=to_checksum_address(path_req.from_),
to=to_checksum_address(path_req.to),
value=path_req.value,
)

# this is for assertion via the scenario player
Expand Down
61 changes: 45 additions & 16 deletions src/pathfinding_service/model/token_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import structlog
from eth_utils import to_checksum_address
from networkx import DiGraph
from networkx.exception import NetworkXNoPath, NodeNotFound

from pathfinding_service.constants import (
DEFAULT_SETTLE_TO_REVEAL_TIMEOUT_RATIO,
Expand All @@ -30,6 +31,23 @@
log = structlog.get_logger(__name__)


def prune_graph(
graph: DiGraph, address_to_reachability: Dict[Address, AddressReachability]
) -> DiGraph:
""" Prunes the given `graph` of all channels where the participants are not reachable. """
pruned_graph = DiGraph()
for p1, p2 in graph.edges:
nodes_online = (
address_to_reachability[p1] == AddressReachability.REACHABLE
and address_to_reachability[p2] == AddressReachability.REACHABLE
)
if nodes_online:
pruned_graph.add_edge(p1, p2, view=graph[p1][p2]["view"])
pruned_graph.add_edge(p2, p1, view=graph[p2][p1]["view"])

return pruned_graph


def window(seq: Sequence, n: int = 2) -> Iterable[tuple]:
"""Returns a sliding window (of width n) over data from the iterable
s -> (s0,s1,...s[n-1]), (s1,s2,...,sn), ...
Expand Down Expand Up @@ -295,8 +313,9 @@ def edge_weight(
no_refund_weight = 1
return 1 + diversity_weight + fee_weight + no_refund_weight

def _get_single_path( # pylint: disable=too-many-arguments
def _get_single_path( # pylint: disable=too-many-arguments, too-many-locals
self,
graph: DiGraph,
source: Address,
target: Address,
value: PaymentAmount,
Expand All @@ -306,14 +325,14 @@ def _get_single_path( # pylint: disable=too-many-arguments
fee_penalty: float,
) -> Optional[Path]:
# update edge weights
for node1, node2 in self.G.edges():
edge = self.G[node1][node2]
backwards_edge = self.G[node2][node1]
for node1, node2 in graph.edges():
edge = graph[node1][node2]
backwards_edge = graph[node2][node1]
edge["weight"] = self.edge_weight(visited, edge, backwards_edge, value, fee_penalty)

# find next path
all_paths: Iterable[List[Address]] = nx.shortest_simple_paths(
G=self.G, source=source, target=target, weight="weight"
G=graph, source=source, target=target, weight="weight"
)
try:
# skip duplicates and invalid paths
Expand Down Expand Up @@ -349,25 +368,35 @@ def get_paths( # pylint: disable=too-many-arguments

log.debug(
"Finding paths for payment",
source=to_checksum_address(source),
target=to_checksum_address(target),
source=source,
target=target,
value=value,
max_paths=max_paths,
diversity_penalty=diversity_penalty,
fee_penalty=fee_penalty,
reachabilities=address_to_reachability,
)

# TODO: improve the pruning
# Currently we make a snapshot of the currently reachable nodes, so the serached graph
# becomes smaller
pruned_graph = prune_graph(graph=self.G, address_to_reachability=address_to_reachability)

while len(paths) < max_paths:
path = self._get_single_path(
source=source,
target=target,
value=value,
address_to_reachability=address_to_reachability,
visited=visited,
disallowed_paths=[p.nodes for p in paths],
fee_penalty=fee_penalty,
)
try:
path = self._get_single_path(
graph=pruned_graph,
source=source,
target=target,
value=value,
address_to_reachability=address_to_reachability,
visited=visited,
disallowed_paths=[p.nodes for p in paths],
fee_penalty=fee_penalty,
)
except (NetworkXNoPath, NodeNotFound):
return []

if path is None:
break
paths.append(path)
Expand Down
4 changes: 2 additions & 2 deletions src/raiden_libs/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union

from coincurve import PrivateKey, PublicKey
from eth_utils import keccak, to_bytes
from eth_utils import decode_hex, keccak

from raiden.utils.typing import Address

Expand All @@ -15,7 +15,7 @@ def public_key_to_address(public_key: PublicKey) -> Address:
def private_key_to_address(private_key: Union[str, bytes]) -> Address:
""" Converts a private key to an Ethereum address. """
if isinstance(private_key, str):
private_key = to_bytes(hexstr=private_key)
private_key = decode_hex(private_key)

assert isinstance(private_key, bytes)
privkey = PrivateKey(private_key)
Expand Down
31 changes: 15 additions & 16 deletions tests/pathfinding/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from copy import deepcopy
from typing import Dict, List

import networkx
import pytest
from eth_utils import decode_hex, to_canonical_address, to_checksum_address

Expand Down Expand Up @@ -120,14 +119,14 @@ def test_routing_simple(
}

# Not connected.
with pytest.raises(NetworkXNoPath):
token_network_model.get_paths(
source=addresses[0],
target=addresses[5],
value=PaymentAmount(10),
max_paths=1,
address_to_reachability=address_to_reachability,
)
no_paths = token_network_model.get_paths(
source=addresses[0],
target=addresses[5],
value=PaymentAmount(10),
max_paths=1,
address_to_reachability=address_to_reachability,
)
assert [] == no_paths


@pytest.mark.usefixtures("populate_token_network_case_1")
Expand All @@ -139,8 +138,8 @@ def test_capacity_check(
""" Test that the mediation fees are included in the capacity check """
# First get a path without mediation fees. This must return the shortest path: 4->1->0
paths = token_network_model.get_paths(
addresses[4],
addresses[0],
source=addresses[4],
target=addresses[0],
value=PaymentAmount(35),
max_paths=1,
address_to_reachability=address_to_reachability,
Expand All @@ -156,8 +155,8 @@ def test_capacity_check(
# value of 35. But 35 + 1 exceeds the capacity for channel 4->1, which is
# 35. So we should now get the next best route instead.
paths = model_with_fees.get_paths(
addresses[4],
addresses[0],
source=addresses[4],
target=addresses[0],
value=PaymentAmount(35),
max_paths=1,
address_to_reachability=address_to_reachability,
Expand All @@ -175,8 +174,8 @@ def test_routing_result_order(
):
hex_addrs = [to_checksum_address(addr) for addr in addresses]
paths = token_network_model.get_paths(
addresses[0],
addresses[2],
source=addresses[0],
target=addresses[2],
value=PaymentAmount(10),
max_paths=5,
address_to_reachability=address_to_reachability,
Expand Down Expand Up @@ -342,7 +341,7 @@ def test_reachability_target(
)


@pytest.mark.skip('Just run it locally for now')
@pytest.mark.skip("Just run it locally for now")
@pytest.mark.usefixtures("populate_token_network_random")
def test_routing_benchmark(token_network_model: TokenNetwork): # pylint: disable=too-many-locals
value = PaymentAmount(100)
Expand Down

0 comments on commit ac83828

Please sign in to comment.