Skip to content
This repository has been archived by the owner on Oct 27, 2024. It is now read-only.

Commit

Permalink
Handle sender+receiver fees for each channel view
Browse files Browse the repository at this point in the history
  • Loading branch information
karlb committed May 21, 2019
1 parent a1282fc commit 836cbfc
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 35 deletions.
6 changes: 4 additions & 2 deletions src/pathfinding_service/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,16 @@ def upsert_channel_view(self, channel_view: ChannelView) -> None:
"update_nonce",
):
cv_dict[key] = hex256(cv_dict[key])
cv_dict["fee_schedule"] = json.dumps(cv_dict["fee_schedule"])
cv_dict["fee_schedule_sender"] = json.dumps(cv_dict["fee_schedule_sender"])
cv_dict["fee_schedule_receiver"] = json.dumps(cv_dict["fee_schedule_receiver"])
self.upsert("channel_view", cv_dict)

def get_channel_views(self) -> Iterator[ChannelView]:
query = "SELECT * FROM channel_view"
for row in self.conn.execute(query):
cv_dict = dict(zip(row.keys(), row))
cv_dict["fee_schedule"] = json.loads(cv_dict["fee_schedule"])
cv_dict["fee_schedule_sender"] = json.loads(cv_dict["fee_schedule_sender"])
cv_dict["fee_schedule_receiver"] = json.loads(cv_dict["fee_schedule_receiver"])
yield ChannelView.Schema(strict=True).load(cv_dict)[0]

def delete_channel_views(self, channel_id: ChannelID) -> None:
Expand Down
34 changes: 19 additions & 15 deletions src/pathfinding_service/model/channel_view.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from bisect import bisect_right
from dataclasses import dataclass, field
from typing import Callable, ClassVar, List, Optional, Sequence, Type
from typing import ClassVar, List, Optional, Sequence, Type

import marshmallow
from marshmallow_dataclass import add_schema
Expand All @@ -27,7 +27,8 @@ def __init__(self, x_list: Sequence, y_list: Sequence):
self.slopes = [(y2 - y1) / (x2 - x1) for x1, x2, y1, y2 in intervals]

def __call__(self, x: float) -> float:
assert self.x_list[0] <= x <= self.x_list[-1]
if not self.x_list[0] <= x <= self.x_list[-1]:
raise ValueError("x out of bounds!")
if x == self.x_list[-1]:
return self.y_list[-1]
i = bisect_right(self.x_list, x) - 1
Expand All @@ -40,30 +41,32 @@ class FeeSchedule:
flat: FeeAmount = FeeAmount(0)
proportional: float = FeeAmount(0)
imbalance_penalty: Optional[List[List[TokenAmount]]] = None
_penalty_func: Callable = field(init=False)
_penalty_func: Interpolate = field(init=False, repr=False)

def __post_init__(self) -> None:
if self.imbalance_penalty:
assert isinstance(self.imbalance_penalty, list)
x_list, y_list = tuple(zip(*self.imbalance_penalty))
# see https://github.com/python/mypy/issues/2427 for type problem
self._penalty_func = Interpolate(x_list, y_list) # type: ignore
else:
self._penalty_func = lambda amount: 0 # type: ignore
self._penalty_func = Interpolate(x_list, y_list)

def fee(self, amount: TokenAmount, capacity: TokenAmount) -> FeeAmount:
imbalance_fee = self._penalty_func(capacity + amount) - self._penalty_func(capacity)
if self.imbalance_penalty:
# Total channel capacity - node capacity = balance (used as x-axis for the penalty)
balance = self._penalty_func.x_list[-1] - capacity
imbalance_fee = self._penalty_func(balance + amount) - self._penalty_func(balance)
else:
imbalance_fee = 0
return FeeAmount(round(self.flat + amount * self.proportional + imbalance_fee))

def reversed(self) -> "FeeSchedule":
if not self.imbalance_penalty:
return self
max_x = max(x for x, penalty in self.imbalance_penalty)
max_penalty = max(penalty for x, penalty in self.imbalance_penalty)
return FeeSchedule(
flat=self.flat,
proportional=self.proportional,
imbalance_penalty=[
[TokenAmount(max_x - x), penalty] for x, penalty in self.imbalance_penalty
[x, TokenAmount(max_penalty - penalty)] for x, penalty in self.imbalance_penalty
],
)

Expand All @@ -86,7 +89,8 @@ class ChannelView:
reveal_timeout: int = DEFAULT_REVEAL_TIMEOUT
deposit: TokenAmount = TokenAmount(0)
update_nonce: Nonce = Nonce(0)
fee_schedule: FeeSchedule = field(default_factory=FeeSchedule)
fee_schedule_sender: FeeSchedule = field(default_factory=FeeSchedule)
fee_schedule_receiver: FeeSchedule = field(default_factory=FeeSchedule)
Schema: ClassVar[Type[marshmallow.Schema]]

