Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve FeeUpdate tests #661

Merged
merged 3 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pathfinding_service/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class InvalidCapacityUpdate(InvalidGlobalMessage):
pass


class InvalidPFSFeeUpdate(InvalidGlobalMessage):
class InvalidFeeUpdate(InvalidGlobalMessage):
pass


Expand Down
9 changes: 6 additions & 3 deletions src/pathfinding_service/model/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

import marshmallow
from eth_utils import to_checksum_address
from marshmallow.fields import NaiveDateTime
from marshmallow_dataclass import add_schema

from pathfinding_service.constants import DEFAULT_REVEAL_TIMEOUT
from pathfinding_service.exceptions import InvalidPFSFeeUpdate
from pathfinding_service.exceptions import InvalidFeeUpdate
from raiden.transfer.mediated_transfer.mediation_fee import FeeScheduleState as FeeScheduleRaiden
from raiden.utils.typing import (
Address,
Expand All @@ -22,7 +23,9 @@

@dataclass
class FeeSchedule(FeeScheduleRaiden):
timestamp: datetime = datetime(2000, 1, 1)
timestamp: datetime = field(
metadata={"marshmallow_field": NaiveDateTime()}, default=datetime(2000, 1, 1)
)

@classmethod
def from_raiden(cls, fee_schedule: FeeScheduleRaiden, timestamp: datetime) -> "FeeSchedule":
Expand Down Expand Up @@ -145,7 +148,7 @@ def update_capacity(

def set_fee_schedule(self, fee_schedule: FeeSchedule) -> None:
if self.fee_schedule_sender.timestamp >= fee_schedule.timestamp:
raise InvalidPFSFeeUpdate("Timestamp must increase between fee updates")
raise InvalidFeeUpdate("Timestamp must increase between fee updates")
if self.reverse:
self.channel.fee_schedule2 = fee_schedule
else:
Expand Down
4 changes: 2 additions & 2 deletions src/pathfinding_service/model/token_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
DIVERSITY_PEN_DEFAULT,
FEE_PEN_DEFAULT,
)
from pathfinding_service.exceptions import InvalidPFSFeeUpdate
from pathfinding_service.exceptions import InvalidFeeUpdate
from pathfinding_service.model.channel import Channel, ChannelView, FeeSchedule
from raiden.messages.path_finding_service import PFSCapacityUpdate, PFSFeeUpdate
from raiden.network.transport.matrix import AddressReachability
Expand Down Expand Up @@ -313,7 +313,7 @@ def handle_channel_fee_update(self, message: PFSFeeUpdate) -> Channel:
# We don't really care about the time, but if we accept a time far
# in the future, the client will have problems sending fee updates
# with increasing time after fixing his clock.
raise InvalidPFSFeeUpdate("Timestamp is in the future")
raise InvalidFeeUpdate("Timestamp is in the future")
channel_id = message.canonical_identifier.channel_identifier
participants = self.channel_id_to_addresses[channel_id]
other_participant = (set(participants) - {message.updating_participant}).pop()
Expand Down
46 changes: 30 additions & 16 deletions src/pathfinding_service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from pathfinding_service.database import PFSDatabase
from pathfinding_service.exceptions import (
InvalidCapacityUpdate,
InvalidFeeUpdate,
InvalidGlobalMessage,
InvalidPFSFeeUpdate,
)
from pathfinding_service.model import TokenNetwork
from pathfinding_service.model.channel import Channel
Expand Down Expand Up @@ -283,35 +283,49 @@ def defer_message_until_channel_is_open(self, message: DeferableMessage) -> None
)
self.database.insert_waiting_message(message)

def on_fee_update(self, message: PFSFeeUpdate) -> Optional[Channel]:
if message.sender != message.updating_participant:
raise InvalidPFSFeeUpdate("Invalid sender recovered from signature in PFSFeeUpdate")
def _validate_pfs_fee_update(self, message: PFSFeeUpdate) -> TokenNetwork:
# check if chain_id matches
if message.canonical_identifier.chain_identifier != self.chain_id:
raise InvalidFeeUpdate("Received Fee Update with unknown chain identifier")

# check if token network exists
token_network = self.get_token_network(message.canonical_identifier.token_network_address)
if not token_network:
return None
if token_network is None:
raise InvalidFeeUpdate("Received Fee Update with unknown token network")

log.debug("Received Fee Update", message=message)
# check signature of Capacity Update
if message.sender != message.updating_participant:
raise InvalidFeeUpdate("Fee Update not signed correctly")

if (
message.canonical_identifier.channel_identifier
not in token_network.channel_id_to_addresses
):
# check if channel exists
channel_identifier = message.canonical_identifier.channel_identifier
if channel_identifier not in token_network.channel_id_to_addresses:
raise DeferMessage(message)

# check if participants fit to channel id
participants = token_network.channel_id_to_addresses[channel_identifier]
if message.updating_participant not in participants:
raise InvalidFeeUpdate("Sender of Fee Update does not match the internal channel")

# check that timestamp has no timezone
if message.timestamp.tzinfo is not None:
raise InvalidFeeUpdate("Timestamp of Fee Update should not contain timezone")

