From ef2e2487af9ccc8260bfcb8a8ecf51120a482847 Mon Sep 17 00:00:00 2001 From: aratz-lasa Date: Thu, 27 Feb 2020 09:12:09 +0100 Subject: [PATCH 1/2] Implemented PING fully-featured First draft of PingService, which calculates RTT results --- libp2p/host/ping.py | 90 +++++++++++++++++++++++++++++++++----- libp2p/pubsub/gossipsub.py | 2 +- libp2p/tools/factories.py | 10 ++--- libp2p/tools/utils.py | 2 +- tests/host/test_ping.py | 31 ++++++++++++- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 011024519..663aed574 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -1,7 +1,13 @@ import logging +import math +import secrets +import time +from typing import Union import trio +from libp2p.exceptions import ValidationError +from libp2p.host.host_interface import IHost from libp2p.network.stream.exceptions import StreamClosed, StreamEOF, StreamReset from libp2p.network.stream.net_stream_interface import INetStream from libp2p.peer.id import ID as PeerID @@ -14,6 +20,21 @@ logger = logging.getLogger("libp2p.host.ping") +async def handle_ping(stream: INetStream) -> None: + """``handle_ping`` responds to incoming ping requests until one side errors + or closes the ``stream``.""" + peer_id = stream.muxed_conn.peer_id + + while True: + try: + should_continue = await _handle_ping(stream, peer_id) + if not should_continue: + return + except Exception: + await stream.reset() + return + + async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: """Return a boolean indicating if we expect more pings from the peer at ``peer_id``.""" @@ -45,16 +66,65 @@ async def _handle_ping(stream: INetStream, peer_id: PeerID) -> bool: return True -async def handle_ping(stream: INetStream) -> None: - """``handle_ping`` responds to incoming ping requests until one side errors - or closes the ``stream``.""" - peer_id = stream.muxed_conn.peer_id +class PingService: + """PingService executes pings and returns RTT in miliseconds.""" - while True: + def __init__(self, host: IHost): + self._host = host + + async def ping(self, peer_id: PeerID) -> int: + stream = await self._host.new_stream(peer_id, (ID,)) try: - should_continue = await _handle_ping(stream, peer_id) - if not should_continue: - return + rtt = await _ping(stream) + await _close_stream(stream) + return rtt except Exception: - await stream.reset() - return + await _close_stream(stream) + raise + + async def ping_loop( + self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf + ) -> "PingIterator": + stream = await self._host.new_stream(peer_id, (ID,)) + ping_iterator = PingIterator(stream, ping_amount) + return ping_iterator + + +class PingIterator: + def __init__(self, stream: INetStream, ping_amount: Union[int, float]): + self._stream = stream + self._ping_limit = ping_amount + self._ping_counter = 0 + + def __aiter__(self) -> "PingIterator": + return self + + async def __anext__(self) -> int: + if self._ping_counter > self._ping_limit: + await _close_stream(self._stream) + raise StopAsyncIteration + + self._ping_counter += 1 + try: + return await _ping(self._stream) + except trio.EndOfChannel: + await _close_stream(self._stream) + raise StopAsyncIteration + + +async def _ping(stream: INetStream) -> int: + ping_bytes = secrets.token_bytes(PING_LENGTH) + before = int(time.time() * 10 ** 6) # convert float of seconds to int miliseconds + await stream.write(ping_bytes) + pong_bytes = await stream.read(PING_LENGTH) + rtt = int(time.time() * 10 ** 6) - before + if ping_bytes != pong_bytes: + raise ValidationError("Invalid PING response") + return rtt + + +async def _close_stream(stream: INetStream) -> None: + try: + await stream.close() + except Exception: + pass diff --git a/libp2p/pubsub/gossipsub.py b/libp2p/pubsub/gossipsub.py index 4d25c254c..b57501ec8 100644 --- a/libp2p/pubsub/gossipsub.py +++ b/libp2p/pubsub/gossipsub.py @@ -391,7 +391,7 @@ async def heartbeat(self) -> None: await trio.sleep(self.heartbeat_interval) def mesh_heartbeat( - self + self, ) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]: peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list) peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list) diff --git a/libp2p/tools/factories.py b/libp2p/tools/factories.py index 67e265199..c6367ac7f 100644 --- a/libp2p/tools/factories.py +++ b/libp2p/tools/factories.py @@ -67,7 +67,7 @@ def security_transport_factory( @asynccontextmanager async def raw_conn_factory( - nursery: trio.Nursery + nursery: trio.Nursery, ) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]: conn_0 = None conn_1 = None @@ -351,7 +351,7 @@ async def swarm_pair_factory( @asynccontextmanager async def host_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[BasicHost, BasicHost]]: async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts: await connect(hosts[0], hosts[1]) @@ -370,7 +370,7 @@ async def swarm_conn_pair_factory( @asynccontextmanager async def mplex_conn_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[Mplex, Mplex]]: muxer_opt = {MPLEX_PROTOCOL_ID: Mplex} async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair: @@ -382,7 +382,7 @@ async def mplex_conn_pair_factory( @asynccontextmanager async def mplex_stream_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[MplexStream, MplexStream]]: async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info: mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info @@ -398,7 +398,7 @@ async def mplex_stream_pair_factory( @asynccontextmanager async def net_stream_pair_factory( - is_secure: bool + is_secure: bool, ) -> AsyncIterator[Tuple[INetStream, INetStream]]: protocol_id = TProtocol("/example/id/1") diff --git a/libp2p/tools/utils.py b/libp2p/tools/utils.py index 5a262b3b6..b6e94dee3 100644 --- a/libp2p/tools/utils.py +++ b/libp2p/tools/utils.py @@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None: def create_echo_stream_handler( - ack_prefix: str + ack_prefix: str, ) -> Callable[[INetStream], Awaitable[None]]: async def echo_stream_handler(stream: INetStream) -> None: while True: diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 7a0f8db51..3bf7783d7 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -3,7 +3,7 @@ import pytest import trio -from libp2p.host.ping import ID, PING_LENGTH +from libp2p.host.ping import ID, PING_LENGTH, PingService from libp2p.tools.factories import host_pair_factory @@ -36,3 +36,32 @@ async def test_ping_several(is_host_secure): # NOTE: this interval can be `0` for this test. await trio.sleep(0) await stream.close() + + +@pytest.mark.trio +async def test_ping_service_once(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): + ping_service = PingService(host_b) + rtt = await ping_service.ping(host_a.get_id()) + assert rtt < 10 ** 6 # rtt is in miliseconds + + +@pytest.mark.trio +async def test_ping_service_loop(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): + ping_service = PingService(host_b) + ping_loop = await ping_service.ping_loop( + host_a.get_id(), ping_amount=SOME_PING_COUNT + ) + async for rtt in ping_loop: + assert rtt < 10 ** 6 + + +@pytest.mark.trio +async def test_ping_service_loop_infinite(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): + ping_service = PingService(host_b) + ping_loop = await ping_service.ping_loop(host_a.get_id()) + with trio.move_on_after(1): # breaking loop after one second + async for rtt in ping_loop: + assert rtt < 10 ** 6 From 6811a508b26da65120debc110b91cca765a0d0e4 Mon Sep 17 00:00:00 2001 From: aratz-lasa Date: Tue, 3 Mar 2020 19:36:01 +0100 Subject: [PATCH 2/2] Added docstrings and ping-amount to 'ping' --- libp2p/host/ping.py | 41 ++++++++++++++++++++++------------------- tests/host/test_ping.py | 17 ++++++++++++++--- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/libp2p/host/ping.py b/libp2p/host/ping.py index 663aed574..6f8154f24 100644 --- a/libp2p/host/ping.py +++ b/libp2p/host/ping.py @@ -2,7 +2,7 @@ import math import secrets import time -from typing import Union +from typing import List, Optional import trio @@ -72,28 +72,38 @@ class PingService: def __init__(self, host: IHost): self._host = host - async def ping(self, peer_id: PeerID) -> int: + async def ping(self, peer_id: PeerID, ping_amount: int = 1) -> List[int]: + """method for PINGing 'n' times and returning the RTTs.""" + stream = await self._host.new_stream(peer_id, (ID,)) try: - rtt = await _ping(stream) - await _close_stream(stream) - return rtt + rtts = [ + await _ping(stream) for _ in range(ping_amount) + ] # todo: maybe it is better to run them concurrently? + await stream.close() + return rtts except Exception: - await _close_stream(stream) + await stream.close() raise async def ping_loop( - self, peer_id: PeerID, ping_amount: Union[int, float] = math.inf + self, peer_id: PeerID, ping_limit: Optional[int] = None ) -> "PingIterator": + """ + method for generating a PING iterator, so that some logic can be + implemented inbetween each PING. + + Every iteration returns the RTT + """ stream = await self._host.new_stream(peer_id, (ID,)) - ping_iterator = PingIterator(stream, ping_amount) + ping_iterator = PingIterator(stream, ping_limit) return ping_iterator class PingIterator: - def __init__(self, stream: INetStream, ping_amount: Union[int, float]): + def __init__(self, stream: INetStream, ping_limit: Optional[int]): self._stream = stream - self._ping_limit = ping_amount + self._ping_limit = ping_limit or math.inf self._ping_counter = 0 def __aiter__(self) -> "PingIterator": @@ -101,14 +111,14 @@ def __aiter__(self) -> "PingIterator": async def __anext__(self) -> int: if self._ping_counter > self._ping_limit: - await _close_stream(self._stream) + await self._stream.close() raise StopAsyncIteration self._ping_counter += 1 try: return await _ping(self._stream) except trio.EndOfChannel: - await _close_stream(self._stream) + await self._stream.close() raise StopAsyncIteration @@ -121,10 +131,3 @@ async def _ping(stream: INetStream) -> int: if ping_bytes != pong_bytes: raise ValidationError("Invalid PING response") return rtt - - -async def _close_stream(stream: INetStream) -> None: - try: - await stream.close() - except Exception: - pass diff --git a/tests/host/test_ping.py b/tests/host/test_ping.py index 3bf7783d7..9787a3adf 100644 --- a/tests/host/test_ping.py +++ b/tests/host/test_ping.py @@ -42,8 +42,19 @@ async def test_ping_several(is_host_secure): async def test_ping_service_once(is_host_secure): async with host_pair_factory(is_host_secure) as (host_a, host_b): ping_service = PingService(host_b) - rtt = await ping_service.ping(host_a.get_id()) - assert rtt < 10 ** 6 # rtt is in miliseconds + rtts = await ping_service.ping(host_a.get_id()) + assert len(rtts) == 1 + assert rtts[0] < 10 ** 6 # rtt is in miliseconds + + +@pytest.mark.trio +async def test_ping_service_several(is_host_secure): + async with host_pair_factory(is_host_secure) as (host_a, host_b): + ping_service = PingService(host_b) + rtts = await ping_service.ping(host_a.get_id(), ping_amount=SOME_PING_COUNT) + assert len(rtts) == SOME_PING_COUNT + for rtt in rtts: + assert rtt < 10 ** 6 # rtt is in miliseconds @pytest.mark.trio @@ -51,7 +62,7 @@ async def test_ping_service_loop(is_host_secure): async with host_pair_factory(is_host_secure) as (host_a, host_b): ping_service = PingService(host_b) ping_loop = await ping_service.ping_loop( - host_a.get_id(), ping_amount=SOME_PING_COUNT + host_a.get_id(), ping_limit=SOME_PING_COUNT ) async for rtt in ping_loop: assert rtt < 10 ** 6