Skip to content

Commit

Permalink
feat(quic): Wake transport when adding a new dialer or listener (#3342)
Browse files Browse the repository at this point in the history
Wake `quic::GenTransport` if a new dialer or listener is added.
  • Loading branch information
elenaf9 authored Jan 23, 2023
1 parent 778f7a2 commit dcfa7ec
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 3 deletions.
3 changes: 3 additions & 0 deletions transports/quic/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
- Add opt-in support for the `/quic` codepoint, interpreted as QUIC version draft-29.
See [PR 3151].

- Wake the transport's task when a new dialer or listener is added. See [3342].

[PR 3151]: https://github.com/libp2p/rust-libp2p/pull/3151
[PR 3342]: https://github.com/libp2p/rust-libp2p/pull/3342

# 0.7.0-alpha

Expand Down
20 changes: 17 additions & 3 deletions transports/quic/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ pub struct GenTransport<P: Provider> {
listeners: SelectAll<Listener<P>>,
/// Dialer for each socket family if no matching listener exists.
dialer: HashMap<SocketFamily, Dialer>,
/// Waker to poll the transport again when a new dialer or listener is added.
waker: Option<Waker>,
}

impl<P: Provider> GenTransport<P> {
Expand All @@ -84,6 +86,7 @@ impl<P: Provider> GenTransport<P> {
quinn_config,
handshake_timeout,
dialer: HashMap::new(),
waker: None,
support_draft_29,
}
}
Expand All @@ -108,6 +111,10 @@ impl<P: Provider> Transport for GenTransport<P> {
)?;
self.listeners.push(listener);

if let Some(waker) = self.waker.take() {
waker.wake();
}

// Remove dialer endpoint so that the endpoint is dropped once the last
// connection that uses it is closed.
// New outbound connections will use the bidirectional (listener) endpoint.
Expand Down Expand Up @@ -163,6 +170,9 @@ impl<P: Provider> Transport for GenTransport<P> {
let dialer = match self.dialer.entry(socket_family) {
Entry::Occupied(occupied) => occupied.into_mut(),
Entry::Vacant(vacant) => {
if let Some(waker) = self.waker.take() {
waker.wake();
}
vacant.insert(Dialer::new::<P>(self.quinn_config.clone(), socket_family)?)
}
};
Expand Down Expand Up @@ -202,15 +212,19 @@ impl<P: Provider> Transport for GenTransport<P> {
errored.push(*key);
}
}

for key in errored {
// Endpoint driver of dialer crashed.
// Drop dialer and all pending dials so that the connection receiver is notified.
self.dialer.remove(&key);
}
match self.listeners.poll_next_unpin(cx) {
Poll::Ready(Some(ev)) => Poll::Ready(ev),
_ => Poll::Pending,

if let Poll::Ready(Some(ev)) = self.listeners.poll_next_unpin(cx) {
return Poll::Ready(ev);
}

self.waker = Some(cx.waker().clone());
Poll::Pending
}
}

Expand Down
114 changes: 114 additions & 0 deletions transports/quic/tests/smoke.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#![cfg(any(feature = "async-std", feature = "tokio"))]

use futures::channel::{mpsc, oneshot};
use futures::future::BoxFuture;
use futures::future::{poll_fn, Either};
use futures::stream::StreamExt;
use futures::{future, AsyncReadExt, AsyncWriteExt, FutureExt, SinkExt};
use futures_timer::Delay;
use libp2p_core::muxing::{StreamMuxerBox, StreamMuxerExt, SubstreamBox};
use libp2p_core::transport::{Boxed, OrTransport, TransportEvent};
use libp2p_core::transport::{ListenerId, TransportError};
use libp2p_core::{multiaddr::Protocol, upgrade, Multiaddr, PeerId, Transport};
use libp2p_noise as noise;
use libp2p_quic as quic;
Expand All @@ -18,6 +21,10 @@ use std::io;
use std::num::NonZeroU8;
use std::task::Poll;
use std::time::Duration;
use std::{
pin::Pin,
sync::{Arc, Mutex},
};

