diff --git a/tonic/src/transport/channel/endpoint.rs b/tonic/src/transport/channel/endpoint.rs index 1ea4b3b6d..5d08e22f3 100644 --- a/tonic/src/transport/channel/endpoint.rs +++ b/tonic/src/transport/channel/endpoint.rs @@ -4,16 +4,19 @@ use super::Channel; use super::ClientTlsConfig; #[cfg(feature = "tls")] use crate::transport::service::TlsConnector; -use crate::transport::Error; +use crate::transport::{service::SharedExec, Error, Executor}; use bytes::Bytes; use http::{uri::Uri, HeaderValue}; use std::{ convert::{TryFrom, TryInto}, fmt, + future::Future, + pin::Pin, str::FromStr, time::Duration, }; use tower::make::MakeConnection; +// use crate::transport::E /// Channel builder. /// @@ -37,6 +40,7 @@ pub struct Endpoint { pub(crate) http2_keep_alive_while_idle: Option, pub(crate) connect_timeout: Option, pub(crate) http2_adaptive_window: Option, + pub(crate) executor: SharedExec, } impl Endpoint { @@ -263,6 +267,17 @@ impl Endpoint { } } + /// Sets the executor used to spawn async tasks. + /// + /// Uses `tokio::spawn` by default. + pub fn executor(mut self, executor: E) -> Self + where + E: Executor + Send>>> + Send + Sync + 'static, + { + self.executor = SharedExec::new(executor); + self + } + /// Create a channel from this config. pub async fn connect(&self) -> Result { let mut http = hyper::client::connect::HttpConnector::new(); @@ -396,6 +411,7 @@ impl From for Endpoint { http2_keep_alive_while_idle: None, connect_timeout: None, http2_adaptive_window: None, + executor: SharedExec::tokio(), } } } diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index af8cfeebf..eaff9744f 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -9,8 +9,9 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, DynamicServiceStream}; +use super::service::{Connection, DynamicServiceStream, SharedExec}; use crate::body::BoxBody; +use crate::transport::Executor; use bytes::Bytes; use http::{ uri::{InvalidUri, Uri}, @@ -124,10 +125,26 @@ impl Channel { pub fn balance_channel(capacity: usize) -> (Self, Sender>) where K: Hash + Eq + Send + Clone + 'static, + { + Self::balance_channel_with_executor(capacity, SharedExec::tokio()) + } + + /// Balance a list of [`Endpoint`]'s. + /// + /// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints. + /// + /// The [`Channel`] will use the given executor to spawn async tasks. + pub fn balance_channel_with_executor( + capacity: usize, + executor: E, + ) -> (Self, Sender>) + where + K: Hash + Eq + Send + Clone + 'static, + E: Executor + Send>>> + Send + Sync + 'static, { let (tx, rx) = channel(capacity); let list = DynamicServiceStream::new(rx); - (Self::balance(list, DEFAULT_BUFFER_SIZE), tx) + (Self::balance(list, DEFAULT_BUFFER_SIZE, executor), tx) } pub(crate) fn new(connector: C, endpoint: Endpoint) -> Self @@ -138,9 +155,11 @@ impl Channel { C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let executor = endpoint.executor.clone(); let svc = Connection::lazy(connector, endpoint); - let svc = Buffer::new(Either::A(svc), buffer_size); + let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size); + executor.execute(Box::pin(worker)); Channel { svc } } @@ -153,25 +172,29 @@ impl Channel { C::Response: AsyncRead + AsyncWrite + HyperConnection + Unpin + Send + 'static, { let buffer_size = endpoint.buffer_size.unwrap_or(DEFAULT_BUFFER_SIZE); + let executor = endpoint.executor.clone(); let svc = Connection::connect(connector, endpoint) .await .map_err(super::Error::from_source)?; - let svc = Buffer::new(Either::A(svc), buffer_size); + let (svc, worker) = Buffer::pair(Either::A(svc), buffer_size); + executor.execute(Box::pin(worker)); Ok(Channel { svc }) } - pub(crate) fn balance(discover: D, buffer_size: usize) -> Self + pub(crate) fn balance(discover: D, buffer_size: usize, executor: E) -> Self where D: Discover + Unpin + Send + 'static, D::Error: Into, D::Key: Hash + Send + Clone, + E: Executor> + Send + Sync + 'static, { let svc = Balance::new(discover); let svc = BoxService::new(svc); - let svc = Buffer::new(Either::B(svc), buffer_size); + let (svc, worker) = Buffer::pair(Either::B(svc), buffer_size); + executor.execute(Box::pin(worker)); Channel { svc } } diff --git a/tonic/src/transport/mod.rs b/tonic/src/transport/mod.rs index 789e531a6..80db95605 100644 --- a/tonic/src/transport/mod.rs +++ b/tonic/src/transport/mod.rs @@ -99,10 +99,12 @@ pub use self::error::Error; #[doc(inline)] pub use self::server::{NamedService, Server}; #[doc(inline)] -pub use self::service::TimeoutExpired; +pub use self::service::grpc_timeout::TimeoutExpired; pub use self::tls::Certificate; pub use hyper::{Body, Uri}; +pub(crate) use self::service::executor::Executor; + #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub use self::channel::ClientTlsConfig; diff --git a/tonic/src/transport/service/connection.rs b/tonic/src/transport/service/connection.rs index f321f3402..3aee2681c 100644 --- a/tonic/src/transport/service/connection.rs +++ b/tonic/src/transport/service/connection.rs @@ -39,6 +39,7 @@ impl Connection { .http2_initial_connection_window_size(endpoint.init_connection_window_size) .http2_only(true) .http2_keep_alive_interval(endpoint.http2_keep_alive_interval) + .executor(endpoint.executor.clone()) .clone(); if let Some(val) = endpoint.http2_keep_alive_timeout { diff --git a/tonic/src/transport/service/executor.rs b/tonic/src/transport/service/executor.rs new file mode 100644 index 000000000..d9706b55f --- /dev/null +++ b/tonic/src/transport/service/executor.rs @@ -0,0 +1,43 @@ +use futures_core::future::BoxFuture; +use std::{future::Future, sync::Arc}; + +pub(crate) use hyper::rt::Executor; + +#[derive(Copy, Clone)] +struct TokioExec; + +impl Executor for TokioExec +where + F: Future + Send + 'static, + F::Output: Send + 'static, +{ + fn execute(&self, fut: F) { + tokio::spawn(fut); + } +} + +#[derive(Clone)] +pub(crate) struct SharedExec { + inner: Arc> + Send + Sync + 'static>, +} + +impl SharedExec { + pub(crate) fn new(exec: E) -> Self + where + E: Executor> + Send + Sync + 'static, + { + Self { + inner: Arc::new(exec), + } + } + + pub(crate) fn tokio() -> Self { + Self::new(TokioExec) + } +} + +impl Executor> for SharedExec { + fn execute(&self, fut: BoxFuture<'static, ()>) { + self.inner.execute(fut) + } +} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index da7b46cca..355aadf09 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -2,7 +2,8 @@ mod add_origin; mod connection; mod connector; mod discover; -mod grpc_timeout; +pub(crate) mod executor; +pub(crate) mod grpc_timeout; mod io; mod reconnect; mod router; @@ -14,11 +15,11 @@ pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; pub(crate) use self::discover::DynamicServiceStream; +pub(crate) use self::executor::SharedExec; pub(crate) use self::grpc_timeout::GrpcTimeout; pub(crate) use self::io::ServerIo; #[cfg(feature = "tls")] pub(crate) use self::tls::{TlsAcceptor, TlsConnector}; pub(crate) use self::user_agent::UserAgent; -pub use self::grpc_timeout::TimeoutExpired; pub use self::router::Routes;