diff --git a/protocols/request-response/Cargo.toml b/protocols/request-response/Cargo.toml index edd17db1ad3..ed04eb1e38d 100644 --- a/protocols/request-response/Cargo.toml +++ b/protocols/request-response/Cargo.toml @@ -11,13 +11,15 @@ categories = ["network-programming", "asynchronous"] [dependencies] async-trait = "0.1" +bytes = "0.5.6" futures = "0.3.1" libp2p-core = { version = "0.22.0", path = "../../core" } libp2p-swarm = { version = "0.22.0", path = "../../swarm" } log = "0.4.11" -lru = "0.6" +minicbor = { version = "0.5", features = ["std", "derive"] } rand = "0.7" smallvec = "1.4" +unsigned-varint = { version = "0.5", features = ["std", "futures"] } wasm-timer = "0.2" [dev-dependencies] diff --git a/protocols/request-response/src/codec.rs b/protocols/request-response/src/codec.rs index da85b277d81..bbb708081dc 100644 --- a/protocols/request-response/src/codec.rs +++ b/protocols/request-response/src/codec.rs @@ -64,3 +64,4 @@ pub trait RequestResponseCodec { where T: AsyncWrite + Unpin + Send; } + diff --git a/protocols/request-response/src/handler.rs b/protocols/request-response/src/handler.rs index 392988d322c..fe374f54877 100644 --- a/protocols/request-response/src/handler.rs +++ b/protocols/request-response/src/handler.rs @@ -47,6 +47,7 @@ use smallvec::SmallVec; use std::{ collections::VecDeque, io, + sync::{atomic::{AtomicU64, Ordering}, Arc}, time::Duration, task::{Context, Poll} }; @@ -79,9 +80,10 @@ where /// Inbound upgrades waiting for the incoming request. inbound: FuturesUnordered), + ((RequestId, TCodec::Request), oneshot::Sender), oneshot::Canceled >>>, + inbound_request_id: Arc } impl RequestResponseHandler @@ -93,6 +95,7 @@ where codec: TCodec, keep_alive_timeout: Duration, substream_timeout: Duration, + inbound_request_id: Arc ) -> Self { Self { inbound_protocols, @@ -104,6 +107,7 @@ where inbound: FuturesUnordered::new(), pending_events: VecDeque::new(), pending_error: None, + inbound_request_id } } } @@ -117,6 +121,7 @@ where { /// An inbound request. Request { + request_id: RequestId, request: TCodec::Request, sender: oneshot::Sender }, @@ -130,9 +135,9 @@ where /// An outbound request failed to negotiate a mutually supported protocol. OutboundUnsupportedProtocols(RequestId), /// An inbound request timed out. - InboundTimeout, + InboundTimeout(RequestId), /// An inbound request failed to negotiate a mutually supported protocol. - InboundUnsupportedProtocols, + InboundUnsupportedProtocols(RequestId), } impl ProtocolsHandler for RequestResponseHandler @@ -145,7 +150,7 @@ where type InboundProtocol = ResponseProtocol; type OutboundProtocol = RequestProtocol; type OutboundOpenInfo = RequestId; - type InboundOpenInfo = (); + type InboundOpenInfo = RequestId; fn listen_protocol(&self) -> SubstreamProtocol { // A channel for notifying the handler when the inbound @@ -156,6 +161,8 @@ where // response is sent. let (rs_send, rs_recv) = oneshot::channel(); + let request_id = RequestId(self.inbound_request_id.fetch_add(1, Ordering::Relaxed)); + // By keeping all I/O inside the `ResponseProtocol` and thus the // inbound substream upgrade via above channels, we ensure that it // is all subject to the configured timeout without extra bookkeeping @@ -167,6 +174,7 @@ where codec: self.codec.clone(), request_sender: rq_send, response_receiver: rs_recv, + request_id }; // The handler waits for the request to come in. It then emits @@ -174,16 +182,14 @@ where // `ResponseChannel`. self.inbound.push(rq_recv.map_ok(move |rq| (rq, rs_send)).boxed()); - SubstreamProtocol::new(proto, ()).with_timeout(self.substream_timeout) + SubstreamProtocol::new(proto, request_id).with_timeout(self.substream_timeout) } fn inject_fully_negotiated_inbound( &mut self, (): (), - (): () + _: RequestId ) { - // Nothing to do, as the response has already been sent - // as part of the upgrade. } fn inject_fully_negotiated_outbound( @@ -231,13 +237,12 @@ where fn inject_listen_upgrade_error( &mut self, - (): Self::InboundOpenInfo, + info: RequestId, error: ProtocolsHandlerUpgrErr ) { match error { ProtocolsHandlerUpgrErr::Timeout => { - self.pending_events.push_back( - RequestResponseHandlerEvent::InboundTimeout); + self.pending_events.push_back(RequestResponseHandlerEvent::InboundTimeout(info)) } ProtocolsHandlerUpgrErr::Upgrade(UpgradeError::Select(NegotiationError::Failed)) => { // The local peer merely doesn't support the protocol(s) requested. @@ -246,7 +251,7 @@ where // An event is reported to permit user code to react to the fact that // the local peer does not support the requested protocol(s). self.pending_events.push_back( - RequestResponseHandlerEvent::InboundUnsupportedProtocols); + RequestResponseHandlerEvent::InboundUnsupportedProtocols(info)); } _ => { // Anything else is considered a fatal error or misbehaviour of @@ -282,12 +287,12 @@ where // Check for inbound requests. while let Poll::Ready(Some(result)) = self.inbound.poll_next_unpin(cx) { match result { - Ok((rq, rs_sender)) => { + Ok(((id, rq), rs_sender)) => { // We received an inbound request. self.keep_alive = KeepAlive::Yes; return Poll::Ready(ProtocolsHandlerEvent::Custom( RequestResponseHandlerEvent::Request { - request: rq, sender: rs_sender + request_id: id, request: rq, sender: rs_sender })) } Err(oneshot::Canceled) => { diff --git a/protocols/request-response/src/handler/protocol.rs b/protocols/request-response/src/handler/protocol.rs index bbd0b80f953..0fc2b99df9f 100644 --- a/protocols/request-response/src/handler/protocol.rs +++ b/protocols/request-response/src/handler/protocol.rs @@ -71,8 +71,10 @@ where { pub(crate) codec: TCodec, pub(crate) protocols: SmallVec<[TCodec::Protocol; 2]>, - pub(crate) request_sender: oneshot::Sender, - pub(crate) response_receiver: oneshot::Receiver + pub(crate) request_sender: oneshot::Sender<(RequestId, TCodec::Request)>, + pub(crate) response_receiver: oneshot::Receiver, + pub(crate) request_id: RequestId + } impl UpgradeInfo for ResponseProtocol @@ -99,7 +101,7 @@ where async move { let read = self.codec.read_request(&protocol, &mut io); let request = read.await?; - if let Ok(()) = self.request_sender.send(request) { + if let Ok(()) = self.request_sender.send((self.request_id, request)) { if let Ok(response) = self.response_receiver.await { let write = self.codec.write_response(&protocol, &mut io, response); write.await?; diff --git a/protocols/request-response/src/lib.rs b/protocols/request-response/src/lib.rs index e7a728b3425..31ee6af9304 100644 --- a/protocols/request-response/src/lib.rs +++ b/protocols/request-response/src/lib.rs @@ -70,13 +70,11 @@ pub mod codec; pub mod handler; - -// Disabled until #1706 is fixed: -// pub mod throttled; -// pub use throttled::Throttled; +pub mod throttled; pub use codec::{RequestResponseCodec, ProtocolName}; pub use handler::ProtocolSupport; +pub use throttled::Throttled; use futures::{ channel::oneshot, @@ -102,21 +100,25 @@ use libp2p_swarm::{ use smallvec::SmallVec; use std::{ collections::{VecDeque, HashMap}, + fmt, time::Duration, + sync::{atomic::AtomicU64, Arc}, task::{Context, Poll} }; /// An inbound request or response. #[derive(Debug)] -pub enum RequestResponseMessage { +pub enum RequestResponseMessage { /// A request message. Request { + /// The ID of this request. + request_id: RequestId, /// The request message. request: TRequest, /// The sender of the request who is awaiting a response. /// /// See [`RequestResponse::send_response`]. - channel: ResponseChannel, + channel: ResponseChannel, }, /// A response message. Response { @@ -131,13 +133,13 @@ pub enum RequestResponseMessage { /// The events emitted by a [`RequestResponse`] protocol. #[derive(Debug)] -pub enum RequestResponseEvent { +pub enum RequestResponseEvent { /// An incoming message (request or response). Message { /// The peer who sent the message. peer: PeerId, /// The incoming message. - message: RequestResponseMessage + message: RequestResponseMessage }, /// An outbound request failed. OutboundFailure { @@ -152,6 +154,8 @@ pub enum RequestResponseEvent { InboundFailure { /// The peer from whom the request was received. peer: PeerId, + /// The ID of the failed inbound request. + request_id: RequestId, /// The error that occurred. error: InboundFailure, }, @@ -188,6 +192,8 @@ pub enum InboundFailure { Timeout, /// The local peer supports none of the requested protocols. UnsupportedProtocols, + /// The connection closed before a response was delivered. + ConnectionClosed, } /// A channel for sending a response to an inbound request. @@ -195,6 +201,7 @@ pub enum InboundFailure { /// See [`RequestResponse::send_response`]. #[derive(Debug)] pub struct ResponseChannel { + request_id: RequestId, peer: PeerId, sender: oneshot::Sender, } @@ -210,14 +217,23 @@ impl ResponseChannel { pub fn is_open(&self) -> bool { !self.sender.is_canceled() } + + /// Get the ID of the inbound request waiting for a response. + pub(crate) fn request_id(&self) -> RequestId { + self.request_id + } } -/// The (local) ID of an outgoing request. -/// -/// See [`RequestResponse::send_request`]. +/// The ID of an inbound or outbound request. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct RequestId(u64); +impl fmt::Display for RequestId { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + /// The configuration for a `RequestResponse` protocol. #[derive(Debug, Clone)] pub struct RequestResponseConfig { @@ -259,6 +275,8 @@ where outbound_protocols: SmallVec<[TCodec::Protocol; 2]>, /// The next (local) request ID. next_request_id: RequestId, + /// The next (inbound) request ID. + next_inbound_id: Arc, /// The protocol configuration. config: RequestResponseConfig, /// The protocol codec for reading and writing requests and responses. @@ -276,7 +294,7 @@ where /// to be established. pending_requests: HashMap; 10]>>, /// Responses that have not yet been received. - pending_responses: HashMap, + pending_responses: HashMap } impl RequestResponse @@ -303,6 +321,7 @@ where inbound_protocols, outbound_protocols, next_request_id: RequestId(1), + next_inbound_id: Arc::new(AtomicU64::new(1)), config: cfg, codec, pending_events: VecDeque::new(), @@ -313,11 +332,18 @@ where } } -// Disabled until #1706 is fixed. -// /// Wrap this behaviour in [`Throttled`] to limit the number of concurrent requests per peer. -// pub fn throttled(self) -> Throttled { -// Throttled::new(self) -// } + /// Creates a `RequestResponse` which limits requests per peer. + /// + /// The behaviour is wrapped in [`Throttled`] and detects the limits + /// per peer at runtime which are then enforced. + pub fn throttled(c: TCodec, protos: I, cfg: RequestResponseConfig) -> Throttled + where + I: IntoIterator, + TCodec: Send, + TCodec::Protocol: Sync + { + Throttled::new(c, protos, cfg) + } /// Initiates sending a request. /// @@ -389,13 +415,17 @@ where /// Checks whether a peer is currently connected. pub fn is_connected(&self, peer: &PeerId) -> bool { - self.connected.contains_key(peer) + if let Some(connections) = self.connected.get(peer) { + !connections.is_empty() + } else { + false + } } /// Checks whether an outbound request initiated by /// [`RequestResponse::send_request`] is still pending, i.e. waiting /// for a response. - pub fn is_pending(&self, req_id: &RequestId) -> bool { + pub fn is_pending_outbound(&self, req_id: &RequestId) -> bool { self.pending_responses.contains_key(req_id) } @@ -413,6 +443,9 @@ where -> Option> { if let Some(connections) = self.connected.get(peer) { + if connections.is_empty() { + return Some(request) + } let ix = (request.request_id.0 as usize) % connections.len(); let conn = connections[ix].id; self.pending_responses.insert(request.request_id, (peer.clone(), conn)); @@ -441,6 +474,7 @@ where self.codec.clone(), self.config.connection_keep_alive, self.config.request_timeout, + self.next_inbound_id.clone() ) } @@ -480,27 +514,22 @@ where } } - // Any pending responses of requests sent over this connection - // must be considered failed. - let failed = self.pending_responses.iter() - .filter_map(|(r, (p, c))| - if conn == c { - Some((p.clone(), *r)) - } else { - None - }) - .collect::>(); - - for (peer, request_id) in failed { - self.pending_responses.remove(&request_id); - self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( + let pending_events = &mut self.pending_events; + + // Any pending responses of requests sent over this connection must be considered failed. + self.pending_responses.retain(|rid, (peer, cid)| { + if conn != cid { + return true + } + pending_events.push_back(NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::OutboundFailure { - peer, - request_id, + peer: peer.clone(), + request_id: *rid, error: OutboundFailure::ConnectionClosed } )); - } + false + }); } fn inject_disconnected(&mut self, peer: &PeerId) { @@ -541,12 +570,12 @@ where NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::Message { peer, message })); } - RequestResponseHandlerEvent::Request { request, sender } => { - let channel = ResponseChannel { peer: peer.clone(), sender }; - let message = RequestResponseMessage::Request { request, channel }; - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::Message { peer, message })); + RequestResponseHandlerEvent::Request { request_id, request, sender } => { + let channel = ResponseChannel { request_id, peer: peer.clone(), sender }; + let message = RequestResponseMessage::Request { request_id, request, channel }; + self.pending_events.push_back(NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::Message { peer, message } + )); } RequestResponseHandlerEvent::OutboundTimeout(request_id) => { if let Some((peer, _conn)) = self.pending_responses.remove(&request_id) { @@ -559,13 +588,14 @@ where })); } } - RequestResponseHandlerEvent::InboundTimeout => { - self.pending_events.push_back( - NetworkBehaviourAction::GenerateEvent( - RequestResponseEvent::InboundFailure { - peer, - error: InboundFailure::Timeout, - })); + RequestResponseHandlerEvent::InboundTimeout(request_id) => { + self.pending_events.push_back( + NetworkBehaviourAction::GenerateEvent( + RequestResponseEvent::InboundFailure { + peer, + request_id, + error: InboundFailure::Timeout, + })); } RequestResponseHandlerEvent::OutboundUnsupportedProtocols(request_id) => { self.pending_events.push_back( @@ -576,11 +606,12 @@ where error: OutboundFailure::UnsupportedProtocols, })); } - RequestResponseHandlerEvent::InboundUnsupportedProtocols => { + RequestResponseHandlerEvent::InboundUnsupportedProtocols(request_id) => { self.pending_events.push_back( NetworkBehaviourAction::GenerateEvent( RequestResponseEvent::InboundFailure { peer, + request_id, error: InboundFailure::UnsupportedProtocols, })); } diff --git a/protocols/request-response/src/throttled.rs b/protocols/request-response/src/throttled.rs index 990c8665484..1214ce4baff 100644 --- a/protocols/request-response/src/throttled.rs +++ b/protocols/request-response/src/throttled.rs @@ -18,124 +18,193 @@ // FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +//! Limit the number of requests peers can send to each other. +//! +//! Each peer is assigned a budget for sending and a budget for receiving +//! requests. Initially a peer assumes it has a send budget of 1. When its +//! budget has been used up its remote peer will send a credit message which +//! informs it how many more requests it can send before it needs to wait for +//! the next credit message. Credit messages which error or time out are +//! retried until they have reached the peer which is assumed once a +//! corresponding ack or a new request has been received from the peer. +//! +//! The `Throttled` behaviour wraps an existing `RequestResponse` behaviour +//! and uses a codec implementation that sends ordinary requests and responses +//! as well as a special credit message to which an ack message is expected +//! as a response. It does so by putting a small CBOR encoded header in front +//! of each message the inner codec produces. + +mod codec; + +use codec::{Codec, Message, ProtocolWrapper, Type}; use crate::handler::{RequestProtocol, RequestResponseHandler, RequestResponseHandlerEvent}; +use futures::ready; use libp2p_core::{ConnectedPoint, connection::ConnectionId, Multiaddr, PeerId}; use libp2p_swarm::{NetworkBehaviour, NetworkBehaviourAction, PollParameters}; -use lru::LruCache; use std::{collections::{HashMap, VecDeque}, task::{Context, Poll}}; -use std::{cmp::min, num::NonZeroU16}; +use std::num::NonZeroU16; use super::{ + ProtocolSupport, RequestId, RequestResponse, RequestResponseCodec, + RequestResponseConfig, RequestResponseEvent, + RequestResponseMessage, ResponseChannel }; /// A wrapper around [`RequestResponse`] which adds request limits per peer. -/// -/// Each peer is assigned a default limit of concurrent requests and -/// responses, which can be overriden per peer. -/// -/// It is not possible to send more requests than configured and receiving -/// more is reported as an error event. Since `libp2p-request-response` is -/// not its own protocol, there is no way to communicate limits to peers, -/// hence nodes must have pre-established knowledge about each other's limits. -pub struct Throttled { +pub struct Throttled +where + C: RequestResponseCodec + Send, + C::Protocol: Sync +{ /// A random id used for logging. id: u32, /// The wrapped behaviour. - behaviour: RequestResponse, - /// Limits per peer. - limits: HashMap, - /// After disconnects we keep limits around to prevent circumventing - /// them by successive reconnects. - previous: LruCache, - /// The default limit applied to all peers unless overriden. - default: Limit, + behaviour: RequestResponse>, + /// Information per peer. + peer_info: HashMap, + /// The default limit applies to all peers unless overriden. + default_limit: Limit, + /// Permanent limit overrides per peer. + limit_overrides: HashMap, /// Pending events to report in `Throttled::poll`. - events: VecDeque> + events: VecDeque>>, + /// Current outbound credit grants in flight. + credit_messages: HashMap, + /// The current credit ID. + credit_id: u64 } -/// A `Limit` of inbound and outbound requests. -#[derive(Clone, Debug)] +/// Credit information that is sent to remote peers. +#[derive(Clone, Copy, Debug)] +struct Credit { + /// A credit ID. Used to deduplicate retransmitted credit messages. + id: u64, + /// The ID of the outbound credit grant message. + request: RequestId, + /// The number of requests the remote is allowed to send. + amount: u16 +} + +/// Max. number of inbound requests that can be received. +#[derive(Clone, Copy, Debug)] struct Limit { - /// The remaining number of outbound requests that can be send. - send_budget: u16, - /// The remaining number of inbound requests that can be received. - recv_budget: u16, - /// The original limit which applies to inbound and outbound requests. - maximum: NonZeroU16 + /// The current receive limit. + max_recv: NonZeroU16, + /// The next receive limit which becomes active after + /// the current limit has been reached. + next_max: NonZeroU16 } -impl Default for Limit { - fn default() -> Self { - let maximum = NonZeroU16::new(1).expect("1 > 0"); +impl Limit { + /// Create a new limit. + fn new(max: NonZeroU16) -> Self { + // The max. limit provided will be effective after the initial request + // from a peer which is always allowed has been answered. Values greater + // than 1 would prevent sending the credit grant, leading to a stalling + // sender so we must not use `max` right away. Limit { - send_budget: maximum.get(), - recv_budget: maximum.get(), - maximum + max_recv: NonZeroU16::new(1).expect("1 > 0"), + next_max: max } } + + /// Set a new limit. + /// + /// The new limit becomes effective when all current inbound + /// requests have been processed and replied to. + fn set(&mut self, next: NonZeroU16) { + self.next_max = next + } + + /// Activate the new limit. + fn switch(&mut self) -> u16 { + self.max_recv = self.next_max; + self.max_recv.get() + } } -/// A Wrapper around [`RequestResponseEvent`]. -#[derive(Debug)] -pub enum Event { - /// A regular request-response event. - Event(RequestResponseEvent), - /// We received more inbound requests than allowed. - TooManyInboundRequests(PeerId), - /// When previously reaching the send limit of a peer, - /// this event is eventually emitted when sending is - /// allowed to resume. - ResumeSending(PeerId) +/// Budget information about a peer. +#[derive(Clone, Debug)] +struct PeerInfo { + /// Limit that applies to this peer. + limit: Limit, + /// Remaining number of outbound requests that can be sent. + send_budget: u16, + /// Remaining number of inbound requests that can be received. + recv_budget: u16, + /// The ID of the credit message that granted the current `send_budget`. + send_budget_id: Option } -impl Throttled { +impl PeerInfo { + fn new(limit: Limit) -> Self { + PeerInfo { + limit, + send_budget: 1, + recv_budget: 1, + send_budget_id: None + } + } +} + +impl Throttled +where + C: RequestResponseCodec + Send + Clone, + C::Protocol: Sync +{ + /// Create a new throttled request-response behaviour. + pub fn new(c: C, protos: I, cfg: RequestResponseConfig) -> Self + where + I: IntoIterator, + C: Send, + C::Protocol: Sync + { + let protos = protos.into_iter().map(|(p, ps)| (ProtocolWrapper::new(b"/t/1", p), ps)); + Throttled::from(RequestResponse::new(Codec::new(c, 8192), protos, cfg)) + } + /// Wrap an existing `RequestResponse` behaviour and apply send/recv limits. - pub fn new(behaviour: RequestResponse) -> Self { + pub fn from(behaviour: RequestResponse>) -> Self { Throttled { id: rand::random(), behaviour, - limits: HashMap::new(), - previous: LruCache::new(2048), - default: Limit::default(), - events: VecDeque::new() + peer_info: HashMap::new(), + default_limit: Limit::new(NonZeroU16::new(1).expect("1 > 0")), + limit_overrides: HashMap::new(), + events: VecDeque::new(), + credit_messages: HashMap::new(), + credit_id: 0 } } - /// Get the current default limit applied to all peers. - pub fn default_limit(&self) -> u16 { - self.default.maximum.get() + /// Set the global default receive limit per peer. + pub fn set_receive_limit(&mut self, limit: NonZeroU16) { + log::trace!("{:08x}: new default limit: {:?}", self.id, limit); + self.default_limit = Limit::new(limit) } - /// Override the global default limit. - /// - /// See [`Throttled::set_limit`] to override limits for individual peers. - pub fn set_default_limit(&mut self, limit: NonZeroU16) { - log::trace!("{:08x}: new default limit: {:?}", self.id, limit); - self.default = Limit { - send_budget: limit.get(), - recv_budget: limit.get(), - maximum: limit + /// Override the receive limit of a single peer. + pub fn override_receive_limit(&mut self, p: &PeerId, limit: NonZeroU16) { + log::debug!("{:08x}: override limit for {}: {:?}", self.id, p, limit); + if let Some(info) = self.peer_info.get_mut(p) { + info.limit.set(limit) } + self.limit_overrides.insert(p.clone(), Limit::new(limit)); } - /// Specify the send and receive limit for a single peer. - pub fn set_limit(&mut self, id: &PeerId, limit: NonZeroU16) { - log::trace!("{:08x}: new limit for {}: {:?}", self.id, id, limit); - self.previous.pop(id); - self.limits.insert(id.clone(), Limit { - send_budget: limit.get(), - recv_budget: limit.get(), - maximum: limit - }); + /// Remove any limit overrides for the given peer. + pub fn remove_override(&mut self, p: &PeerId) { + log::trace!("{:08x}: removing limit override for {}", self.id, p); + self.limit_overrides.remove(p); } /// Has the limit of outbound requests been reached for the given peer? - pub fn can_send(&mut self, id: &PeerId) -> bool { - self.limits.get(id).map(|l| l.send_budget > 0).unwrap_or(true) + pub fn can_send(&mut self, p: &PeerId) -> bool { + self.peer_info.get(p).map(|i| i.send_budget > 0).unwrap_or(true) } /// Send a request to a peer. @@ -143,160 +212,323 @@ impl Throttled { /// If the limit of outbound requests has been reached, the request is /// returned. Sending more outbound requests should only be attempted /// once [`Event::ResumeSending`] has been received from [`NetworkBehaviour::poll`]. - pub fn send_request(&mut self, id: &PeerId, req: C::Request) -> Result { - log::trace!("{:08x}: sending request to {}", self.id, id); - - // Getting the limit is somewhat complicated due to the connection state. - // Applications may try to send a request to a peer we have never been connected - // to, or a peer we have previously been connected to. In the first case, the - // default limit applies and in the latter, the cached limit from the previous - // connection (if still available). - let mut limit = - if let Some(limit) = self.limits.get_mut(id) { - limit + pub fn send_request(&mut self, p: &PeerId, req: C::Request) -> Result { + let info = + if let Some(info) = self.peer_info.get_mut(p) { + info } else { - let limit = self.previous.pop(id).unwrap_or_else(|| self.default.clone()); - self.limits.entry(id.clone()).or_insert(limit) + let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); + self.peer_info.entry(p.clone()).or_insert(PeerInfo::new(limit)) }; - if limit.send_budget == 0 { - log::trace!("{:08x}: no budget to send request to {}", self.id, id); + if info.send_budget == 0 { + log::trace!("{:08x}: no more budget to send another request to {}", self.id, p); return Err(req) } - limit.send_budget -= 1; + info.send_budget -= 1; + + let rid = self.behaviour.send_request(p, Message::request(req)); - Ok(self.behaviour.send_request(id, req)) + log::trace! { "{:08x}: sending request {} to {} (send budget = {})", + self.id, + rid, + p, + info.send_budget + 1 + }; + + Ok(rid) } /// Answer an inbound request with a response. /// /// See [`RequestResponse::send_response`] for details. - pub fn send_response(&mut self, ch: ResponseChannel, rs: C::Response) { - if let Some(limit) = self.limits.get_mut(&ch.peer) { - limit.recv_budget += 1; - debug_assert!(limit.recv_budget <= limit.maximum.get()) + pub fn send_response(&mut self, ch: ResponseChannel>, res: C::Response) { + log::trace!("{:08x}: sending response {} to peer {}", self.id, ch.request_id(), &ch.peer); + if let Some(info) = self.peer_info.get_mut(&ch.peer) { + if info.recv_budget == 0 { // need to send more credit to the remote peer + let crd = info.limit.switch(); + info.recv_budget = info.limit.max_recv.get(); + let cid = self.next_credit_id(); + let rid = self.behaviour.send_request(&ch.peer, Message::credit(crd, cid)); + log::trace!("{:08x}: sending {} as credit {} to {}", self.id, crd, cid, ch.peer); + let credit = Credit { id: cid, request: rid, amount: crd }; + self.credit_messages.insert(ch.peer.clone(), credit); + } } - self.behaviour.send_response(ch, rs) + self.behaviour.send_response(ch, Message::response(res)) } /// Add a known peer address. /// /// See [`RequestResponse::add_address`] for details. - pub fn add_address(&mut self, id: &PeerId, ma: Multiaddr) { - self.behaviour.add_address(id, ma) + pub fn add_address(&mut self, p: &PeerId, a: Multiaddr) { + self.behaviour.add_address(p, a) } /// Remove a previously added peer address. /// /// See [`RequestResponse::remove_address`] for details. - pub fn remove_address(&mut self, id: &PeerId, ma: &Multiaddr) { - self.behaviour.remove_address(id, ma) + pub fn remove_address(&mut self, p: &PeerId, a: &Multiaddr) { + self.behaviour.remove_address(p, a) } /// Are we connected to the given peer? /// /// See [`RequestResponse::is_connected`] for details. - pub fn is_connected(&self, id: &PeerId) -> bool { - self.behaviour.is_connected(id) + pub fn is_connected(&self, p: &PeerId) -> bool { + self.behaviour.is_connected(p) } /// Are we waiting for a response to the given request? /// - /// See [`RequestResponse::is_pending`] for details. - pub fn is_pending(&self, id: &RequestId) -> bool { - self.behaviour.is_pending(id) + /// See [`RequestResponse::is_pending_outbound`] for details. + pub fn is_pending_outbound(&self, p: &RequestId) -> bool { + self.behaviour.is_pending_outbound(p) + } + + /// Create a new credit message ID. + fn next_credit_id(&mut self) -> u64 { + let n = self.credit_id; + self.credit_id += 1; + n } } +/// A Wrapper around [`RequestResponseEvent`]. +#[derive(Debug)] +pub enum Event { + /// A regular request-response event. + Event(RequestResponseEvent), + /// We received more inbound requests than allowed. + TooManyInboundRequests(PeerId), + /// When previously reaching the send limit of a peer, + /// this event is eventually emitted when sending is + /// allowed to resume. + ResumeSending(PeerId) +} + impl NetworkBehaviour for Throttled where - C: RequestResponseCodec + Send + Clone + 'static + C: RequestResponseCodec + Send + Clone + 'static, + C::Protocol: Sync { - type ProtocolsHandler = RequestResponseHandler; - type OutEvent = Event; + type ProtocolsHandler = RequestResponseHandler>; + type OutEvent = Event>; fn new_handler(&mut self) -> Self::ProtocolsHandler { self.behaviour.new_handler() } - fn addresses_of_peer(&mut self, peer: &PeerId) -> Vec { - self.behaviour.addresses_of_peer(peer) + fn addresses_of_peer(&mut self, p: &PeerId) -> Vec { + self.behaviour.addresses_of_peer(p) } fn inject_connection_established(&mut self, p: &PeerId, id: &ConnectionId, end: &ConnectedPoint) { self.behaviour.inject_connection_established(p, id, end) } - fn inject_connection_closed(&mut self, p: &PeerId, id: &ConnectionId, end: &ConnectedPoint) { - self.behaviour.inject_connection_closed(p, id, end); + fn inject_connection_closed(&mut self, peer: &PeerId, id: &ConnectionId, end: &ConnectedPoint) { + self.behaviour.inject_connection_closed(peer, id, end); + if self.is_connected(peer) { + if let Some(credit) = self.credit_messages.get_mut(peer) { + log::debug! { "{:08x}: resending credit grant {} to {} after connection closed", + self.id, + credit.id, + peer + }; + let msg = Message::credit(credit.amount, credit.id); + credit.request = self.behaviour.send_request(peer, msg) + } + } } - fn inject_connected(&mut self, peer: &PeerId) { - log::trace!("{:08x}: connected to {}", self.id, peer); - self.behaviour.inject_connected(peer); - // The limit may have been added by [`Throttled::send_request`] already. - if !self.limits.contains_key(peer) { - let limit = self.previous.pop(peer).unwrap_or_else(|| self.default.clone()); - self.limits.insert(peer.clone(), limit); + fn inject_connected(&mut self, p: &PeerId) { + log::trace!("{:08x}: connected to {}", self.id, p); + self.behaviour.inject_connected(p); + // The limit may have been added by `Throttled::send_request` already. + if !self.peer_info.contains_key(p) { + let limit = self.limit_overrides.get(p).copied().unwrap_or(self.default_limit); + self.peer_info.insert(p.clone(), PeerInfo::new(limit)); } } - fn inject_disconnected(&mut self, peer: &PeerId) { - log::trace!("{:08x}: disconnected from {}", self.id, peer); - self.behaviour.inject_disconnected(peer); - // Save the limit in case the peer reconnects soon. - if let Some(limit) = self.limits.remove(peer) { - self.previous.put(peer.clone(), limit); - } + fn inject_disconnected(&mut self, p: &PeerId) { + log::trace!("{:08x}: disconnected from {}", self.id, p); + self.peer_info.remove(p); + self.credit_messages.remove(p); + self.behaviour.inject_disconnected(p) } - fn inject_dial_failure(&mut self, peer: &PeerId) { - self.behaviour.inject_dial_failure(peer) + fn inject_dial_failure(&mut self, p: &PeerId) { + self.behaviour.inject_dial_failure(p) } - fn inject_event(&mut self, p: PeerId, i: ConnectionId, e: RequestResponseHandlerEvent) { - match e { - // Cases where an outbound request has been resolved. - | RequestResponseHandlerEvent::Response {..} - | RequestResponseHandlerEvent::OutboundTimeout (_) - | RequestResponseHandlerEvent::OutboundUnsupportedProtocols (_) => - if let Some(limit) = self.limits.get_mut(&p) { - if limit.send_budget == 0 { - log::trace!("{:08x}: sending to peer {} can resume", self.id, p); - self.events.push_back(Event::ResumeSending(p.clone())) - } - limit.send_budget = min(limit.send_budget + 1, limit.maximum.get()) + fn inject_event(&mut self, p: PeerId, i: ConnectionId, e: RequestResponseHandlerEvent>) { + self.behaviour.inject_event(p, i, e) + } + + fn poll(&mut self, cx: &mut Context<'_>, params: &mut impl PollParameters) + -> Poll>, Self::OutEvent>> + { + loop { + if let Some(ev) = self.events.pop_front() { + return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)) + } else if self.events.capacity() > super::EMPTY_QUEUE_SHRINK_THRESHOLD { + self.events.shrink_to_fit() + } + + let event = match ready!(self.behaviour.poll(cx, params)) { + | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::Message { peer, message }) => { + let message = match message { + | RequestResponseMessage::Response { request_id, response } => + match &response.header().typ { + | Some(Type::Ack) => { + if let Some(id) = self.credit_messages.get(&peer).map(|c| c.id) { + if Some(id) == response.header().ident { + log::trace!("{:08x}: received ack {} from {}", self.id, id, peer); + self.credit_messages.remove(&peer); + } + } + continue + } + | Some(Type::Response) => { + log::trace!("{:08x}: received response {} from {}", self.id, request_id, peer); + if let Some(rs) = response.into_parts().1 { + RequestResponseMessage::Response { request_id, response: rs } + } else { + log::error! { "{:08x}: missing data for response {} from peer {}", + self.id, + request_id, + peer + } + continue + } + } + | ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected response or credit", + self.id, + ty, + peer + }; + continue + } + } + | RequestResponseMessage::Request { request_id, request, channel } => + match &request.header().typ { + | Some(Type::Credit) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + let id = if let Some(n) = request.header().ident { + n + } else { + log::warn! { "{:08x}: missing credit id in message from {}", + self.id, + peer + } + continue + }; + let credit = request.header().credit.unwrap_or(0); + log::trace! { "{:08x}: received {} additional credit {} from {}", + self.id, + credit, + id, + peer + }; + if info.send_budget_id < Some(id) { + if info.send_budget == 0 && credit > 0 { + log::trace!("{:08x}: sending to peer {} can resume", self.id, peer); + self.events.push_back(Event::ResumeSending(peer.clone())) + } + info.send_budget += credit; + info.send_budget_id = Some(id) + } + self.behaviour.send_response(channel, Message::ack(id)) + } + continue + } + | Some(Type::Request) => { + if let Some(info) = self.peer_info.get_mut(&peer) { + log::trace! { "{:08x}: received request {} (recv. budget = {})", + self.id, + request_id, + info.recv_budget + }; + if info.recv_budget == 0 { + log::debug!("{:08x}: peer {} exceeds its budget", self.id, peer); + self.events.push_back(Event::TooManyInboundRequests(peer.clone())); + continue + } + info.recv_budget -= 1; + // We consider a request as proof that our credit grant has + // reached the peer. Usually, an ACK has already been + // received. + self.credit_messages.remove(&peer); + } + if let Some(rq) = request.into_parts().1 { + RequestResponseMessage::Request { request_id, request: rq, channel } + } else { + log::error! { "{:08x}: missing data for request {} from peer {}", + self.id, + request_id, + peer + } + continue + } + } + | ty => { + log::trace! { + "{:08x}: unknown message type: {:?} from {}; expected request or ack", + self.id, + ty, + peer + }; + continue + } + } + }; + let event = RequestResponseEvent::Message { peer, message }; + NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - // A new inbound request. - | RequestResponseHandlerEvent::Request {..} => - if let Some(limit) = self.limits.get_mut(&p) { - if limit.recv_budget == 0 { - log::error!("{:08x}: peer {} exceeds its budget", self.id, p); - return self.events.push_back(Event::TooManyInboundRequests(p)) + | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::OutboundFailure { + peer, + request_id, + error + }) => { + if let Some(credit) = self.credit_messages.get_mut(&peer) { + if credit.request == request_id { + log::debug! { "{:08x}: failed to send {} as credit {} to {}; retrying...", + self.id, + credit.amount, + credit.id, + peer + }; + let msg = Message::credit(credit.amount, credit.id); + credit.request = self.behaviour.send_request(&peer, msg) + } } - limit.recv_budget -= 1 + let event = RequestResponseEvent::OutboundFailure { peer, request_id, error }; + NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - // The inbound request has expired so grant more budget to receive another one. - | RequestResponseHandlerEvent::InboundTimeout => - if let Some(limit) = self.limits.get_mut(&p) { - limit.recv_budget = min(limit.recv_budget + 1, limit.maximum.get()) + | NetworkBehaviourAction::GenerateEvent(RequestResponseEvent::InboundFailure { + peer, + request_id, + error + }) => { + let event = RequestResponseEvent::InboundFailure { peer, request_id, error }; + NetworkBehaviourAction::GenerateEvent(Event::Event(event)) } - // Nothing to do here ... - | RequestResponseHandlerEvent::InboundUnsupportedProtocols => {} - } - self.behaviour.inject_event(p, i, e) - } + | NetworkBehaviourAction::DialAddress { address } => + NetworkBehaviourAction::DialAddress { address }, + | NetworkBehaviourAction::DialPeer { peer_id, condition } => + NetworkBehaviourAction::DialPeer { peer_id, condition }, + | NetworkBehaviourAction::NotifyHandler { peer_id, handler, event } => + NetworkBehaviourAction::NotifyHandler { peer_id, handler, event }, + | NetworkBehaviourAction::ReportObservedAddr { address } => + NetworkBehaviourAction::ReportObservedAddr { address } + }; - fn poll(&mut self, cx: &mut Context<'_>, p: &mut impl PollParameters) - -> Poll, Self::OutEvent>> - { - if let Some(ev) = self.events.pop_front() { - return Poll::Ready(NetworkBehaviourAction::GenerateEvent(ev)) - } else if self.events.capacity() > super::EMPTY_QUEUE_SHRINK_THRESHOLD { - self.events.shrink_to_fit() + return Poll::Ready(event) } - - self.behaviour.poll(cx, p).map(|a| a.map_out(Event::Event)) } } diff --git a/protocols/request-response/src/throttled/codec.rs b/protocols/request-response/src/throttled/codec.rs new file mode 100644 index 00000000000..580fdd3da85 --- /dev/null +++ b/protocols/request-response/src/throttled/codec.rs @@ -0,0 +1,251 @@ +// Copyright 2020 Parity Technologies (UK) Ltd. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal in the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +// DEALINGS IN THE SOFTWARE. + +use async_trait::async_trait; +use bytes::{Bytes, BytesMut}; +use futures::prelude::*; +use libp2p_core::ProtocolName; +use minicbor::{Encode, Decode}; +use std::io; +use super::RequestResponseCodec; +use unsigned_varint::{aio, io::ReadError}; + +/// A protocol header. +#[derive(Debug, Default, Clone, PartialEq, Eq, Encode, Decode)] +#[cbor(map)] +pub struct Header { + /// The type of message. + #[n(0)] pub typ: Option, + /// The number of additional requests the remote is willing to receive. + #[n(1)] pub credit: Option, + /// An identifier used for sending credit grants. + #[n(2)] pub ident: Option +} + +/// A protocol message type. +#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)] +pub enum Type { + #[n(0)] Request, + #[n(1)] Response, + #[n(2)] Credit, + #[n(3)] Ack +} + +/// A protocol message consisting of header and data. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Message { + header: Header, + data: Option +} + +impl Message { + /// Create a new message of some type. + fn new(header: Header) -> Self { + Message { header, data: None } + } + + /// Create a request message. + pub fn request(data: T) -> Self { + let mut m = Message::new(Header { typ: Some(Type::Request), .. Header::default() }); + m.data = Some(data); + m + } + + /// Create a response message. + pub fn response(data: T) -> Self { + let mut m = Message::new(Header { typ: Some(Type::Response), .. Header::default() }); + m.data = Some(data); + m + } + + /// Create a credit grant. + pub fn credit(credit: u16, ident: u64) -> Self { + Message::new(Header { typ: Some(Type::Credit), credit: Some(credit), ident: Some(ident) }) + } + + /// Create an acknowledge message. + pub fn ack(ident: u64) -> Self { + Message::new(Header { typ: Some(Type::Ack), credit: None, ident: Some(ident) }) + } + + /// Access the message header. + pub fn header(&self) -> &Header { + &self.header + } + + /// Access the message data. + pub fn data(&self) -> Option<&T> { + self.data.as_ref() + } + + /// Consume this message and return header and data. + pub fn into_parts(self) -> (Header, Option) { + (self.header, self.data) + } +} + +/// A wrapper around a `ProtocolName` impl which augments the protocol name. +/// +/// The type implements `ProtocolName` itself and creates a name for a +/// request-response protocol based on the protocol name of the wrapped type. +#[derive(Debug, Clone)] +pub struct ProtocolWrapper

