From 5c79bf25eab08b762a16dbbf14011d350e6e4a69 Mon Sep 17 00:00:00 2001 From: Stanimal Date: Sun, 27 Jun 2021 12:36:26 +0400 Subject: [PATCH] Remove type alias impl trait unstable feature requirement from DHT Replaces the `type_alias_impl_trait` feature requirement with boxed Futures on all services. Removes/simplifies some redundant trait bounds. SAF handler task had to be implemented to be a less concurrent because of the additional trait bounds required for boxing. --- base_layer/p2p/src/initialization.rs | 25 ++-- .../service_framework/src/context/handles.rs | 4 +- comms/Cargo.toml | 2 +- comms/dht/examples/memory_net/utilities.rs | 29 ++-- comms/dht/src/dedup.rs | 13 +- comms/dht/src/dht.rs | 10 +- comms/dht/src/inbound/decryption.rs | 46 ++++--- comms/dht/src/inbound/deserialize.rs | 16 ++- .../dht/src/inbound/dht_handler/middleware.rs | 27 ++-- comms/dht/src/inbound/metrics.rs | 22 ++- comms/dht/src/inbound/validate.rs | 55 ++++---- comms/dht/src/lib.rs | 4 - comms/dht/src/logging_middleware.rs | 21 ++- comms/dht/src/outbound/broadcast.rs | 43 +++--- comms/dht/src/outbound/serialize.rs | 126 +++++++++--------- comms/dht/src/store_forward/forward.rs | 13 +- .../store_forward/saf_handler/middleware.rs | 33 ++--- .../dht/src/store_forward/saf_handler/task.rs | 124 +++++++++-------- comms/dht/src/store_forward/store.rs | 43 +++--- comms/dht/src/test_utils/mod.rs | 7 + comms/dht/tests/dht.rs | 28 ++-- comms/src/pipeline/inbound.rs | 2 +- comms/src/pipeline/outbound.rs | 2 +- comms/src/protocol/extensions.rs | 2 +- comms/src/protocol/messaging/extension.rs | 14 +- 25 files changed, 364 insertions(+), 347 deletions(-) diff --git a/base_layer/p2p/src/initialization.rs b/base_layer/p2p/src/initialization.rs index 9949e7f6e7..38a4c1bcf7 100644 --- a/base_layer/p2p/src/initialization.rs +++ b/base_layer/p2p/src/initialization.rs @@ -339,15 +339,19 @@ where .with_peer_storage(peer_database, Some(file_lock)) .build()?; + let peer_manager = comms.peer_manager(); + let connectivity = comms.connectivity(); + let node_identity = comms.node_identity(); + let shutdown_signal = comms.shutdown_signal(); // Create outbound channel let (outbound_tx, outbound_rx) = mpsc::channel(config.outbound_buffer_size); let dht = DhtBuilder::new( - comms.node_identity(), - comms.peer_manager(), + node_identity.clone(), + peer_manager, outbound_tx, - comms.connectivity(), - comms.shutdown_signal(), + connectivity, + shutdown_signal, ) .with_config(config.dht.clone()) .build() @@ -356,10 +360,7 @@ where let dht_outbound_layer = dht.outbound_middleware_layer(); // DHT RPC service is only available for communication nodes - if comms - .node_identity() - .has_peer_features(PeerFeatures::COMMUNICATION_NODE) - { + if node_identity.has_peer_features(PeerFeatures::COMMUNICATION_NODE) { comms = comms.add_rpc_server(RpcServer::new().add_service(dht.rpc_service())); } @@ -542,7 +543,9 @@ impl ServiceInitializer for P2pInitializer { let (comms, dht) = configure_comms_and_dht(builder, &config, connector).await?; let peers = Self::try_parse_seed_peers(&config.peer_seeds)?; - add_all_peers(&comms.peer_manager(), &comms.node_identity(), peers).await?; + let peer_manager = comms.peer_manager(); + let node_identity = comms.node_identity(); + add_all_peers(&peer_manager, &node_identity, peers).await?; let peers = Self::try_resolve_dns_seeds( config.dns_seeds_name_server, @@ -550,10 +553,10 @@ impl ServiceInitializer for P2pInitializer { config.dns_seeds_use_dnssec, ) .await?; - add_all_peers(&comms.peer_manager(), &comms.node_identity(), peers).await?; + add_all_peers(&peer_manager, &node_identity, peers).await?; context.register_handle(comms.connectivity()); - context.register_handle(comms.peer_manager()); + context.register_handle(peer_manager); context.register_handle(comms); context.register_handle(dht); diff --git a/base_layer/service_framework/src/context/handles.rs b/base_layer/service_framework/src/context/handles.rs index bd3cd6b30e..9661d8bd13 100644 --- a/base_layer/service_framework/src/context/handles.rs +++ b/base_layer/service_framework/src/context/handles.rs @@ -67,7 +67,7 @@ impl ServiceInitializerContext { /// Insert a service handle with the given name pub fn register_handle(&self, handle: H) - where H: Any + Send + Sync { + where H: Any + Send { self.inner.register(handle); } @@ -160,7 +160,7 @@ impl ServiceHandles { /// Register a handle pub fn register(&self, handle: H) - where H: Any + Send + Sync { + where H: Any + Send { acquire_lock!(self.handles).insert(TypeId::of::(), Box::new(handle)); } diff --git a/comms/Cargo.toml b/comms/Cargo.toml index 7bcb0fd5b2..a265c3e505 100644 --- a/comms/Cargo.toml +++ b/comms/Cargo.toml @@ -54,7 +54,7 @@ tokio-macros = "0.2.3" tempfile = "3.1.0" [build-dependencies] -tari_common = { version = "^0.8", path="../common"} +tari_common = { version = "^0.8", path="../common", features = ["build"]} [features] avx2 = ["tari_crypto/avx2"] diff --git a/comms/dht/examples/memory_net/utilities.rs b/comms/dht/examples/memory_net/utilities.rs index e098c5e05c..67ba28af6d 100644 --- a/comms/dht/examples/memory_net/utilities.rs +++ b/comms/dht/examples/memory_net/utilities.rs @@ -917,26 +917,23 @@ async fn setup_comms_dht( .unwrap(); let dht_outbound_layer = dht.outbound_middleware_layer(); + let pipeline = pipeline::Builder::new() + .outbound_buffer_size(10) + .with_outbound_pipeline(outbound_rx, |sink| { + ServiceBuilder::new().layer(dht_outbound_layer).service(sink) + }) + .max_concurrent_inbound_tasks(10) + .with_inbound_pipeline( + ServiceBuilder::new() + .layer(dht.inbound_middleware_layer()) + .service(SinkService::new(inbound_tx)), + ) + .build(); let (messaging_events_tx, _) = broadcast::channel(100); - let comms = comms .add_rpc_server(RpcServer::new().add_service(dht.rpc_service())) - .add_protocol_extension(MessagingProtocolExtension::new( - messaging_events_tx.clone(), - pipeline::Builder::new() - .outbound_buffer_size(10) - .with_outbound_pipeline(outbound_rx, |sink| { - ServiceBuilder::new().layer(dht_outbound_layer).service(sink) - }) - .max_concurrent_inbound_tasks(10) - .with_inbound_pipeline( - ServiceBuilder::new() - .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(inbound_tx)), - ) - .build(), - )) + .add_protocol_extension(MessagingProtocolExtension::new(messaging_events_tx.clone(), pipeline)) .spawn_with_transport(MemoryTransport) .await .unwrap(); diff --git a/comms/dht/src/dedup.rs b/comms/dht/src/dedup.rs index 19fd678585..278ae4a485 100644 --- a/comms/dht/src/dedup.rs +++ b/comms/dht/src/dedup.rs @@ -22,7 +22,7 @@ use crate::{actor::DhtRequester, inbound::DhtInboundMessage}; use digest::Input; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use log::*; use std::task::Poll; use tari_comms::{pipeline::PipelineError, types::Challenge}; @@ -55,13 +55,14 @@ impl DedupMiddleware { } impl Service for DedupMiddleware -where S: Service + Clone +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -69,7 +70,7 @@ where S: Service + Clon fn call(&mut self, message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let mut dht_requester = self.dht_requester.clone(); - async move { + Box::pin(async move { let hash = hash_inbound_message(&message); trace!( target: LOG_TARGET, @@ -96,7 +97,7 @@ where S: Service + Clon message.dht_header.message_tag ); next_service.oneshot(message).await - } + }) } } diff --git a/comms/dht/src/dht.rs b/comms/dht/src/dht.rs index 1bc7d5d53c..fcd746b7b3 100644 --- a/comms/dht/src/dht.rs +++ b/comms/dht/src/dht.rs @@ -278,9 +278,8 @@ impl Dht { InboundMessage, Response = (), Error = PipelineError, - Future = impl Future> + Send, - > + Clone - + Send, + Future = impl Future>, + > + Clone, > where S: Service + Clone + Send + Sync + 'static, @@ -341,9 +340,8 @@ impl Dht { DhtOutboundRequest, Response = (), Error = PipelineError, - Future = impl Future> + Send, - > + Clone - + Send, + Future = impl Future>, + > + Clone, > where S: Service + Clone + Send + 'static, diff --git a/comms/dht/src/inbound/decryption.rs b/comms/dht/src/inbound/decryption.rs index e605227480..9e3a6bbd3e 100644 --- a/comms/dht/src/inbound/decryption.rs +++ b/comms/dht/src/inbound/decryption.rs @@ -27,7 +27,7 @@ use crate::{ proto::envelope::OriginMac, DhtConfig, }; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use log::*; use prost::Message; use std::{sync::Arc, task::Poll, time::Duration}; @@ -123,25 +123,26 @@ impl DecryptionService { } impl Service for DecryptionService -where S: Service + Clone +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, msg: DhtInboundMessage) -> Self::Future { - Self::handle_message( + Box::pin(Self::handle_message( self.inner.clone(), Arc::clone(&self.node_identity), self.connectivity.clone(), self.config.ban_duration, msg, - ) + )) } } @@ -416,10 +417,13 @@ mod test { #[test] fn decrypt_inbound_success() { - let result = Mutex::new(None); - let service = service_fn(|msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) + let result = Arc::new(Mutex::new(None)); + let service = service_fn({ + let result = result.clone(); + move |msg: DecryptedDhtMessage| { + *result.lock().unwrap() = Some(msg); + future::ready(Result::<(), PipelineError>::Ok(())) + } }); let node_identity = make_node_identity(); let (connectivity, _) = create_connectivity_mock(); @@ -441,10 +445,13 @@ mod test { #[test] fn decrypt_inbound_fail() { - let result = Mutex::new(None); - let service = service_fn(|msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) + let result = Arc::new(Mutex::new(None)); + let service = service_fn({ + let result = result.clone(); + move |msg: DecryptedDhtMessage| { + *result.lock().unwrap() = Some(msg); + future::ready(Result::<(), PipelineError>::Ok(())) + } }); let node_identity = make_node_identity(); let (connectivity, _) = create_connectivity_mock(); @@ -466,10 +473,13 @@ mod test { async fn decrypt_inbound_fail_destination() { let (connectivity, mock) = create_connectivity_mock(); mock.spawn(); - let result = Mutex::new(None); - let service = service_fn(|msg: DecryptedDhtMessage| { - *result.lock().unwrap() = Some(msg); - future::ready(Result::<(), PipelineError>::Ok(())) + let result = Arc::new(Mutex::new(None)); + let service = service_fn({ + let result = result.clone(); + move |msg: DecryptedDhtMessage| { + *result.lock().unwrap() = Some(msg); + future::ready(Result::<(), PipelineError>::Ok(())) + } }); let node_identity = make_node_identity(); let mut service = DecryptionService::new(Default::default(), node_identity.clone(), connectivity, service); diff --git a/comms/dht/src/inbound/deserialize.rs b/comms/dht/src/inbound/deserialize.rs index 72cf21ad22..b28a057cb5 100644 --- a/comms/dht/src/inbound/deserialize.rs +++ b/comms/dht/src/inbound/deserialize.rs @@ -21,7 +21,7 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{inbound::DhtInboundMessage, proto::envelope::DhtEnvelope}; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc, task::Poll}; @@ -51,13 +51,14 @@ impl DhtDeserializeMiddleware { } impl Service for DhtDeserializeMiddleware -where S: Service + Clone + 'static +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -65,7 +66,7 @@ where S: Service + Clon fn call(&mut self, message: InboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let peer_manager = self.peer_manager.clone(); - async move { + Box::pin(async move { trace!(target: LOG_TARGET, "Deserializing InboundMessage {}", message.tag); let InboundMessage { @@ -92,6 +93,7 @@ where S: Service + Clon inbound_msg.dht_header.message_tag ); + let next_service = next_service.ready_oneshot().await?; next_service.oneshot(inbound_msg).await }, Err(err) => { @@ -99,7 +101,7 @@ where S: Service + Clon Err(err.into()) }, } - } + }) } } @@ -127,6 +129,7 @@ mod test { use crate::{ envelope::DhtMessageFlags, test_utils::{ + assert_send_static_service, build_peer_manager, make_comms_inbound_message, make_dht_envelope, @@ -144,6 +147,7 @@ mod test { peer_manager.add_peer(node_identity.to_peer()).await.unwrap(); let mut deserialize = DeserializeLayer::new(peer_manager).layer(spy.to_service::()); + assert_send_static_service(&deserialize); let dht_envelope = make_dht_envelope( &node_identity, diff --git a/comms/dht/src/inbound/dht_handler/middleware.rs b/comms/dht/src/inbound/dht_handler/middleware.rs index 8288e36050..6accf65513 100644 --- a/comms/dht/src/inbound/dht_handler/middleware.rs +++ b/comms/dht/src/inbound/dht_handler/middleware.rs @@ -22,7 +22,7 @@ use super::task::ProcessDhtMessage; use crate::{discovery::DhtDiscoveryRequester, inbound::DecryptedDhtMessage, outbound::OutboundMessageRequester}; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, @@ -60,26 +60,29 @@ impl DhtHandlerMiddleware { } impl Service for DhtHandlerMiddleware -where S: Service + Clone +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.next_service.poll_ready(cx) } fn call(&mut self, message: DecryptedDhtMessage) -> Self::Future { - ProcessDhtMessage::new( - self.next_service.clone(), - Arc::clone(&self.peer_manager), - self.outbound_service.clone(), - Arc::clone(&self.node_identity), - self.discovery_requester.clone(), - message, + Box::pin( + ProcessDhtMessage::new( + self.next_service.clone(), + Arc::clone(&self.peer_manager), + self.outbound_service.clone(), + Arc::clone(&self.node_identity), + self.discovery_requester.clone(), + message, + ) + .run(), ) - .run() } } diff --git a/comms/dht/src/inbound/metrics.rs b/comms/dht/src/inbound/metrics.rs index 88f85ca336..d028708f73 100644 --- a/comms/dht/src/inbound/metrics.rs +++ b/comms/dht/src/inbound/metrics.rs @@ -21,11 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::connectivity::MetricsCollectorHandle; -use futures::{task::Context, Future}; +use futures::task::Context; use log::*; use std::task::Poll; -use tari_comms::{message::InboundMessage, pipeline::PipelineError}; -use tower::{layer::Layer, Service, ServiceExt}; +use tari_comms::message::InboundMessage; +use tower::{layer::Layer, Service}; const LOG_TARGET: &str = "comms::dht::metrics"; @@ -45,19 +45,17 @@ impl Metrics { } impl Service for Metrics -where S: Service + Clone + 'static +where S: Service + Clone + 'static { - type Error = PipelineError; - type Response = (); + type Error = S::Error; + type Future = S::Future; + type Response = S::Response; - type Future = impl Future>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next_service.poll_ready(cx) } fn call(&mut self, message: InboundMessage) -> Self::Future { - let next_service = self.next_service.clone(); if !self .metric_collector .write_metric_message_received(message.source_peer.clone()) @@ -65,7 +63,7 @@ where S: Service + Clone + debug!(target: LOG_TARGET, "Unable to write metric"); } - next_service.oneshot(message) + self.next_service.call(message) } } diff --git a/comms/dht/src/inbound/validate.rs b/comms/dht/src/inbound/validate.rs index 112d8a3179..39a1a42c97 100644 --- a/comms/dht/src/inbound/validate.rs +++ b/comms/dht/src/inbound/validate.rs @@ -21,11 +21,11 @@ // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. use crate::{inbound::DhtInboundMessage, proto::envelope::Network}; -use futures::{task::Context, Future}; +use futures::{future, future::Either, task::Context}; use log::*; use std::task::Poll; use tari_comms::pipeline::PipelineError; -use tower::{layer::Layer, Service, ServiceExt}; +use tower::{layer::Layer, util::Oneshot, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::validate"; @@ -49,41 +49,37 @@ impl ValidateMiddleware { } impl Service for ValidateMiddleware -where S: Service + Clone + 'static +where S: Service + Clone + Send + 'static { type Error = PipelineError; + type Future = Either, future::Ready>>; type Response = (); - type Future = impl Future>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.next_service.poll_ready(cx) } fn call(&mut self, message: DhtInboundMessage) -> Self::Future { let next_service = self.next_service.clone(); let target_network = self.target_network; - async move { - if message.dht_header.network == target_network && message.dht_header.is_valid() { - trace!( - target: LOG_TARGET, - "Passing message {} to next service (Trace: {})", - message.tag, - message.dht_header.message_tag - ); - next_service.oneshot(message).await?; - } else { - debug!( - target: LOG_TARGET, - "Message is for another network (want = {:?} got = {:?}) or message header is invalid. Discarding \ - the message (Trace: {}).", - target_network, - message.dht_header.network, - message.dht_header.message_tag - ); - } - - Ok(()) + if message.dht_header.network == target_network && message.dht_header.is_valid() { + trace!( + target: LOG_TARGET, + "Passing message {} to next service (Trace: {})", + message.tag, + message.dht_header.message_tag + ); + Either::Left(next_service.oneshot(message)) + } else { + debug!( + target: LOG_TARGET, + "Message is for another network (want = {:?} got = {:?}) or message header is invalid. Discarding the \ + message (Trace: {}).", + target_network, + message.dht_header.network, + message.dht_header.message_tag + ); + Either::Right(future::ready(Ok(()))) } } } @@ -111,7 +107,7 @@ mod test { use super::*; use crate::{ envelope::DhtMessageFlags, - test_utils::{make_dht_inbound_message, make_node_identity, service_spy}, + test_utils::{assert_send_static_service, make_dht_inbound_message, make_node_identity, service_spy}, }; use tari_test_utils::panic_context; use tokio::runtime::Runtime; @@ -122,6 +118,7 @@ mod test { let spy = service_spy(); let mut validate = ValidateLayer::new(Network::LocalTest).layer(spy.to_service::()); + assert_send_static_service(&validate); panic_context!(cx); diff --git a/comms/dht/src/lib.rs b/comms/dht/src/lib.rs index 7044785833..cab2f8ab6f 100644 --- a/comms/dht/src/lib.rs +++ b/comms/dht/src/lib.rs @@ -111,10 +111,6 @@ //! ``` #![recursion_limit = "256"] -// Details: https://doc.rust-lang.org/beta/unstable-book/language-features/type-alias-impl-trait.html -#![allow(incomplete_features)] -#![feature(type_alias_impl_trait)] -#![feature(min_type_alias_impl_trait)] #[macro_use] extern crate diesel; #[macro_use] diff --git a/comms/dht/src/logging_middleware.rs b/comms/dht/src/logging_middleware.rs index 7fa2b31d0e..edcfc73980 100644 --- a/comms/dht/src/logging_middleware.rs +++ b/comms/dht/src/logging_middleware.rs @@ -20,11 +20,10 @@ // WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE // USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -use futures::{task::Context, Future, TryFutureExt}; +use futures::task::Context; use log::*; use std::{borrow::Cow, fmt::Display, marker::PhantomData, task::Poll}; -use tari_comms::pipeline::PipelineError; -use tower::{layer::Layer, Service, ServiceExt}; +use tower::{layer::Layer, Service}; const LOG_TARGET: &str = "comms::middleware::message_logging"; @@ -46,7 +45,6 @@ impl<'a, R> MessageLoggingLayer<'a, R> { impl<'a, S, R> Layer for MessageLoggingLayer<'a, R> where S: Service, - S::Error: Into + Send + Sync + 'static, R: Display, { type Service = MessageLoggingService<'a, S>; @@ -73,22 +71,19 @@ impl<'a, S> MessageLoggingService<'a, S> { impl Service for MessageLoggingService<'_, S> where - S: Service + Clone, - S::Error: Into + Send + Sync + 'static, + S: Service, R: Display, { - type Error = PipelineError; + type Error = S::Error; + type Future = S::Future; type Response = S::Response; - type Future = impl Future>; - - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) } fn call(&mut self, msg: R) -> Self::Future { trace!(target: LOG_TARGET, "{}{}", self.prefix_msg, msg); - let mut inner = self.inner.clone(); - async move { inner.ready_and().and_then(|s| s.call(msg)).await.map_err(Into::into) } + self.inner.call(msg) } } diff --git a/comms/dht/src/outbound/broadcast.rs b/comms/dht/src/outbound/broadcast.rs index 7125e12227..3df9fb2919 100644 --- a/comms/dht/src/outbound/broadcast.rs +++ b/comms/dht/src/outbound/broadcast.rs @@ -41,9 +41,9 @@ use digest::Digest; use futures::{ channel::oneshot, future, + future::BoxFuture, stream::{self, StreamExt}, task::Context, - Future, }; use log::*; use rand::rngs::OsRng; @@ -109,7 +109,7 @@ impl Layer for BroadcastLayer { /// the worker task. #[derive(Clone)] pub struct BroadcastMiddleware { - next: S, + next_service: S, dht_requester: DhtRequester, dht_discovery_requester: DhtDiscoveryRequester, node_identity: Arc, @@ -127,7 +127,7 @@ impl BroadcastMiddleware { message_validity_window: chrono::Duration, ) -> Self { Self { - next: service, + next_service: service, dht_requester, dht_discovery_requester, node_identity, @@ -138,28 +138,31 @@ impl BroadcastMiddleware { } impl Service for BroadcastMiddleware -where S: Service + Clone +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, msg: DhtOutboundRequest) -> Self::Future { - BroadcastTask::new( - self.next.clone(), - Arc::clone(&self.node_identity), - self.dht_requester.clone(), - self.dht_discovery_requester.clone(), - self.target_network, - msg, - self.message_validity_window, + Box::pin( + BroadcastTask::new( + self.next_service.clone(), + Arc::clone(&self.node_identity), + self.dht_requester.clone(), + self.dht_discovery_requester.clone(), + self.target_network, + msg, + self.message_validity_window, + ) + .handle(), ) - .handle() } } @@ -523,7 +526,14 @@ mod test { use super::*; use crate::{ outbound::SendMessageParams, - test_utils::{create_dht_actor_mock, create_dht_discovery_mock, make_peer, service_spy, DhtDiscoveryMockState}, + test_utils::{ + assert_send_static_service, + create_dht_actor_mock, + create_dht_discovery_mock, + make_peer, + service_spy, + DhtDiscoveryMockState, + }, }; use futures::channel::oneshot; use rand::rngs::OsRng; @@ -585,6 +595,7 @@ mod test { Network::LocalTest, chrono::Duration::seconds(10800), ); + assert_send_static_service(&service); let (reply_tx, _reply_rx) = oneshot::channel(); service diff --git a/comms/dht/src/outbound/serialize.rs b/comms/dht/src/outbound/serialize.rs index 7de8dd5487..9734556f82 100644 --- a/comms/dht/src/outbound/serialize.rs +++ b/comms/dht/src/outbound/serialize.rs @@ -25,7 +25,7 @@ use crate::{ outbound::message::DhtOutboundMessage, proto::envelope::{DhtEnvelope, DhtHeader}, }; -use futures::{task::Context, Future}; +use futures::task::Context; use log::*; use std::task::Poll; use tari_comms::{ @@ -34,7 +34,7 @@ use tari_comms::{ Bytes, }; use tari_utilities::ByteArray; -use tower::{layer::Layer, Service, ServiceExt}; +use tower::{layer::Layer, util::Oneshot, Service, ServiceExt}; const LOG_TARGET: &str = "comms::dht::serialize"; @@ -50,71 +50,68 @@ impl SerializeMiddleware { } impl Service for SerializeMiddleware -where S: Service + Clone + 'static +where + S: Service + Clone + Send, + S::Future: Send, { type Error = PipelineError; + type Future = Oneshot; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, message: DhtOutboundMessage) -> Self::Future { let next_service = self.inner.clone(); - async move { - let DhtOutboundMessage { - tag, - destination_node_id, - custom_header, - body, - ephemeral_public_key, - destination, - dht_message_type, - network, - dht_flags, - origin_mac, - reply, - expires, - .. - } = message; - trace!( - target: LOG_TARGET, - "Serializing outbound message {:?} for peer `{}`", - message.tag, - destination_node_id.short_str() - ); - let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { - version: DHT_ENVELOPE_HEADER_VERSION, - origin_mac: origin_mac.map(|b| b.to_vec()).unwrap_or_else(Vec::new), - ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), - message_type: dht_message_type as i32, - network: network as i32, - flags: dht_flags.bits(), - destination: Some(destination.into()), - message_tag: tag.as_value(), - expires, - }); - let envelope = DhtEnvelope::new(dht_header, body); - - let body = Bytes::from(envelope.to_encoded_bytes()); - - trace!( - target: LOG_TARGET, - "Serialized outbound message {} for peer `{}`. Passing onto next service", - tag, - destination_node_id.short_str() - ); - next_service - .oneshot(OutboundMessage { - tag, - peer_node_id: destination_node_id, - reply, - body, - }) - .await - } + let DhtOutboundMessage { + tag, + destination_node_id, + custom_header, + body, + ephemeral_public_key, + destination, + dht_message_type, + network, + dht_flags, + origin_mac, + reply, + expires, + .. + } = message; + trace!( + target: LOG_TARGET, + "Serializing outbound message {:?} for peer `{}`", + message.tag, + destination_node_id.short_str() + ); + let dht_header = custom_header.map(DhtHeader::from).unwrap_or_else(|| DhtHeader { + version: DHT_ENVELOPE_HEADER_VERSION, + origin_mac: origin_mac.map(|b| b.to_vec()).unwrap_or_else(Vec::new), + ephemeral_public_key: ephemeral_public_key.map(|e| e.to_vec()).unwrap_or_else(Vec::new), + message_type: dht_message_type as i32, + network: network as i32, + flags: dht_flags.bits(), + destination: Some(destination.into()), + message_tag: tag.as_value(), + expires, + }); + let envelope = DhtEnvelope::new(dht_header, body); + + let body = Bytes::from(envelope.to_encoded_bytes()); + + trace!( + target: LOG_TARGET, + "Serialized outbound message {} for peer `{}`. Passing onto next service", + tag, + destination_node_id.short_str() + ); + next_service.oneshot(OutboundMessage { + tag, + peer_node_id: destination_node_id, + reply, + body, + }) } } @@ -138,24 +135,21 @@ impl Layer for SerializeLayer { #[cfg(test)] mod test { use super::*; - use crate::test_utils::{create_outbound_message, service_spy}; - use futures::executor::block_on; + use crate::test_utils::{assert_send_static_service, create_outbound_message, service_spy}; use prost::Message; use tari_comms::peer_manager::NodeId; - use tari_test_utils::panic_context; - #[test] - fn serialize() { + #[tokio_macros::test_basic] + async fn serialize() { let spy = service_spy(); let mut serialize = SerializeLayer.layer(spy.to_service::()); - panic_context!(cx); - - assert!(serialize.poll_ready(&mut cx).is_ready()); let body = b"A"; let msg = create_outbound_message(body); - block_on(serialize.call(msg)).unwrap(); + assert_send_static_service(&serialize); + let service = serialize.ready_and().await.unwrap(); + service.call(msg).await.unwrap(); let mut msg = spy.pop_request().unwrap(); let dht_envelope = DhtEnvelope::decode(&mut msg.body).unwrap(); assert_eq!(dht_envelope.body, b"A".to_vec()); diff --git a/comms/dht/src/store_forward/forward.rs b/comms/dht/src/store_forward/forward.rs index f3f52b7801..607dfe0fd1 100644 --- a/comms/dht/src/store_forward/forward.rs +++ b/comms/dht/src/store_forward/forward.rs @@ -26,7 +26,7 @@ use crate::{ outbound::{OutboundMessageRequester, SendMessageParams}, store_forward::error::StoreAndForwardError, }; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use log::*; use std::task::Poll; use tari_comms::{peer_manager::Peer, pipeline::PipelineError}; @@ -84,13 +84,14 @@ impl ForwardMiddleware { } impl Service for ForwardMiddleware -where S: Service + Clone + 'static +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } @@ -99,7 +100,7 @@ where S: Service + Cl let next_service = self.next_service.clone(); let outbound_service = self.outbound_service.clone(); let is_enabled = self.is_enabled; - async move { + Box::pin(async move { if !is_enabled { trace!( target: LOG_TARGET, @@ -118,7 +119,7 @@ where S: Service + Cl ); let forwarder = Forwarder::new(next_service, outbound_service); forwarder.handle(message).await - } + }) } } diff --git a/comms/dht/src/store_forward/saf_handler/middleware.rs b/comms/dht/src/store_forward/saf_handler/middleware.rs index 0d46855ff8..578fc1dcbc 100644 --- a/comms/dht/src/store_forward/saf_handler/middleware.rs +++ b/comms/dht/src/store_forward/saf_handler/middleware.rs @@ -28,7 +28,7 @@ use crate::{ outbound::OutboundMessageRequester, store_forward::StoreAndForwardRequester, }; -use futures::{channel::mpsc, task::Context, Future}; +use futures::{channel::mpsc, future::BoxFuture, task::Context}; use std::{sync::Arc, task::Poll}; use tari_comms::{ peer_manager::{NodeIdentity, PeerManager}, @@ -75,29 +75,32 @@ impl MessageHandlerMiddleware { } impl Service for MessageHandlerMiddleware -where S: Service + Clone + Sync + Send +where + S: Service + Clone + Send + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.next_service.poll_ready(cx) } fn call(&mut self, message: DecryptedDhtMessage) -> Self::Future { - MessageHandlerTask::new( - self.config.clone(), - self.next_service.clone(), - self.saf_requester.clone(), - self.dht_requester.clone(), - Arc::clone(&self.peer_manager), - self.outbound_service.clone(), - Arc::clone(&self.node_identity), - message, - self.saf_response_signal_sender.clone(), + Box::pin( + MessageHandlerTask::new( + self.config.clone(), + self.next_service.clone(), + self.saf_requester.clone(), + self.dht_requester.clone(), + Arc::clone(&self.peer_manager), + self.outbound_service.clone(), + Arc::clone(&self.node_identity), + message, + self.saf_response_signal_sender.clone(), + ) + .run(), ) - .run() } } diff --git a/comms/dht/src/store_forward/saf_handler/task.rs b/comms/dht/src/store_forward/saf_handler/task.rs index 910eada97f..ecab6b9862 100644 --- a/comms/dht/src/store_forward/saf_handler/task.rs +++ b/comms/dht/src/store_forward/saf_handler/task.rs @@ -39,7 +39,7 @@ use crate::{ store_forward::{error::StoreAndForwardError, service::FetchStoredMessageQuery, StoreAndForwardRequester}, }; use digest::Digest; -use futures::{channel::mpsc, future, stream, Future, SinkExt, StreamExt}; +use futures::{channel::mpsc, future, stream, SinkExt, StreamExt}; use log::*; use prost::Message; use std::{convert::TryInto, sync::Arc}; @@ -267,14 +267,15 @@ where S: Service message_tag ); - let tasks = response - .messages - .into_iter() - // Map to futures which process the stored message - .map(|msg| self.process_incoming_stored_message(Arc::clone(&source_peer), msg)); + let mut results = Vec::with_capacity(response.messages.len()); + for msg in response.messages { + let result = self + .process_incoming_stored_message(Arc::clone(&source_peer), msg) + .await; + results.push(result); + } - let successful_msgs_iter = future::join_all(tasks) - .await + let successful_msgs_iter = results .into_iter() .map(|result| { match &result { @@ -352,71 +353,68 @@ where S: Service Ok(()) } - fn process_incoming_stored_message( - &self, + async fn process_incoming_stored_message( + &mut self, source_peer: Arc, message: ProtoStoredMessage, - ) -> impl Future> { - let node_identity = Arc::clone(&self.node_identity); - let peer_manager = Arc::clone(&self.peer_manager); - let config = self.config.clone(); - let mut dht_requester = self.dht_requester.clone(); - - async move { - if message.dht_header.is_none() { - return Err(StoreAndForwardError::DhtHeaderNotProvided); - } + ) -> Result { + let node_identity = &self.node_identity; + let peer_manager = &self.peer_manager; + let config = &self.config; - let dht_header: DhtMessageHeader = message - .dht_header - .expect("previously checked") - .try_into() - .map_err(StoreAndForwardError::DhtMessageError)?; + if message.dht_header.is_none() { + return Err(StoreAndForwardError::DhtHeaderNotProvided); + } - if !dht_header.is_valid() { - return Err(StoreAndForwardError::InvalidDhtHeader); - } - let message_type = dht_header.message_type; + let dht_header: DhtMessageHeader = message + .dht_header + .expect("previously checked") + .try_into() + .map_err(StoreAndForwardError::DhtMessageError)?; - if message_type.is_dht_message() { - if !message_type.is_dht_discovery() { - debug!( - target: LOG_TARGET, - "Discarding {} message from peer '{}'", - message_type, - source_peer.node_id.short_str() - ); - return Err(StoreAndForwardError::InvalidDhtMessageType); - } - if dht_header.destination.is_unknown() { - debug!( - target: LOG_TARGET, - "Discarding anonymous discovery message from peer '{}'", - source_peer.node_id.short_str() - ); - return Err(StoreAndForwardError::InvalidDhtMessageType); - } + if !dht_header.is_valid() { + return Err(StoreAndForwardError::InvalidDhtHeader); + } + let message_type = dht_header.message_type; + + if message_type.is_dht_message() { + if !message_type.is_dht_discovery() { + debug!( + target: LOG_TARGET, + "Discarding {} message from peer '{}'", + message_type, + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::InvalidDhtMessageType); + } + if dht_header.destination.is_unknown() { + debug!( + target: LOG_TARGET, + "Discarding anonymous discovery message from peer '{}'", + source_peer.node_id.short_str() + ); + return Err(StoreAndForwardError::InvalidDhtMessageType); } + } - // Check that the destination is either undisclosed, for us or for our network region - Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; - // Check that the message has not already been received. - Self::check_duplicate(&mut dht_requester, &message.body).await?; + // Check that the destination is either undisclosed, for us or for our network region + Self::check_destination(&config, &peer_manager, &node_identity, &dht_header).await?; + // Check that the message has not already been received. + Self::check_duplicate(&mut self.dht_requester, &message.body).await?; - // Attempt to decrypt the message (if applicable), and deserialize it - let (authenticated_pk, decrypted_body) = - Self::authenticate_and_decrypt_if_required(&node_identity, &dht_header, &message.body)?; + // Attempt to decrypt the message (if applicable), and deserialize it + let (authenticated_pk, decrypted_body) = + Self::authenticate_and_decrypt_if_required(&node_identity, &dht_header, &message.body)?; - let mut inbound_msg = - DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); - inbound_msg.is_saf_message = true; + let mut inbound_msg = + DhtInboundMessage::new(MessageTag::new(), dht_header, Arc::clone(&source_peer), message.body); + inbound_msg.is_saf_message = true; - Ok(DecryptedDhtMessage::succeeded( - decrypted_body, - authenticated_pk, - inbound_msg, - )) - } + Ok(DecryptedDhtMessage::succeeded( + decrypted_body, + authenticated_pk, + inbound_msg, + )) } async fn check_duplicate(dht_requester: &mut DhtRequester, body: &[u8]) -> Result<(), StoreAndForwardError> { diff --git a/comms/dht/src/store_forward/store.rs b/comms/dht/src/store_forward/store.rs index 115329bf83..32144df2a7 100644 --- a/comms/dht/src/store_forward/store.rs +++ b/comms/dht/src/store_forward/store.rs @@ -31,7 +31,7 @@ use crate::{ }, DhtConfig, }; -use futures::{task::Context, Future}; +use futures::{future::BoxFuture, task::Context}; use log::*; use std::{sync::Arc, task::Poll}; use tari_comms::{ @@ -109,26 +109,29 @@ impl StoreMiddleware { } impl Service for StoreMiddleware -where S: Service + Clone + 'static +where + S: Service + Clone + Send + Sync + 'static, + S::Future: Send, { type Error = PipelineError; + type Future = BoxFuture<'static, Result>; type Response = (); - type Future = impl Future>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { Poll::Ready(Ok(())) } fn call(&mut self, msg: DecryptedDhtMessage) -> Self::Future { - StoreTask::new( - self.next_service.clone(), - self.config.clone(), - Arc::clone(&self.peer_manager), - Arc::clone(&self.node_identity), - self.saf_requester.clone(), + Box::pin( + StoreTask::new( + self.next_service.clone(), + self.config.clone(), + Arc::clone(&self.peer_manager), + Arc::clone(&self.node_identity), + self.saf_requester.clone(), + ) + .handle(msg), ) - .handle(msg) } } @@ -142,7 +145,9 @@ struct StoreTask { saf_requester: StoreAndForwardRequester, } -impl StoreTask { +impl StoreTask +where S: Service + Send + Sync +{ pub fn new( next_service: S, config: DhtConfig, @@ -159,11 +164,7 @@ impl StoreTask { saf_requester, } } -} -impl StoreTask -where S: Service -{ /// Determine if this is a message we should store for our peers and, if so, store it. /// /// The criteria for storing a message is: @@ -181,8 +182,8 @@ where S: Service message.tag, message.dht_header.message_tag ); - self.next_service.oneshot(message).await?; - return Ok(()); + let service = self.next_service.ready_oneshot().await?; + return service.oneshot(message).await; } message.set_saf_stored(false); @@ -198,9 +199,9 @@ where S: Service message.tag, message.dht_header.message_tag ); - self.next_service.oneshot(message).await?; - Ok(()) + let service = self.next_service.ready_oneshot().await?; + return service.oneshot(message).await; } async fn get_storage_priority(&self, message: &DecryptedDhtMessage) -> SafResult> { @@ -436,6 +437,7 @@ mod test { use crate::{ envelope::{DhtMessageFlags, NodeDestination}, test_utils::{ + assert_send_static_service, build_peer_manager, create_store_and_forward_mock, make_dht_inbound_message, @@ -458,6 +460,7 @@ mod test { let node_identity = make_node_identity(); let mut service = StoreLayer::new(Default::default(), peer_manager, node_identity, requester) .layer(spy.to_service::()); + assert_send_static_service(&service); let inbound_msg = make_dht_inbound_message(&make_node_identity(), b"".to_vec(), DhtMessageFlags::empty(), false); diff --git a/comms/dht/src/test_utils/mod.rs b/comms/dht/src/test_utils/mod.rs index c103b2aa9a..d3d1dc402d 100644 --- a/comms/dht/src/test_utils/mod.rs +++ b/comms/dht/src/test_utils/mod.rs @@ -51,3 +51,10 @@ pub use service::service_spy; mod store_and_forward_mock; pub use store_and_forward_mock::{create_store_and_forward_mock, StoreAndForwardMockState}; + +pub fn assert_send_static_service(_: &S) +where + S: tower::Service + Send + 'static, + S::Future: Send, +{ +} diff --git a/comms/dht/tests/dht.rs b/comms/dht/tests/dht.rs index 04807a7d49..2681f2a1ab 100644 --- a/comms/dht/tests/dht.rs +++ b/comms/dht/tests/dht.rs @@ -178,24 +178,22 @@ async fn setup_comms_dht( } let dht_outbound_layer = dht.outbound_middleware_layer(); + let pipeline = pipeline::Builder::new() + .outbound_buffer_size(10) + .with_outbound_pipeline(outbound_rx, |sink| { + ServiceBuilder::new().layer(dht_outbound_layer).service(sink) + }) + .max_concurrent_inbound_tasks(10) + .with_inbound_pipeline( + ServiceBuilder::new() + .layer(dht.inbound_middleware_layer()) + .service(SinkService::new(inbound_tx)), + ) + .build(); let (event_tx, _) = broadcast::channel(100); let comms = comms - .add_protocol_extension(MessagingProtocolExtension::new( - event_tx.clone(), - pipeline::Builder::new() - .outbound_buffer_size(10) - .with_outbound_pipeline(outbound_rx, |sink| { - ServiceBuilder::new().layer(dht_outbound_layer).service(sink) - }) - .max_concurrent_inbound_tasks(10) - .with_inbound_pipeline( - ServiceBuilder::new() - .layer(dht.inbound_middleware_layer()) - .service(SinkService::new(inbound_tx)), - ) - .build(), - )) + .add_protocol_extension(MessagingProtocolExtension::new(event_tx.clone(), pipeline)) .spawn_with_transport(MemoryTransport) .await .unwrap(); diff --git a/comms/src/pipeline/inbound.rs b/comms/src/pipeline/inbound.rs index 7cc2b5239b..c2035cf9f0 100644 --- a/comms/src/pipeline/inbound.rs +++ b/comms/src/pipeline/inbound.rs @@ -42,7 +42,7 @@ pub struct Inbound { impl Inbound where - TStream: Stream + FusedStream + Unpin + Send + 'static, + TStream: Stream + FusedStream + Unpin, TStream::Item: Send + 'static, TSvc: Service + Clone + Send + 'static, TSvc::Error: Display + Send, diff --git a/comms/src/pipeline/outbound.rs b/comms/src/pipeline/outbound.rs index 979f102489..c860166ad0 100644 --- a/comms/src/pipeline/outbound.rs +++ b/comms/src/pipeline/outbound.rs @@ -44,7 +44,7 @@ pub struct Outbound { impl Outbound where - TStream: Stream + FusedStream + Unpin + Send + 'static, + TStream: Stream + FusedStream + Unpin, TStream::Item: Send + 'static, TPipeline: Service + Clone + Send + 'static, TPipeline::Error: Display + Send, diff --git a/comms/src/protocol/extensions.rs b/comms/src/protocol/extensions.rs index e29df49af7..fdb6369e97 100644 --- a/comms/src/protocol/extensions.rs +++ b/comms/src/protocol/extensions.rs @@ -31,7 +31,7 @@ use tari_shutdown::ShutdownSignal; pub type ProtocolExtensionError = anyhow::Error; -pub trait ProtocolExtension: Send + Sync { +pub trait ProtocolExtension: Send { // TODO: The Box is easier to do for now at the cost of ProtocolExtension being less generic. fn install(self: Box, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError>; } diff --git a/comms/src/protocol/messaging/extension.rs b/comms/src/protocol/messaging/extension.rs index 692b1034e2..58b8a67248 100644 --- a/comms/src/protocol/messaging/extension.rs +++ b/comms/src/protocol/messaging/extension.rs @@ -51,13 +51,13 @@ impl MessagingProtocolExtension ProtocolExtension for MessagingProtocolExtension where - TOutPipe: Service + Clone + Send + Sync + 'static, - TOutPipe::Error: fmt::Display + Send + Sync, - TOutPipe::Future: Send + Sync + 'static, - TInPipe: Service + Clone + Send + Sync + 'static, - TInPipe::Error: fmt::Display + Send + Sync, - TInPipe::Future: Send + Sync + 'static, - TOutReq: Send + Sync + 'static, + TOutPipe: Service + Clone + Send + 'static, + TOutPipe::Error: fmt::Display + Send, + TOutPipe::Future: Send + 'static, + TInPipe: Service + Clone + Send + 'static, + TInPipe::Error: fmt::Display + Send, + TInPipe::Future: Send + 'static, + TOutReq: Send + 'static, { fn install(self: Box, context: &mut ProtocolExtensionContext) -> Result<(), ProtocolExtensionError> { let (proto_tx, proto_rx) = mpsc::channel(consts::MESSAGING_PROTOCOL_EVENTS_BUFFER_SIZE);