diff --git a/Cargo.lock b/Cargo.lock index 8a4a30c38549..bd3169affed7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1596,6 +1596,15 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-bounded" +version = "0.1.0" +dependencies = [ + "futures-timer", + "futures-util", + "tokio", +] + [[package]] name = "futures-channel" version = "0.3.28" @@ -2933,6 +2942,7 @@ dependencies = [ "either", "env_logger 0.10.0", "futures", + "futures-bounded", "futures-timer", "instant", "libp2p-core", diff --git a/Cargo.toml b/Cargo.toml index 15848016e002..0cf485e803db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "interop-tests", "misc/allow-block-list", "misc/connection-limits", + "misc/futures-bounded", "misc/keygen", "misc/memory-connection-limits", "misc/metrics", @@ -69,6 +70,7 @@ resolver = "2" rust-version = "1.65.0" [workspace.dependencies] +futures-bounded = { version = "0.1.0", path = "misc/futures-bounded" } libp2p = { version = "0.52.3", path = "libp2p" } libp2p-allow-block-list = { version = "0.2.0", path = "misc/allow-block-list" } libp2p-autonat = { version = "0.11.0", path = "protocols/autonat" } diff --git a/misc/futures-bounded/CHANGELOG.md b/misc/futures-bounded/CHANGELOG.md new file mode 100644 index 000000000000..712e55433860 --- /dev/null +++ b/misc/futures-bounded/CHANGELOG.md @@ -0,0 +1,3 @@ +## 0.1.0 - unreleased + +Initial release. diff --git a/misc/futures-bounded/Cargo.toml b/misc/futures-bounded/Cargo.toml new file mode 100644 index 000000000000..b273226239c8 --- /dev/null +++ b/misc/futures-bounded/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "futures-bounded" +version = "0.1.0" +edition = "2021" +rust-version.workspace = true +license = "MIT" +repository = "https://github.com/libp2p/rust-libp2p" +keywords = ["futures", "async", "backpressure"] +categories = ["data-structures", "asynchronous"] +description = "Utilities for bounding futures in size and time." +publish = false # TEMP FIX until https://github.com/obi1kenobi/cargo-semver-checks-action/issues/53 is fixed. + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +futures-util = { version = "0.3.28" } +futures-timer = "3.0.2" + +[dev-dependencies] +tokio = { version = "1.29.1", features = ["macros", "rt"] } diff --git a/misc/futures-bounded/src/lib.rs b/misc/futures-bounded/src/lib.rs new file mode 100644 index 000000000000..e7b461dc8229 --- /dev/null +++ b/misc/futures-bounded/src/lib.rs @@ -0,0 +1,28 @@ +mod map; +mod set; + +pub use map::{FuturesMap, PushError}; +pub use set::FuturesSet; +use std::fmt; +use std::fmt::Formatter; +use std::time::Duration; + +/// A future failed to complete within the given timeout. +#[derive(Debug)] +pub struct Timeout { + limit: Duration, +} + +impl Timeout { + fn new(duration: Duration) -> Self { + Self { limit: duration } + } +} + +impl fmt::Display for Timeout { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "future failed to complete within {:?}", self.limit) + } +} + +impl std::error::Error for Timeout {} diff --git a/misc/futures-bounded/src/map.rs b/misc/futures-bounded/src/map.rs new file mode 100644 index 000000000000..cecf6070efe2 --- /dev/null +++ b/misc/futures-bounded/src/map.rs @@ -0,0 +1,268 @@ +use std::future::Future; +use std::hash::Hash; +use std::mem; +use std::pin::Pin; +use std::task::{Context, Poll, Waker}; +use std::time::Duration; + +use futures_timer::Delay; +use futures_util::future::BoxFuture; +use futures_util::stream::FuturesUnordered; +use futures_util::{FutureExt, StreamExt}; + +use crate::Timeout; + +/// Represents a map of [`Future`]s. +/// +/// Each future must finish within the specified time and the map never outgrows its capacity. +pub struct FuturesMap { + timeout: Duration, + capacity: usize, + inner: FuturesUnordered>>>, + empty_waker: Option, + full_waker: Option, +} + +/// Error of a future pushing +#[derive(PartialEq, Debug)] +pub enum PushError { + /// The length of the set is equal to the capacity + BeyondCapacity(F), + /// The set already contains the given future's ID + ReplacedFuture(F), +} + +impl FuturesMap { + pub fn new(timeout: Duration, capacity: usize) -> Self { + Self { + timeout, + capacity, + inner: Default::default(), + empty_waker: None, + full_waker: None, + } + } +} + +impl FuturesMap +where + ID: Clone + Hash + Eq + Send + Unpin + 'static, +{ + /// Push a future into the map. + /// + /// This method inserts the given future with defined `future_id` to the set. + /// If the length of the map is equal to the capacity, this method returns [PushError::BeyondCapacity], + /// that contains the passed future. In that case, the future is not inserted to the map. + /// If a future with the given `future_id` already exists, then the old future will be replaced by a new one. + /// In that case, the returned error [PushError::ReplacedFuture] contains the old future. + pub fn try_push(&mut self, future_id: ID, future: F) -> Result<(), PushError>> + where + F: Future + Send + 'static, + { + if self.inner.len() >= self.capacity { + return Err(PushError::BeyondCapacity(future.boxed())); + } + + if let Some(waker) = self.empty_waker.take() { + waker.wake(); + } + + match self.inner.iter_mut().find(|tagged| tagged.tag == future_id) { + None => { + self.inner.push(TaggedFuture { + tag: future_id, + inner: TimeoutFuture { + inner: future.boxed(), + timeout: Delay::new(self.timeout), + }, + }); + + Ok(()) + } + Some(existing) => { + let old_future = mem::replace( + &mut existing.inner, + TimeoutFuture { + inner: future.boxed(), + timeout: Delay::new(self.timeout), + }, + ); + + Err(PushError::ReplacedFuture(old_future.inner)) + } + } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + #[allow(unknown_lints, clippy::needless_pass_by_ref_mut)] // &mut Context is idiomatic. + pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> { + if self.inner.len() < self.capacity { + return Poll::Ready(()); + } + + self.full_waker = Some(cx.waker().clone()); + + Poll::Pending + } + + pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll<(ID, Result)> { + let maybe_result = futures_util::ready!(self.inner.poll_next_unpin(cx)); + + match maybe_result { + None => { + self.empty_waker = Some(cx.waker().clone()); + Poll::Pending + } + Some((id, Ok(output))) => Poll::Ready((id, Ok(output))), + Some((id, Err(_timeout))) => Poll::Ready((id, Err(Timeout::new(self.timeout)))), + } + } +} + +struct TimeoutFuture { + inner: F, + timeout: Delay, +} + +impl Future for TimeoutFuture +where + F: Future + Unpin, +{ + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + if self.timeout.poll_unpin(cx).is_ready() { + return Poll::Ready(Err(())); + } + + self.inner.poll_unpin(cx).map(Ok) + } +} + +struct TaggedFuture { + tag: T, + inner: F, +} + +impl Future for TaggedFuture +where + T: Clone + Unpin, + F: Future + Unpin, +{ + type Output = (T, F::Output); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let output = futures_util::ready!(self.inner.poll_unpin(cx)); + + Poll::Ready((self.tag.clone(), output)) + } +} + +#[cfg(test)] +mod tests { + use std::future::{pending, poll_fn, ready}; + use std::pin::Pin; + use std::time::Instant; + + use super::*; + + #[test] + fn cannot_push_more_than_capacity_tasks() { + let mut futures = FuturesMap::new(Duration::from_secs(10), 1); + + assert!(futures.try_push("ID_1", ready(())).is_ok()); + matches!( + futures.try_push("ID_2", ready(())), + Err(PushError::BeyondCapacity(_)) + ); + } + + #[test] + fn cannot_push_the_same_id_few_times() { + let mut futures = FuturesMap::new(Duration::from_secs(10), 5); + + assert!(futures.try_push("ID", ready(())).is_ok()); + matches!( + futures.try_push("ID", ready(())), + Err(PushError::ReplacedFuture(_)) + ); + } + + #[tokio::test] + async fn futures_timeout() { + let mut futures = FuturesMap::new(Duration::from_millis(100), 1); + + let _ = futures.try_push("ID", pending::<()>()); + Delay::new(Duration::from_millis(150)).await; + let (_, result) = poll_fn(|cx| futures.poll_unpin(cx)).await; + + assert!(result.is_err()) + } + + // Each future causes a delay, `Task` only has a capacity of 1, meaning they must be processed in sequence. + // We stop after NUM_FUTURES tasks, meaning the overall execution must at least take DELAY * NUM_FUTURES. + #[tokio::test] + async fn backpressure() { + const DELAY: Duration = Duration::from_millis(100); + const NUM_FUTURES: u32 = 10; + + let start = Instant::now(); + Task::new(DELAY, NUM_FUTURES, 1).await; + let duration = start.elapsed(); + + assert!(duration >= DELAY * NUM_FUTURES); + } + + struct Task { + future: Duration, + num_futures: usize, + num_processed: usize, + inner: FuturesMap, + } + + impl Task { + fn new(future: Duration, num_futures: u32, capacity: usize) -> Self { + Self { + future, + num_futures: num_futures as usize, + num_processed: 0, + inner: FuturesMap::new(Duration::from_secs(60), capacity), + } + } + } + + impl Future for Task { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + while this.num_processed < this.num_futures { + if let Poll::Ready((_, result)) = this.inner.poll_unpin(cx) { + if result.is_err() { + panic!("Timeout is great than future delay") + } + + this.num_processed += 1; + continue; + } + + if let Poll::Ready(()) = this.inner.poll_ready_unpin(cx) { + // We push the constant future's ID to prove that user can use the same ID + // if the future was finished + let maybe_future = this.inner.try_push(1u8, Delay::new(this.future)); + assert!(maybe_future.is_ok(), "we polled for readiness"); + + continue; + } + + return Poll::Pending; + } + + Poll::Ready(()) + } + } +} diff --git a/misc/futures-bounded/src/set.rs b/misc/futures-bounded/src/set.rs new file mode 100644 index 000000000000..96140d82f9a0 --- /dev/null +++ b/misc/futures-bounded/src/set.rs @@ -0,0 +1,58 @@ +use std::future::Future; +use std::task::{ready, Context, Poll}; +use std::time::Duration; + +use futures_util::future::BoxFuture; + +use crate::{FuturesMap, PushError, Timeout}; + +/// Represents a list of [Future]s. +/// +/// Each future must finish within the specified time and the list never outgrows its capacity. +pub struct FuturesSet { + id: u32, + inner: FuturesMap, +} + +impl FuturesSet { + pub fn new(timeout: Duration, capacity: usize) -> Self { + Self { + id: 0, + inner: FuturesMap::new(timeout, capacity), + } + } +} + +impl FuturesSet { + /// Push a future into the list. + /// + /// This method adds the given future to the list. + /// If the length of the list is equal to the capacity, this method returns a error that contains the passed future. + /// In that case, the future is not added to the set. + pub fn try_push(&mut self, future: F) -> Result<(), BoxFuture> + where + F: Future + Send + 'static, + { + self.id = self.id.wrapping_add(1); + + match self.inner.try_push(self.id, future) { + Ok(()) => Ok(()), + Err(PushError::BeyondCapacity(w)) => Err(w), + Err(PushError::ReplacedFuture(_)) => unreachable!("we never reuse IDs"), + } + } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + pub fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<()> { + self.inner.poll_ready_unpin(cx) + } + + pub fn poll_unpin(&mut self, cx: &mut Context<'_>) -> Poll> { + let (_, res) = ready!(self.inner.poll_unpin(cx)); + + Poll::Ready(res) + } +} diff --git a/protocols/relay/Cargo.toml b/protocols/relay/Cargo.toml index 31f6cc16d1e4..a13ba2bb229f 100644 --- a/protocols/relay/Cargo.toml +++ b/protocols/relay/Cargo.toml @@ -16,6 +16,7 @@ bytes = "1" either = "1.9.0" futures = "0.3.28" futures-timer = "3" +futures-bounded = { workspace = true } instant = "0.1.12" libp2p-core = { workspace = true } libp2p-swarm = { workspace = true } diff --git a/protocols/relay/src/behaviour.rs b/protocols/relay/src/behaviour.rs index d8654f00caeb..f2eb89de249f 100644 --- a/protocols/relay/src/behaviour.rs +++ b/protocols/relay/src/behaviour.rs @@ -20,7 +20,7 @@ //! [`NetworkBehaviour`] to act as a circuit relay v2 **relay**. -mod handler; +pub(crate) mod handler; pub(crate) mod rate_limiter; use crate::behaviour::handler::Handler; use crate::multiaddr_ext::MultiaddrExt; diff --git a/protocols/relay/src/behaviour/handler.rs b/protocols/relay/src/behaviour/handler.rs index 9c1b8524ec3a..895228e807b8 100644 --- a/protocols/relay/src/behaviour/handler.rs +++ b/protocols/relay/src/behaviour/handler.rs @@ -20,8 +20,8 @@ use crate::behaviour::CircuitId; use crate::copy_future::CopyFuture; -use crate::proto; use crate::protocol::{inbound_hop, outbound_stop}; +use crate::{proto, HOP_PROTOCOL_NAME, STOP_PROTOCOL_NAME}; use bytes::Bytes; use either::Either; use futures::channel::oneshot::{self, Canceled}; @@ -30,21 +30,24 @@ use futures::io::AsyncWriteExt; use futures::stream::{FuturesUnordered, StreamExt}; use futures_timer::Delay; use instant::Instant; +use libp2p_core::upgrade::ReadyUpgrade; use libp2p_core::{ConnectedPoint, Multiaddr}; use libp2p_identity::PeerId; use libp2p_swarm::handler::{ ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, - ListenUpgradeError, }; use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, ConnectionId, KeepAlive, Stream, StreamUpgradeError, - SubstreamProtocol, + ConnectionHandler, ConnectionHandlerEvent, ConnectionId, KeepAlive, Stream, StreamProtocol, + StreamUpgradeError, SubstreamProtocol, }; use std::collections::VecDeque; use std::fmt; use std::task::{Context, Poll}; use std::time::Duration; +const MAX_CONCURRENT_STREAMS_PER_CONNECTION: usize = 10; +const STREAM_TIMEOUT: Duration = Duration::from_secs(60); + #[derive(Debug, Clone)] pub struct Config { pub reservation_duration: Duration, @@ -174,7 +177,7 @@ pub enum Event { dst_peer_id: PeerId, error: inbound_hop::UpgradeError, }, - /// An inbound cirucit request has been accepted. + /// An inbound circuit request has been accepted. CircuitReqAccepted { circuit_id: CircuitId, dst_peer_id: PeerId, @@ -363,7 +366,7 @@ pub struct Handler { /// Futures accepting an inbound circuit request. circuit_accept_futures: Futures>, - /// Futures deying an inbound circuit request. + /// Futures denying an inbound circuit request. circuit_deny_futures: Futures<( Option, PeerId, @@ -380,11 +383,30 @@ pub struct Handler { alive_lend_out_substreams: FuturesUnordered>, /// Futures relaying data for circuit between two peers. circuits: Futures<(CircuitId, PeerId, Result<(), std::io::Error>)>, + + pending_connect_requests: VecDeque, + + workers: futures_bounded::FuturesSet< + Either< + Result< + Either, + inbound_hop::FatalUpgradeError, + >, + Result< + Result, + outbound_stop::FatalUpgradeError, + >, + >, + >, } impl Handler { pub fn new(config: Config, endpoint: ConnectedPoint) -> Handler { Handler { + workers: futures_bounded::FuturesSet::new( + STREAM_TIMEOUT, + MAX_CONCURRENT_STREAMS_PER_CONNECTION, + ), endpoint, config, queued_events: Default::default(), @@ -396,93 +418,49 @@ impl Handler { circuits: Default::default(), active_reservation: Default::default(), keep_alive: KeepAlive::Yes, + pending_connect_requests: Default::default(), } } - fn on_fully_negotiated_inbound( - &mut self, - FullyNegotiatedInbound { - protocol: request, .. - }: FullyNegotiatedInbound< - ::InboundProtocol, - ::InboundOpenInfo, - >, - ) { - match request { - inbound_hop::Req::Reserve(inbound_reservation_req) => { - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour( - Event::ReservationReqReceived { - inbound_reservation_req, - endpoint: self.endpoint.clone(), - renewed: self.active_reservation.is_some(), - }, - )); - } - inbound_hop::Req::Connect(inbound_circuit_req) => { - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour( - Event::CircuitReqReceived { - inbound_circuit_req, - endpoint: self.endpoint.clone(), - }, - )); - } + fn on_fully_negotiated_inbound(&mut self, stream: Stream) { + if self + .workers + .try_push( + inbound_hop::handle_inbound_request( + stream, + self.config.reservation_duration, + self.config.max_circuit_duration, + self.config.max_circuit_bytes, + ) + .map(Either::Left), + ) + .is_err() + { + log::warn!("Dropping inbound stream because we are at capacity") } } - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { - protocol: (dst_stream, dst_pending_data), - info: outbound_open_info, - }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - let OutboundOpenInfo { - circuit_id, - inbound_circuit_req, - src_peer_id, - src_connection_id, - } = outbound_open_info; + fn on_fully_negotiated_outbound(&mut self, stream: Stream) { + let stop_command = self + .pending_connect_requests + .pop_front() + .expect("opened a stream without a pending stop command"); + let (tx, rx) = oneshot::channel(); self.alive_lend_out_substreams.push(rx); - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundConnectNegotiated { - circuit_id, - src_peer_id, - src_connection_id, - inbound_circuit_req, - dst_handler_notifier: tx, - dst_stream, - dst_pending_data, - }, - )); - } - - fn on_listen_upgrade_error( - &mut self, - ListenUpgradeError { - error: inbound_hop::UpgradeError::Fatal(error), - .. - }: ListenUpgradeError< - ::InboundOpenInfo, - ::InboundProtocol, - >, - ) { - self.pending_error = Some(StreamUpgradeError::Apply(Either::Left(error))); + if self + .workers + .try_push(outbound_stop::connect(stream, stop_command, tx).map(Either::Right)) + .is_err() + { + log::warn!("Dropping outbound stream because we are at capacity") + } } fn on_dial_upgrade_error( &mut self, - DialUpgradeError { - info: open_info, - error, - }: DialUpgradeError< + DialUpgradeError { error, .. }: DialUpgradeError< ::OutboundOpenInfo, ::OutboundProtocol, >, @@ -502,39 +480,21 @@ impl Handler { self.pending_error = Some(StreamUpgradeError::Io(e)); return; } - StreamUpgradeError::Apply(error) => match error { - outbound_stop::UpgradeError::Fatal(error) => { - self.pending_error = Some(StreamUpgradeError::Apply(Either::Right(error))); - return; - } - outbound_stop::UpgradeError::CircuitFailed(error) => { - let status = match error { - outbound_stop::CircuitFailedReason::ResourceLimitExceeded => { - proto::Status::RESOURCE_LIMIT_EXCEEDED - } - outbound_stop::CircuitFailedReason::PermissionDenied => { - proto::Status::PERMISSION_DENIED - } - }; - (StreamUpgradeError::Apply(error), status) - } - }, + StreamUpgradeError::Apply(v) => void::unreachable(v), }; - let OutboundOpenInfo { - circuit_id, - inbound_circuit_req, - src_peer_id, - src_connection_id, - } = open_info; + let stop_command = self + .pending_connect_requests + .pop_front() + .expect("failed to open a stream without a pending stop command"); self.queued_events .push_back(ConnectionHandlerEvent::NotifyBehaviour( Event::OutboundConnectNegotiationFailed { - circuit_id, - src_peer_id, - src_connection_id, - inbound_circuit_req, + circuit_id: stop_command.circuit_id, + src_peer_id: stop_command.src_peer_id, + src_connection_id: stop_command.src_connection_id, + inbound_circuit_req: stop_command.inbound_circuit_req, status, error: non_fatal_error, }, @@ -555,20 +515,13 @@ impl ConnectionHandler for Handler { type Error = StreamUpgradeError< Either, >; - type InboundProtocol = inbound_hop::Upgrade; - type OutboundProtocol = outbound_stop::Upgrade; - type OutboundOpenInfo = OutboundOpenInfo; + type InboundProtocol = ReadyUpgrade; type InboundOpenInfo = (); + type OutboundProtocol = ReadyUpgrade; + type OutboundOpenInfo = (); fn listen_protocol(&self) -> SubstreamProtocol { - SubstreamProtocol::new( - inbound_hop::Upgrade { - reservation_duration: self.config.reservation_duration, - max_circuit_duration: self.config.max_circuit_duration, - max_circuit_bytes: self.config.max_circuit_bytes, - }, - (), - ) + SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()) } fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { @@ -580,7 +533,7 @@ impl ConnectionHandler for Handler { if self .reservation_request_future .replace(ReservationRequestFuture::Accepting( - inbound_reservation_req.accept(addrs).boxed(), + inbound_reservation_req.accept(addrs).err_into().boxed(), )) .is_some() { @@ -594,7 +547,7 @@ impl ConnectionHandler for Handler { if self .reservation_request_future .replace(ReservationRequestFuture::Denying( - inbound_reservation_req.deny(status).boxed(), + inbound_reservation_req.deny(status).err_into().boxed(), )) .is_some() { @@ -607,21 +560,17 @@ impl ConnectionHandler for Handler { src_peer_id, src_connection_id, } => { + self.pending_connect_requests + .push_back(outbound_stop::PendingConnect::new( + circuit_id, + inbound_circuit_req, + src_peer_id, + src_connection_id, + &self.config, + )); self.queued_events .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - outbound_stop::Upgrade { - src_peer_id, - max_circuit_duration: self.config.max_circuit_duration, - max_circuit_bytes: self.config.max_circuit_bytes, - }, - OutboundOpenInfo { - circuit_id, - inbound_circuit_req, - src_peer_id, - src_connection_id, - }, - ), + protocol: SubstreamProtocol::new(ReadyUpgrade::new(STOP_PROTOCOL_NAME), ()), }); } In::DenyCircuitReq { @@ -633,6 +582,7 @@ impl ConnectionHandler for Handler { self.circuit_deny_futures.push( inbound_circuit_req .deny(status) + .err_into() .map(move |result| (circuit_id, dst_peer_id, result)) .boxed(), ); @@ -648,6 +598,7 @@ impl ConnectionHandler for Handler { self.circuit_accept_futures.push( inbound_circuit_req .accept() + .err_into() .map_ok(move |(src_stream, src_pending_data)| CircuitParts { circuit_id, src_stream, @@ -716,6 +667,66 @@ impl ConnectionHandler for Handler { } } + // Process protocol requests + match self.workers.poll_unpin(cx) { + Poll::Ready(Ok(Either::Left(Ok(Either::Left(inbound_reservation_req))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::ReservationReqReceived { + inbound_reservation_req, + endpoint: self.endpoint.clone(), + renewed: self.active_reservation.is_some(), + }, + )); + } + Poll::Ready(Ok(Either::Left(Ok(Either::Right(inbound_circuit_req))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::CircuitReqReceived { + inbound_circuit_req, + endpoint: self.endpoint.clone(), + }, + )); + } + Poll::Ready(Ok(Either::Right(Ok(Ok(circuit))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundConnectNegotiated { + circuit_id: circuit.circuit_id, + src_peer_id: circuit.src_peer_id, + src_connection_id: circuit.src_connection_id, + inbound_circuit_req: circuit.inbound_circuit_req, + dst_handler_notifier: circuit.dst_handler_notifier, + dst_stream: circuit.dst_stream, + dst_pending_data: circuit.dst_pending_data, + }, + )); + } + Poll::Ready(Ok(Either::Right(Ok(Err(circuit_failed))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundConnectNegotiationFailed { + circuit_id: circuit_failed.circuit_id, + src_peer_id: circuit_failed.src_peer_id, + src_connection_id: circuit_failed.src_connection_id, + inbound_circuit_req: circuit_failed.inbound_circuit_req, + status: circuit_failed.status, + error: circuit_failed.error, + }, + )); + } + Poll::Ready(Err(futures_bounded::Timeout { .. })) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Timeout)); + } + Poll::Ready(Ok(Either::Left(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Apply( + Either::Left(e), + ))); + } + Poll::Ready(Ok(Either::Right(Err(e)))) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Apply( + Either::Right(e), + ))); + } + Poll::Pending => {} + } + // Deny new circuits. if let Poll::Ready(Some((circuit_id, dst_peer_id, result))) = self.circuit_deny_futures.poll_next_unpin(cx) @@ -896,33 +907,30 @@ impl ConnectionHandler for Handler { >, ) { match event { - ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { - self.on_fully_negotiated_inbound(fully_negotiated_inbound) - } - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: stream, + .. + }) => { + self.on_fully_negotiated_inbound(stream); } - ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - self.on_listen_upgrade_error(listen_upgrade_error) + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: stream, + .. + }) => { + self.on_fully_negotiated_outbound(stream); } ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { - self.on_dial_upgrade_error(dial_upgrade_error) + self.on_dial_upgrade_error(dial_upgrade_error); } ConnectionEvent::AddressChange(_) + | ConnectionEvent::ListenUpgradeError(_) | ConnectionEvent::LocalProtocolsChange(_) | ConnectionEvent::RemoteProtocolsChange(_) => {} } } } -pub struct OutboundOpenInfo { - circuit_id: CircuitId, - inbound_circuit_req: inbound_hop::CircuitReq, - src_peer_id: PeerId, - src_connection_id: ConnectionId, -} - -pub(crate) struct CircuitParts { +struct CircuitParts { circuit_id: CircuitId, src_stream: Stream, src_pending_data: Bytes, diff --git a/protocols/relay/src/priv_client.rs b/protocols/relay/src/priv_client.rs index 5f18a62a96dd..cecdfa52bf3a 100644 --- a/protocols/relay/src/priv_client.rs +++ b/protocols/relay/src/priv_client.rs @@ -20,7 +20,7 @@ //! [`NetworkBehaviour`] to act as a circuit relay v2 **client**. -mod handler; +pub(crate) mod handler; pub(crate) mod transport; use crate::multiaddr_ext::MultiaddrExt; @@ -163,7 +163,6 @@ impl NetworkBehaviour for Behaviour { if local_addr.is_relayed() { return Ok(Either::Right(dummy::ConnectionHandler)); } - let mut handler = Handler::new(self.local_peer_id, peer, remote_addr.clone()); if let Some(event) = self.pending_handler_commands.remove(&connection_id) { @@ -377,10 +376,10 @@ impl NetworkBehaviour for Behaviour { /// /// Internally, this uses a stream to the relay. pub struct Connection { - state: ConnectionState, + pub(crate) state: ConnectionState, } -enum ConnectionState { +pub(crate) enum ConnectionState { InboundAccepting { accept: BoxFuture<'static, Result>, }, diff --git a/protocols/relay/src/priv_client/handler.rs b/protocols/relay/src/priv_client/handler.rs index 9613d7d6b3e3..25488ac3041c 100644 --- a/protocols/relay/src/priv_client/handler.rs +++ b/protocols/relay/src/priv_client/handler.rs @@ -19,27 +19,30 @@ // DEALINGS IN THE SOFTWARE. use crate::priv_client::transport; -use crate::proto; use crate::protocol::{self, inbound_stop, outbound_hop}; +use crate::{proto, HOP_PROTOCOL_NAME, STOP_PROTOCOL_NAME}; use either::Either; use futures::channel::{mpsc, oneshot}; use futures::future::{BoxFuture, FutureExt}; use futures::sink::SinkExt; use futures::stream::{FuturesUnordered, StreamExt}; +use futures::TryFutureExt; +use futures_bounded::{PushError, Timeout}; use futures_timer::Delay; use instant::Instant; use libp2p_core::multiaddr::Protocol; +use libp2p_core::upgrade::ReadyUpgrade; use libp2p_core::Multiaddr; use libp2p_identity::PeerId; use libp2p_swarm::handler::{ ConnectionEvent, DialUpgradeError, FullyNegotiatedInbound, FullyNegotiatedOutbound, - ListenUpgradeError, }; use libp2p_swarm::{ - ConnectionHandler, ConnectionHandlerEvent, KeepAlive, StreamUpgradeError, SubstreamProtocol, + ConnectionHandler, ConnectionHandlerEvent, KeepAlive, StreamProtocol, StreamUpgradeError, + SubstreamProtocol, }; use log::debug; -use std::collections::{HashMap, VecDeque}; +use std::collections::VecDeque; use std::fmt; use std::task::{Context, Poll}; use std::time::Duration; @@ -48,6 +51,10 @@ use std::time::Duration; /// /// Circuits to be denied exceeding the limit are dropped. const MAX_NUMBER_DENYING_CIRCUIT: usize = 8; +const DENYING_CIRCUIT_TIMEOUT: Duration = Duration::from_secs(60); + +const MAX_CONCURRENT_STREAMS_PER_CONNECTION: usize = 10; +const STREAM_TIMEOUT: Duration = Duration::from_secs(60); pub enum In { Reserve { @@ -121,10 +128,21 @@ pub struct Handler { /// Queue of events to return when polled. queued_events: VecDeque< ConnectionHandlerEvent< - ::OutboundProtocol, - ::OutboundOpenInfo, - ::ToBehaviour, - ::Error, + ::OutboundProtocol, + ::OutboundOpenInfo, + ::ToBehaviour, + ::Error, + >, + >, + + wait_for_outbound_stream: VecDeque, + outbound_circuits: futures_bounded::FuturesSet< + Result< + Either< + Result, + Result, outbound_hop::CircuitFailedReason>, + >, + outbound_hop::FatalUpgradeError, >, >, @@ -140,8 +158,10 @@ pub struct Handler { /// eventually. alive_lend_out_substreams: FuturesUnordered>, - circuit_deny_futs: - HashMap>>, + open_circuit_futs: + futures_bounded::FuturesSet>, + + circuit_deny_futs: futures_bounded::FuturesMap>, /// Futures that try to send errors to the transport. /// @@ -158,163 +178,38 @@ impl Handler { remote_addr, queued_events: Default::default(), pending_error: Default::default(), + wait_for_outbound_stream: Default::default(), + outbound_circuits: futures_bounded::FuturesSet::new( + STREAM_TIMEOUT, + MAX_CONCURRENT_STREAMS_PER_CONNECTION, + ), reservation: Reservation::None, alive_lend_out_substreams: Default::default(), - circuit_deny_futs: Default::default(), + open_circuit_futs: futures_bounded::FuturesSet::new( + STREAM_TIMEOUT, + MAX_CONCURRENT_STREAMS_PER_CONNECTION, + ), + circuit_deny_futs: futures_bounded::FuturesMap::new( + DENYING_CIRCUIT_TIMEOUT, + MAX_NUMBER_DENYING_CIRCUIT, + ), send_error_futs: Default::default(), keep_alive: KeepAlive::Yes, } } - fn on_fully_negotiated_inbound( - &mut self, - FullyNegotiatedInbound { - protocol: inbound_circuit, - .. - }: FullyNegotiatedInbound< - ::InboundProtocol, - ::InboundOpenInfo, - >, - ) { - match &mut self.reservation { - Reservation::Accepted { pending_msgs, .. } - | Reservation::Renewing { pending_msgs, .. } => { - let src_peer_id = inbound_circuit.src_peer_id(); - let limit = inbound_circuit.limit(); - - let (tx, rx) = oneshot::channel(); - self.alive_lend_out_substreams.push(rx); - let connection = super::ConnectionState::new_inbound(inbound_circuit, tx); - - pending_msgs.push_back(transport::ToListenerMsg::IncomingRelayedConnection { - // stream: connection, - stream: super::Connection { state: connection }, - src_peer_id, - relay_peer_id: self.remote_peer_id, - relay_addr: self.remote_addr.clone(), - }); - - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour( - Event::InboundCircuitEstablished { src_peer_id, limit }, - )); - } - Reservation::None => { - let src_peer_id = inbound_circuit.src_peer_id(); - - if self.circuit_deny_futs.len() == MAX_NUMBER_DENYING_CIRCUIT - && !self.circuit_deny_futs.contains_key(&src_peer_id) - { - log::warn!( - "Dropping inbound circuit request to be denied from {:?} due to exceeding limit.", - src_peer_id, - ); - } else if self - .circuit_deny_futs - .insert( - src_peer_id, - inbound_circuit.deny(proto::Status::NO_RESERVATION).boxed(), - ) - .is_some() - { - log::warn!( - "Dropping existing inbound circuit request to be denied from {:?} in favor of new one.", - src_peer_id - ) - } - } - } - } - - fn on_fully_negotiated_outbound( - &mut self, - FullyNegotiatedOutbound { - protocol: output, - info, - }: FullyNegotiatedOutbound< - ::OutboundProtocol, - ::OutboundOpenInfo, - >, - ) { - match (output, info) { - // Outbound reservation - ( - outbound_hop::Output::Reservation { - renewal_timeout, - addrs, - limit, - }, - OutboundOpenInfo::Reserve { to_listener }, - ) => { - let event = self.reservation.accepted( - renewal_timeout, - addrs, - to_listener, - self.local_peer_id, - limit, - ); - - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour(event)); - } - - // Outbound circuit - ( - outbound_hop::Output::Circuit { - substream, - read_buffer, - limit, - }, - OutboundOpenInfo::Connect { send_back }, - ) => { - let (tx, rx) = oneshot::channel(); - match send_back.send(Ok(super::Connection { - state: super::ConnectionState::new_outbound(substream, read_buffer, tx), - })) { - Ok(()) => { - self.alive_lend_out_substreams.push(rx); - self.queued_events - .push_back(ConnectionHandlerEvent::NotifyBehaviour( - Event::OutboundCircuitEstablished { limit }, - )); - } - Err(_) => debug!( - "Oneshot to `client::transport::Dial` future dropped. \ - Dropping established relayed connection to {:?}.", - self.remote_peer_id, - ), - } - } - - _ => unreachable!(), - } - } - - fn on_listen_upgrade_error( - &mut self, - ListenUpgradeError { - error: inbound_stop::UpgradeError::Fatal(error), - .. - }: ListenUpgradeError< - ::InboundOpenInfo, - ::InboundProtocol, - >, - ) { - self.pending_error = Some(StreamUpgradeError::Apply(Either::Left(error))); - } - fn on_dial_upgrade_error( &mut self, - DialUpgradeError { - info: open_info, - error, - }: DialUpgradeError< + DialUpgradeError { error, .. }: DialUpgradeError< ::OutboundOpenInfo, ::OutboundProtocol, >, ) { - match open_info { - OutboundOpenInfo::Reserve { mut to_listener } => { + let outbound_info = self.wait_for_outbound_stream.pop_front().expect( + "got a stream error without a pending connection command or a reserve listener", + ); + match outbound_info { + outbound_hop::OutboundStreamInfo::Reserve(mut to_listener) => { let non_fatal_error = match error { StreamUpgradeError::Timeout => StreamUpgradeError::Timeout, StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed, @@ -322,19 +217,7 @@ impl Handler { self.pending_error = Some(StreamUpgradeError::Io(e)); return; } - StreamUpgradeError::Apply(error) => match error { - outbound_hop::UpgradeError::Fatal(error) => { - self.pending_error = - Some(StreamUpgradeError::Apply(Either::Right(error))); - return; - } - outbound_hop::UpgradeError::ReservationFailed(error) => { - StreamUpgradeError::Apply(error) - } - outbound_hop::UpgradeError::CircuitFailed(_) => { - unreachable!("Do not emitt `CircuitFailed` for outgoing reservation.") - } - }, + StreamUpgradeError::Apply(v) => void::unreachable(v), }; if self.pending_error.is_none() { @@ -347,11 +230,12 @@ impl Handler { .boxed(), ); } else { - // Fatal error occured, thus handler is closing as quickly as possible. + // Fatal error occurred, thus handler is closing as quickly as possible. // Transport is notified through dropping `to_listener`. } let renewal = self.reservation.failed(); + self.queued_events .push_back(ConnectionHandlerEvent::NotifyBehaviour( Event::ReservationReqFailed { @@ -360,7 +244,7 @@ impl Handler { }, )); } - OutboundOpenInfo::Connect { send_back } => { + outbound_hop::OutboundStreamInfo::CircuitConnection(cmd) => { let non_fatal_error = match error { StreamUpgradeError::Timeout => StreamUpgradeError::Timeout, StreamUpgradeError::NegotiationFailed => StreamUpgradeError::NegotiationFailed, @@ -368,22 +252,10 @@ impl Handler { self.pending_error = Some(StreamUpgradeError::Io(e)); return; } - StreamUpgradeError::Apply(error) => match error { - outbound_hop::UpgradeError::Fatal(error) => { - self.pending_error = - Some(StreamUpgradeError::Apply(Either::Right(error))); - return; - } - outbound_hop::UpgradeError::CircuitFailed(error) => { - StreamUpgradeError::Apply(error) - } - outbound_hop::UpgradeError::ReservationFailed(_) => { - unreachable!("Do not emitt `ReservationFailed` for outgoing circuit.") - } - }, + StreamUpgradeError::Apply(v) => void::unreachable(v), }; - let _ = send_back.send(Err(())); + let _ = cmd.send_back.send(Err(())); self.queued_events .push_back(ConnectionHandlerEvent::NotifyBehaviour( @@ -394,6 +266,23 @@ impl Handler { } } } + + fn insert_to_deny_futs(&mut self, circuit: inbound_stop::Circuit) { + let src_peer_id = circuit.src_peer_id(); + + match self.circuit_deny_futs.try_push( + src_peer_id, + circuit.deny(proto::Status::NO_RESERVATION), + ) { + Err(PushError::BeyondCapacity(_)) => log::warn!( + "Dropping inbound circuit request to be denied from {src_peer_id} due to exceeding limit." + ), + Err(PushError::ReplacedFuture(_)) => log::warn!( + "Dropping existing inbound circuit request to be denied from {src_peer_id} in favor of new one." + ), + Ok(()) => {} + } + } } impl ConnectionHandler for Handler { @@ -402,36 +291,37 @@ impl ConnectionHandler for Handler { type Error = StreamUpgradeError< Either, >; - type InboundProtocol = inbound_stop::Upgrade; - type OutboundProtocol = outbound_hop::Upgrade; - type OutboundOpenInfo = OutboundOpenInfo; + type InboundProtocol = ReadyUpgrade; type InboundOpenInfo = (); + type OutboundProtocol = ReadyUpgrade; + type OutboundOpenInfo = (); fn listen_protocol(&self) -> SubstreamProtocol { - SubstreamProtocol::new(inbound_stop::Upgrade {}, ()) + SubstreamProtocol::new(ReadyUpgrade::new(STOP_PROTOCOL_NAME), ()) } fn on_behaviour_event(&mut self, event: Self::FromBehaviour) { match event { In::Reserve { to_listener } => { + self.wait_for_outbound_stream + .push_back(outbound_hop::OutboundStreamInfo::Reserve(to_listener)); self.queued_events .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - outbound_hop::Upgrade::Reserve, - OutboundOpenInfo::Reserve { to_listener }, - ), + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), }); } In::EstablishCircuit { send_back, dst_peer_id, } => { + self.wait_for_outbound_stream.push_back( + outbound_hop::OutboundStreamInfo::CircuitConnection( + outbound_hop::Command::new(dst_peer_id, send_back), + ), + ); self.queued_events .push_back(ConnectionHandlerEvent::OutboundSubstreamRequest { - protocol: SubstreamProtocol::new( - outbound_hop::Upgrade::Connect { dst_peer_id }, - OutboundOpenInfo::Connect { send_back }, - ), + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), }); } } @@ -458,38 +348,132 @@ impl ConnectionHandler for Handler { return Poll::Ready(ConnectionHandlerEvent::Close(err)); } + // Inbound circuits + loop { + match self.outbound_circuits.poll_unpin(cx) { + Poll::Ready(Ok(Ok(Either::Left(Ok(outbound_hop::Reservation { + renewal_timeout, + addrs, + limit, + to_listener, + }))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + self.reservation.accepted( + renewal_timeout, + addrs, + to_listener, + self.local_peer_id, + limit, + ), + )) + } + Poll::Ready(Ok(Ok(Either::Right(Ok(Some(outbound_hop::Circuit { limit })))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundCircuitEstablished { limit }, + )); + } + Poll::Ready(Ok(Ok(Either::Right(Ok(None))))) => continue, + Poll::Ready(Ok(Ok(Either::Right(Err(e))))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::OutboundCircuitReqFailed { + error: StreamUpgradeError::Apply(e), + }, + )); + } + Poll::Ready(Ok(Ok(Either::Left(Err(e))))) => { + let renewal = self.reservation.failed(); + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::ReservationReqFailed { + renewal, + error: StreamUpgradeError::Apply(e), + }, + )); + } + Poll::Ready(Ok(Err(e))) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Apply( + Either::Right(e), + ))) + } + Poll::Ready(Err(Timeout { .. })) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Timeout)); + } + Poll::Pending => break, + } + } + // Return queued events. if let Some(event) = self.queued_events.pop_front() { return Poll::Ready(event); } - if let Poll::Ready(Some(protocol)) = self.reservation.poll(cx) { - return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { protocol }); + if let Poll::Ready(worker_res) = self.open_circuit_futs.poll_unpin(cx) { + let res = match worker_res { + Ok(r) => r, + Err(Timeout { .. }) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Timeout)); + } + }; + + match res { + Ok(circuit) => match &mut self.reservation { + Reservation::Accepted { pending_msgs, .. } + | Reservation::Renewing { pending_msgs, .. } => { + let src_peer_id = circuit.src_peer_id(); + let limit = circuit.limit(); + + let (tx, rx) = oneshot::channel(); + self.alive_lend_out_substreams.push(rx); + let connection = super::ConnectionState::new_inbound(circuit, tx); + + pending_msgs.push_back( + transport::ToListenerMsg::IncomingRelayedConnection { + stream: super::Connection { state: connection }, + src_peer_id, + relay_peer_id: self.remote_peer_id, + relay_addr: self.remote_addr.clone(), + }, + ); + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundCircuitEstablished { src_peer_id, limit }, + )); + } + Reservation::None => { + self.insert_to_deny_futs(circuit); + } + }, + Err(e) => { + return Poll::Ready(ConnectionHandlerEvent::Close(StreamUpgradeError::Apply( + Either::Left(e), + ))); + } + } + } + + if let Poll::Ready(Some(to_listener)) = self.reservation.poll(cx) { + self.wait_for_outbound_stream + .push_back(outbound_hop::OutboundStreamInfo::Reserve(to_listener)); + + return Poll::Ready(ConnectionHandlerEvent::OutboundSubstreamRequest { + protocol: SubstreamProtocol::new(ReadyUpgrade::new(HOP_PROTOCOL_NAME), ()), + }); } // Deny incoming circuit requests. - let maybe_event = - self.circuit_deny_futs - .iter_mut() - .find_map(|(src_peer_id, fut)| match fut.poll_unpin(cx) { - Poll::Ready(Ok(())) => Some(( - *src_peer_id, - Event::InboundCircuitReqDenied { - src_peer_id: *src_peer_id, - }, - )), - Poll::Ready(Err(error)) => Some(( - *src_peer_id, - Event::InboundCircuitReqDenyFailed { - src_peer_id: *src_peer_id, - error, - }, - )), - Poll::Pending => None, - }); - if let Some((src_peer_id, event)) = maybe_event { - self.circuit_deny_futs.remove(&src_peer_id); - return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour(event)); + match self.circuit_deny_futs.poll_unpin(cx) { + Poll::Ready((src_peer_id, Ok(Ok(())))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundCircuitReqDenied { src_peer_id }, + )); + } + Poll::Ready((src_peer_id, Ok(Err(error)))) => { + return Poll::Ready(ConnectionHandlerEvent::NotifyBehaviour( + Event::InboundCircuitReqDenyFailed { src_peer_id, error }, + )); + } + Poll::Ready((src_peer_id, Err(Timeout { .. }))) => { + log::warn!("Dropping inbound circuit request to be denied from {:?} due to exceeding limit.", src_peer_id); + } + Poll::Pending => {} } // Send errors to transport. @@ -533,14 +517,62 @@ impl ConnectionHandler for Handler { >, ) { match event { - ConnectionEvent::FullyNegotiatedInbound(fully_negotiated_inbound) => { - self.on_fully_negotiated_inbound(fully_negotiated_inbound) + ConnectionEvent::FullyNegotiatedInbound(FullyNegotiatedInbound { + protocol: stream, + .. + }) => { + if self + .open_circuit_futs + .try_push(inbound_stop::handle_open_circuit(stream)) + .is_err() + { + log::warn!("Dropping inbound stream because we are at capacity") + } } - ConnectionEvent::FullyNegotiatedOutbound(fully_negotiated_outbound) => { - self.on_fully_negotiated_outbound(fully_negotiated_outbound) + ConnectionEvent::FullyNegotiatedOutbound(FullyNegotiatedOutbound { + protocol: stream, + .. + }) => { + let outbound_info = self.wait_for_outbound_stream.pop_front().expect( + "opened a stream without a pending connection command or a reserve listener", + ); + match outbound_info { + outbound_hop::OutboundStreamInfo::Reserve(to_listener) => { + if self + .outbound_circuits + .try_push( + outbound_hop::handle_reserve_message_response(stream, to_listener) + .map_ok(Either::Left), + ) + .is_err() + { + log::warn!("Dropping outbound stream because we are at capacity") + } + } + outbound_hop::OutboundStreamInfo::CircuitConnection(cmd) => { + let (tx, rx) = oneshot::channel(); + self.alive_lend_out_substreams.push(rx); + + if self + .outbound_circuits + .try_push( + outbound_hop::handle_connection_message_response( + stream, + self.remote_peer_id, + cmd, + tx, + ) + .map_ok(Either::Right), + ) + .is_err() + { + log::warn!("Dropping outbound stream because we are at capacity") + } + } + } } ConnectionEvent::ListenUpgradeError(listen_upgrade_error) => { - self.on_listen_upgrade_error(listen_upgrade_error) + void::unreachable(listen_upgrade_error.error) } ConnectionEvent::DialUpgradeError(dial_upgrade_error) => { self.on_dial_upgrade_error(dial_upgrade_error) @@ -648,7 +680,7 @@ impl Reservation { fn poll( &mut self, cx: &mut Context<'_>, - ) -> Poll>> { + ) -> Poll>> { self.forward_messages_to_transport_listener(cx); // Check renewal timeout if any. @@ -660,10 +692,7 @@ impl Reservation { } => match renewal_timeout.poll_unpin(cx) { Poll::Ready(()) => ( Reservation::Renewing { pending_msgs }, - Poll::Ready(Some(SubstreamProtocol::new( - outbound_hop::Upgrade::Reserve, - OutboundOpenInfo::Reserve { to_listener }, - ))), + Poll::Ready(Some(to_listener)), ), Poll::Pending => ( Reservation::Accepted { @@ -681,12 +710,3 @@ impl Reservation { poll_val } } - -pub enum OutboundOpenInfo { - Reserve { - to_listener: mpsc::Sender, - }, - Connect { - send_back: oneshot::Sender>, - }, -} diff --git a/protocols/relay/src/priv_client/transport.rs b/protocols/relay/src/priv_client/transport.rs index 45cc685aea04..41114d0cdd51 100644 --- a/protocols/relay/src/priv_client/transport.rs +++ b/protocols/relay/src/priv_client/transport.rs @@ -55,7 +55,7 @@ use thiserror::Error; /// # use libp2p_identity::PeerId; /// let actual_transport = MemoryTransport::default(); /// let (relay_transport, behaviour) = relay::client::new( -/// PeerId::random(), +/// PeerId::random() /// ); /// let mut transport = OrTransport::new(relay_transport, actual_transport); /// # let relay_id = PeerId::random(); @@ -80,7 +80,7 @@ use thiserror::Error; /// # let local_peer_id = PeerId::random(); /// let actual_transport = MemoryTransport::default(); /// let (relay_transport, behaviour) = relay::client::new( -/// local_peer_id, +/// local_peer_id /// ); /// let mut transport = OrTransport::new(relay_transport, actual_transport); /// let relay_addr = Multiaddr::empty() diff --git a/protocols/relay/src/protocol.rs b/protocols/relay/src/protocol.rs index f9b1e1ac0d18..b94151259cd8 100644 --- a/protocols/relay/src/protocol.rs +++ b/protocols/relay/src/protocol.rs @@ -31,7 +31,7 @@ pub const HOP_PROTOCOL_NAME: StreamProtocol = pub const STOP_PROTOCOL_NAME: StreamProtocol = StreamProtocol::new("/libp2p/circuit/relay/0.2.0/stop"); -const MAX_MESSAGE_SIZE: usize = 4096; +pub(crate) const MAX_MESSAGE_SIZE: usize = 4096; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Limit { diff --git a/protocols/relay/src/protocol/inbound_hop.rs b/protocols/relay/src/protocol/inbound_hop.rs index 27f2572a636b..b44d29e42ce0 100644 --- a/protocols/relay/src/protocol/inbound_hop.rs +++ b/protocols/relay/src/protocol/inbound_hop.rs @@ -18,79 +18,21 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::proto; -use crate::protocol::{HOP_PROTOCOL_NAME, MAX_MESSAGE_SIZE}; +use std::time::{Duration, SystemTime}; + use asynchronous_codec::{Framed, FramedParts}; use bytes::Bytes; -use futures::{future::BoxFuture, prelude::*}; -use instant::{Duration, SystemTime}; -use libp2p_core::{upgrade, Multiaddr}; -use libp2p_identity::PeerId; -use libp2p_swarm::{Stream, StreamProtocol}; -use std::convert::TryInto; -use std::iter; +use either::Either; +use futures::prelude::*; use thiserror::Error; -pub struct Upgrade { - pub reservation_duration: Duration, - pub max_circuit_duration: Duration, - pub max_circuit_bytes: u64, -} - -impl upgrade::UpgradeInfo for Upgrade { - type Info = StreamProtocol; - type InfoIter = iter::Once; - - fn protocol_info(&self) -> Self::InfoIter { - iter::once(HOP_PROTOCOL_NAME) - } -} - -impl upgrade::InboundUpgrade for Upgrade { - type Output = Req; - type Error = UpgradeError; - type Future = BoxFuture<'static, Result>; - - fn upgrade_inbound(self, substream: Stream, _: Self::Info) -> Self::Future { - let mut substream = Framed::new( - substream, - quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE), - ); - - async move { - let proto::HopMessage { - type_pb, - peer, - reservation: _, - limit: _, - status: _, - } = substream - .next() - .await - .ok_or(FatalUpgradeError::StreamClosed)??; - - let req = match type_pb { - proto::HopMessageType::RESERVE => Req::Reserve(ReservationReq { - substream, - reservation_duration: self.reservation_duration, - max_circuit_duration: self.max_circuit_duration, - max_circuit_bytes: self.max_circuit_bytes, - }), - proto::HopMessageType::CONNECT => { - let dst = PeerId::from_bytes(&peer.ok_or(FatalUpgradeError::MissingPeer)?.id) - .map_err(|_| FatalUpgradeError::ParsePeerId)?; - Req::Connect(CircuitReq { dst, substream }) - } - proto::HopMessageType::STATUS => { - return Err(FatalUpgradeError::UnexpectedTypeStatus.into()) - } - }; +use libp2p_core::Multiaddr; +use libp2p_identity::PeerId; +use libp2p_swarm::Stream; - Ok(req) - } - .boxed() - } -} +use crate::proto; +use crate::proto::message_v2::pb::mod_HopMessage::Type; +use crate::protocol::MAX_MESSAGE_SIZE; #[derive(Debug, Error)] pub enum UpgradeError { @@ -120,11 +62,6 @@ pub enum FatalUpgradeError { UnexpectedTypeStatus, } -pub enum Req { - Reserve(ReservationReq), - Connect(CircuitReq), -} - pub struct ReservationReq { substream: Framed>, reservation_duration: Duration, @@ -133,7 +70,7 @@ pub struct ReservationReq { } impl ReservationReq { - pub async fn accept(self, addrs: Vec) -> Result<(), UpgradeError> { + pub async fn accept(self, addrs: Vec) -> Result<(), FatalUpgradeError> { if addrs.is_empty() { log::debug!( "Accepting relay reservation without providing external addresses of local node. \ @@ -167,7 +104,7 @@ impl ReservationReq { self.send(msg).await } - pub async fn deny(self, status: proto::Status) -> Result<(), UpgradeError> { + pub async fn deny(self, status: proto::Status) -> Result<(), FatalUpgradeError> { let msg = proto::HopMessage { type_pb: proto::HopMessageType::STATUS, peer: None, @@ -179,7 +116,7 @@ impl ReservationReq { self.send(msg).await } - async fn send(mut self, msg: proto::HopMessage) -> Result<(), UpgradeError> { + async fn send(mut self, msg: proto::HopMessage) -> Result<(), FatalUpgradeError> { self.substream.send(msg).await?; self.substream.flush().await?; self.substream.close().await?; @@ -198,7 +135,7 @@ impl CircuitReq { self.dst } - pub async fn accept(mut self) -> Result<(Stream, Bytes), UpgradeError> { + pub async fn accept(mut self) -> Result<(Stream, Bytes), FatalUpgradeError> { let msg = proto::HopMessage { type_pb: proto::HopMessageType::STATUS, peer: None, @@ -223,7 +160,7 @@ impl CircuitReq { Ok((io, read_buffer.freeze())) } - pub async fn deny(mut self, status: proto::Status) -> Result<(), UpgradeError> { + pub async fn deny(mut self, status: proto::Status) -> Result<(), FatalUpgradeError> { let msg = proto::HopMessage { type_pb: proto::HopMessageType::STATUS, peer: None, @@ -242,3 +179,51 @@ impl CircuitReq { Ok(()) } } + +pub(crate) async fn handle_inbound_request( + io: Stream, + reservation_duration: Duration, + max_circuit_duration: Duration, + max_circuit_bytes: u64, +) -> Result, FatalUpgradeError> { + let mut substream = Framed::new(io, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE)); + + let res = substream.next().await; + + if let None | Some(Err(_)) = res { + return Err(FatalUpgradeError::StreamClosed); + } + + let proto::HopMessage { + type_pb, + peer, + reservation: _, + limit: _, + status: _, + } = res.unwrap().expect("should be ok"); + + let req = match type_pb { + Type::RESERVE => Either::Left(ReservationReq { + substream, + reservation_duration, + max_circuit_duration, + max_circuit_bytes, + }), + Type::CONNECT => { + let peer_id_res = match peer { + Some(r) => PeerId::from_bytes(&r.id), + None => return Err(FatalUpgradeError::MissingPeer), + }; + + let dst = match peer_id_res { + Ok(res) => res, + Err(_) => return Err(FatalUpgradeError::ParsePeerId), + }; + + Either::Right(CircuitReq { dst, substream }) + } + Type::STATUS => return Err(FatalUpgradeError::UnexpectedTypeStatus), + }; + + Ok(req) +} diff --git a/protocols/relay/src/protocol/inbound_stop.rs b/protocols/relay/src/protocol/inbound_stop.rs index c279c8ee6015..caaeee9cc533 100644 --- a/protocols/relay/src/protocol/inbound_stop.rs +++ b/protocols/relay/src/protocol/inbound_stop.rs @@ -19,66 +19,38 @@ // DEALINGS IN THE SOFTWARE. use crate::proto; -use crate::protocol::{self, MAX_MESSAGE_SIZE, STOP_PROTOCOL_NAME}; +use crate::protocol::{self, MAX_MESSAGE_SIZE}; use asynchronous_codec::{Framed, FramedParts}; use bytes::Bytes; -use futures::{future::BoxFuture, prelude::*}; -use libp2p_core::upgrade; +use futures::prelude::*; use libp2p_identity::PeerId; -use libp2p_swarm::{Stream, StreamProtocol}; -use std::iter; +use libp2p_swarm::Stream; use thiserror::Error; -pub struct Upgrade {} - -impl upgrade::UpgradeInfo for Upgrade { - type Info = StreamProtocol; - type InfoIter = iter::Once; - - fn protocol_info(&self) -> Self::InfoIter { - iter::once(STOP_PROTOCOL_NAME) - } -} - -impl upgrade::InboundUpgrade for Upgrade { - type Output = Circuit; - type Error = UpgradeError; - type Future = BoxFuture<'static, Result>; - - fn upgrade_inbound(self, substream: Stream, _: Self::Info) -> Self::Future { - let mut substream = Framed::new( - substream, - quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE), - ); - - async move { - let proto::StopMessage { - type_pb, - peer, - limit, - status: _, - } = substream - .next() - .await - .ok_or(FatalUpgradeError::StreamClosed)??; - - match type_pb { - proto::StopMessageType::CONNECT => { - let src_peer_id = - PeerId::from_bytes(&peer.ok_or(FatalUpgradeError::MissingPeer)?.id) - .map_err(|_| FatalUpgradeError::ParsePeerId)?; - Ok(Circuit { - substream, - src_peer_id, - limit: limit.map(Into::into), - }) - } - proto::StopMessageType::STATUS => { - Err(FatalUpgradeError::UnexpectedTypeStatus.into()) - } - } +pub(crate) async fn handle_open_circuit(io: Stream) -> Result { + let mut substream = Framed::new(io, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE)); + + let proto::StopMessage { + type_pb, + peer, + limit, + status: _, + } = substream + .next() + .await + .ok_or(FatalUpgradeError::StreamClosed)??; + + match type_pb { + proto::StopMessageType::CONNECT => { + let src_peer_id = PeerId::from_bytes(&peer.ok_or(FatalUpgradeError::MissingPeer)?.id) + .map_err(|_| FatalUpgradeError::ParsePeerId)?; + Ok(Circuit { + substream, + src_peer_id, + limit: limit.map(Into::into), + }) } - .boxed() + proto::StopMessageType::STATUS => Err(FatalUpgradeError::UnexpectedTypeStatus), } } @@ -110,22 +82,22 @@ pub enum FatalUpgradeError { UnexpectedTypeStatus, } -pub struct Circuit { +pub(crate) struct Circuit { substream: Framed>, src_peer_id: PeerId, limit: Option, } impl Circuit { - pub fn src_peer_id(&self) -> PeerId { + pub(crate) fn src_peer_id(&self) -> PeerId { self.src_peer_id } - pub fn limit(&self) -> Option { + pub(crate) fn limit(&self) -> Option { self.limit } - pub async fn accept(mut self) -> Result<(Stream, Bytes), UpgradeError> { + pub(crate) async fn accept(mut self) -> Result<(Stream, Bytes), UpgradeError> { let msg = proto::StopMessage { type_pb: proto::StopMessageType::STATUS, peer: None, @@ -149,7 +121,7 @@ impl Circuit { Ok((io, read_buffer.freeze())) } - pub async fn deny(mut self, status: proto::Status) -> Result<(), UpgradeError> { + pub(crate) async fn deny(mut self, status: proto::Status) -> Result<(), UpgradeError> { let msg = proto::StopMessage { type_pb: proto::StopMessageType::STATUS, peer: None, diff --git a/protocols/relay/src/protocol/outbound_hop.rs b/protocols/relay/src/protocol/outbound_hop.rs index bec348e87db7..adad0e23711d 100644 --- a/protocols/relay/src/protocol/outbound_hop.rs +++ b/protocols/relay/src/protocol/outbound_hop.rs @@ -18,201 +18,23 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::proto; -use crate::protocol::{Limit, HOP_PROTOCOL_NAME, MAX_MESSAGE_SIZE}; +use std::time::{Duration, SystemTime}; + use asynchronous_codec::{Framed, FramedParts}; -use bytes::Bytes; -use futures::{future::BoxFuture, prelude::*}; +use futures::channel::{mpsc, oneshot}; +use futures::prelude::*; use futures_timer::Delay; -use instant::{Duration, SystemTime}; -use libp2p_core::{upgrade, Multiaddr}; -use libp2p_identity::PeerId; -use libp2p_swarm::{Stream, StreamProtocol}; -use std::convert::TryFrom; -use std::iter; +use log::debug; use thiserror::Error; +use void::Void; -pub enum Upgrade { - Reserve, - Connect { dst_peer_id: PeerId }, -} - -impl upgrade::UpgradeInfo for Upgrade { - type Info = StreamProtocol; - type InfoIter = iter::Once; - - fn protocol_info(&self) -> Self::InfoIter { - iter::once(HOP_PROTOCOL_NAME) - } -} - -impl upgrade::OutboundUpgrade for Upgrade { - type Output = Output; - type Error = UpgradeError; - type Future = BoxFuture<'static, Result>; - - fn upgrade_outbound(self, substream: Stream, _: Self::Info) -> Self::Future { - let msg = match self { - Upgrade::Reserve => proto::HopMessage { - type_pb: proto::HopMessageType::RESERVE, - peer: None, - reservation: None, - limit: None, - status: None, - }, - Upgrade::Connect { dst_peer_id } => proto::HopMessage { - type_pb: proto::HopMessageType::CONNECT, - peer: Some(proto::Peer { - id: dst_peer_id.to_bytes(), - addrs: vec![], - }), - reservation: None, - limit: None, - status: None, - }, - }; - - let mut substream = Framed::new( - substream, - quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE), - ); - - async move { - substream.send(msg).await?; - let proto::HopMessage { - type_pb, - peer: _, - reservation, - limit, - status, - } = substream - .next() - .await - .ok_or(FatalUpgradeError::StreamClosed)??; - - match type_pb { - proto::HopMessageType::CONNECT => { - return Err(FatalUpgradeError::UnexpectedTypeConnect.into()) - } - proto::HopMessageType::RESERVE => { - return Err(FatalUpgradeError::UnexpectedTypeReserve.into()) - } - proto::HopMessageType::STATUS => {} - } - - let limit = limit.map(Into::into); - - let output = match self { - Upgrade::Reserve => { - match status - .ok_or(UpgradeError::Fatal(FatalUpgradeError::MissingStatusField))? - { - proto::Status::OK => {} - proto::Status::RESERVATION_REFUSED => { - return Err(ReservationFailedReason::Refused.into()) - } - proto::Status::RESOURCE_LIMIT_EXCEEDED => { - return Err(ReservationFailedReason::ResourceLimitExceeded.into()) - } - s => return Err(FatalUpgradeError::UnexpectedStatus(s).into()), - } - - let reservation = - reservation.ok_or(FatalUpgradeError::MissingReservationField)?; - - if reservation.addrs.is_empty() { - return Err(FatalUpgradeError::NoAddressesInReservation.into()); - } - - let addrs = reservation - .addrs - .into_iter() - .map(|b| Multiaddr::try_from(b.to_vec())) - .collect::, _>>() - .map_err(|_| FatalUpgradeError::InvalidReservationAddrs)?; - - let renewal_timeout = reservation - .expire - .checked_sub( - SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs(), - ) - // Renew the reservation after 3/4 of the reservation expiration timestamp. - .and_then(|duration| duration.checked_sub(duration / 4)) - .map(Duration::from_secs) - .map(Delay::new) - .ok_or(FatalUpgradeError::InvalidReservationExpiration)?; - - substream.close().await?; - - Output::Reservation { - renewal_timeout, - addrs, - limit, - } - } - Upgrade::Connect { .. } => { - match status - .ok_or(UpgradeError::Fatal(FatalUpgradeError::MissingStatusField))? - { - proto::Status::OK => {} - proto::Status::RESOURCE_LIMIT_EXCEEDED => { - return Err(CircuitFailedReason::ResourceLimitExceeded.into()) - } - proto::Status::CONNECTION_FAILED => { - return Err(CircuitFailedReason::ConnectionFailed.into()) - } - proto::Status::NO_RESERVATION => { - return Err(CircuitFailedReason::NoReservation.into()) - } - proto::Status::PERMISSION_DENIED => { - return Err(CircuitFailedReason::PermissionDenied.into()) - } - s => return Err(FatalUpgradeError::UnexpectedStatus(s).into()), - } - - let FramedParts { - io, - read_buffer, - write_buffer, - .. - } = substream.into_parts(); - assert!( - write_buffer.is_empty(), - "Expect a flushed Framed to have empty write buffer." - ); - - Output::Circuit { - substream: io, - read_buffer: read_buffer.freeze(), - limit, - } - } - }; - - Ok(output) - } - .boxed() - } -} - -#[derive(Debug, Error)] -pub enum UpgradeError { - #[error("Reservation failed")] - ReservationFailed(#[from] ReservationFailedReason), - #[error("Circuit failed")] - CircuitFailed(#[from] CircuitFailedReason), - #[error("Fatal")] - Fatal(#[from] FatalUpgradeError), -} +use libp2p_core::Multiaddr; +use libp2p_identity::PeerId; +use libp2p_swarm::Stream; -impl From for UpgradeError { - fn from(error: quick_protobuf_codec::Error) -> Self { - Self::Fatal(error.into()) - } -} +use crate::priv_client::transport; +use crate::protocol::{Limit, MAX_MESSAGE_SIZE}; +use crate::{priv_client, proto}; #[derive(Debug, Error)] pub enum CircuitFailedReason { @@ -262,15 +84,216 @@ pub enum FatalUpgradeError { UnexpectedStatus(proto::Status), } -pub enum Output { - Reservation { - renewal_timeout: Delay, - addrs: Vec, - limit: Option, - }, - Circuit { - substream: Stream, - read_buffer: Bytes, - limit: Option, - }, +pub(crate) struct Reservation { + pub(crate) renewal_timeout: Delay, + pub(crate) addrs: Vec, + pub(crate) limit: Option, + pub(crate) to_listener: mpsc::Sender, +} + +pub(crate) struct Circuit { + pub(crate) limit: Option, +} + +pub(crate) async fn handle_reserve_message_response( + protocol: Stream, + to_listener: mpsc::Sender, +) -> Result, FatalUpgradeError> { + let msg = proto::HopMessage { + type_pb: proto::HopMessageType::RESERVE, + peer: None, + reservation: None, + limit: None, + status: None, + }; + let mut substream = Framed::new(protocol, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE)); + + substream.send(msg).await?; + + let proto::HopMessage { + type_pb, + peer: _, + reservation, + limit, + status, + } = substream + .next() + .await + .ok_or(FatalUpgradeError::StreamClosed)??; + + match type_pb { + proto::HopMessageType::CONNECT => { + return Err(FatalUpgradeError::UnexpectedTypeConnect); + } + proto::HopMessageType::RESERVE => { + return Err(FatalUpgradeError::UnexpectedTypeReserve); + } + proto::HopMessageType::STATUS => {} + } + + let limit = limit.map(Into::into); + + match status.ok_or(FatalUpgradeError::MissingStatusField)? { + proto::Status::OK => {} + proto::Status::RESERVATION_REFUSED => { + return Ok(Err(ReservationFailedReason::Refused)); + } + proto::Status::RESOURCE_LIMIT_EXCEEDED => { + return Ok(Err(ReservationFailedReason::ResourceLimitExceeded)); + } + s => return Err(FatalUpgradeError::UnexpectedStatus(s)), + } + + let reservation = reservation.ok_or(FatalUpgradeError::MissingReservationField)?; + + if reservation.addrs.is_empty() { + return Err(FatalUpgradeError::NoAddressesInReservation); + } + + let addrs = reservation + .addrs + .into_iter() + .map(|b| Multiaddr::try_from(b.to_vec())) + .collect::, _>>() + .map_err(|_| FatalUpgradeError::InvalidReservationAddrs)?; + + let renewal_timeout = reservation + .expire + .checked_sub( + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + ) + // Renew the reservation after 3/4 of the reservation expiration timestamp. + .and_then(|duration| duration.checked_sub(duration / 4)) + .map(Duration::from_secs) + .map(Delay::new) + .ok_or(FatalUpgradeError::InvalidReservationExpiration)?; + + substream.close().await?; + + Ok(Ok(Reservation { + renewal_timeout, + addrs, + limit, + to_listener, + })) +} + +pub(crate) async fn handle_connection_message_response( + protocol: Stream, + remote_peer_id: PeerId, + con_command: Command, + tx: oneshot::Sender, +) -> Result, CircuitFailedReason>, FatalUpgradeError> { + let msg = proto::HopMessage { + type_pb: proto::HopMessageType::CONNECT, + peer: Some(proto::Peer { + id: con_command.dst_peer_id.to_bytes(), + addrs: vec![], + }), + reservation: None, + limit: None, + status: None, + }; + + let mut substream = Framed::new(protocol, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE)); + + if substream.send(msg).await.is_err() { + return Err(FatalUpgradeError::StreamClosed); + } + + let proto::HopMessage { + type_pb, + peer: _, + reservation: _, + limit, + status, + } = match substream.next().await { + Some(Ok(r)) => r, + _ => return Err(FatalUpgradeError::StreamClosed), + }; + + match type_pb { + proto::HopMessageType::CONNECT => { + return Err(FatalUpgradeError::UnexpectedTypeConnect); + } + proto::HopMessageType::RESERVE => { + return Err(FatalUpgradeError::UnexpectedTypeReserve); + } + proto::HopMessageType::STATUS => {} + } + + match status { + Some(proto::Status::OK) => {} + Some(proto::Status::RESOURCE_LIMIT_EXCEEDED) => { + return Ok(Err(CircuitFailedReason::ResourceLimitExceeded)); + } + Some(proto::Status::CONNECTION_FAILED) => { + return Ok(Err(CircuitFailedReason::ConnectionFailed)); + } + Some(proto::Status::NO_RESERVATION) => { + return Ok(Err(CircuitFailedReason::NoReservation)); + } + Some(proto::Status::PERMISSION_DENIED) => { + return Ok(Err(CircuitFailedReason::PermissionDenied)); + } + Some(s) => { + return Err(FatalUpgradeError::UnexpectedStatus(s)); + } + None => { + return Err(FatalUpgradeError::MissingStatusField); + } + } + + let limit = limit.map(Into::into); + + let FramedParts { + io, + read_buffer, + write_buffer, + .. + } = substream.into_parts(); + assert!( + write_buffer.is_empty(), + "Expect a flushed Framed to have empty write buffer." + ); + + match con_command.send_back.send(Ok(priv_client::Connection { + state: priv_client::ConnectionState::new_outbound(io, read_buffer.freeze(), tx), + })) { + Ok(()) => Ok(Ok(Some(Circuit { limit }))), + Err(_) => { + debug!( + "Oneshot to `client::transport::Dial` future dropped. \ + Dropping established relayed connection to {:?}.", + remote_peer_id, + ); + + Ok(Ok(None)) + } + } +} + +pub(crate) enum OutboundStreamInfo { + Reserve(mpsc::Sender), + CircuitConnection(Command), +} + +pub(crate) struct Command { + dst_peer_id: PeerId, + pub(crate) send_back: oneshot::Sender>, +} + +impl Command { + pub(crate) fn new( + dst_peer_id: PeerId, + send_back: oneshot::Sender>, + ) -> Self { + Self { + dst_peer_id, + send_back, + } + } } diff --git a/protocols/relay/src/protocol/outbound_stop.rs b/protocols/relay/src/protocol/outbound_stop.rs index 836468a86053..e45029579959 100644 --- a/protocols/relay/src/protocol/outbound_stop.rs +++ b/protocols/relay/src/protocol/outbound_stop.rs @@ -18,112 +18,23 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -use crate::proto; -use crate::protocol::{MAX_MESSAGE_SIZE, STOP_PROTOCOL_NAME}; +use std::time::Duration; + use asynchronous_codec::{Framed, FramedParts}; use bytes::Bytes; -use futures::{future::BoxFuture, prelude::*}; -use libp2p_core::upgrade; -use libp2p_identity::PeerId; -use libp2p_swarm::{Stream, StreamProtocol}; -use std::convert::TryInto; -use std::iter; -use std::time::Duration; +use futures::channel::oneshot::{self}; +use futures::prelude::*; use thiserror::Error; -pub struct Upgrade { - pub src_peer_id: PeerId, - pub max_circuit_duration: Duration, - pub max_circuit_bytes: u64, -} - -impl upgrade::UpgradeInfo for Upgrade { - type Info = StreamProtocol; - type InfoIter = iter::Once; +use libp2p_identity::PeerId; +use libp2p_swarm::{ConnectionId, Stream, StreamUpgradeError}; - fn protocol_info(&self) -> Self::InfoIter { - iter::once(STOP_PROTOCOL_NAME) - } -} - -impl upgrade::OutboundUpgrade for Upgrade { - type Output = (Stream, Bytes); - type Error = UpgradeError; - type Future = BoxFuture<'static, Result>; - - fn upgrade_outbound(self, substream: Stream, _: Self::Info) -> Self::Future { - let msg = proto::StopMessage { - type_pb: proto::StopMessageType::CONNECT, - peer: Some(proto::Peer { - id: self.src_peer_id.to_bytes(), - addrs: vec![], - }), - limit: Some(proto::Limit { - duration: Some( - self.max_circuit_duration - .as_secs() - .try_into() - .expect("`max_circuit_duration` not to exceed `u32::MAX`."), - ), - data: Some(self.max_circuit_bytes), - }), - status: None, - }; - - let mut substream = Framed::new( - substream, - quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE), - ); - - async move { - substream.send(msg).await?; - let proto::StopMessage { - type_pb, - peer: _, - limit: _, - status, - } = substream - .next() - .await - .ok_or(FatalUpgradeError::StreamClosed)??; - - match type_pb { - proto::StopMessageType::CONNECT => { - return Err(FatalUpgradeError::UnexpectedTypeConnect.into()) - } - proto::StopMessageType::STATUS => {} - } - - match status.ok_or(UpgradeError::Fatal(FatalUpgradeError::MissingStatusField))? { - proto::Status::OK => {} - proto::Status::RESOURCE_LIMIT_EXCEEDED => { - return Err(CircuitFailedReason::ResourceLimitExceeded.into()) - } - proto::Status::PERMISSION_DENIED => { - return Err(CircuitFailedReason::PermissionDenied.into()) - } - s => return Err(FatalUpgradeError::UnexpectedStatus(s).into()), - } - - let FramedParts { - io, - read_buffer, - write_buffer, - .. - } = substream.into_parts(); - assert!( - write_buffer.is_empty(), - "Expect a flushed Framed to have an empty write buffer." - ); - - Ok((io, read_buffer.freeze())) - } - .boxed() - } -} +use crate::behaviour::handler::Config; +use crate::protocol::{inbound_hop, MAX_MESSAGE_SIZE}; +use crate::{proto, CircuitId}; #[derive(Debug, Error)] -pub enum UpgradeError { +pub(crate) enum UpgradeError { #[error("Circuit failed")] CircuitFailed(#[from] CircuitFailedReason), #[error("Fatal")] @@ -161,3 +72,147 @@ pub enum FatalUpgradeError { #[error("Unexpected message status '{0:?}'")] UnexpectedStatus(proto::Status), } + +/// Attempts to _connect_ to a peer via the given stream. +pub(crate) async fn connect( + io: Stream, + stop_command: PendingConnect, + tx: oneshot::Sender<()>, +) -> Result, FatalUpgradeError> { + let msg = proto::StopMessage { + type_pb: proto::StopMessageType::CONNECT, + peer: Some(proto::Peer { + id: stop_command.src_peer_id.to_bytes(), + addrs: vec![], + }), + limit: Some(proto::Limit { + duration: Some( + stop_command + .max_circuit_duration + .as_secs() + .try_into() + .expect("`max_circuit_duration` not to exceed `u32::MAX`."), + ), + data: Some(stop_command.max_circuit_bytes), + }), + status: None, + }; + + let mut substream = Framed::new(io, quick_protobuf_codec::Codec::new(MAX_MESSAGE_SIZE)); + + if substream.send(msg).await.is_err() { + return Err(FatalUpgradeError::StreamClosed); + } + + let res = substream.next().await; + + if let None | Some(Err(_)) = res { + return Err(FatalUpgradeError::StreamClosed); + } + + let proto::StopMessage { + type_pb, + peer: _, + limit: _, + status, + } = res.unwrap().expect("should be ok"); + + match type_pb { + proto::StopMessageType::CONNECT => return Err(FatalUpgradeError::UnexpectedTypeConnect), + proto::StopMessageType::STATUS => {} + } + + match status { + Some(proto::Status::OK) => {} + Some(proto::Status::RESOURCE_LIMIT_EXCEEDED) => { + return Ok(Err(CircuitFailed { + circuit_id: stop_command.circuit_id, + src_peer_id: stop_command.src_peer_id, + src_connection_id: stop_command.src_connection_id, + inbound_circuit_req: stop_command.inbound_circuit_req, + status: proto::Status::RESOURCE_LIMIT_EXCEEDED, + error: StreamUpgradeError::Apply(CircuitFailedReason::ResourceLimitExceeded), + })) + } + Some(proto::Status::PERMISSION_DENIED) => { + return Ok(Err(CircuitFailed { + circuit_id: stop_command.circuit_id, + src_peer_id: stop_command.src_peer_id, + src_connection_id: stop_command.src_connection_id, + inbound_circuit_req: stop_command.inbound_circuit_req, + status: proto::Status::PERMISSION_DENIED, + error: StreamUpgradeError::Apply(CircuitFailedReason::PermissionDenied), + })) + } + Some(s) => return Err(FatalUpgradeError::UnexpectedStatus(s)), + None => return Err(FatalUpgradeError::MissingStatusField), + } + + let FramedParts { + io, + read_buffer, + write_buffer, + .. + } = substream.into_parts(); + assert!( + write_buffer.is_empty(), + "Expect a flushed Framed to have an empty write buffer." + ); + + Ok(Ok(Circuit { + circuit_id: stop_command.circuit_id, + src_peer_id: stop_command.src_peer_id, + src_connection_id: stop_command.src_connection_id, + inbound_circuit_req: stop_command.inbound_circuit_req, + dst_handler_notifier: tx, + dst_stream: io, + dst_pending_data: read_buffer.freeze(), + })) +} + +pub(crate) struct Circuit { + pub(crate) circuit_id: CircuitId, + pub(crate) src_peer_id: PeerId, + pub(crate) src_connection_id: ConnectionId, + pub(crate) inbound_circuit_req: inbound_hop::CircuitReq, + pub(crate) dst_handler_notifier: oneshot::Sender<()>, + pub(crate) dst_stream: Stream, + pub(crate) dst_pending_data: Bytes, +} + +pub(crate) struct CircuitFailed { + pub(crate) circuit_id: CircuitId, + pub(crate) src_peer_id: PeerId, + pub(crate) src_connection_id: ConnectionId, + pub(crate) inbound_circuit_req: inbound_hop::CircuitReq, + pub(crate) status: proto::Status, + pub(crate) error: StreamUpgradeError, +} + +pub(crate) struct PendingConnect { + pub(crate) circuit_id: CircuitId, + pub(crate) inbound_circuit_req: inbound_hop::CircuitReq, + pub(crate) src_peer_id: PeerId, + pub(crate) src_connection_id: ConnectionId, + max_circuit_duration: Duration, + max_circuit_bytes: u64, +} + +impl PendingConnect { + pub(crate) fn new( + circuit_id: CircuitId, + inbound_circuit_req: inbound_hop::CircuitReq, + src_peer_id: PeerId, + src_connection_id: ConnectionId, + config: &Config, + ) -> Self { + Self { + circuit_id, + inbound_circuit_req, + src_peer_id, + src_connection_id, + max_circuit_duration: config.max_circuit_duration, + max_circuit_bytes: config.max_circuit_bytes, + } + } +}