Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to stop servers #386

Merged
merged 7 commits into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion http-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use access_control::{
};
pub use jsonrpsee_types::{Error, TEN_MB_SIZE_BYTES};
pub use jsonrpsee_utils::server::rpc_module::RpcModule;
pub use server::{Builder as HttpServerBuilder, Server as HttpServer};
pub use server::{Builder as HttpServerBuilder, Server as HttpServer, StopHandle as HttpStopHandle};

#[cfg(test)]
mod tests;
46 changes: 44 additions & 2 deletions http-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

use crate::{response, AccessControl, TEN_MB_SIZE_BYTES};
use futures_channel::mpsc;
use futures_util::stream::StreamExt;
use futures_util::{lock::Mutex, stream::StreamExt, SinkExt};
use hyper::{
server::{conn::AddrIncoming, Builder as HyperBuilder},
service::{make_service_fn, service_fn},
Expand Down Expand Up @@ -95,12 +95,16 @@ impl Builder {
let local_addr = listener.local_addr().ok();

let listener = hyper::Server::from_tcp(listener)?;

let stop_pair = mpsc::channel(1);
Ok(Server {
listener,
local_addr,
methods: Methods::default(),
access_control: self.access_control,
max_request_body_size: self.max_request_body_size,
stop_pair,
stop_handle: Arc::new(Mutex::new(())),
})
}
}
Expand All @@ -111,6 +115,25 @@ impl Default for Builder {
}
}

/// Handle used to stop the running server.
#[derive(Debug, Clone)]
pub struct StopHandle {
stop_sender: mpsc::Sender<()>,
stop_handle: Arc<Mutex<()>>,
}

impl StopHandle {
/// Requests server to stop. Returns an error if server was already stopped.
pub async fn stop(&mut self) -> Result<(), Error> {
self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped)
}

/// Blocks indefinitely until the server is stopped.
pub async fn wait_for_stop(&self) {
self.stop_handle.lock().await;
}
}

/// An HTTP JSON RPC server.
#[derive(Debug)]
pub struct Server {
Expand All @@ -124,6 +147,10 @@ pub struct Server {
max_request_body_size: u32,
/// Access control
access_control: AccessControl,
/// Pair of channels to stop the server.
stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>),
/// Stop handle that indicates whether server has been stopped.
stop_handle: Arc<Mutex<()>>,
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
}

impl Server {
Expand All @@ -145,11 +172,21 @@ impl Server {
self.local_addr.ok_or_else(|| Error::Custom("Local address not found".into()))
}

/// Returns the handle to stop the running server.
pub fn stop_handle(&self) -> StopHandle {
StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() }
}

/// Start the server.
pub async fn start(self) -> Result<(), Error> {
// Lock the stop mutex so existing stop handles can wait for server to stop.
// It will be unlocked once this function returns.
let _stop_handle = self.stop_handle.lock().await;

let methods = Arc::new(self.methods);
let max_request_body_size = self.max_request_body_size;
let access_control = self.access_control;
let mut stop_receiver = self.stop_pair.1;

let make_service = make_service_fn(move |_| {
let methods = methods.clone();
Expand Down Expand Up @@ -236,7 +273,12 @@ impl Server {
});

let server = self.listener.serve(make_service);
server.await.map_err(Into::into)
server
.with_graceful_shutdown(async move {
stop_receiver.next().await;
})
.await
.map_err(Into::into)
}
}

Expand Down
32 changes: 29 additions & 3 deletions http-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

use std::net::SocketAddr;

use crate::{HttpServerBuilder, RpcModule};
use crate::{server::StopHandle, HttpServerBuilder, RpcModule};
use futures_util::FutureExt;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, StatusCode, TestContext};
use jsonrpsee_test_utils::TimeoutFutureExt;
use jsonrpsee_types::error::{CallError, Error};
use serde_json::Value as JsonValue;
use tokio::task::JoinHandle;

async fn server() -> SocketAddr {
server_with_handles().await.0
}

async fn server_with_handles() -> (SocketAddr, JoinHandle<Result<(), Error>>, StopHandle) {
let mut server = HttpServerBuilder::default().build("127.0.0.1:0".parse().unwrap()).unwrap();
let ctx = TestContext;
let mut module = RpcModule::new(ctx);
Expand Down Expand Up @@ -56,8 +61,9 @@ async fn server() -> SocketAddr {
.unwrap();

server.register_module(module).unwrap();
tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() });
addr
let stop_handle = server.stop_handle();
let join_handle = tokio::spawn(async move { server.start().with_default_timeout().await.unwrap() });
(addr, join_handle, stop_handle)
}

#[tokio::test]
Expand Down Expand Up @@ -308,3 +314,23 @@ async fn can_register_modules() {
assert_eq!(err.to_string(), expected_err.to_string());
assert_eq!(server.method_names().len(), 2);
}