(P, Bytes); + +impl ProtocolWrapper

{ + pub fn new(prefix: &[u8], p: P) -> Self { + let mut full = BytesMut::from(prefix); + full.extend_from_slice(p.protocol_name()); + ProtocolWrapper(p, full.freeze()) + } +} + +impl

ProtocolName for ProtocolWrapper

{ + fn protocol_name(&self) -> &[u8] { + self.1.as_ref() + } +} + +/// A `RequestResponseCodec` wrapper that adds headers to the payload data. +#[derive(Debug, Clone)] +pub struct Codec { + /// The wrapped codec. + inner: C, + /// Encoding/decoding buffer. + buffer: Vec, + /// Max. header length. + max_header_len: u32 +} + +impl Codec { + /// Create a codec by wrapping an existing one. + pub fn new(c: C, max_header_len: u32) -> Self { + Codec { inner: c, buffer: Vec::new(), max_header_len } + } + + /// Read and decode a request header. + async fn read_header(&mut self, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send, + H: for<'a> minicbor::Decode<'a> + { + let header_len = aio::read_u32(&mut *io).await + .map_err(|e| match e { + ReadError::Io(e) => e, + other => io::Error::new(io::ErrorKind::Other, other) + })?; + if header_len > self.max_header_len { + return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to read")) + } + self.buffer.resize(u32_to_usize(header_len), 0u8); + io.read_exact(&mut self.buffer).await?; + minicbor::decode(&self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e)) + } + + /// Encode and write a response header. + async fn write_header(&mut self, hdr: &H, io: &mut T) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send, + H: minicbor::Encode + { + self.buffer.clear(); + minicbor::encode(hdr, &mut self.buffer).map_err(|e| io::Error::new(io::ErrorKind::Other, e))?; + if self.buffer.len() > u32_to_usize(self.max_header_len) { + return Err(io::Error::new(io::ErrorKind::InvalidData, "header too large to write")) + } + let mut b = unsigned_varint::encode::u32_buffer(); + let header_len = unsigned_varint::encode::u32(self.buffer.len() as u32, &mut b); + io.write_all(header_len).await?; + io.write_all(&self.buffer).await + } +} + +#[async_trait] +impl RequestResponseCodec for Codec +where + C: RequestResponseCodec + Send, + C::Protocol: Sync +{ + type Protocol = ProtocolWrapper; + type Request = Message; + type Response = Message; + + async fn read_request(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send + { + let mut msg = Message::new(self.read_header(io).await?); + match msg.header.typ { + Some(Type::Request) => { + msg.data = Some(self.inner.read_request(&p.0, io).await?); + Ok(msg) + } + Some(Type::Credit) => Ok(msg), + Some(Type::Response) | Some(Type::Ack) | None => { + log::debug!("unexpected {:?} when expecting request or credit grant", msg.header.typ); + Err(io::ErrorKind::InvalidData.into()) + } + } + } + + async fn read_response(&mut self, p: &Self::Protocol, io: &mut T) -> io::Result + where + T: AsyncRead + Unpin + Send + { + let mut msg = Message::new(self.read_header(io).await?); + match msg.header.typ { + Some(Type::Response) => { + msg.data = Some(self.inner.read_response(&p.0, io).await?); + Ok(msg) + } + Some(Type::Ack) => Ok(msg), + Some(Type::Request) | Some(Type::Credit) | None => { + log::debug!("unexpected {:?} when expecting response or ack", msg.header.typ); + Err(io::ErrorKind::InvalidData.into()) + } + } + } + + async fn write_request(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Request) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send + { + self.write_header(&r.header, io).await?; + if let Some(data) = r.data { + self.inner.write_request(&p.0, io, data).await? + } + Ok(()) + } + + async fn write_response(&mut self, p: &Self::Protocol, io: &mut T, r: Self::Response) -> io::Result<()> + where + T: AsyncWrite + Unpin + Send + { + self.write_header(&r.header, io).await?; + if let Some(data) = r.data { + self.inner.write_response(&p.0, io, data).await? + } + Ok(()) + } +} + +#[cfg(any(target_pointer_width = "64", target_pointer_width = "32"))] +fn u32_to_usize(n: u32) -> usize { + n as usize +} diff --git a/protocols/request-response/tests/ping.rs b/protocols/request-response/tests/ping.rs index 11a8601d5f2..6735fd2e8e3 100644 --- a/protocols/request-response/tests/ping.rs +++ b/protocols/request-response/tests/ping.rs @@ -36,7 +36,7 @@ use libp2p_tcp::TcpConfig; use futures::{prelude::*, channel::mpsc}; use rand::{self, Rng}; use std::{io, iter}; -// use std::{collections::HashSet, num::NonZeroU16}; // Disabled until #1706 is fixed. +use std::{collections::HashSet, num::NonZeroU16}; /// Exercises a simple ping protocol. #[test] @@ -73,7 +73,7 @@ fn ping_protocol() { match swarm1.next().await { RequestResponseEvent::Message { peer, - message: RequestResponseMessage::Request { request, channel } + message: RequestResponseMessage::Request { request, channel, .. } } => { assert_eq!(&request, &expected_ping); assert_eq!(&peer, &peer2_id); @@ -117,202 +117,101 @@ fn ping_protocol() { let () = async_std::task::block_on(peer2); } -// Disabled until #1706 is fixed. -///// Like `ping_protocol`, but throttling concurrent requests. -//#[test] -//fn ping_protocol_throttled() { -// let ping = Ping("ping".to_string().into_bytes()); -// let pong = Pong("pong".to_string().into_bytes()); -// -// let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); -// let cfg = RequestResponseConfig::default(); -// -// let (peer1_id, trans) = mk_transport(); -// let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()).throttled(); -// let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); -// -// let (peer2_id, trans) = mk_transport(); -// let ping_proto2 = RequestResponse::new(PingCodec(), protocols, cfg).throttled(); -// let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); -// -// let (mut tx, mut rx) = mpsc::channel::(1); -// -// let addr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); -// Swarm::listen_on(&mut swarm1, addr).unwrap(); -// -// let expected_ping = ping.clone(); -// let expected_pong = pong.clone(); -// -// let limit: u16 = rand::thread_rng().gen_range(1, 10); -// swarm1.set_default_limit(NonZeroU16::new(limit).unwrap()); -// swarm2.set_default_limit(NonZeroU16::new(limit).unwrap()); -// -// let peer1 = async move { -// while let Some(_) = swarm1.next().now_or_never() {} -// -// let l = Swarm::listeners(&swarm1).next().unwrap(); -// tx.send(l.clone()).await.unwrap(); -// -// loop { -// match swarm1.next().await { -// throttled::Event::Event(RequestResponseEvent::Message { -// peer, -// message: RequestResponseMessage::Request { request, channel } -// }) => { -// assert_eq!(&request, &expected_ping); -// assert_eq!(&peer, &peer2_id); -// swarm1.send_response(channel, pong.clone()); -// }, -// e => panic!("Peer1: Unexpected event: {:?}", e) -// } -// } -// }; -// -// let num_pings: u8 = rand::thread_rng().gen_range(1, 100); -// -// let peer2 = async move { -// let mut count = 0; -// let addr = rx.next().await.unwrap(); -// swarm2.add_address(&peer1_id, addr.clone()); -// let mut blocked = false; -// let mut req_ids = HashSet::new(); -// -// loop { -// if !blocked { -// while let Some(id) = swarm2.send_request(&peer1_id, ping.clone()).ok() { -// req_ids.insert(id); -// } -// blocked = true; -// } -// match swarm2.next().await { -// throttled::Event::ResumeSending(peer) => { -// assert_eq!(peer, peer1_id); -// blocked = false -// } -// throttled::Event::Event(RequestResponseEvent::Message { -// peer, -// message: RequestResponseMessage::Response { request_id, response } -// }) => { -// count += 1; -// assert_eq!(&response, &expected_pong); -// assert_eq!(&peer, &peer1_id); -// assert!(req_ids.remove(&request_id)); -// if count >= num_pings { -// break -// } -// } -// e => panic!("Peer2: Unexpected event: {:?}", e) -// } -// } -// }; -// -// async_std::task::spawn(Box::pin(peer1)); -// let () = async_std::task::block_on(peer2); -//} -// -//#[test] -//fn ping_protocol_limit_violation() { -// let ping = Ping("ping".to_string().into_bytes()); -// let pong = Pong("pong".to_string().into_bytes()); -// -// let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); -// let cfg = RequestResponseConfig::default(); -// -// let (peer1_id, trans) = mk_transport(); -// let ping_proto1 = RequestResponse::new(PingCodec(), protocols.clone(), cfg.clone()).throttled(); -// let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); -// -// let (peer2_id, trans) = mk_transport(); -// let ping_proto2 = RequestResponse::new(PingCodec(), protocols, cfg).throttled(); -// let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); -// -// let (mut tx, mut rx) = mpsc::channel::(1); -// -// let addr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); -// Swarm::listen_on(&mut swarm1, addr).unwrap(); -// -// let expected_ping = ping.clone(); -// let expected_pong = pong.clone(); -// -// swarm2.set_default_limit(NonZeroU16::new(2).unwrap()); -// -// let peer1 = async move { -// while let Some(_) = swarm1.next().now_or_never() {} -// -// let l = Swarm::listeners(&swarm1).next().unwrap(); -// tx.send(l.clone()).await.unwrap(); -// -// let mut pending_responses = Vec::new(); -// -// loop { -// match swarm1.next().await { -// throttled::Event::Event(RequestResponseEvent::Message { -// peer, -// message: RequestResponseMessage::Request { request, channel } -// }) => { -// assert_eq!(&request, &expected_ping); -// assert_eq!(&peer, &peer2_id); -// pending_responses.push((channel, pong.clone())); -// }, -// throttled::Event::TooManyInboundRequests(p) => { -// assert_eq!(p, peer2_id); -// break -// } -// e => panic!("Peer1: Unexpected event: {:?}", e) -// } -// if pending_responses.len() >= 2 { -// for (channel, pong) in pending_responses.drain(..) { -// swarm1.send_response(channel, pong) -// } -// } -// } -// }; -// -// let num_pings: u8 = rand::thread_rng().gen_range(1, 100); -// -// let peer2 = async move { -// let mut count = 0; -// let addr = rx.next().await.unwrap(); -// swarm2.add_address(&peer1_id, addr.clone()); -// let mut blocked = false; -// let mut req_ids = HashSet::new(); -// -// loop { -// if !blocked { -// while let Some(id) = swarm2.send_request(&peer1_id, ping.clone()).ok() { -// req_ids.insert(id); -// } -// blocked = true; -// } -// match swarm2.next().await { -// throttled::Event::ResumeSending(peer) => { -// assert_eq!(peer, peer1_id); -// blocked = false -// } -// throttled::Event::Event(RequestResponseEvent::Message { -// peer, -// message: RequestResponseMessage::Response { request_id, response } -// }) => { -// count += 1; -// assert_eq!(&response, &expected_pong); -// assert_eq!(&peer, &peer1_id); -// assert!(req_ids.remove(&request_id)); -// if count >= num_pings { -// break -// } -// } -// throttled::Event::Event(RequestResponseEvent::OutboundFailure { error, .. }) => { -// assert!(matches!(error, OutboundFailure::ConnectionClosed)); -// break -// } -// e => panic!("Peer2: Unexpected event: {:?}", e) -// } -// } -// }; -// -// async_std::task::spawn(Box::pin(peer1)); -// let () = async_std::task::block_on(peer2); -//} +#[test] +fn ping_protocol_throttled() { + let ping = Ping("ping".to_string().into_bytes()); + let pong = Pong("pong".to_string().into_bytes()); + + let protocols = iter::once((PingProtocol(), ProtocolSupport::Full)); + let cfg = RequestResponseConfig::default(); + + let (peer1_id, trans) = mk_transport(); + let ping_proto1 = RequestResponse::throttled(PingCodec(), protocols.clone(), cfg.clone()); + let mut swarm1 = Swarm::new(trans, ping_proto1, peer1_id.clone()); + + let (peer2_id, trans) = mk_transport(); + let ping_proto2 = RequestResponse::throttled(PingCodec(), protocols, cfg); + let mut swarm2 = Swarm::new(trans, ping_proto2, peer2_id.clone()); + + let (mut tx, mut rx) = mpsc::channel::(1); + + let addr = "/ip4/127.0.0.1/tcp/0".parse().unwrap(); + Swarm::listen_on(&mut swarm1, addr).unwrap(); + + let expected_ping = ping.clone(); + let expected_pong = pong.clone(); + + let limit1: u16 = rand::thread_rng().gen_range(1, 10); + let limit2: u16 = rand::thread_rng().gen_range(1, 10); + swarm1.set_receive_limit(NonZeroU16::new(limit1).unwrap()); + swarm2.set_receive_limit(NonZeroU16::new(limit2).unwrap()); + + let peer1 = async move { + while let Some(_) = swarm1.next().now_or_never() {} + + let l = Swarm::listeners(&swarm1).next().unwrap(); + tx.send(l.clone()).await.unwrap(); + for i in 1.. { + match swarm1.next().await { + throttled::Event::Event(RequestResponseEvent::Message { + peer, + message: RequestResponseMessage::Request { request, channel, .. }, + }) => { + assert_eq!(&request, &expected_ping); + assert_eq!(&peer, &peer2_id); + swarm1.send_response(channel, pong.clone()); + }, + e => panic!("Peer1: Unexpected event: {:?}", e) + } + if i % 31 == 0 { + let lim = rand::thread_rng().gen_range(1, 17); + swarm1.override_receive_limit(&peer2_id, NonZeroU16::new(lim).unwrap()); + } + } + }; + + let num_pings: u16 = rand::thread_rng().gen_range(100, 1000); + + let peer2 = async move { + let mut count = 0; + let addr = rx.next().await.unwrap(); + swarm2.add_address(&peer1_id, addr.clone()); + + let mut blocked = false; + let mut req_ids = HashSet::new(); + + loop { + if !blocked { + while let Some(id) = swarm2.send_request(&peer1_id, ping.clone()).ok() { + req_ids.insert(id); + } + blocked = true; + } + match swarm2.next().await { + throttled::Event::ResumeSending(peer) => { + assert_eq!(peer, peer1_id); + blocked = false + } + throttled::Event::Event(RequestResponseEvent::Message { + peer, + message: RequestResponseMessage::Response { request_id, response } + }) => { + count += 1; + assert_eq!(&response, &expected_pong); + assert_eq!(&peer, &peer1_id); + assert!(req_ids.remove(&request_id)); + if count >= num_pings { + break + } + } + e => panic!("Peer2: Unexpected event: {:?}", e) + } + } + }; + + async_std::task::spawn(Box::pin(peer1)); + let () = async_std::task::block_on(peer2); +} fn mk_transport() -> (PeerId, Boxed<(PeerId, StreamMuxerBox), io::Error>) { let id_keys = identity::Keypair::generate_ed25519();