#[cfg(feature = "tokio")]
#[tokio::test]
Expand Down Expand Up @@ -89,6 +96,113 @@ async fn ipv4_dial_ipv6() {
assert_eq!(b_connected, a_peer_id);
}

/// Tests that a [`Transport::dial`] wakes up the task previously polling [`Transport::poll`].
///
/// See https://github.com/libp2p/rust-libp2p/pull/3306 for context.
#[cfg(feature = "async-std")]
#[async_std::test]
async fn wrapped_with_delay() {
let _ = env_logger::try_init();

struct DialDelay(Arc<Mutex<Boxed<(PeerId, StreamMuxerBox)>>>);

impl Transport for DialDelay {
type Output = (PeerId, StreamMuxerBox);
type Error = std::io::Error;
type ListenerUpgrade = Pin<Box<dyn Future<Output = io::Result<Self::Output>> + Send>>;
type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;

fn listen_on(
&mut self,
addr: Multiaddr,
) -> Result<ListenerId, TransportError<Self::Error>> {
self.0.lock().unwrap().listen_on(addr)
}

fn remove_listener(&mut self, id: ListenerId) -> bool {
self.0.lock().unwrap().remove_listener(id)
}

fn address_translation(
&self,
listen: &Multiaddr,
observed: &Multiaddr,
) -> Option<Multiaddr> {
self.0.lock().unwrap().address_translation(listen, observed)
}

/// Delayed dial, i.e. calling [`Transport::dial`] on the inner [`Transport`] not within the
/// synchronous [`Transport::dial`] method, but within the [`Future`] returned by the outer
/// [`Transport::dial`].
fn dial(&mut self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
let t = self.0.clone();
Ok(async move {
// Simulate DNS lookup. Giving the `Transport::poll` the chance to return
// `Poll::Pending` and thus suspending its task, waiting for a wakeup from the dial
// on the inner transport below.
Delay::new(Duration::from_millis(100)).await;

let dial = t.lock().unwrap().dial(addr).map_err(|e| match e {
TransportError::MultiaddrNotSupported(_) => {
panic!()
}
TransportError::Other(e) => e,
})?;
dial.await
}
.boxed())
}

fn dial_as_listener(
&mut self,
addr: Multiaddr,
) -> Result<Self::Dial, TransportError<Self::Error>> {
self.0.lock().unwrap().dial_as_listener(addr)
}

fn poll(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
Pin::new(&mut *self.0.lock().unwrap()).poll(cx)
}
}

let (a_peer_id, mut a_transport) = create_default_transport::<quic::async_std::Provider>();
let (b_peer_id, mut b_transport) = {
let (id, transport) = create_default_transport::<quic::async_std::Provider>();
(id, DialDelay(Arc::new(Mutex::new(transport))).boxed())
};

// Spawn A
let a_addr = start_listening(&mut a_transport, "/ip6/::1/udp/0/quic-v1").await;
let listener = async_std::task::spawn(async move {
let (upgrade, _) = a_transport
.select_next_some()
.await
.into_incoming()
.unwrap();
let (peer_id, _) = upgrade.await.unwrap();

peer_id
});

// Spawn B
//
// Note that the dial is spawned on a different task than the transport allowing the transport
// task to poll the transport once and then suspend, waiting for the wakeup from the dial.
let dial = async_std::task::spawn({
let dial = b_transport.dial(a_addr).unwrap();
async { dial.await.unwrap().0 }
});
async_std::task::spawn(async move { b_transport.next().await });

let (a_connected, b_connected) = future::join(listener, dial).await;

assert_eq!(a_connected, b_peer_id);
assert_eq!(b_connected, a_peer_id);
}

#[cfg(feature = "async-std")]
#[async_std::test]
#[ignore] // Transport currently does not validate PeerId. Enable once we make use of PeerId validation in rustls.
Expand Down

0 comments on commit dcfa7ec

Please sign in to comment.