diff --git a/http-server/src/lib.rs b/http-server/src/lib.rs index 155451d7e1..a19cd64598 100644 --- a/http-server/src/lib.rs +++ b/http-server/src/lib.rs @@ -42,7 +42,7 @@ pub use access_control::{ }; pub use jsonrpsee_types::{Error, TEN_MB_SIZE_BYTES}; pub use jsonrpsee_utils::server::rpc_module::RpcModule; -pub use server::{Builder as HttpServerBuilder, Server as HttpServer}; +pub use server::{Builder as HttpServerBuilder, Server as HttpServer, StopHandle as HttpStopHandle}; #[cfg(test)] mod tests; diff --git a/http-server/src/server.rs b/http-server/src/server.rs index 9d485db331..bd26b6bc31 100644 --- a/http-server/src/server.rs +++ b/http-server/src/server.rs @@ -26,7 +26,7 @@ use crate::{response, AccessControl, TEN_MB_SIZE_BYTES}; use futures_channel::mpsc; -use futures_util::stream::StreamExt; +use futures_util::{lock::Mutex, stream::StreamExt, SinkExt}; use hyper::{ server::{conn::AddrIncoming, Builder as HyperBuilder}, service::{make_service_fn, service_fn}, @@ -95,12 +95,16 @@ impl Builder { let local_addr = listener.local_addr().ok(); let listener = hyper::Server::from_tcp(listener)?; + + let stop_pair = mpsc::channel(1); Ok(Server { listener, local_addr, methods: Methods::default(), access_control: self.access_control, max_request_body_size: self.max_request_body_size, + stop_pair, + stop_handle: Arc::new(Mutex::new(())), }) } } @@ -111,6 +115,25 @@ impl Default for Builder { } } +/// Handle used to stop the running server. +#[derive(Debug, Clone)] +pub struct StopHandle { + stop_sender: mpsc::Sender<()>, + stop_handle: Arc>, +} + +impl StopHandle { + /// Requests server to stop. Returns an error if server was already stopped. + pub async fn stop(&mut self) -> Result<(), Error> { + self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped) + } + + /// Blocks indefinitely until the server is stopped. + pub async fn wait_for_stop(&self) { + self.stop_handle.lock().await; + } +} + /// An HTTP JSON RPC server. #[derive(Debug)] pub struct Server { @@ -124,6 +147,10 @@ pub struct Server { max_request_body_size: u32, /// Access control access_control: AccessControl, + /// Pair of channels to stop the server. + stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>), + /// Stop handle that indicates whether server has been stopped. + stop_handle: Arc>, } impl Server { @@ -145,11 +172,21 @@ impl Server { self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into())) } + /// Returns the handle to stop the running server. + pub fn stop_handle(&self) -> StopHandle { + StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() } + } + /// Start the server. pub async fn start(self) -> Result<(), Error> { + // Lock the stop mutex so existing stop handles can wait for server to stop. + // It will be unlocked once this function returns. + let _stop_handle = self.stop_handle.lock().await; + let methods = Arc::new(self.methods); let max_request_body_size = self.max_request_body_size; let access_control = self.access_control; + let mut stop_receiver = self.stop_pair.1; let make_service = make_service_fn(move |_| { let methods = methods.clone(); @@ -236,7 +273,12 @@ impl Server { }); let server = self.listener.serve(make_service); - server.await.map_err(Into::into) + server + .with_graceful_shutdown(async move { + stop_receiver.next().await; + }) + .await + .map_err(Into::into) } } diff --git a/http-server/src/tests.rs b/http-server/src/tests.rs index dd3d520f13..14532e04b1 100644 --- a/http-server/src/tests.rs +++ b/http-server/src/tests.rs @@ -2,15 +2,20 @@ use std::net::SocketAddr; -use crate::{HttpServerBuilder, RpcModule}; +use crate::{server::StopHandle, HttpServerBuilder, RpcModule}; use futures_util::FutureExt; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext}; use jsonrpsee_test_utils::TimeoutFutureExt; use jsonrpsee_types::error::{CallError, Error}; use serde_json::Value as JsonValue; +use tokio::task::JoinHandle; async fn server() -> SocketAddr { + server_with_handles().await.0 +} + +async fn server_with_handles() -> (SocketAddr, JoinHandle>, StopHandle) { let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap(); let ctx = TestContext; let mut module = RpcModule::new(ctx); @@ -56,8 +61,9 @@ async fn server() -> SocketAddr { .unwrap(); server.register_module(module).unwrap(); - tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() }); - addr + let stop_handle = server.stop_handle(); + let join_handle = tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() }); + (addr, join_handle, stop_handle) } #[tokio::test] @@ -308,3 +314,23 @@ async fn can_register_modules() { assert_eq!(err.to_string(), expected_err.to_string()); assert_eq!(server.method_names().len(), 2); } + +#[tokio::test] +async fn stop_works() { + let _ = env_logger::try_init(); + let (_addr, join_handle, mut stop_handle) = server_with_handles().with_default_timeout().await.unwrap(); + stop_handle.stop().with_default_timeout().await.unwrap().unwrap(); + stop_handle.wait_for_stop().with_default_timeout().await.unwrap(); + + // After that we should be able to wait for task handle to finish. + // First `unwrap` is timeout, second is `JoinHandle`'s one, third is the server future result. + join_handle + .with_default_timeout() + .await + .expect("Timeout") + .expect("Join error") + .expect("Server stopped with an error"); + + // After server was stopped, attempt to stop it again should result in an error. + assert!(matches!(stop_handle.stop().with_default_timeout().await.unwrap(), Err(Error::AlreadyStopped))); +} diff --git a/types/src/error.rs b/types/src/error.rs index 84a699971e..5fe57efadf 100644 --- a/types/src/error.rs +++ b/types/src/error.rs @@ -81,6 +81,9 @@ pub enum Error { /// Configured max number of request slots exceeded. #[error("Configured max number of request slots exceeded")] MaxSlotsExceeded, + /// Attempted to stop server that is already stopped. + #[error("Attempted to stop server that is already stopped")] + AlreadyStopped, /// List passed into `set_allowed_origins` was empty #[error("Must set at least one allowed origin")] EmptyAllowedOrigins, diff --git a/ws-server/src/lib.rs b/ws-server/src/lib.rs index 3f6bef39c8..1f8d6d45e1 100644 --- a/ws-server/src/lib.rs +++ b/ws-server/src/lib.rs @@ -39,4 +39,4 @@ mod tests; pub use jsonrpsee_types::error::Error; pub use jsonrpsee_utils::server::rpc_module::{RpcModule, SubscriptionSink}; -pub use server::{Builder as WsServerBuilder, Server as WsServer}; +pub use server::{Builder as WsServerBuilder, Server as WsServer, StopHandle as WsStopHandle}; diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index 21bcb8a8cd..6a9bb99455 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -25,12 +25,18 @@ // DEALINGS IN THE SOFTWARE. use futures_channel::mpsc; -use futures_util::io::{BufReader, BufWriter}; use futures_util::stream::StreamExt; +use futures_util::{ + io::{BufReader, BufWriter}, + SinkExt, +}; use jsonrpsee_types::TEN_MB_SIZE_BYTES; use soketto::handshake::{server::Response, Server as SokettoServer}; use std::{net::SocketAddr, sync::Arc}; -use tokio::net::{TcpListener, ToSocketAddrs}; +use tokio::{ + net::{TcpListener, ToSocketAddrs}, + sync::Mutex, +}; use tokio_stream::wrappers::TcpListenerStream; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -50,6 +56,10 @@ pub struct Server { methods: Methods, listener: TcpListener, cfg: Settings, + /// Pair of channels to stop the server. + stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>), + /// Stop handle that indicates whether server has been stopped. + stop_handle: Arc>, } impl Server { @@ -71,35 +81,57 @@ impl Server { self.listener.local_addr().map_err(Into::into) } + /// Returns the handle to stop the running server. + pub fn stop_handle(&self) -> StopHandle { + StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() } + } + /// Start responding to connections requests. This will block current thread until the server is stopped. pub async fn start(self) { - let mut incoming = TcpListenerStream::new(self.listener); + // Lock the stop mutex so existing stop handles can wait for server to stop. + // It will be unlocked once this function returns. + let _stop_handle = self.stop_handle.lock().await; + + let mut incoming = TcpListenerStream::new(self.listener).fuse(); let methods = self.methods; let conn_counter = Arc::new(()); let mut id = 0; - - while let Some(socket) = incoming.next().await { - if let Ok(socket) = socket { - if let Err(e) = socket.set_nodelay(true) { - log::error!("Could not set NODELAY on socket: {:?}", e); - continue; - } - - if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { - log::warn!("Too many connections. Try again in a while"); - continue; - } - let methods = methods.clone(); - let counter = conn_counter.clone(); - let cfg = self.cfg.clone(); - - tokio::spawn(async move { - let r = background_task(socket, id, methods, cfg).await; - drop(counter); - r - }); - - id += 1; + let mut stop_receiver = self.stop_pair.1; + + loop { + futures_util::select! { + socket = incoming.next() => { + if let Some(Ok(socket)) = socket { + if let Err(e) = socket.set_nodelay(true) { + log::error!("Could not set NODELAY on socket: {:?}", e); + continue; + } + + if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize { + log::warn!("Too many connections. Try again in a while"); + continue; + } + let methods = methods.clone(); + let counter = conn_counter.clone(); + let cfg = self.cfg.clone(); + + tokio::spawn(async move { + let r = background_task(socket, id, methods, cfg).await; + drop(counter); + r + }); + + id += 1; + } else { + break; + } + }, + stop = stop_receiver.next() => { + if stop.is_some() { + break; + } + }, + complete => break, } } } @@ -296,7 +328,14 @@ impl Builder { /// Finalize the configuration of the server. Consumes the [`Builder`]. pub async fn build(self, addr: impl ToSocketAddrs) -> Result { let listener = TcpListener::bind(addr).await?; - Ok(Server { listener, methods: Methods::default(), cfg: self.settings }) + let stop_pair = mpsc::channel(1); + Ok(Server { + listener, + methods: Methods::default(), + cfg: self.settings, + stop_pair, + stop_handle: Arc::new(Mutex::new(())), + }) } } @@ -305,3 +344,25 @@ impl Default for Builder { Self { settings: Settings::default() } } } + +/// Handle that is able to stop the running server. +#[derive(Debug, Clone)] +pub struct StopHandle { + stop_sender: mpsc::Sender<()>, + stop_handle: Arc>, +} + +impl StopHandle { + /// Requests server to stop. Returns an error if server was already stopped. + /// + /// Note: This method *does not* abort spawned futures, e.g. `tokio::spawn` handlers + /// for subscriptions. It only prevents server from accepting new connections. + pub async fn stop(&mut self) -> Result<(), Error> { + self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped) + } + + /// Blocks indefinitely until the server is stopped. + pub async fn wait_for_stop(&self) { + self.stop_handle.lock().await; + } +} diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index 897261c7ce..fdf138502b 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -1,6 +1,6 @@ #![cfg(test)] -use crate::{RpcModule, WsServerBuilder}; +use crate::{server::StopHandle, RpcModule, WsServerBuilder}; use futures_util::FutureExt; use jsonrpsee_test_utils::helpers::*; use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient}; @@ -12,6 +12,7 @@ use jsonrpsee_types::{ use serde_json::Value as JsonValue; use std::fmt; use std::net::SocketAddr; +use tokio::task::JoinHandle; /// Applications can/should provide their own error. #[derive(Debug)] @@ -26,6 +27,13 @@ impl std::error::Error for MyAppError {} /// Spawns a dummy `JSONRPC v2 WebSocket` /// It has two hardcoded methods: "say_hello" and "add" async fn server() -> SocketAddr { + server_with_handles().await.0 +} + +/// Spawns a dummy `JSONRPC v2 WebSocket` +/// It has two hardcoded methods: "say_hello" and "add" +/// Returns the address together with handles for server future and server stop. +async fn server_with_handles() -> (SocketAddr, JoinHandle<()>, StopHandle) { let mut server = WsServerBuilder::default().build("127.0.0.1:0").with_default_timeout().await.unwrap().unwrap(); let mut module = RpcModule::new(()); module @@ -64,8 +72,10 @@ async fn server() -> SocketAddr { let addr = server.local_addr().unwrap(); server.register_module(module).unwrap(); - tokio::spawn(async { server.start().await }); - addr + + let stop_handle = server.stop_handle(); + let join_handle = tokio::spawn(server.start()); + (addr, join_handle, stop_handle) } /// Run server with user provided context. @@ -114,7 +124,7 @@ async fn server_with_context() -> SocketAddr { server.register_module(rpc_module).unwrap(); let addr = server.local_addr().unwrap(); - tokio::spawn(async { server.start().await }); + tokio::spawn(server.start()); addr } @@ -305,7 +315,7 @@ async fn async_method_call_that_fails() { let req = r#"{"jsonrpc":"2.0","method":"err_async", "params":[],"id":1}"#; let response = client.send_request_text(req).await.unwrap(); - assert_eq!(response, call_execution_failed("nah".into(), Id::Num(1))); + assert_eq!(response, call_execution_failed("nah", Id::Num(1))); } #[tokio::test] @@ -442,3 +452,18 @@ async fn can_register_modules() { assert!(matches!(err, _expected_err)); assert_eq!(server.method_names().len(), 2); } + +#[tokio::test] +async fn stop_works() { + let _ = env_logger::try_init(); + let (_addr, join_handle, mut stop_handle) = server_with_handles().with_default_timeout().await.unwrap(); + stop_handle.stop().with_default_timeout().await.unwrap().unwrap(); + stop_handle.wait_for_stop().with_default_timeout().await.unwrap(); + + // After that we should be able to wait for task handle to finish. + // First `unwrap` is timeout, second is `JoinHandle`'s one. + join_handle.with_default_timeout().await.expect("Timeout").expect("Join error"); + + // After server was stopped, attempt to stop it again should result in an error. + assert!(matches!(stop_handle.stop().with_default_timeout().await.unwrap(), Err(Error::AlreadyStopped))); +}