Skip to content

Commit

Permalink
Changed 'nursery' for 'task_status'
Browse files Browse the repository at this point in the history
This way it is possible to implement a I/O agnostic Listener. It would  still have the 'task_status' default argument, but there would be no need for receiving it. And this way there is also no need for explicit nursery passing
  • Loading branch information
aratz-lasa committed Feb 26, 2020
1 parent 99f505d commit b53675e
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 25 deletions.
5 changes: 3 additions & 2 deletions libp2p/network/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ async def stream_handler(stream: INetStream) -> None:


class Swarm(Service, INetworkService):

self_id: ID
peerstore: IPeerStore
upgrader: TransportUpgrader
Expand Down Expand Up @@ -276,7 +275,9 @@ async def conn_handler(read_write_closer: ReadWriteCloser) -> None:
# I/O agnostic, we should change the API.
if self.listener_nursery is None:
raise SwarmException("swarm instance hasn't been run")
await listener.listen(maddr, self.listener_nursery)
await self.listener_nursery.start(
listener.listen, maddr # type: ignore
)

# Call notifiers since event occurred
await self.notify_listen(maddr)
Expand Down
2 changes: 1 addition & 1 deletion libp2p/pubsub/gossipsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ async def heartbeat(self) -> None:
await trio.sleep(self.heartbeat_interval)

def mesh_heartbeat(
self
self,
) -> Tuple[DefaultDict[ID, List[str]], DefaultDict[ID, List[str]]]:
peers_to_graft: DefaultDict[ID, List[str]] = defaultdict(list)
peers_to_prune: DefaultDict[ID, List[str]] = defaultdict(list)
Expand Down
12 changes: 6 additions & 6 deletions libp2p/tools/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def security_transport_factory(

@asynccontextmanager
async def raw_conn_factory(
nursery: trio.Nursery
nursery: trio.Nursery,
) -> AsyncIterator[Tuple[IRawConnection, IRawConnection]]:
conn_0 = None
conn_1 = None
Expand All @@ -81,7 +81,7 @@ async def tcp_stream_handler(stream: ReadWriteCloser) -> None:

tcp_transport = TCP()
listener = tcp_transport.create_listener(tcp_stream_handler)
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
listening_maddr = listener.get_addrs()[0]
conn_0 = await tcp_transport.dial(listening_maddr)
await event.wait()
Expand Down Expand Up @@ -351,7 +351,7 @@ async def swarm_pair_factory(

@asynccontextmanager
async def host_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[BasicHost, BasicHost]]:
async with HostFactory.create_batch_and_listen(is_secure, 2) as hosts:
await connect(hosts[0], hosts[1])
Expand All @@ -370,7 +370,7 @@ async def swarm_conn_pair_factory(

@asynccontextmanager
async def mplex_conn_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[Mplex, Mplex]]:
muxer_opt = {MPLEX_PROTOCOL_ID: Mplex}
async with swarm_conn_pair_factory(is_secure, muxer_opt=muxer_opt) as swarm_pair:
Expand All @@ -382,7 +382,7 @@ async def mplex_conn_pair_factory(

@asynccontextmanager
async def mplex_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[MplexStream, MplexStream]]:
async with mplex_conn_pair_factory(is_secure) as mplex_conn_pair_info:
mplex_conn_0, mplex_conn_1 = mplex_conn_pair_info
Expand All @@ -398,7 +398,7 @@ async def mplex_stream_pair_factory(

@asynccontextmanager
async def net_stream_pair_factory(
is_secure: bool
is_secure: bool,
) -> AsyncIterator[Tuple[INetStream, INetStream]]:
protocol_id = TProtocol("/example/id/1")

Expand Down
2 changes: 1 addition & 1 deletion libp2p/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def connect(node1: IHost, node2: IHost) -> None:


def create_echo_stream_handler(
ack_prefix: str
ack_prefix: str,
) -> Callable[[INetStream], Awaitable[None]]:
async def echo_stream_handler(stream: INetStream) -> None:
while True:
Expand Down
6 changes: 4 additions & 2 deletions libp2p/transport/listener_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from abc import ABC, abstractmethod
from typing import Tuple
from typing import Any, Tuple

from multiaddr import Multiaddr
import trio


class IListener(ABC):
@abstractmethod
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> bool:
async def listen(
self, maddr: Multiaddr, task_status: Any = trio.TASK_STATUS_IGNORED
) -> bool:
"""
put listener in listening mode and wait for incoming connections.
Expand Down
23 changes: 13 additions & 10 deletions libp2p/transport/tcp/tcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Awaitable, Callable, List, Sequence, Tuple
from typing import Any, Awaitable, Callable, List, Sequence, Tuple

from multiaddr import Multiaddr
import trio
Expand All @@ -23,8 +23,9 @@ def __init__(self, handler_function: THandler) -> None:
self.listeners = []
self.handler = handler_function

# TODO: Get rid of `nursery`?
async def listen(self, maddr: Multiaddr, nursery: trio.Nursery) -> None:
async def listen(
self, maddr: Multiaddr, task_status: TaskStatus[Any] = trio.TASK_STATUS_IGNORED
) -> None:
"""
put listener in listening mode and wait for incoming connections.
Expand All @@ -46,13 +47,15 @@ async def handler(stream: trio.SocketStream) -> None:
tcp_stream = TrioTCPStream(stream)
await self.handler(tcp_stream)

listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
)
self.listeners.extend(listeners)
async with trio.open_nursery() as nursery:
listeners = await nursery.start(
serve_tcp,
handler,
int(maddr.value_for_protocol("tcp")),
maddr.value_for_protocol("ip4"),
)
task_status.started()
self.listeners.extend(listeners)

def get_addrs(self) -> Tuple[Multiaddr, ...]:
"""
Expand Down
6 changes: 3 additions & 3 deletions tests/transport/test_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ async def handler(tcp_stream):

listener = transport.create_listener(handler)
assert len(listener.get_addrs()) == 0
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
assert len(listener.get_addrs()) == 1
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
assert len(listener.get_addrs()) == 2


Expand All @@ -41,7 +41,7 @@ async def handler(tcp_stream):
await transport.dial(Multiaddr("/ip4/127.0.0.1/tcp/1"))

listener = transport.create_listener(handler)
await listener.listen(LISTEN_MADDR, nursery)
await nursery.start(listener.listen, LISTEN_MADDR)
addrs = listener.get_addrs()
assert len(addrs) == 1
listen_addr = addrs[0]
Expand Down

0 comments on commit b53675e

Please sign in to comment.