From aa62ac821c7efed75d34e3b665c6f1e6b34df112 Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Wed, 11 Oct 2023 10:59:03 +0200 Subject: [PATCH] feat(middleware): add `HostFilterLater::disable` --- server/src/middleware/host_filter.rs | 34 ++++++++++++++++++++++++---- tests/tests/integration_tests.rs | 29 ++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/server/src/middleware/host_filter.rs b/server/src/middleware/host_filter.rs index 417582778a..1c871bb0c0 100644 --- a/server/src/middleware/host_filter.rs +++ b/server/src/middleware/host_filter.rs @@ -39,7 +39,7 @@ use tower::{Layer, Service}; /// Middleware to enable host filtering. #[derive(Debug)] -pub struct HostFilterLayer(Arc); +pub struct HostFilterLayer(Option>); impl HostFilterLayer { /// Enables host filtering and allow only the specified hosts. @@ -49,7 +49,33 @@ impl HostFilterLayer { U: TryInto, { let allow_only: Result, _> = allow_only.into_iter().map(|a| a.try_into()).collect(); - Ok(Self(Arc::new(WhitelistedHosts::from(allow_only?)))) + Ok(Self(Some(Arc::new(WhitelistedHosts::from(allow_only?))))) + } + + /// Convenience method to disable host filtering but less efficient + /// than to not enable the middleware at all. + /// + /// Because is the `tower middleware` returns a different type + /// depending on which Layers are configured it and may not compile + /// in some contexts. + /// + /// For example the following won't compile: + /// + /// ```ignore + /// use jsonrpsee_server::middleware::{ProxyGetRequestLayer, HostFilterLayer}; + /// + /// let host_filter = false; + /// + /// let middleware = if host_filter { + /// tower::ServiceBuilder::new() + /// .layer(HostFilterLayer::new(["example.com"]).unwrap()) + /// .layer(ProxyGetRequestLayer::new("/health", "system_health").unwrap()) + /// } else { + /// tower::ServiceBuilder::new() + /// }; + /// ``` + pub fn disable() -> Self { + Self(None) } } @@ -65,7 +91,7 @@ impl Layer for HostFilterLayer { #[derive(Debug)] pub struct HostFilter { inner: S, - filter: Arc, + filter: Option>, } impl Service> for HostFilter @@ -88,7 +114,7 @@ where return async { Ok(http::response::malformed()) }.boxed(); }; - if self.filter.recognize(&authority) { + if self.filter.as_ref().map_or(true, |f| f.recognize(&authority)) { Box::pin(self.inner.call(request).map_err(Into::into)) } else { tracing::debug!("Denied request: {:?}", request); diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 80235f968c..6cd2351f12 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -1078,6 +1078,35 @@ async fn deny_invalid_host() { } } +#[tokio::test] +async fn disable_host_filter_works() { + use jsonrpsee::server::*; + + init_logger(); + + let middleware = tower::ServiceBuilder::new().layer(HostFilterLayer::disable()); + + let server = Server::builder().set_middleware(middleware).build("127.0.0.1:0").await.unwrap(); + let mut module = RpcModule::new(()); + let addr = server.local_addr().unwrap(); + module.register_method("say_hello", |_, _| "hello").unwrap(); + + let _handle = server.start(module); + + // HTTP + { + let server_url = format!("http://{}", addr); + let client = HttpClientBuilder::default().build(&server_url).unwrap(); + assert!(client.request::("say_hello", rpc_params![]).await.is_ok()); + } + + // WebSocket + { + let server_url = format!("ws://{}", addr); + assert!(WsClientBuilder::default().build(&server_url).await.is_ok()); + } +} + #[tokio::test] async fn subscription_option_err_is_not_sent() { init_logger();