From 8b06300553406eafde40f44c9791f4433dc9a766 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 13:12:37 +0200 Subject: [PATCH 01/14] refactor: server host filtering Cleans up the host filtering by allowing `ipv6 addresses`, fixes a bug when host filtering with `*` is configured when a request is missing the default port. In addition the API on the server is stricter and hosts filtering with invalid authorities are now rejected which wasn't the case before. --- core/src/server/host_filtering/host.rs | 190 ++++++++++++++----------- core/src/server/host_filtering/mod.rs | 2 +- tests/tests/integration_tests.rs | 14 +- 3 files changed, 118 insertions(+), 88 deletions(-) diff --git a/core/src/server/host_filtering/host.rs b/core/src/server/host_filtering/host.rs index 893dddc169..840097100e 100644 --- a/core/src/server/host_filtering/host.rs +++ b/core/src/server/host_filtering/host.rs @@ -26,18 +26,20 @@ //! Host header validation. +use std::str::FromStr; + +use hyper::http::uri::InvalidUri; + use crate::server::host_filtering::matcher::{Matcher, Pattern}; use crate::Error; -const SPLIT_PROOF: &str = "split always returns non-empty iterator."; - /// Port pattern #[derive(Clone, Hash, PartialEq, Eq, Debug)] -pub enum Port { +enum Port { /// No port specified (default port) - None, - /// Port specified as a wildcard pattern - Pattern(String), + Default, + /// Port specified as a wildcard pattern (*). + Any, /// Fixed numeric port Fixed(u16), } @@ -46,7 +48,7 @@ impl From> for Port { fn from(opt: Option) -> Self { match opt { Some(port) => Port::Fixed(port), - None => Port::None, + None => Port::Default, } } } @@ -57,103 +59,95 @@ impl From for Port { } } -/// Host type +/// Represent the http URI scheme that is returned by the HTTP host header +/// +/// +/// +/// Further information can be found: https://www.rfc-editor.org/rfc/rfc7230#section-2.7.1 #[derive(Clone, Hash, PartialEq, Eq, Debug)] -pub struct Host { +struct Authority { hostname: String, port: Port, - host_with_port: String, - matcher: Matcher, -} - -impl> From for Host { - fn from(string: T) -> Self { - Host::parse(string.as_ref()) - } } -impl Host { - /// Creates a new `Host` given hostname and port number. - pub fn new>(hostname: &str, port: T) -> Self { - let port = port.into(); - let hostname = Self::pre_process(hostname); - let host_with_port = Self::from_str(&hostname, &port); - let matcher = Matcher::new(&host_with_port); +impl FromStr for Authority { + type Err = String; - Host { hostname, port, host_with_port, matcher } - } + fn from_str(s: &str) -> Result { + let uri: hyper::Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; + let authority = uri.authority().ok_or_else(|| "HTTP Host must contain authority".to_owned())?; + let hostname = authority.host(); + let maybe_port = &authority.as_str()[hostname.len()..]; - /// Attempts to parse given string as a `Host`. - /// NOTE: This method always succeeds and falls back to sensible defaults. - pub fn parse(hostname: &str) -> Self { - let hostname = Self::pre_process(hostname); - let mut hostname = hostname.split(':'); - let host = hostname.next().expect(SPLIT_PROOF); - let port = match hostname.next() { - None => Port::None, - Some(port) => match port.parse::().ok() { - Some(num) => Port::Fixed(num), - None => Port::Pattern(port.into()), - }, + let port = match maybe_port.split(":").nth(1) { + Some("*") => Port::Any, + Some(p) => { + let port_u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; + Port::Fixed(port_u16) + } + None => Port::Default, }; - Host::new(host, port) + Ok(Self { hostname: hostname.to_string(), port }) } +} + +/// Represents a whitelisted host/authority. +/// which contains a matcher to decide whether to +/// reject or accept a request. +#[derive(Clone, Hash, PartialEq, Eq, Debug)] +pub struct AllowHost { + authority: Authority, + matcher: Matcher, +} - fn pre_process(host: &str) -> String { - // Remove possible protocol definition - let mut it = host.split("://"); - let protocol = it.next().expect(SPLIT_PROOF); - let host = match it.next() { - Some(data) => data, - None => protocol, +impl AllowHost { + fn matches(&self, other: &Authority) -> bool { + let port_match = match (&self.authority.port, &other.port) { + (Port::Any, _) => true, + (Port::Default, Port::Default) => true, + (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true, + _ => false, }; - let mut it = host.split('/'); - it.next().expect(SPLIT_PROOF).to_lowercase() + port_match && self.matcher.matches(&other.hostname) } +} + +impl FromStr for AllowHost { + type Err = String; + + fn from_str(s: &str) -> Result { + let authority = Authority::from_str(s)?; + let matcher = Matcher::new(&authority.hostname); - fn from_str(hostname: &str, port: &Port) -> String { - format!( - "{}{}", - hostname, - match *port { - Port::Fixed(port) => format!(":{port}"), - Port::Pattern(ref port) => format!(":{port}"), - Port::None => "".into(), - }, - ) + Ok(Self { authority, matcher }) } } -impl Pattern for Host { +impl Pattern for AllowHost { fn matches>(&self, other: T) -> bool { self.matcher.matches(other) } } -impl std::ops::Deref for Host { - type Target = str; - - fn deref(&self) -> &Self::Target { - &self.host_with_port - } -} - /// Policy for validating the `HTTP host header`. #[derive(Debug, Clone)] pub enum AllowHosts { /// Allow all hosts (no filter). Any, /// Allow only specified hosts. - Only(Vec), + Only(Vec), } impl AllowHosts { /// Verify a host. pub fn verify(&self, value: &str) -> Result<(), Error> { + let authority = Authority::from_str(value) + .map_err(|_| Error::HttpHeaderRejected("host", format!("Invalid authority: {value}")))?; + if let AllowHosts::Only(list) = self { - if !list.iter().any(|o| o.matches(value)) { + if !list.iter().any(|o| o.matches(&authority)) { return Err(Error::HttpHeaderRejected("host", value.into())); } } @@ -164,19 +158,40 @@ impl AllowHosts { #[cfg(test)] mod tests { - use super::{AllowHosts, Host, Port}; + use super::{AllowHost, AllowHosts, Authority, Port}; + use std::str::FromStr; + + fn authority(host: &str, port: Port) -> Authority { + Authority { hostname: host.to_owned(), port } + } + + #[test] + fn should_parse_valid_authority() { + assert_eq!(Authority::from_str("http://parity.io").unwrap(), authority("parity.io", Port::Default)); + assert_eq!(Authority::from_str("https://parity.io:8443").unwrap(), authority("parity.io", Port::Fixed(8443))); + assert_eq!(Authority::from_str("chrome-extension://124.0.0.1").unwrap(), authority("124.0.0.1", Port::Default)); + assert_eq!(Authority::from_str("http://*.domain:*/somepath").unwrap(), authority("*.domain", Port::Any)); + assert_eq!(Authority::from_str("parity.io").unwrap(), authority("parity.io", Port::Default)); + assert_eq!( + Authority::from_str("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:9933/").unwrap(), + authority("[2001:db8:85a3:8d3:1319:8a2e:370:7348]", Port::Fixed(9933)) + ); + assert_eq!( + Authority::from_str("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]/").unwrap(), + authority("[2001:db8:85a3:8d3:1319:8a2e:370:7348]", Port::Default) + ); + assert_eq!( + Authority::from_str("https://user:password@example.com/tmp/foo").unwrap(), + authority("example.com", Port::Default) + ); + } #[test] - fn should_parse_host() { - assert_eq!(Host::parse("http://parity.io"), Host::new("parity.io", None)); - assert_eq!(Host::parse("https://parity.io:8443"), Host::new("parity.io", Some(8443))); - assert_eq!(Host::parse("chrome-extension://124.0.0.1"), Host::new("124.0.0.1", None)); - assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None)); - assert_eq!(Host::parse("127.0.0.1:8545/somepath"), Host::new("127.0.0.1", Some(8545))); - - let host = Host::parse("*.domain:*/somepath"); - assert_eq!(host.port, Port::Pattern("*".into())); - assert_eq!(host.hostname.as_str(), "*.domain"); + fn should_not_parse_invalid_authority() { + assert!(Authority::from_str("/foo/bar").is_err()); + assert!(Authority::from_str("user:password").is_err()); + assert!(Authority::from_str("parity.io/somepath").is_err()); + assert!(Authority::from_str("127.0.0.1:8545/somepath").is_err()); } #[test] @@ -191,17 +206,24 @@ mod tests { #[test] fn should_accept_if_on_the_list() { - assert!((AllowHosts::Only(vec!["parity.io".into()])).verify("parity.io").is_ok()); + assert!(AllowHosts::Only(vec![AllowHost::from_str("parity.io").unwrap()]).verify("parity.io").is_ok()); } #[test] fn should_accept_if_on_the_list_with_port() { - assert!((AllowHosts::Only(vec!["parity.io:443".into()])).verify("parity.io:443").is_ok()); - assert!((AllowHosts::Only(vec!["parity.io".into()])).verify("parity.io:443").is_err()); + assert!((AllowHosts::Only(vec![AllowHost::from_str("parity.io:443").unwrap()])) + .verify("parity.io:443") + .is_ok()); + assert!(AllowHosts::Only(vec![AllowHost::from_str("parity.io").unwrap()]).verify("parity.io:443").is_err()); } #[test] fn should_support_wildcards() { - assert!((AllowHosts::Only(vec!["*.web3.site:*".into()])).verify("parity.web3.site:8180").is_ok()); + assert!((AllowHosts::Only(vec![AllowHost::from_str("*.web3.site:*").unwrap()])) + .verify("parity.web3.site:8180") + .is_ok()); + assert!((AllowHosts::Only(vec![AllowHost::from_str("*.web3.site:*").unwrap()])) + .verify("parity.web3.site") + .is_ok()); } } diff --git a/core/src/server/host_filtering/mod.rs b/core/src/server/host_filtering/mod.rs index 4f0fffbccd..cc2a2b580f 100644 --- a/core/src/server/host_filtering/mod.rs +++ b/core/src/server/host_filtering/mod.rs @@ -3,4 +3,4 @@ mod host; mod matcher; -pub use host::AllowHosts; +pub use host::*; diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 91df2841ee..b510128772 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -29,6 +29,7 @@ mod helpers; +use std::str::FromStr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; @@ -1021,10 +1022,14 @@ async fn http_health_api_works() { #[tokio::test] async fn ws_host_filtering_wildcard_works() { use jsonrpsee::server::*; + use std::str::FromStr; init_logger(); - let acl = AllowHosts::Only(vec!["http://localhost:*".into(), "http://127.0.0.1:*".into()]); + let acl = AllowHosts::Only(vec![ + AllowHost::from_str("http://localhost:*").unwrap(), + AllowHost::from_str("http://127.0.0.1:*").unwrap(), + ]); let server = ServerBuilder::default().set_host_filtering(acl).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); @@ -1045,7 +1050,10 @@ async fn http_host_filtering_wildcard_works() { init_logger(); - let allowed_hosts = AllowHosts::Only(vec!["http://localhost:*".into(), "http://127.0.0.1:*".into()]); + let allowed_hosts = AllowHosts::Only(vec![ + AllowHost::from_str("http://localhost:*").unwrap(), + AllowHost::from_str("http://127.0.0.1:*").unwrap(), + ]); let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); @@ -1066,7 +1074,7 @@ async fn deny_invalid_host() { init_logger(); - let allowed_hosts = AllowHosts::Only(vec!["http://example.com".into()]); + let allowed_hosts = AllowHosts::Only(vec![AllowHost::from_str("http://example.com").unwrap()]); let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); From 856c005cb97bf14013ad66866d0c6ab45a96c16d Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 13:43:02 +0200 Subject: [PATCH 02/14] fix some nits --- core/Cargo.toml | 2 ++ core/src/server/host_filtering/host.rs | 9 +++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index cf36b80976..71b6e1c9be 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -35,6 +35,7 @@ tokio = { version = "1.16", optional = true } wasm-bindgen-futures = { version = "0.4.19", optional = true } futures-timer = { version = "3", optional = true } globset = { version = "0.4", optional = true } +http = { version = "0.2.9", optional = true } [features] default = [] @@ -49,6 +50,7 @@ server = [ "tokio/sync", "tokio/macros", "tokio/time", + "http", ] client = ["futures-util/sink", "tokio/sync"] async-client = [ diff --git a/core/src/server/host_filtering/host.rs b/core/src/server/host_filtering/host.rs index 840097100e..e45da55c6c 100644 --- a/core/src/server/host_filtering/host.rs +++ b/core/src/server/host_filtering/host.rs @@ -74,14 +74,15 @@ impl FromStr for Authority { type Err = String; fn from_str(s: &str) -> Result { - let uri: hyper::Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; + let uri: http::Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; let authority = uri.authority().ok_or_else(|| "HTTP Host must contain authority".to_owned())?; let hostname = authority.host(); let maybe_port = &authority.as_str()[hostname.len()..]; - let port = match maybe_port.split(":").nth(1) { - Some("*") => Port::Any, - Some(p) => { + // After the host segment, the authority may contain a port such as `fooo:33`, `foo:*` or `foo` + let port = match maybe_port.split_once(':') { + Some((_, "*")) => Port::Any, + Some((_, p)) => { let port_u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; Port::Fixed(port_u16) } From 19acee07cbc0d24473c5455e32296273944ab9c0 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 14:40:20 +0200 Subject: [PATCH 03/14] fix build again --- core/src/server/host_filtering/host.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/server/host_filtering/host.rs b/core/src/server/host_filtering/host.rs index e45da55c6c..528a30717d 100644 --- a/core/src/server/host_filtering/host.rs +++ b/core/src/server/host_filtering/host.rs @@ -28,10 +28,9 @@ use std::str::FromStr; -use hyper::http::uri::InvalidUri; - use crate::server::host_filtering::matcher::{Matcher, Pattern}; use crate::Error; +use http::uri::{InvalidUri, Uri}; /// Port pattern #[derive(Clone, Hash, PartialEq, Eq, Debug)] @@ -74,7 +73,7 @@ impl FromStr for Authority { type Err = String; fn from_str(s: &str) -> Result { - let uri: http::Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; + let uri: Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; let authority = uri.authority().ok_or_else(|| "HTTP Host must contain authority".to_owned())?; let hostname = authority.host(); let maybe_port = &authority.as_str()[hostname.len()..]; From e36d9029005c493bde6cc3a93e19a71297fe27a0 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 16:16:25 +0200 Subject: [PATCH 04/14] allow requests with/without default port --- core/src/server/host_filtering/host.rs | 27 +++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/core/src/server/host_filtering/host.rs b/core/src/server/host_filtering/host.rs index 528a30717d..deda40aa36 100644 --- a/core/src/server/host_filtering/host.rs +++ b/core/src/server/host_filtering/host.rs @@ -83,7 +83,12 @@ impl FromStr for Authority { Some((_, "*")) => Port::Any, Some((_, p)) => { let port_u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; - Port::Fixed(port_u16) + + // Omit default port to allow both requests with and without the default port. + match default_port(uri.scheme_str()) { + Some(p) if p == port_u16 => Port::Default, + _ => Port::Fixed(port_u16), + } } None => Port::Default, }; @@ -156,6 +161,15 @@ impl AllowHosts { } } +fn default_port(scheme: Option<&str>) -> Option { + match scheme { + Some("http") | Some("ws") => Some(80), + Some("https") | Some("wss") => Some(443), + Some("ftp") => Some(21), + _ => None, + } +} + #[cfg(test)] mod tests { use super::{AllowHost, AllowHosts, Authority, Port}; @@ -226,4 +240,15 @@ mod tests { .verify("parity.web3.site") .is_ok()); } + + #[test] + fn should_accept_with_and_without_default_port() { + assert!(AllowHosts::Only(vec![AllowHost::from_str("https://parity.io:443").unwrap()]) + .verify("https://parity.io") + .is_ok()); + + assert!(AllowHosts::Only(vec![AllowHost::from_str("https://parity.io").unwrap()]) + .verify("https://parity.io:443") + .is_ok()); + } } From 2b5754461b527e6623633b79c5d1f180de8b24b3 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 17:23:14 +0200 Subject: [PATCH 05/14] switch to `route_recognizer` for URL recognition --- core/Cargo.toml | 4 +- .../host.rs => host_filtering.rs} | 98 +++++++++---------- core/src/server/host_filtering/matcher.rs | 85 ---------------- core/src/server/host_filtering/mod.rs | 6 -- tests/tests/integration_tests.rs | 18 ++-- 5 files changed, 59 insertions(+), 152 deletions(-) rename core/src/server/{host_filtering/host.rs => host_filtering.rs} (74%) delete mode 100644 core/src/server/host_filtering/matcher.rs delete mode 100644 core/src/server/host_filtering/mod.rs diff --git a/core/Cargo.toml b/core/Cargo.toml index 71b6e1c9be..d1ac9d0d4d 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -34,7 +34,7 @@ parking_lot = { version = "0.12", optional = true } tokio = { version = "1.16", optional = true } wasm-bindgen-futures = { version = "0.4.19", optional = true } futures-timer = { version = "3", optional = true } -globset = { version = "0.4", optional = true } +route-recognizer = { version = "0.3.1", optional = true } http = { version = "0.2.9", optional = true } [features] @@ -42,7 +42,7 @@ default = [] http-helpers = ["hyper", "futures-util"] server = [ "futures-util/alloc", - "globset", + "route-recognizer", "rustc-hash/std", "parking_lot", "rand", diff --git a/core/src/server/host_filtering/host.rs b/core/src/server/host_filtering.rs similarity index 74% rename from core/src/server/host_filtering/host.rs rename to core/src/server/host_filtering.rs index deda40aa36..45d9657abf 100644 --- a/core/src/server/host_filtering/host.rs +++ b/core/src/server/host_filtering.rs @@ -24,17 +24,16 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. -//! Host header validation. +//! HTTP Host Header validation. -use std::str::FromStr; - -use crate::server::host_filtering::matcher::{Matcher, Pattern}; use crate::Error; use http::uri::{InvalidUri, Uri}; +use route_recognizer::Router; +use std::str::FromStr; /// Port pattern -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -enum Port { +#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] +pub enum Port { /// No port specified (default port) Default, /// Port specified as a wildcard pattern (*). @@ -64,7 +63,7 @@ impl From for Port { /// /// Further information can be found: https://www.rfc-editor.org/rfc/rfc7230#section-2.7.1 #[derive(Clone, Hash, PartialEq, Eq, Debug)] -struct Authority { +pub struct Authority { hostname: String, port: Port, } @@ -97,42 +96,39 @@ impl FromStr for Authority { } } -/// Represents a whitelisted host/authority. -/// which contains a matcher to decide whether to -/// reject or accept a request. -#[derive(Clone, Hash, PartialEq, Eq, Debug)] -pub struct AllowHost { - authority: Authority, - matcher: Matcher, -} - -impl AllowHost { - fn matches(&self, other: &Authority) -> bool { - let port_match = match (&self.authority.port, &other.port) { - (Port::Any, _) => true, - (Port::Default, Port::Default) => true, - (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true, - _ => false, - }; - - port_match && self.matcher.matches(&other.hostname) - } -} +/// Represent the URL patterns that is whitelisted. +#[derive(Default, Debug, Clone)] +pub struct UrlPattern(Router); -impl FromStr for AllowHost { - type Err = String; +impl From for UrlPattern +where + T: IntoIterator, +{ + fn from(value: T) -> Self { + let mut router = Router::new(); - fn from_str(s: &str) -> Result { - let authority = Authority::from_str(s)?; - let matcher = Matcher::new(&authority.hostname); + for auth in value.into_iter() { + router.add(&auth.hostname, auth.port); + } - Ok(Self { authority, matcher }) + Self(router) } } -impl Pattern for AllowHost { - fn matches>(&self, other: T) -> bool { - self.matcher.matches(other) +impl UrlPattern { + fn recognize(&self, other: &Authority) -> bool { + if let Ok(p) = self.0.recognize(&other.hostname) { + let p = p.handler(); + + match (p, &other.port) { + (Port::Any, _) => true, + (Port::Default, Port::Default) => true, + (Port::Fixed(p1), Port::Fixed(p2)) if p1 == p2 => true, + _ => false, + } + } else { + false + } } } @@ -142,17 +138,17 @@ pub enum AllowHosts { /// Allow all hosts (no filter). Any, /// Allow only specified hosts. - Only(Vec), + Only(UrlPattern), } impl AllowHosts { /// Verify a host. pub fn verify(&self, value: &str) -> Result<(), Error> { - let authority = Authority::from_str(value) + let auth = Authority::from_str(value) .map_err(|_| Error::HttpHeaderRejected("host", format!("Invalid authority: {value}")))?; - if let AllowHosts::Only(list) = self { - if !list.iter().any(|o| o.matches(&authority)) { + if let AllowHosts::Only(url_pat) = self { + if !url_pat.recognize(&auth) { return Err(Error::HttpHeaderRejected("host", value.into())); } } @@ -172,7 +168,7 @@ fn default_port(scheme: Option<&str>) -> Option { #[cfg(test)] mod tests { - use super::{AllowHost, AllowHosts, Authority, Port}; + use super::{AllowHosts, Authority, Port}; use std::str::FromStr; fn authority(host: &str, port: Port) -> Authority { @@ -215,39 +211,41 @@ mod tests { #[test] fn should_reject_if_header_not_on_the_list() { - assert!((AllowHosts::Only(vec![])).verify("parity.io").is_err()); + assert!((AllowHosts::Only(vec![].into())).verify("parity.io").is_err()); } #[test] fn should_accept_if_on_the_list() { - assert!(AllowHosts::Only(vec![AllowHost::from_str("parity.io").unwrap()]).verify("parity.io").is_ok()); + assert!(AllowHosts::Only(vec![Authority::from_str("parity.io").unwrap()].into()).verify("parity.io").is_ok()); } #[test] fn should_accept_if_on_the_list_with_port() { - assert!((AllowHosts::Only(vec![AllowHost::from_str("parity.io:443").unwrap()])) + assert!((AllowHosts::Only(vec![Authority::from_str("parity.io:443").unwrap()].into())) .verify("parity.io:443") .is_ok()); - assert!(AllowHosts::Only(vec![AllowHost::from_str("parity.io").unwrap()]).verify("parity.io:443").is_err()); + assert!(AllowHosts::Only(vec![Authority::from_str("parity.io").unwrap()].into()) + .verify("parity.io:443") + .is_err()); } #[test] fn should_support_wildcards() { - assert!((AllowHosts::Only(vec![AllowHost::from_str("*.web3.site:*").unwrap()])) + assert!((AllowHosts::Only(vec![Authority::from_str("*.web3.site:*").unwrap()].into())) .verify("parity.web3.site:8180") .is_ok()); - assert!((AllowHosts::Only(vec![AllowHost::from_str("*.web3.site:*").unwrap()])) + assert!((AllowHosts::Only(vec![Authority::from_str("*.web3.site:*").unwrap()].into())) .verify("parity.web3.site") .is_ok()); } #[test] fn should_accept_with_and_without_default_port() { - assert!(AllowHosts::Only(vec![AllowHost::from_str("https://parity.io:443").unwrap()]) + assert!(AllowHosts::Only(vec![Authority::from_str("https://parity.io:443").unwrap()].into()) .verify("https://parity.io") .is_ok()); - assert!(AllowHosts::Only(vec![AllowHost::from_str("https://parity.io").unwrap()]) + assert!(AllowHosts::Only(vec![Authority::from_str("https://parity.io").unwrap()].into()) .verify("https://parity.io:443") .is_ok()); } diff --git a/core/src/server/host_filtering/matcher.rs b/core/src/server/host_filtering/matcher.rs deleted file mode 100644 index 9479070ef1..0000000000 --- a/core/src/server/host_filtering/matcher.rs +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. -// -// Permission is hereby granted, free of charge, to any -// person obtaining a copy of this software and associated -// documentation files (the "Software"), to deal in the -// Software without restriction, including without -// limitation the rights to use, copy, modify, merge, -// publish, distribute, sublicense, and/or sell copies of -// the Software, and to permit persons to whom the Software -// is furnished to do so, subject to the following -// conditions: -// -// The above copyright notice and this permission notice -// shall be included in all copies or substantial portions -// of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF -// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED -// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A -// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT -// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY -// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR -// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -// DEALINGS IN THE SOFTWARE. - -use std::{fmt, hash}; - -use globset::{GlobBuilder, GlobMatcher}; -use tracing::warn; - -/// Pattern that can be matched to string. -pub(crate) trait Pattern { - /// Returns true if given string matches the pattern. - fn matches>(&self, other: T) -> bool; -} - -#[derive(Clone)] -pub(crate) struct Matcher(Option, String); - -impl Matcher { - pub(crate) fn new(string: &str) -> Matcher { - Matcher( - GlobBuilder::new(string) - .case_insensitive(true) - .build() - .map(|g| g.compile_matcher()) - .map_err(|e| warn!("Invalid glob pattern for {}: {:?}", string, e)) - .ok(), - string.into(), - ) - } -} - -impl Pattern for Matcher { - fn matches>(&self, other: T) -> bool { - let s = other.as_ref(); - match self.0 { - Some(ref matcher) => matcher.is_match(s), - None => self.1.eq_ignore_ascii_case(s), - } - } -} - -impl fmt::Debug for Matcher { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{:?} ({})", self.1, self.0.is_some()) - } -} - -impl hash::Hash for Matcher { - fn hash(&self, state: &mut H) - where - H: hash::Hasher, - { - self.1.hash(state) - } -} - -impl Eq for Matcher {} -impl PartialEq for Matcher { - fn eq(&self, other: &Matcher) -> bool { - self.1.eq(&other.1) - } -} diff --git a/core/src/server/host_filtering/mod.rs b/core/src/server/host_filtering/mod.rs deleted file mode 100644 index cc2a2b580f..0000000000 --- a/core/src/server/host_filtering/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Host filtering. - -mod host; -mod matcher; - -pub use host::*; diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index b510128772..c7cbee14c5 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -1026,10 +1026,10 @@ async fn ws_host_filtering_wildcard_works() { init_logger(); - let acl = AllowHosts::Only(vec![ - AllowHost::from_str("http://localhost:*").unwrap(), - AllowHost::from_str("http://127.0.0.1:*").unwrap(), - ]); + let acl = AllowHosts::Only( + vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()] + .into(), + ); let server = ServerBuilder::default().set_host_filtering(acl).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); @@ -1050,10 +1050,10 @@ async fn http_host_filtering_wildcard_works() { init_logger(); - let allowed_hosts = AllowHosts::Only(vec![ - AllowHost::from_str("http://localhost:*").unwrap(), - AllowHost::from_str("http://127.0.0.1:*").unwrap(), - ]); + let allowed_hosts = AllowHosts::Only( + vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()] + .into(), + ); let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); @@ -1074,7 +1074,7 @@ async fn deny_invalid_host() { init_logger(); - let allowed_hosts = AllowHosts::Only(vec![AllowHost::from_str("http://example.com").unwrap()]); + let allowed_hosts = AllowHosts::Only(vec![Authority::from_str("http://example.com").unwrap()].into()); let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); From c9e5a38cec2140773c8174c8818632137264a14c Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 17:27:57 +0200 Subject: [PATCH 06/14] remove weird From impl --- core/src/server/host_filtering.rs | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/core/src/server/host_filtering.rs b/core/src/server/host_filtering.rs index 45d9657abf..37974ef39d 100644 --- a/core/src/server/host_filtering.rs +++ b/core/src/server/host_filtering.rs @@ -42,15 +42,6 @@ pub enum Port { Fixed(u16), } -impl From> for Port { - fn from(opt: Option) -> Self { - match opt { - Some(port) => Port::Fixed(port), - None => Port::Default, - } - } -} - impl From for Port { fn from(port: u16) -> Port { Port::Fixed(port) @@ -81,12 +72,12 @@ impl FromStr for Authority { let port = match maybe_port.split_once(':') { Some((_, "*")) => Port::Any, Some((_, p)) => { - let port_u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; + let port_u16: u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; // Omit default port to allow both requests with and without the default port. match default_port(uri.scheme_str()) { Some(p) if p == port_u16 => Port::Default, - _ => Port::Fixed(port_u16), + _ => port_u16.into(), } } None => Port::Default, From b349074922cd8bcd04c8dcecbe858d3082f18d41 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 17:28:39 +0200 Subject: [PATCH 07/14] Update core/src/server/host_filtering.rs --- core/src/server/host_filtering.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/server/host_filtering.rs b/core/src/server/host_filtering.rs index 37974ef39d..0fdd2d342b 100644 --- a/core/src/server/host_filtering.rs +++ b/core/src/server/host_filtering.rs @@ -1,4 +1,4 @@ -// Copyright 2019-2021 Parity Technologies (UK) Ltd. +// Copyright 2019-2023 Parity Technologies (UK) Ltd. // // Permission is hereby granted, free of charge, to any // person obtaining a copy of this software and associated From b2fa27356d3f4a0b7821a98fc4f352081b811a58 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 17:29:24 +0200 Subject: [PATCH 08/14] Update tests/tests/integration_tests.rs --- tests/tests/integration_tests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index c7cbee14c5..d90c379a4f 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -1022,7 +1022,6 @@ async fn http_health_api_works() { #[tokio::test] async fn ws_host_filtering_wildcard_works() { use jsonrpsee::server::*; - use std::str::FromStr; init_logger(); From ab82820c7d4da5e7f23f9f3b9d33371e79fbf6ad Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 18:00:41 +0200 Subject: [PATCH 09/14] refactor host filter API --- examples/examples/cors_server.rs | 4 ++-- server/src/server.rs | 18 ++++++++++++++---- tests/tests/helpers.rs | 24 ++++++++++++++---------- tests/tests/integration_tests.rs | 27 ++++++++++++--------------- 4 files changed, 42 insertions(+), 31 deletions(-) diff --git a/examples/examples/cors_server.rs b/examples/examples/cors_server.rs index 633949d326..04a91432bf 100644 --- a/examples/examples/cors_server.rs +++ b/examples/examples/cors_server.rs @@ -28,7 +28,7 @@ //! with access control allowing requests from all hosts. use hyper::Method; -use jsonrpsee::server::{AllowHosts, RpcModule, Server}; +use jsonrpsee::server::{RpcModule, Server}; use std::net::SocketAddr; use tower_http::cors::{Any, CorsLayer}; @@ -86,7 +86,7 @@ async fn run_server() -> anyhow::Result { // and can also be used separately. // In this example, we use both features. let server = Server::builder() - .set_host_filtering(AllowHosts::Any) + .disable_host_filtering() .set_middleware(middleware) .build("127.0.0.1:0".parse::()?) .await?; diff --git a/server/src/server.rs b/server/src/server.rs index 8411b44ff6..ac699866d3 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -42,7 +42,7 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::{AllowHosts, Methods}; +use jsonrpsee_core::server::{AllowHosts, Authority, Methods, UrlPattern}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; @@ -419,9 +419,19 @@ impl Builder { self } - /// Sets host filtering. - pub fn set_host_filtering(mut self, allow: AllowHosts) -> Self { - self.settings.allow_hosts = allow; + /// Enables host filtering and allow only the specified hosts. + /// + /// Default: allow all. + pub fn host_filter>(mut self, allow_only: T) -> Self { + self.settings.allow_hosts = AllowHosts::Only(UrlPattern::from(allow_only.into_iter())); + self + } + + /// Disable host filtering and allow all. + /// + /// Default: allow all. + pub fn disable_host_filtering(mut self) -> Self { + self.settings.allow_hosts = AllowHosts::Any; self } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 4345ef029e..9f0c6fc520 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -33,10 +33,10 @@ use futures::{SinkExt, Stream, StreamExt}; use jsonrpsee::core::Error; use jsonrpsee::server::middleware::proxy_get_request::ProxyGetRequestLayer; use jsonrpsee::server::{ - AllowHosts, PendingSubscriptionSink, RpcModule, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, + PendingSubscriptionSink, RpcModule, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; -use jsonrpsee::SubscriptionCloseResponse; +use jsonrpsee::{Authority, SubscriptionCloseResponse}; use serde::Serialize; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; @@ -195,22 +195,26 @@ pub async fn server_with_sleeping_subscription(tx: futures::channel::mpsc::Sende #[allow(dead_code)] pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) { - server_with_access_control(AllowHosts::Any, CorsLayer::new()).await + server_with_access_control(None, CorsLayer::new()).await } -pub async fn server_with_access_control(allowed_hosts: AllowHosts, cors: CorsLayer) -> (SocketAddr, ServerHandle) { +pub async fn server_with_access_control( + allowed_hosts: Option>, + cors: CorsLayer, +) -> (SocketAddr, ServerHandle) { let middleware = tower::ServiceBuilder::new() // Proxy `GET /health` requests to internal `system_health` method. .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap()) // Add `CORS` layer. .layer(cors); - let server = ServerBuilder::default() - .set_host_filtering(allowed_hosts) - .set_middleware(middleware) - .build("127.0.0.1:0") - .await - .unwrap(); + let mut builder = jsonrpsee::server::Server::builder(); + + if let Some(filter) = allowed_hosts { + builder = builder.host_filter(filter) + } + + let server = builder.set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index c7cbee14c5..aed6a15493 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -915,7 +915,6 @@ async fn http_correct_content_type_required() { #[tokio::test] async fn http_cors_preflight_works() { use hyper::{Body, Client, Method, Request}; - use jsonrpsee::server::AllowHosts; init_logger(); @@ -923,7 +922,7 @@ async fn http_cors_preflight_works() { .allow_methods([Method::POST]) .allow_origin("https://foo.com".parse::().unwrap()) .allow_headers([hyper::header::CONTENT_TYPE]); - let (server_addr, _handle) = server_with_access_control(AllowHosts::Any, cors).await; + let (server_addr, _handle) = server_with_access_control(None, cors).await; let http_client = Client::new(); let uri = format!("http://{}", server_addr); @@ -1026,12 +1025,10 @@ async fn ws_host_filtering_wildcard_works() { init_logger(); - let acl = AllowHosts::Only( - vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()] - .into(), - ); + let whitelist = + vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()]; - let server = ServerBuilder::default().set_host_filtering(acl).build("127.0.0.1:0").await.unwrap(); + let server = ServerBuilder::default().host_filter(whitelist).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1050,12 +1047,10 @@ async fn http_host_filtering_wildcard_works() { init_logger(); - let allowed_hosts = AllowHosts::Only( - vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()] - .into(), - ); + let allowed_hosts = + vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()]; - let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); + let server = ServerBuilder::default().host_filter(allowed_hosts).build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1074,9 +1069,11 @@ async fn deny_invalid_host() { init_logger(); - let allowed_hosts = AllowHosts::Only(vec![Authority::from_str("http://example.com").unwrap()].into()); - - let server = ServerBuilder::default().set_host_filtering(allowed_hosts).build("127.0.0.1:0").await.unwrap(); + let server = ServerBuilder::default() + .host_filter(vec![Authority::from_str("http://example.com").unwrap()]) + .build("127.0.0.1:0") + .await + .unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); From 9a323338f47de7746b2a029ccebb9af1a1a37b37 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 19:38:21 +0200 Subject: [PATCH 10/14] address grumbles --- core/src/server/host_filtering.rs | 35 +++++++++++++++++++++---------- server/src/server.rs | 13 ++++++++---- tests/tests/helpers.rs | 6 +++--- tests/tests/integration_tests.rs | 28 ++++++++++++------------- 4 files changed, 50 insertions(+), 32 deletions(-) diff --git a/core/src/server/host_filtering.rs b/core/src/server/host_filtering.rs index 0fdd2d342b..d2c6b7188e 100644 --- a/core/src/server/host_filtering.rs +++ b/core/src/server/host_filtering.rs @@ -50,21 +50,33 @@ impl From for Port { /// Represent the http URI scheme that is returned by the HTTP host header /// -/// -/// -/// Further information can be found: https://www.rfc-editor.org/rfc/rfc7230#section-2.7.1 +/// Further information can be found: #[derive(Clone, Hash, PartialEq, Eq, Debug)] pub struct Authority { hostname: String, port: Port, } +/// Error that can happen when parsing an URI authority fails. +#[derive(Debug, thiserror::Error)] +pub enum AuthorityError { + /// Invalid URI. + #[error("{0}")] + InvalidUri(InvalidUri), + /// Invalid port. + #[error("{0}")] + InvalidPort(String), + /// The host was not found. + #[error("The host was not found")] + MissingHost, +} + impl FromStr for Authority { - type Err = String; + type Err = AuthorityError; fn from_str(s: &str) -> Result { - let uri: Uri = s.parse().map_err(|e: InvalidUri| e.to_string())?; - let authority = uri.authority().ok_or_else(|| "HTTP Host must contain authority".to_owned())?; + let uri: Uri = s.parse().map_err(|e: InvalidUri| AuthorityError::InvalidUri(e))?; + let authority = uri.authority().ok_or_else(|| AuthorityError::MissingHost)?; let hostname = authority.host(); let maybe_port = &authority.as_str()[hostname.len()..]; @@ -72,7 +84,8 @@ impl FromStr for Authority { let port = match maybe_port.split_once(':') { Some((_, "*")) => Port::Any, Some((_, p)) => { - let port_u16: u16 = p.parse().map_err(|e: std::num::ParseIntError| e.to_string())?; + let port_u16: u16 = + p.parse().map_err(|e: std::num::ParseIntError| AuthorityError::InvalidPort(e.to_string()))?; // Omit default port to allow both requests with and without the default port. match default_port(uri.scheme_str()) { @@ -89,9 +102,9 @@ impl FromStr for Authority { /// Represent the URL patterns that is whitelisted. #[derive(Default, Debug, Clone)] -pub struct UrlPattern(Router); +pub struct WhitelistedHosts(Router); -impl From for UrlPattern +impl From for WhitelistedHosts where T: IntoIterator, { @@ -106,7 +119,7 @@ where } } -impl UrlPattern { +impl WhitelistedHosts { fn recognize(&self, other: &Authority) -> bool { if let Ok(p) = self.0.recognize(&other.hostname) { let p = p.handler(); @@ -129,7 +142,7 @@ pub enum AllowHosts { /// Allow all hosts (no filter). Any, /// Allow only specified hosts. - Only(UrlPattern), + Only(WhitelistedHosts), } impl AllowHosts { diff --git a/server/src/server.rs b/server/src/server.rs index ac699866d3..a002f5bcb2 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -28,6 +28,7 @@ use std::error::Error as StdError; use std::future::Future; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::pin::Pin; +use std::str::FromStr; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -42,7 +43,7 @@ use futures_util::io::{BufReader, BufWriter}; use hyper::body::HttpBody; use jsonrpsee_core::id_providers::RandomIntegerIdProvider; -use jsonrpsee_core::server::{AllowHosts, Authority, Methods, UrlPattern}; +use jsonrpsee_core::server::{AllowHosts, Authority, AuthorityError, Methods, WhitelistedHosts}; use jsonrpsee_core::traits::IdProvider; use jsonrpsee_core::{http_helpers, Error, TEN_MB_SIZE_BYTES}; @@ -422,9 +423,13 @@ impl Builder { /// Enables host filtering and allow only the specified hosts. /// /// Default: allow all. - pub fn host_filter>(mut self, allow_only: T) -> Self { - self.settings.allow_hosts = AllowHosts::Only(UrlPattern::from(allow_only.into_iter())); - self + pub fn host_filter, U: AsRef>( + mut self, + allow_only: T, + ) -> Result { + let allow_only: Result, _> = allow_only.into_iter().map(|a| Authority::from_str(a.as_ref())).collect(); + self.settings.allow_hosts = AllowHosts::Only(WhitelistedHosts::from(allow_only?)); + Ok(self) } /// Disable host filtering and allow all. diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 9f0c6fc520..26f9e48cf0 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -36,7 +36,7 @@ use jsonrpsee::server::{ PendingSubscriptionSink, RpcModule, ServerBuilder, ServerHandle, SubscriptionMessage, TrySendError, }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; -use jsonrpsee::{Authority, SubscriptionCloseResponse}; +use jsonrpsee::SubscriptionCloseResponse; use serde::Serialize; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; @@ -199,7 +199,7 @@ pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) { } pub async fn server_with_access_control( - allowed_hosts: Option>, + allowed_hosts: Option>, cors: CorsLayer, ) -> (SocketAddr, ServerHandle) { let middleware = tower::ServiceBuilder::new() @@ -211,7 +211,7 @@ pub async fn server_with_access_control( let mut builder = jsonrpsee::server::Server::builder(); if let Some(filter) = allowed_hosts { - builder = builder.host_filter(filter) + builder = builder.host_filter(filter).unwrap(); } let server = builder.set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 4b8b9470eb..df6f30430d 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -29,7 +29,6 @@ mod helpers; -use std::str::FromStr; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; @@ -1024,10 +1023,12 @@ async fn ws_host_filtering_wildcard_works() { init_logger(); - let whitelist = - vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()]; - - let server = ServerBuilder::default().host_filter(whitelist).build("127.0.0.1:0").await.unwrap(); + let server = ServerBuilder::default() + .host_filter(["http://localhost:*", "http://127.0.0.1:*"]) + .unwrap() + .build("127.0.0.1:0") + .await + .unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1046,10 +1047,12 @@ async fn http_host_filtering_wildcard_works() { init_logger(); - let allowed_hosts = - vec![Authority::from_str("http://localhost:*").unwrap(), Authority::from_str("http://127.0.0.1:*").unwrap()]; - - let server = ServerBuilder::default().host_filter(allowed_hosts).build("127.0.0.1:0").await.unwrap(); + let server = ServerBuilder::default() + .host_filter(vec!["http://localhost:*", "http://127.0.0.1:*"]) + .unwrap() + .build("127.0.0.1:0") + .await + .unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); @@ -1068,11 +1071,8 @@ async fn deny_invalid_host() { init_logger(); - let server = ServerBuilder::default() - .host_filter(vec![Authority::from_str("http://example.com").unwrap()]) - .build("127.0.0.1:0") - .await - .unwrap(); + let server = + ServerBuilder::default().host_filter(["http://example.com"]).unwrap().build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); let addr = server.local_addr().unwrap(); module.register_method("say_hello", |_, _| "hello").unwrap(); From b9ccf12ddd9ca5b7318283dda46bd388300ccd94 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 8 Aug 2023 19:43:42 +0200 Subject: [PATCH 11/14] fix clippy --- core/src/server/host_filtering.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/server/host_filtering.rs b/core/src/server/host_filtering.rs index d2c6b7188e..2b9f1862ba 100644 --- a/core/src/server/host_filtering.rs +++ b/core/src/server/host_filtering.rs @@ -75,8 +75,8 @@ impl FromStr for Authority { type Err = AuthorityError; fn from_str(s: &str) -> Result { - let uri: Uri = s.parse().map_err(|e: InvalidUri| AuthorityError::InvalidUri(e))?; - let authority = uri.authority().ok_or_else(|| AuthorityError::MissingHost)?; + let uri: Uri = s.parse().map_err(AuthorityError::InvalidUri)?; + let authority = uri.authority().ok_or(AuthorityError::MissingHost)?; let hostname = authority.host(); let maybe_port = &authority.as_str()[hostname.len()..]; From 14f15aa54fe7838ae07473d9adcbf472b14cbc3b Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 9 Aug 2023 10:18:29 +0200 Subject: [PATCH 12/14] host filter: switch to TryFrom --- core/src/server/host_filtering.rs | 105 ++++++++++++++++++------------ server/src/server.rs | 11 ++-- tests/tests/integration_tests.rs | 2 +- 3 files changed, 72 insertions(+), 46 deletions(-) diff --git a/core/src/server/host_filtering.rs b/core/src/server/host_filtering.rs index 2b9f1862ba..0270196c14 100644 --- a/core/src/server/host_filtering.rs +++ b/core/src/server/host_filtering.rs @@ -26,10 +26,11 @@ //! HTTP Host Header validation. +use std::net::SocketAddr; + use crate::Error; use http::uri::{InvalidUri, Uri}; use route_recognizer::Router; -use std::str::FromStr; /// Port pattern #[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)] @@ -57,25 +58,9 @@ pub struct Authority { port: Port, } -/// Error that can happen when parsing an URI authority fails. -#[derive(Debug, thiserror::Error)] -pub enum AuthorityError { - /// Invalid URI. - #[error("{0}")] - InvalidUri(InvalidUri), - /// Invalid port. - #[error("{0}")] - InvalidPort(String), - /// The host was not found. - #[error("The host was not found")] - MissingHost, -} - -impl FromStr for Authority { - type Err = AuthorityError; - - fn from_str(s: &str) -> Result { - let uri: Uri = s.parse().map_err(AuthorityError::InvalidUri)?; +impl Authority { + fn inner_from_str(value: &str) -> Result { + let uri: Uri = value.parse().map_err(AuthorityError::InvalidUri)?; let authority = uri.authority().ok_or(AuthorityError::MissingHost)?; let hostname = authority.host(); let maybe_port = &authority.as_str()[hostname.len()..]; @@ -100,6 +85,44 @@ impl FromStr for Authority { } } +/// Error that can happen when parsing an URI authority fails. +#[derive(Debug, thiserror::Error)] +pub enum AuthorityError { + /// Invalid URI. + #[error("{0}")] + InvalidUri(InvalidUri), + /// Invalid port. + #[error("{0}")] + InvalidPort(String), + /// The host was not found. + #[error("The host was not found")] + MissingHost, +} + +impl<'a> TryFrom<&'a str> for Authority { + type Error = AuthorityError; + + fn try_from(value: &'a str) -> Result { + Self::inner_from_str(value) + } +} + +impl TryFrom for Authority { + type Error = AuthorityError; + + fn try_from(value: String) -> Result { + Self::inner_from_str(&value) + } +} + +impl TryFrom for Authority { + type Error = AuthorityError; + + fn try_from(sockaddr: SocketAddr) -> Result { + Self::inner_from_str(&sockaddr.to_string()) + } +} + /// Represent the URL patterns that is whitelisted. #[derive(Default, Debug, Clone)] pub struct WhitelistedHosts(Router); @@ -148,7 +171,7 @@ pub enum AllowHosts { impl AllowHosts { /// Verify a host. pub fn verify(&self, value: &str) -> Result<(), Error> { - let auth = Authority::from_str(value) + let auth = Authority::try_from(value) .map_err(|_| Error::HttpHeaderRejected("host", format!("Invalid authority: {value}")))?; if let AllowHosts::Only(url_pat) = self { @@ -173,7 +196,6 @@ fn default_port(scheme: Option<&str>) -> Option { #[cfg(test)] mod tests { use super::{AllowHosts, Authority, Port}; - use std::str::FromStr; fn authority(host: &str, port: Port) -> Authority { Authority { hostname: host.to_owned(), port } @@ -181,31 +203,32 @@ mod tests { #[test] fn should_parse_valid_authority() { - assert_eq!(Authority::from_str("http://parity.io").unwrap(), authority("parity.io", Port::Default)); - assert_eq!(Authority::from_str("https://parity.io:8443").unwrap(), authority("parity.io", Port::Fixed(8443))); - assert_eq!(Authority::from_str("chrome-extension://124.0.0.1").unwrap(), authority("124.0.0.1", Port::Default)); - assert_eq!(Authority::from_str("http://*.domain:*/somepath").unwrap(), authority("*.domain", Port::Any)); - assert_eq!(Authority::from_str("parity.io").unwrap(), authority("parity.io", Port::Default)); + assert_eq!(Authority::try_from("http://parity.io").unwrap(), authority("parity.io", Port::Default)); + assert_eq!(Authority::try_from("https://parity.io:8443").unwrap(), authority("parity.io", Port::Fixed(8443))); + assert_eq!(Authority::try_from("chrome-extension://124.0.0.1").unwrap(), authority("124.0.0.1", Port::Default)); + assert_eq!(Authority::try_from("http://*.domain:*/somepath").unwrap(), authority("*.domain", Port::Any)); + assert_eq!(Authority::try_from("parity.io").unwrap(), authority("parity.io", Port::Default)); + assert_eq!(Authority::try_from("127.0.0.1:8845").unwrap(), authority("127.0.0.1", Port::Fixed(8845))); assert_eq!( - Authority::from_str("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:9933/").unwrap(), + Authority::try_from("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]:9933/").unwrap(), authority("[2001:db8:85a3:8d3:1319:8a2e:370:7348]", Port::Fixed(9933)) ); assert_eq!( - Authority::from_str("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]/").unwrap(), + Authority::try_from("http://[2001:db8:85a3:8d3:1319:8a2e:370:7348]/").unwrap(), authority("[2001:db8:85a3:8d3:1319:8a2e:370:7348]", Port::Default) ); assert_eq!( - Authority::from_str("https://user:password@example.com/tmp/foo").unwrap(), + Authority::try_from("https://user:password@example.com/tmp/foo").unwrap(), authority("example.com", Port::Default) ); } #[test] fn should_not_parse_invalid_authority() { - assert!(Authority::from_str("/foo/bar").is_err()); - assert!(Authority::from_str("user:password").is_err()); - assert!(Authority::from_str("parity.io/somepath").is_err()); - assert!(Authority::from_str("127.0.0.1:8545/somepath").is_err()); + assert!(Authority::try_from("/foo/bar").is_err()); + assert!(Authority::try_from("user:password").is_err()); + assert!(Authority::try_from("parity.io/somepath").is_err()); + assert!(Authority::try_from("127.0.0.1:8545/somepath").is_err()); } #[test] @@ -220,36 +243,36 @@ mod tests { #[test] fn should_accept_if_on_the_list() { - assert!(AllowHosts::Only(vec![Authority::from_str("parity.io").unwrap()].into()).verify("parity.io").is_ok()); + assert!(AllowHosts::Only(vec![Authority::try_from("parity.io").unwrap()].into()).verify("parity.io").is_ok()); } #[test] fn should_accept_if_on_the_list_with_port() { - assert!((AllowHosts::Only(vec![Authority::from_str("parity.io:443").unwrap()].into())) + assert!((AllowHosts::Only(vec![Authority::try_from("parity.io:443").unwrap()].into())) .verify("parity.io:443") .is_ok()); - assert!(AllowHosts::Only(vec![Authority::from_str("parity.io").unwrap()].into()) + assert!(AllowHosts::Only(vec![Authority::try_from("parity.io").unwrap()].into()) .verify("parity.io:443") .is_err()); } #[test] fn should_support_wildcards() { - assert!((AllowHosts::Only(vec![Authority::from_str("*.web3.site:*").unwrap()].into())) + assert!((AllowHosts::Only(vec![Authority::try_from("*.web3.site:*").unwrap()].into())) .verify("parity.web3.site:8180") .is_ok()); - assert!((AllowHosts::Only(vec![Authority::from_str("*.web3.site:*").unwrap()].into())) + assert!((AllowHosts::Only(vec![Authority::try_from("*.web3.site:*").unwrap()].into())) .verify("parity.web3.site") .is_ok()); } #[test] fn should_accept_with_and_without_default_port() { - assert!(AllowHosts::Only(vec![Authority::from_str("https://parity.io:443").unwrap()].into()) + assert!(AllowHosts::Only(vec![Authority::try_from("https://parity.io:443").unwrap()].into()) .verify("https://parity.io") .is_ok()); - assert!(AllowHosts::Only(vec![Authority::from_str("https://parity.io").unwrap()].into()) + assert!(AllowHosts::Only(vec![Authority::try_from("https://parity.io").unwrap()].into()) .verify("https://parity.io:443") .is_ok()); } diff --git a/server/src/server.rs b/server/src/server.rs index a002f5bcb2..11a505d6a8 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -28,7 +28,6 @@ use std::error::Error as StdError; use std::future::Future; use std::net::{SocketAddr, TcpListener as StdTcpListener}; use std::pin::Pin; -use std::str::FromStr; use std::sync::Arc; use std::task::{Context, Poll}; use std::time::Duration; @@ -423,11 +422,15 @@ impl Builder { /// Enables host filtering and allow only the specified hosts. /// /// Default: allow all. - pub fn host_filter, U: AsRef>( + pub fn host_filter, U: TryInto>( mut self, allow_only: T, - ) -> Result { - let allow_only: Result, _> = allow_only.into_iter().map(|a| Authority::from_str(a.as_ref())).collect(); + ) -> Result + where + T: IntoIterator, + U: TryInto, + { + let allow_only: Result, _> = allow_only.into_iter().map(|a| a.try_into()).collect(); self.settings.allow_hosts = AllowHosts::Only(WhitelistedHosts::from(allow_only?)); Ok(self) } diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index df6f30430d..855a349ac4 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -1024,7 +1024,7 @@ async fn ws_host_filtering_wildcard_works() { init_logger(); let server = ServerBuilder::default() - .host_filter(["http://localhost:*", "http://127.0.0.1:*"]) + .host_filter(["http://localhost:*".to_string(), "http://127.0.0.1:*".to_string()]) .unwrap() .build("127.0.0.1:0") .await From 809809c2635510aef28d5cef1c6d35d70a938774 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 9 Aug 2023 14:10:33 +0200 Subject: [PATCH 13/14] Update server/src/server.rs --- server/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/server.rs b/server/src/server.rs index 11a505d6a8..0fd5236a94 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -421,7 +421,7 @@ impl Builder { /// Enables host filtering and allow only the specified hosts. /// - /// Default: allow all. + /// Default: no host filtering is enabled. pub fn host_filter, U: TryInto>( mut self, allow_only: T, From 500b3f33cd6ecc01af3fbc9aa51d320989ee17a5 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 9 Aug 2023 14:11:12 +0200 Subject: [PATCH 14/14] Update server/src/server.rs --- server/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/server.rs b/server/src/server.rs index 0fd5236a94..85135a531c 100644 --- a/server/src/server.rs +++ b/server/src/server.rs @@ -437,7 +437,7 @@ impl Builder { /// Disable host filtering and allow all. /// - /// Default: allow all. + /// Default: no host filtering is enabled. pub fn disable_host_filtering(mut self) -> Self { self.settings.allow_hosts = AllowHosts::Any; self