def __post_init__(self) -> None:
Expand All @@ -106,13 +110,13 @@ def update_capacity(
if reveal_timeout is not None:
self.reveal_timeout = reveal_timeout

def fee_out(self, amount: TokenAmount) -> int:
def fee_sender(self, amount: TokenAmount) -> int:
"""Return the mediation fee for this channel when transferring the given amount"""
return int(self.fee_schedule.flat + amount * self.fee_schedule.proportional)
return self.fee_schedule_sender.fee(amount, self.capacity)

def fee_in(self, amount: TokenAmount) -> int:
def fee_receiver(self, amount: TokenAmount) -> int:
"""Return the mediation fee for this channel when receiving the given amount"""
return int(self.fee_schedule.flat + amount * self.fee_schedule.proportional)
return self.fee_schedule_receiver.fee(amount, self.capacity)

def __repr__(self) -> str:
return "<ChannelView from={} to={} capacity={}>".format(
Expand Down
10 changes: 5 additions & 5 deletions src/pathfinding_service/model/token_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def __init__(
self.value = value
self.address_to_reachability = address_to_reachability
self.fees: List[FeeAmount] = [
self.G[self.nodes[i + 1]][self.nodes[i]]["view"].fee_in(self.value)
+ self.G[self.nodes[i + 1]][self.nodes[i + 2]]["view"].fee_out(self.value)
self.G[self.nodes[i]][self.nodes[i + 1]]["view"].fee_receiver(self.value)
+ self.G[self.nodes[i + 1]][self.nodes[i + 2]]["view"].fee_sender(self.value)
for i, node in enumerate(self.nodes[1:-1]) # initiator and target don't cause fees
]

Expand Down Expand Up @@ -240,8 +240,8 @@ def handle_channel_fee_update(self, message: FeeUpdate) -> None:
updating_participant=message.updating_participant,
other_participant=message.other_participant,
)
channel_view_to_partner.fee_schedule = message.fee_schedule
channel_view_from_partner.fee_schedule = message.fee_schedule.reversed()
channel_view_to_partner.fee_schedule_sender = message.fee_schedule
channel_view_from_partner.fee_schedule_receiver = message.fee_schedule.reversed()

@staticmethod
def edge_weight(
Expand All @@ -257,7 +257,7 @@ def edge_weight(
# Fees for initiator and target are included here. This promotes routes
# that are nice to the initiator's and target's capacities, but it's
# inconsistent with the estimated total fee.
fee_weight = (view.fee_out(amount) + view_from_partner.fee_in(amount)) / 1e18 * fee_penalty
fee_weight = (view.fee_sender(amount) + view.fee_receiver(amount)) / 1e18 * fee_penalty
no_refund_weight = 0
if view_from_partner.capacity < int(float(amount) * 1.1):
no_refund_weight = 1
Expand Down
3 changes: 2 additions & 1 deletion src/pathfinding_service/schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ CREATE TABLE channel_view (
reveal_timeout HEX_INT NOT NULL,
deposit HEX_INT NOT NULL,
update_nonce HEX_INT,
fee_schedule JSON,
fee_schedule_sender JSON,
fee_schedule_receiver JSON,
PRIMARY KEY (token_network_address, channel_id, participant1),
FOREIGN KEY (token_network_address)
REFERENCES token_network(address)
Expand Down
95 changes: 88 additions & 7 deletions tests/pathfinding/test_fee_schedule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from eth_utils import decode_hex

from pathfinding_service.model.channel_view import FeeSchedule
from raiden.utils.typing import FeeAmount as FA, TokenAmount as TA
from pathfinding_service.model.token_network import FeeUpdate, TokenNetwork
from raiden.network.transport.matrix.utils import AddressReachability
from raiden.transfer.identifiers import CanonicalIdentifier
from raiden.utils.typing import ChainID, FeeAmount as FA, TokenAmount as TA


def test_basic_fee():
Expand All @@ -19,9 +24,85 @@ def test_imbalance_penalty():
v_schedule = FeeSchedule(
imbalance_penalty=[[TA(0), TA(10)], [TA(50), TA(0)], [TA(100), TA(10)]]
)
assert v_schedule.fee(capacity=TA(0), amount=TA(50)) == FA(-10)
assert v_schedule.fee(capacity=TA(50), amount=TA(50)) == FA(10)
assert v_schedule.fee(capacity=TA(0), amount=TA(10)) == FA(-2)
assert v_schedule.fee(capacity=TA(10), amount=TA(10)) == FA(-2)
assert v_schedule.fee(capacity=TA(0), amount=TA(20)) == FA(-4)
assert v_schedule.fee(capacity=TA(40), amount=TA(20)) == FA(0)
assert v_schedule.fee(capacity=TA(100 - 0), amount=TA(50)) == FA(-10)
assert v_schedule.fee(capacity=TA(100 - 50), amount=TA(50)) == FA(10)
assert v_schedule.fee(capacity=TA(100 - 0), amount=TA(10)) == FA(-2)
assert v_schedule.fee(capacity=TA(100 - 10), amount=TA(10)) == FA(-2)
assert v_schedule.fee(capacity=TA(100 - 0), amount=TA(20)) == FA(-4)
assert v_schedule.fee(capacity=TA(100 - 40), amount=TA(20)) == FA(0)


class PrettyBytes(bytes):
def __repr__(self):
return "b%x" % int.from_bytes(self, byteorder="big")


def a(int_addr): # pylint: disable=invalid-name
return PrettyBytes([0] * 19 + [int_addr])


def test_fees_in_routing():
network = TokenNetwork(token_network_address=a(255))
network.address_to_reachability = {
a(1): AddressReachability.REACHABLE,
a(2): AddressReachability.REACHABLE,
a(3): AddressReachability.REACHABLE,
}
network.handle_channel_opened_event(
channel_identifier=a(100), participant1=a(1), participant2=a(2), settle_timeout=100
)
network.handle_channel_opened_event(
channel_identifier=a(101), participant1=a(2), participant2=a(3), settle_timeout=100
)
for _, _, cv in network.G.edges(data="view"):
cv.capacity = 100

# Make sure that routing works and the default fees are zero
result = network.get_paths(a(1), a(3), value=TA(10), max_paths=1)
assert len(result) == 1
assert [PrettyBytes(decode_hex(node)) for node in result[0]["path"]] == [a(1), a(2), a(3)]
assert result[0]["estimated_fee"] == 0

def set_fee(node1, node2, fee_schedule: FeeSchedule):
channel_id = network.G[node1][node2]["view"].channel_id
network.handle_channel_fee_update(
FeeUpdate(
CanonicalIdentifier(
chain_identifier=ChainID(1),
token_network_address=network.address,
channel_identifier=channel_id,
),
node1,
node2,
fee_schedule,
)
)

def estimate_fee(initator, target, value=TA(10), max_paths=1):
result = network.get_paths(initator, target, value=value, max_paths=max_paths)
return result[0]["estimated_fee"]

# Fees for the initiator are ignored
set_fee(a(1), a(2), FeeSchedule(flat=FA(1)))
assert estimate_fee(a(1), a(3)) == 0

# Node 2 demands fees for incoming transfers
set_fee(a(2), a(1), FeeSchedule(flat=FA(1)))
assert estimate_fee(a(1), a(3)) == 1

# Node 2 demands fees for outgoing transfers
set_fee(a(2), a(3), FeeSchedule(flat=FA(1)))
assert estimate_fee(a(1), a(3)) == 2

# Same fee in the opposite direction
assert estimate_fee(a(3), a(1)) == 2

# Reset fees to zero
set_fee(a(1), a(2), FeeSchedule())
set_fee(a(2), a(1), FeeSchedule())
set_fee(a(2), a(3), FeeSchedule())

# Now let's try imbalance fees
set_fee(a(2), a(3), FeeSchedule(imbalance_penalty=[[TA(0), TA(0)], [TA(200), TA(200)]]))
assert estimate_fee(a(1), a(3)) == 10
assert estimate_fee(a(3), a(1)) == -10
10 changes: 5 additions & 5 deletions tests/pathfinding/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_edge_weight(addresses):
)

# absolute fee
view.fee_schedule.flat = FeeAmount(int(0.03e18))
view.fee_schedule_sender.flat = FeeAmount(int(0.03e18))
assert (
TokenNetwork.edge_weight(
dict(), dict(view=view), dict(view=view_partner), amount=amount, fee_penalty=100
Expand All @@ -63,8 +63,8 @@ def test_edge_weight(addresses):
)

# relative fee
view.fee_schedule.flat = FeeAmount(0)
view.fee_schedule.proportional = 0.01
view.fee_schedule_sender.flat = FeeAmount(0)
view.fee_schedule_sender.proportional = 0.01
assert (
TokenNetwork.edge_weight(
dict(), dict(view=view), dict(view=view_partner), amount=amount, fee_penalty=100
Expand All @@ -89,7 +89,7 @@ def test_routing_simple(token_network_model: TokenNetwork, addresses: List[Addre
view10: ChannelView = token_network_model.G[addresses[1]][addresses[0]]["view"]

assert view01.deposit == 100
assert view01.fee_schedule.flat == 0
assert view01.fee_schedule_sender.flat == 0
assert view01.capacity == 90
assert view10.capacity == 60

Expand Down Expand Up @@ -122,7 +122,7 @@ def test_capacity_check(token_network_model: TokenNetwork, addresses: List[Addre

# New let's add mediation fees to the channel 0->1.
model_with_fees = deepcopy(token_network_model)
model_with_fees.G[addresses[1]][addresses[0]]["view"].fee_schedule.flat = 1
model_with_fees.G[addresses[1]][addresses[0]]["view"].fee_schedule_sender.flat = 1
# The transfer from 4->1 must now include 1 Token for the mediation fee
# which will be payed for the 1->0 channel in addition to the payment
# value of 35. But 35 + 1 exceeds the capacity for channel 4->1, which is
Expand Down

0 comments on commit 836cbfc

Please sign in to comment.