Skip to content

Commit

Permalink
Set allowed Host header values (#399)
Browse files Browse the repository at this point in the history
* Set allowed Host header values

* Error if allowed hosts list is empty

* Grammar

Co-authored-by: David <dvdplm@gmail.com>

Co-authored-by: David <dvdplm@gmail.com>
  • Loading branch information
maciejhirsz and dvdplm authored Jul 1, 2021
1 parent 7496afe commit f705e32
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
4 changes: 2 additions & 2 deletions types/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ pub enum Error {
#[error("Attempted to stop server that is already stopped")]
AlreadyStopped,
/// List passed into `set_allowed_origins` was empty
#[error("Must set at least one allowed origin")]
EmptyAllowedOrigins,
#[error("Must set at least one allowed value for the {0} header")]
EmptyAllowList(&'static str),
/// Custom error.
#[error("Custom error: {0}")]
Custom(String),
Expand Down
75 changes: 58 additions & 17 deletions ws-server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,10 @@ async fn handshake(

let key = {
let req = server.receive_request().await?;
cfg.allowed_origins.verify(req.headers().origin).map(|()| req.key())
let host_check = cfg.allowed_hosts.verify("Host", Some(req.headers().host));
let origin_check = cfg.allowed_origins.verify("Origin", req.headers().origin);

host_check.and(origin_check).map(|()| req.key())
};

match key {
Expand Down Expand Up @@ -324,16 +327,16 @@ async fn background_task(
}

#[derive(Debug, Clone)]
enum AllowedOrigins {
enum AllowedValue {
Any,
OneOf(Arc<[String]>),
OneOf(Box<[String]>),
}

impl AllowedOrigins {
fn verify(&self, origin: Option<&[u8]>) -> Result<(), Error> {
if let (AllowedOrigins::OneOf(list), Some(origin)) = (self, origin) {
if !list.iter().any(|o| o.as_bytes() == origin) {
let error = format!("Origin denied: {}", String::from_utf8_lossy(origin));
impl AllowedValue {
fn verify(&self, header: &str, value: Option<&[u8]>) -> Result<(), Error> {
if let (AllowedValue::OneOf(list), Some(value)) = (self, value) {
if !list.iter().any(|o| o.as_bytes() == value) {
let error = format!("{} denied: {}", header, String::from_utf8_lossy(value));
log::warn!("{}", error);
return Err(Error::Request(error));
}
Expand All @@ -350,16 +353,19 @@ struct Settings {
max_request_body_size: u32,
/// Maximum number of incoming connections allowed.
max_connections: u64,
/// Cross-origin policy by which to accept or deny incoming requests.
allowed_origins: AllowedOrigins,
/// Policy by which to accept or deny incoming requests based on the `Origin` header.
allowed_origins: AllowedValue,
/// Policy by which to accept or deny incoming requests based on the `Host` header.
allowed_hosts: AllowedValue,
}

impl Default for Settings {
fn default() -> Self {
Self {
max_request_body_size: TEN_MB_SIZE_BYTES,
max_connections: MAX_CONNECTIONS,
allowed_origins: AllowedOrigins::Any,
allowed_origins: AllowedValue::Any,
allowed_hosts: AllowedValue::Any,
}
}
}
Expand All @@ -385,11 +391,11 @@ impl Builder {

/// Set a list of allowed origins. During the handshake, the `Origin` header will be
/// checked against the list, connections without a matching origin will be denied.
/// Values should include protocol.
/// Values should be hostnames with protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_origins(vec!["https://example.com"]);
/// builder.set_allowed_origins(["https://example.com"]);
/// ```
///
/// By default allows any `Origin`.
Expand All @@ -400,21 +406,56 @@ impl Builder {
List: IntoIterator<Item = Origin>,
Origin: Into<String>,
{
let list: Arc<_> = list.into_iter().map(Into::into).collect();
let list: Box<_> = list.into_iter().map(Into::into).collect();

if list.len() == 0 {
return Err(Error::EmptyAllowedOrigins);
return Err(Error::EmptyAllowList("Origin"));
}

self.settings.allowed_origins = AllowedOrigins::OneOf(list);
self.settings.allowed_origins = AllowedValue::OneOf(list);

Ok(self)
}

/// Restores the default behavior of allowing connections with `Origin` header
/// containing any value. This will undo any list set by [`set_allowed_origins`](Builder::set_allowed_origins).
pub fn allow_all_origins(mut self) -> Self {
self.settings.allowed_origins = AllowedOrigins::Any;
self.settings.allowed_origins = AllowedValue::Any;
self
}

/// Set a list of allowed hosts. During the handshake, the `Host` header will be
/// checked against the list. Connections without a matching host will be denied.
/// Values should be hostnames without protocol.
///
/// ```rust
/// # let mut builder = jsonrpsee_ws_server::WsServerBuilder::default();
/// builder.set_allowed_hosts(["example.com"]);
/// ```
///
/// By default allows any `Host`.
///
/// Will return an error if `list` is empty. Use [`allow_all_hosts`](Builder::allow_all_hosts) to restore the default.
pub fn set_allowed_hosts<Host, List>(mut self, list: List) -> Result<Self, Error>
where
List: IntoIterator<Item = Host>,
Host: Into<String>,
{
let list: Box<_> = list.into_iter().map(Into::into).collect();

if list.len() == 0 {
return Err(Error::EmptyAllowList("Host"));
}

self.settings.allowed_hosts = AllowedValue::OneOf(list);

Ok(self)
}

/// Restores the default behavior of allowing connections with `Host` header
/// containing any value. This will undo any list set by [`set_allowed_hosts`](Builder::set_allowed_hosts).
pub fn allow_all_hosts(mut self) -> Self {
self.settings.allowed_hosts = AllowedValue::Any;
self
}

Expand Down

0 comments on commit f705e32

Please sign in to comment.