diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index 49f98f9fa6..c1af1c0a1d 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -88,8 +88,21 @@ impl std::fmt::Debug for WebSocketTestClient { } } +#[derive(Debug)] +pub enum WebSocketTestError { + Redirect, + RejectedWithStatusCode(u16), + Soketto(SokettoError), +} + +impl From for WebSocketTestError { + fn from(err: io::Error) -> Self { + WebSocketTestError::Soketto(SokettoError::Io(err)) + } +} + impl WebSocketTestClient { - pub async fn new(url: SocketAddr) -> Result { + pub async fn new(url: SocketAddr) -> Result { let socket = TcpStream::connect(url).await?; let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/"); match client.handshake().await { @@ -97,13 +110,11 @@ impl WebSocketTestClient { let (tx, rx) = client.into_builder().finish(); Ok(Self { tx, rx }) } - Ok(handshake::ServerResponse::Redirect { .. }) => { - Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Redirection not supported in tests"))) + Ok(handshake::ServerResponse::Redirect { .. }) => Err(WebSocketTestError::Redirect), + Ok(handshake::ServerResponse::Rejected { status_code }) => { + Err(WebSocketTestError::RejectedWithStatusCode(status_code)) } - Ok(handshake::ServerResponse::Rejected { .. }) => { - Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Rejected"))) - } - Err(err) => Err(err), + Err(err) => Err(WebSocketTestError::Soketto(err)), } } diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 87ed638de4..72e7dab58c 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -281,8 +281,16 @@ async fn ws_close_pending_subscription_when_server_terminated() { // no new request should be accepted. assert!(matches!(sub2, Err(_))); + // consume final message - assert!(matches!(sub.next().await, Ok(Some(_)))); - // the already established subscription should also be closed. - assert!(matches!(sub.next().await, Ok(None))); + for _ in 0..2 { + match sub.next().await { + // All good, exit test + Ok(None) => return, + // Try again + _ => continue, + } + } + + panic!("subscription keeps sending messages after server shutdown"); } diff --git a/ws-server/src/server.rs b/ws-server/src/server.rs index a421fb35ce..502a53735d 100644 --- a/ws-server/src/server.rs +++ b/ws-server/src/server.rs @@ -35,7 +35,6 @@ use crate::types::{ }; use futures_channel::mpsc; use futures_util::io::{BufReader, BufWriter}; -// use futures_util::future::FutureExt; use futures_util::stream::StreamExt; use soketto::handshake::{server::Response, Server as SokettoServer}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; @@ -86,13 +85,17 @@ impl Server { if connections.count() >= self.cfg.max_connections as usize { log::warn!("Too many connections. Try again in a while."); + connections.add(Box::pin(handshake(socket, HandshakeResponse::Reject { status_code: 429 }))); continue; } let methods = &methods; let cfg = &self.cfg; - connections.add(Box::pin(handshake(socket, id, methods, cfg, &stop_monitor))); + connections.add(Box::pin(handshake( + socket, + HandshakeResponse::Accept { conn_id: id, methods, cfg, stop_monitor: &stop_monitor }, + ))); id = id.wrapping_add(1); } @@ -139,49 +142,64 @@ impl<'a> Future for Incoming<'a> { } } -async fn handshake( - socket: tokio::net::TcpStream, - conn_id: ConnectionId, - methods: &Methods, - cfg: &Settings, - stop_monitor: &StopMonitor, -) -> Result<(), Error> { +enum HandshakeResponse<'a> { + Reject { status_code: u16 }, + Accept { conn_id: ConnectionId, methods: &'a Methods, cfg: &'a Settings, stop_monitor: &'a StopMonitor }, +} + +async fn handshake(socket: tokio::net::TcpStream, mode: HandshakeResponse<'_>) -> Result<(), Error> { // For each incoming background_task we perform a handshake. let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat()))); - let key = { - let req = server.receive_request().await?; - let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host)); - let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin); + match mode { + HandshakeResponse::Reject { status_code } => { + // Forced rejection, don't need to read anything from the socket + let reject = Response::Reject { status_code }; + server.send_response(&reject).await?; - host_check.and(origin_check).map(|()| req.key()) - }; + let (mut sender, _) = server.into_builder().finish(); - match key { - Ok(key) => { - let accept = Response::Accept { key, protocol: None }; - server.send_response(&accept).await?; - } - Err(error) => { - let reject = Response::Reject { status_code: 403 }; - server.send_response(&reject).await?; + // Gracefully shut down the connection + sender.close().await?; - return Err(error); + Ok(()) } - } + HandshakeResponse::Accept { conn_id, methods, cfg, stop_monitor } => { + let key = { + let req = server.receive_request().await?; + let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host)); + let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin); - let join_result = tokio::spawn(background_task( - server, - conn_id, - methods.clone(), - cfg.max_request_body_size, - stop_monitor.clone(), - )) - .await; - - match join_result { - Err(_) => Err(Error::Custom("Background task was aborted".into())), - Ok(result) => result, + host_check.and(origin_check).map(|()| req.key()) + }; + + match key { + Ok(key) => { + let accept = Response::Accept { key, protocol: None }; + server.send_response(&accept).await?; + } + Err(error) => { + let reject = Response::Reject { status_code: 403 }; + server.send_response(&reject).await?; + + return Err(error); + } + } + + let join_result = tokio::spawn(background_task( + server, + conn_id, + methods.clone(), + cfg.max_request_body_size, + stop_monitor.clone(), + )) + .await; + + match join_result { + Err(_) => Err(Error::Custom("Background task was aborted".into())), + Ok(result) => result, + } + } } } diff --git a/ws-server/src/tests.rs b/ws-server/src/tests.rs index ab8213aabe..bf2f35d29f 100644 --- a/ws-server/src/tests.rs +++ b/ws-server/src/tests.rs @@ -31,7 +31,7 @@ use crate::{future::StopHandle, RpcModule, WsServerBuilder}; use anyhow::anyhow; use futures_util::FutureExt; use jsonrpsee_test_utils::helpers::*; -use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient}; +use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient, WebSocketTestError}; use jsonrpsee_test_utils::TimeoutFutureExt; use serde_json::Value as JsonValue; use std::fmt; @@ -203,12 +203,9 @@ async fn can_set_max_connections() { assert!(conn2.is_ok()); // Third connection is rejected assert!(conn3.is_err()); - - let err = match conn3 { - Err(soketto::handshake::Error::Io(err)) => err, - _ => panic!("Invalid error kind; expected std::io::Error"), - }; - assert_eq!(err.kind(), std::io::ErrorKind::ConnectionReset); + if !matches!(conn3, Err(WebSocketTestError::RejectedWithStatusCode(429))) { + panic!("Expected RejectedWithStatusCode(429), got: {:#?}", conn3); + } // Decrement connection count drop(conn2);