return token_network

def on_fee_update(self, message: PFSFeeUpdate) -> Optional[Channel]:
token_network = self._validate_pfs_fee_update(message)
log.debug("Received Fee Update", message=message)

return token_network.handle_channel_fee_update(message)

def _validate_pfs_capacity_update(self, message: PFSCapacityUpdate) -> TokenNetwork:
token_network_address = TokenNetworkAddress(
message.canonical_identifier.token_network_address
)

# check if chain_id matches
if message.canonical_identifier.chain_identifier != self.chain_id:
raise InvalidCapacityUpdate("Received Capacity Update with unknown chain identifier")

# check if token network exists
token_network = self.get_token_network(token_network_address)
token_network = self.get_token_network(message.canonical_identifier.token_network_address)
if token_network is None:
raise InvalidCapacityUpdate("Received Capacity Update with unknown token network")

Expand Down
29 changes: 28 additions & 1 deletion tests/libs/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
import json
import sys
import time
from datetime import timedelta
from datetime import datetime, timedelta, timezone
from unittest.mock import Mock, patch

import pytest
from eth_utils import encode_hex, to_canonical_address
from tests.pathfinding.test_fee_updates import (
PRIVATE_KEY_1,
PRIVATE_KEY_1_ADDRESS,
get_fee_update_message,
)

from monitoring_service.states import HashedBalanceProof
from raiden.messages.monitoring_service import RequestMonitoring
Expand Down Expand Up @@ -67,6 +72,28 @@ def test_deserialize_messages_invalid_sender(request_monitoring_message):
assert len(messages) == 0


def test_deserialize_checks_datetimes_in_messages():
invalid_fee_update = get_fee_update_message(
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
timestamp=datetime.now(timezone.utc),
)
message = MessageSerializer.serialize(invalid_fee_update)

messages = deserialize_messages(data=message, peer_address=PRIVATE_KEY_1_ADDRESS)
assert len(messages) == 0

valid_fee_update = get_fee_update_message(
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
timestamp=datetime.utcnow(),
)
message = MessageSerializer.serialize(valid_fee_update)

messages = deserialize_messages(data=message, peer_address=PRIVATE_KEY_1_ADDRESS)
assert len(messages) == 1


def test_deserialize_messages_valid_message(request_monitoring_message):
message = MessageSerializer.serialize(request_monitoring_message)

Expand Down
171 changes: 171 additions & 0 deletions tests/pathfinding/test_fee_updates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""
The tests in this module mock Fee Updates and call on_fee_update().