#[tokio::test]
async fn stop_works() {
let _ = env_logger::try_init();
let (_addr, join_handle, mut stop_handle) = server_with_handles().with_default_timeout().await.unwrap();
stop_handle.stop().with_default_timeout().await.unwrap().unwrap();
stop_handle.wait_for_stop().with_default_timeout().await.unwrap();

// After that we should be able to wait for task handle to finish.
// First `unwrap` is timeout, second is `JoinHandle`'s one, third is the server future result.
join_handle
.with_default_timeout()
.await
.expect("Timeout")
.expect("Join error")
.expect("Server stopped with an error");

// After server was stopped, attempt to stop it again should result in an error.
assert!(matches!(stop_handle.stop().with_default_timeout().await.unwrap(), Err(Error::AlreadyStopped)));
}
3 changes: 3 additions & 0 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ pub enum Error {
/// Configured max number of request slots exceeded.
#[error("Configured max number of request slots exceeded")]
MaxSlotsExceeded,
/// Attempted to stop server that is already stopped.
#[error("Attempted to stop server that is already stopped")]
AlreadyStopped,
/// List passed into `set_allowed_origins` was empty
#[error("Must set at least one allowed origin")]
EmptyAllowedOrigins,
Expand Down
2 changes: 1 addition & 1 deletion ws-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ mod tests;

pub use jsonrpsee_types::error::Error;
pub use jsonrpsee_utils::server::rpc_module::{RpcModule, SubscriptionSink};
pub use server::{Builder as WsServerBuilder, Server as WsServer};
pub use server::{Builder as WsServerBuilder, Server as WsServer, StopHandle as WsStopHandle};
115 changes: 88 additions & 27 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,18 @@
// DEALINGS IN THE SOFTWARE.

use futures_channel::mpsc;
use futures_util::io::{BufReader, BufWriter};
use futures_util::stream::StreamExt;
use futures_util::{
io::{BufReader, BufWriter},
SinkExt,
};
use jsonrpsee_types::TEN_MB_SIZE_BYTES;
use soketto::handshake::{server::Response, Server as SokettoServer};
use std::{net::SocketAddr, sync::Arc};
use tokio::net::{TcpListener, ToSocketAddrs};
use tokio::{
net::{TcpListener, ToSocketAddrs},
sync::Mutex,
};
use tokio_stream::wrappers::TcpListenerStream;
use tokio_util::compat::TokioAsyncReadCompatExt;

Expand All @@ -50,6 +56,10 @@ pub struct Server {
methods: Methods,
listener: TcpListener,
cfg: Settings,
/// Pair of channels to stop the server.
stop_pair: (mpsc::Sender<()>, mpsc::Receiver<()>),
/// Stop handle that indicates whether server has been stopped.
stop_handle: Arc<Mutex<()>>,
}

impl Server {
Expand All @@ -71,35 +81,57 @@ impl Server {
self.listener.local_addr().map_err(Into::into)
}

/// Returns the handle to stop the running server.
pub fn stop_handle(&self) -> StopHandle {
StopHandle { stop_sender: self.stop_pair.0.clone(), stop_handle: self.stop_handle.clone() }
}

/// Start responding to connections requests. This will block current thread until the server is stopped.
pub async fn start(self) {
let mut incoming = TcpListenerStream::new(self.listener);
// Lock the stop mutex so existing stop handles can wait for server to stop.
// It will be unlocked once this function returns.
let _stop_handle = self.stop_handle.lock().await;

let mut incoming = TcpListenerStream::new(self.listener).fuse();
let methods = self.methods;
let conn_counter = Arc::new(());
let mut id = 0;

while let Some(socket) = incoming.next().await {
if let Ok(socket) = socket {
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
}

if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}
let methods = methods.clone();
let counter = conn_counter.clone();
let cfg = self.cfg.clone();

tokio::spawn(async move {
let r = background_task(socket, id, methods, cfg).await;
drop(counter);
r
});

id += 1;
let mut stop_receiver = self.stop_pair.1;

loop {
futures_util::select! {
socket = incoming.next() => {
if let Some(Ok(socket)) = socket {
if let Err(e) = socket.set_nodelay(true) {
log::error!("Could not set NODELAY on socket: {:?}", e);
continue;
}

if Arc::strong_count(&conn_counter) > self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while");
continue;
}
let methods = methods.clone();
let counter = conn_counter.clone();
let cfg = self.cfg.clone();

tokio::spawn(async move {
let r = background_task(socket, id, methods, cfg).await;
drop(counter);
r
});

id += 1;
} else {
break;
}
},
stop = stop_receiver.next() => {
if stop.is_some() {
break;
}
},
complete => break,
}
}
}
Expand Down Expand Up @@ -296,7 +328,14 @@ impl Builder {
/// Finalize the configuration of the server. Consumes the [`Builder`].
pub async fn build(self, addr: impl ToSocketAddrs) -> Result<Server, Error> {
let listener = TcpListener::bind(addr).await?;
Ok(Server { listener, methods: Methods::default(), cfg: self.settings })
let stop_pair = mpsc::channel(1);
Ok(Server {
listener,
methods: Methods::default(),
cfg: self.settings,
stop_pair,
stop_handle: Arc::new(Mutex::new(())),
})
}
}

Expand All @@ -305,3 +344,25 @@ impl Default for Builder {
Self { settings: Settings::default() }
}
}

/// Handle that is able to stop the running server.
#[derive(Debug, Clone)]
pub struct StopHandle {
stop_sender: mpsc::Sender<()>,
stop_handle: Arc<Mutex<()>>,
}

impl StopHandle {
/// Requests server to stop. Returns an error if server was already stopped.
///
/// Note: This method *does not* abort spawned futures, e.g. `tokio::spawn` handlers
/// for subscriptions. It only prevents server from accepting new connections.
pub async fn stop(&mut self) -> Result<(), Error> {
self.stop_sender.send(()).await.map_err(|_| Error::AlreadyStopped)
}

/// Blocks indefinitely until the server is stopped.
pub async fn wait_for_stop(&self) {
self.stop_handle.lock().await;
}
}
Loading