Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject overflowing connection with status code 429 #456

Merged
merged 5 commits into from
Sep 14, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions test-utils/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,33 @@ impl std::fmt::Debug for WebSocketTestClient {
}
}

#[derive(Debug)]
pub enum WebSocketTestError {
Redirect,
RejectedWithStatusCode(u16),
Soketto(SokettoError),
}

impl From<io::Error> for WebSocketTestError {
fn from(err: io::Error) -> Self {
WebSocketTestError::Soketto(SokettoError::Io(err))
}
}

impl WebSocketTestClient {
pub async fn new(url: SocketAddr) -> Result<Self, SokettoError> {
pub async fn new(url: SocketAddr) -> Result<Self, WebSocketTestError> {
let socket = TcpStream::connect(url).await?;
let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/");
match client.handshake().await {
Ok(handshake::ServerResponse::Accepted { .. }) => {
let (tx, rx) = client.into_builder().finish();
Ok(Self { tx, rx })
}
Ok(handshake::ServerResponse::Redirect { .. }) => {
Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Redirection not supported in tests")))
Ok(handshake::ServerResponse::Redirect { .. }) => Err(WebSocketTestError::Redirect),
Ok(handshake::ServerResponse::Rejected { status_code }) => {
Err(WebSocketTestError::RejectedWithStatusCode(status_code))
}
Ok(handshake::ServerResponse::Rejected { .. }) => {
Err(SokettoError::Io(io::Error::new(io::ErrorKind::Other, "Rejected")))
}
Err(err) => Err(err),
Err(err) => Err(WebSocketTestError::Soketto(err)),
}
}

Expand Down
88 changes: 51 additions & 37 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,17 @@ impl Server {

if connections.count() >= self.cfg.max_connections as usize {
log::warn!("Too many connections. Try again in a while.");
connections.add(Box::pin(handshake(socket, Handshake::Reject { status_code: 429 })));
continue;
}

let methods = &methods;
let cfg = &self.cfg;

connections.add(Box::pin(handshake(socket, id, methods, cfg, &stop_monitor)));
connections.add(Box::pin(handshake(
socket,
Handshake::Accept { conn_id: id, methods, cfg, stop_monitor: &stop_monitor },
)));

id = id.wrapping_add(1);
}
Expand Down Expand Up @@ -139,49 +143,59 @@ impl<'a> Future for Incoming<'a> {
}
}

async fn handshake(
socket: tokio::net::TcpStream,
conn_id: ConnectionId,
methods: &Methods,
cfg: &Settings,
stop_monitor: &StopMonitor,
) -> Result<(), Error> {
enum Handshake<'a> {
Reject { status_code: u16 },
Accept { conn_id: ConnectionId, methods: &'a Methods, cfg: &'a Settings, stop_monitor: &'a StopMonitor },
}

async fn handshake(socket: tokio::net::TcpStream, details: Handshake<'_>) -> Result<(), Error> {
maciejhirsz marked this conversation as resolved.
Show resolved Hide resolved
// For each incoming background_task we perform a handshake.
let mut server = SokettoServer::new(BufReader::new(BufWriter::new(socket.compat())));

let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);

host_check.and(origin_check).map(|()| req.key())
};

match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
match details {
Handshake::Reject { status_code } => {
// Forced rejection, don't need to read anything from the socket
let reject = Response::Reject { status_code };
server.send_response(&reject).await?;

return Err(error);
Ok(())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does Ok(()) vs Err(error) imply here?

I suppose we don't have any error for it in Incoming and return Ok(()) is probably fine because it's already logged once the connection limit is exceeded when Handshake::Reject is detected.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, returning 429 rejection isn't an error here, it's expected behavior (but we can still get an io error when sending the response for whatever reason).

}
}
Handshake::Accept { conn_id, methods, cfg, stop_monitor } => {
let key = {
let req = server.receive_request().await?;
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);

let join_result = tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
stop_monitor.clone(),
))
.await;

match join_result {
Err(_) => Err(Error::Custom("Background task was aborted".into())),
Ok(result) => result,
host_check.and(origin_check).map(|()| req.key())
};

match key {
Ok(key) => {
let accept = Response::Accept { key, protocol: None };
server.send_response(&accept).await?;
}
Err(error) => {
let reject = Response::Reject { status_code: 403 };
server.send_response(&reject).await?;

return Err(error);
}
}

let join_result = tokio::spawn(background_task(
server,
conn_id,
methods.clone(),
cfg.max_request_body_size,
stop_monitor.clone(),
))
.await;

match join_result {
Err(_) => Err(Error::Custom("Background task was aborted".into())),
Ok(result) => result,
}
}
}
}

Expand Down
9 changes: 2 additions & 7 deletions ws-server/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::{future::StopHandle, RpcModule, WsServerBuilder};
use anyhow::anyhow;
use futures_util::FutureExt;
use jsonrpsee_test_utils::helpers::*;
use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient};
use jsonrpsee_test_utils::types::{Id, TestContext, WebSocketTestClient, WebSocketTestError};
use jsonrpsee_test_utils::TimeoutFutureExt;
use serde_json::Value as JsonValue;
use std::fmt;
Expand Down Expand Up @@ -203,12 +203,7 @@ async fn can_set_max_connections() {
assert!(conn2.is_ok());
// Third connection is rejected
assert!(conn3.is_err());

let err = match conn3 {
Err(soketto::handshake::Error::Io(err)) => err,
_ => panic!("Invalid error kind; expected std::io::Error"),
};
assert_eq!(err.kind(), std::io::ErrorKind::ConnectionReset);
assert!(matches!(conn3, Err(WebSocketTestError::RejectedWithStatusCode(429))));

// Decrement connection count
drop(conn2);
Expand Down