The Fee Updates show different correct and incorrect values to test all edge cases
"""
from datetime import datetime, timezone

import pytest
from eth_utils import decode_hex, to_canonical_address

from pathfinding_service.exceptions import InvalidFeeUpdate
from pathfinding_service.model import TokenNetwork
from pathfinding_service.service import DeferMessage, PathfindingService
from raiden.constants import EMPTY_SIGNATURE
from raiden.messages.path_finding_service import PFSFeeUpdate
from raiden.transfer.identifiers import CanonicalIdentifier
from raiden.transfer.mediated_transfer.mediation_fee import FeeScheduleState
from raiden.utils.signer import LocalSigner
from raiden.utils.typing import (
Address,
BlockTimeout,
ChainID,
ChannelID,
FeeAmount,
ProportionalFeeAmount,
TokenNetworkAddress,
)
from raiden_libs.utils import private_key_to_address

DEFAULT_TOKEN_NETWORK_ADDRESS = TokenNetworkAddress(
decode_hex("0x6e46B62a245D9EE7758B8DdCCDD1B85fF56B9Bc9")
)
PRIVATE_KEY_1 = bytes([1] * 32)
PRIVATE_KEY_1_ADDRESS = private_key_to_address(PRIVATE_KEY_1)
PRIVATE_KEY_2 = bytes([2] * 32)
PRIVATE_KEY_2_ADDRESS = private_key_to_address(PRIVATE_KEY_2)
PRIVATE_KEY_3 = bytes([3] * 32)
PRIVATE_KEY_3_ADDRESS = private_key_to_address(PRIVATE_KEY_3)
DEFAULT_CHANNEL_ID = ChannelID(0)


def setup_channel(service: PathfindingService) -> TokenNetwork:
token_network = TokenNetwork(token_network_address=DEFAULT_TOKEN_NETWORK_ADDRESS)
service.token_networks[token_network.address] = token_network

token_network.handle_channel_opened_event(
channel_identifier=DEFAULT_CHANNEL_ID,
participant1=PRIVATE_KEY_1_ADDRESS,
participant2=PRIVATE_KEY_2_ADDRESS,
settle_timeout=BlockTimeout(15),
)

# Check that the new channel has id == 0
assert set(token_network.channel_id_to_addresses[DEFAULT_CHANNEL_ID]) == {
PRIVATE_KEY_1_ADDRESS,
PRIVATE_KEY_2_ADDRESS,
}

return token_network


def get_fee_update_message( # pylint: disable=too-many-arguments
updating_participant: Address,
chain_identifier=ChainID(1),
channel_identifier=DEFAULT_CHANNEL_ID,
token_network_address: TokenNetworkAddress = DEFAULT_TOKEN_NETWORK_ADDRESS,
fee_schedule: FeeScheduleState = FeeScheduleState(
cap_fees=True, flat=FeeAmount(1), proportional=ProportionalFeeAmount(1)
),
timestamp: datetime = datetime.utcnow(),
privkey_signer: bytes = PRIVATE_KEY_1,
) -> PFSFeeUpdate:
fee_message = PFSFeeUpdate(
canonical_identifier=CanonicalIdentifier(
chain_identifier=chain_identifier,
channel_identifier=channel_identifier,
token_network_address=token_network_address,
),
updating_participant=updating_participant,
fee_schedule=fee_schedule,
timestamp=timestamp,
signature=EMPTY_SIGNATURE,
)

fee_message.sign(LocalSigner(privkey_signer))

return fee_message


def test_pfs_rejects_fee_update_with_wrong_chain_id(
pathfinding_service_web3_mock: PathfindingService,
):
setup_channel(pathfinding_service_web3_mock)

message = get_fee_update_message(
chain_identifier=ChainID(121212),
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
)

with pytest.raises(InvalidFeeUpdate) as exinfo:
pathfinding_service_web3_mock.on_fee_update(message)
assert "unknown chain identifier" in str(exinfo.value)


def test_pfs_rejects_capacity_update_with_wrong_token_network_address(
pathfinding_service_web3_mock: PathfindingService,
):
setup_channel(pathfinding_service_web3_mock)

message = get_fee_update_message(
token_network_address=TokenNetworkAddress(to_canonical_address("0x" + "1" * 40)),
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
)

with pytest.raises(InvalidFeeUpdate) as exinfo:
pathfinding_service_web3_mock.on_fee_update(message)
assert "unknown token network" in str(exinfo.value)


def test_pfs_rejects_capacity_update_with_wrong_channel_identifier(
pathfinding_service_web3_mock: PathfindingService,
):
setup_channel(pathfinding_service_web3_mock)

message = get_fee_update_message(
channel_identifier=ChannelID(35),
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
)

with pytest.raises(DeferMessage):
pathfinding_service_web3_mock.on_fee_update(message)


def test_pfs_rejects_fee_update_with_incorrect_signature(
pathfinding_service_web3_mock: PathfindingService,
):
setup_channel(pathfinding_service_web3_mock)

message = get_fee_update_message(
updating_participant=PRIVATE_KEY_1_ADDRESS, privkey_signer=PRIVATE_KEY_3,
)

with pytest.raises(InvalidFeeUpdate) as exinfo:
pathfinding_service_web3_mock.on_fee_update(message)
assert "Fee Update not signed correctly" in str(exinfo.value)


def test_pfs_rejects_fee_update_with_incorrect_timestamp(
pathfinding_service_web3_mock: PathfindingService,
):
setup_channel(pathfinding_service_web3_mock)

message = get_fee_update_message(
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
timestamp=datetime.now(tz=timezone.utc),
)

with pytest.raises(InvalidFeeUpdate) as exinfo:
pathfinding_service_web3_mock.on_fee_update(message)
assert "Fee Update should not contain timezone" in str(exinfo.value)

valid_message = get_fee_update_message(
updating_participant=PRIVATE_KEY_1_ADDRESS,
privkey_signer=PRIVATE_KEY_1,
timestamp=datetime.utcnow(),
)
pathfinding_service_web3_mock.on_fee_update(valid_message)
21 changes: 0 additions & 21 deletions tests/pathfinding/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest
from eth_utils import to_checksum_address

from pathfinding_service import exceptions
from pathfinding_service.model.token_network import PFSFeeUpdate
from pathfinding_service.service import PathfindingService
from raiden.constants import EMPTY_SIGNATURE
Expand Down Expand Up @@ -292,26 +291,6 @@ def test_update_fee(order, pathfinding_service_mock, token_network_model):
assert getattr(cv.fee_schedule_sender, key) == getattr(fee_schedule, key)


def test_invalid_fee_update(pathfinding_service_mock, token_network_model):
setup_channel(pathfinding_service_mock, token_network_model)

fee_update = PFSFeeUpdate(
canonical_identifier=CanonicalIdentifier(
chain_identifier=ChainID(1),
token_network_address=token_network_model.address,
channel_identifier=ChannelID(1),
),
updating_participant=PARTICIPANT1,
fee_schedule=FeeScheduleState(),
timestamp=datetime.utcnow(),
signature=EMPTY_SIGNATURE,
)

# bad/missing signature
with pytest.raises(exceptions.InvalidPFSFeeUpdate):
pathfinding_service_mock.on_fee_update(fee_update)


def test_unhandled_message(pathfinding_service_mock, log):
unknown_message = Processed(MessageID(123), signature=EMPTY_SIGNATURE)
unknown_message.sign(LocalSigner(PARTICIPANT1_PRIVKEY))
Expand Down