From 4ee65c4424274ccd1367a907c1bc94cfa77aa54b Mon Sep 17 00:00:00 2001 From: Leonardo Lima Date: Wed, 2 Aug 2023 23:31:46 -0300 Subject: [PATCH 1/5] refactor(jsonrpsee-client-transport): update and turn `WsTransportClientBuilder` generic - fix(docs): typo on `WsTransportClientBuilder` doc on `Receiver` reference. - refactor: add initial fn signatures for `build_with_stream`, `try_connect_over_tcp`, and `try_connect`. - refactor: expose `EitherStream` visibility to public. - refactor: make `Sender` and `Receiver` generic over T, a data stream. - refactor: make `TransportSenderT` and `TransportReceiverT` implementations over generic `Sender` and `Receiver`, bound to `AsyncRead`, `AsyncRead`, `MaybeSend` and `'static`. - refactor: turn old `try_connect` TCP steps into `try_connect_over_tcp`. - feat: implement `build_with_stream` and `try_connect` to handle and handle the handshake for a generic data stream `T`. - feat: add new `Redirected` error variant to `WsHandshakeError`, as it should be handled by the client when using a generic data stream `T`. - TODO(@oleonardolima): Add new tests that uses a different data stream. --- client/transport/src/ws/mod.rs | 138 ++++++++++++++++++++---------- client/transport/src/ws/stream.rs | 2 +- 2 files changed, 96 insertions(+), 44 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index 0296ca1228..bf973ee173 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -31,7 +31,8 @@ use std::net::SocketAddr; use std::time::Duration; use futures_util::io::{BufReader, BufWriter}; -use jsonrpsee_core::client::{CertificateStore, ReceivedMessage, TransportReceiverT, TransportSenderT}; +use futures_util::{AsyncRead, AsyncWrite}; +use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; use jsonrpsee_core::{async_trait, Cow}; use soketto::connection::Error::Utf8; @@ -48,18 +49,18 @@ pub use url::Url; /// Sending end of WebSocket transport. #[derive(Debug)] -pub struct Sender { - inner: connection::Sender>>, +pub struct Sender { + inner: connection::Sender>>, max_request_size: u32, } /// Receiving end of WebSocket transport. #[derive(Debug)] -pub struct Receiver { - inner: connection::Receiver>>, +pub struct Receiver { + inner: connection::Receiver>>, } -/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair. +/// Builder for a WebSocket transport [`Sender`] and [`Receiver`] pair. #[derive(Debug)] pub struct WsTransportClientBuilder { /// What certificate store to use @@ -190,6 +191,15 @@ pub enum WsHandshakeError { status_code: u16, }, + /// Server redirected to other location. + #[error("Connection redirected with status code: {status_code} and location: {location}")] + Redirected { + /// HTTP status code that the server returned. + status_code: u16, + /// The location URL redirected to. + location: String, + }, + /// Timeout while trying to connect. #[error("Connection timeout exceeded: {0:?}")] Timeout(Duration), @@ -215,7 +225,10 @@ pub enum WsError { } #[async_trait] -impl TransportSenderT for Sender { +impl TransportSenderT for Sender +where + T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, +{ type Error = WsError; /// Sends out a request. Returns a `Future` that finishes when the request has been @@ -252,7 +265,10 @@ impl TransportSenderT for Sender { } #[async_trait] -impl TransportReceiverT for Receiver { +impl TransportReceiverT for Receiver +where + T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, +{ type Error = WsError; /// Returns a `Future` resolving when the server sent us something back. @@ -276,12 +292,18 @@ impl TransportReceiverT for Receiver { impl WsTransportClientBuilder { /// Try to establish the connection. - pub async fn build(self, uri: Url) -> Result<(Sender, Receiver), WsHandshakeError> { - let target: Target = uri.try_into()?; - self.try_connect(target).await + /// + /// Uses the default connection over TCP. + pub async fn build(self, uri: Url) -> Result<(Sender, Receiver), WsHandshakeError> { + self.try_connect_over_tcp(uri).await } - async fn try_connect(self, mut target: Target) -> Result<(Sender, Receiver), WsHandshakeError> { + // Try to establish the connection over TCP. + async fn try_connect_over_tcp( + &self, + uri: Url, + ) -> Result<(Sender, Receiver), WsHandshakeError> { + let mut target: Target = uri.try_into()?; let mut err = None; // Only build TLS connector if `wss` in URL. @@ -317,37 +339,10 @@ impl WsTransportClientBuilder { } }; - let mut client = WsHandshakeClient::new( - BufReader::new(BufWriter::new(tcp_stream)), - &target.host_header, - &target.path_and_query, - ); - - let headers: Vec<_> = self - .headers - .iter() - .map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }) - .collect(); - client.set_headers(&headers); - - // Perform the initial handshake. - match client.handshake().await { - Ok(ServerResponse::Accepted { .. }) => { - tracing::debug!("Connection established to target: {:?}", target); - let mut builder = client.into_builder(); - builder.set_max_message_size(self.max_response_size as usize); - let (sender, receiver) = builder.finish(); - return Ok(( - Sender { inner: sender, max_request_size: self.max_request_size }, - Receiver { inner: receiver }, - )); - } + match self.try_connect(&target, tcp_stream).await { + Ok(result) => return Ok(result), - Ok(ServerResponse::Rejected { status_code }) => { - tracing::debug!("Connection rejected: {:?}", status_code); - err = Some(Err(WsHandshakeError::Rejected { status_code })); - } - Ok(ServerResponse::Redirect { status_code, location }) => { + Err(WsHandshakeError::Redirected { status_code, location }) => { tracing::debug!("Redirection: status_code: {}, location: {}", status_code, location); match Url::parse(&location) { // redirection with absolute path => need to lookup. @@ -396,14 +391,71 @@ impl WsTransportClientBuilder { } }; } + Err(e) => { - err = Some(Err(e.into())); + err = Some(Err(e)); } }; } } err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host))) } + + /// Try to establish the connection over the given data stream. + pub async fn build_with_stream( + self, + uri: Url, + data_stream: T, + ) -> Result<(Sender, Receiver), WsHandshakeError> + where + T: AsyncRead + AsyncWrite + Unpin, + { + let target: Target = uri.try_into()?; + self.try_connect(&target, data_stream).await + } + + /// Try to establish the handshake over the given data stream. + async fn try_connect( + &self, + target: &Target, + data_stream: T, + ) -> Result<(Sender, Receiver), WsHandshakeError> + where + T: AsyncRead + AsyncWrite + Unpin, + { + let mut client = WsHandshakeClient::new( + BufReader::new(BufWriter::new(data_stream)), + &target.host_header, + &target.path_and_query, + ); + + let headers: Vec<_> = + self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect(); + client.set_headers(&headers); + + // Perform the initial handshake. + match client.handshake().await { + Ok(ServerResponse::Accepted { .. }) => { + tracing::debug!("Connection established to target: {:?}", target); + let mut builder = client.into_builder(); + builder.set_max_message_size(self.max_response_size as usize); + let (sender, receiver) = builder.finish(); + Ok((Sender { inner: sender, max_request_size: self.max_request_size }, Receiver { inner: receiver })) + } + + Ok(ServerResponse::Rejected { status_code }) => { + tracing::debug!("Connection rejected: {:?}", status_code); + Err(WsHandshakeError::Rejected { status_code }) + } + + Ok(ServerResponse::Redirect { status_code, location }) => { + tracing::debug!("Redirection: status_code: {}, location: {}", status_code, location); + Err(WsHandshakeError::Redirected { status_code, location }) + } + + Err(e) => Err(e.into()), + } + } } #[cfg(feature = "__tls")] diff --git a/client/transport/src/ws/stream.rs b/client/transport/src/ws/stream.rs index 84bda9cbf8..c9d7d40b72 100644 --- a/client/transport/src/ws/stream.rs +++ b/client/transport/src/ws/stream.rs @@ -41,7 +41,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; #[pin_project(project = EitherStreamProj)] #[derive(Debug)] #[allow(clippy::large_enum_variant)] -pub(crate) enum EitherStream { +pub enum EitherStream { /// Unencrypted socket stream. Plain(#[pin] TcpStream), /// Encrypted socket stream. From 138067ed3940ef23af4ed413832eb7c090d7112e Mon Sep 17 00:00:00 2001 From: Leonardo Lima Date: Thu, 19 Oct 2023 22:03:45 -0300 Subject: [PATCH 2/5] refactor(jsonrpsee-ws-client): add new fns to `WsClientBuilder` - feat: add new `WsClientBuilder::build_with_transport` that builds and returns a `WsClient` with the given `Sender` and `Receiver`. - feat: add new `WsClientBuilder::build_with_stream` that uses the new `WsTransportClientBuilder::build_with_stream`, building and returning the `WsClient` with the given `data_stream` as transport layer. - refactor: update the `WsClientBuilder::build` to use the new `build_with_transport`, it helps not having duplicated code. --- client/transport/src/ws/mod.rs | 2 +- client/ws-client/src/lib.rs | 83 ++++++++++++++++++++++++---------- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index bf973ee173..1d5e3fe83f 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -31,7 +31,7 @@ use std::net::SocketAddr; use std::time::Duration; use futures_util::io::{BufReader, BufWriter}; -use futures_util::{AsyncRead, AsyncWrite}; +pub use futures_util::{AsyncRead, AsyncWrite}; use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; use jsonrpsee_core::{async_trait, Cow}; diff --git a/client/ws-client/src/lib.rs b/client/ws-client/src/lib.rs index 5b89b7353d..cb0b81d843 100644 --- a/client/ws-client/src/lib.rs +++ b/client/ws-client/src/lib.rs @@ -45,8 +45,10 @@ pub use http::{HeaderMap, HeaderValue}; use std::time::Duration; use url::Url; -use jsonrpsee_client_transport::ws::WsTransportClientBuilder; -use jsonrpsee_core::client::{CertificateStore, ClientBuilder, IdKind}; +use jsonrpsee_client_transport::ws::{AsyncRead, AsyncWrite, WsTransportClientBuilder}; +use jsonrpsee_core::client::{ + CertificateStore, ClientBuilder, IdKind, MaybeSend, TransportReceiverT, TransportSenderT, +}; use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES}; /// Builder for [`WsClient`]. @@ -213,40 +215,26 @@ impl WsClientBuilder { self } - /// Build the client with specified URL to connect to. - /// You must provide the port number in the URL. + /// Build the [`WsClient`] with specified [`TransportSenderT`] [`TransportReceiverT`] parameters /// /// ## Panics /// /// Panics if being called outside of `tokio` runtime context. - pub async fn build(self, url: impl AsRef) -> Result { + pub async fn build_with_transport(self, sender: S, receiver: R) -> Result + where + S: TransportSenderT + Send, + R: TransportReceiverT + Send, + { let Self { - certificate_store, max_concurrent_requests, - max_request_size, - max_response_size, request_timeout, - connection_timeout, ping_interval, - headers, - max_redirections, max_buffer_capacity_per_subscription, id_kind, max_log_length, + .. } = self; - let transport_builder = WsTransportClientBuilder { - certificate_store, - connection_timeout, - headers, - max_request_size, - max_response_size, - max_redirections, - }; - - let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?; - let (sender, receiver) = transport_builder.build(uri).await.map_err(|e| Error::Transport(e.into()))?; - let mut client = ClientBuilder::default() .max_buffer_capacity_per_subscription(max_buffer_capacity_per_subscription) .request_timeout(request_timeout) @@ -260,4 +248,53 @@ impl WsClientBuilder { Ok(client.build_with_tokio(sender, receiver)) } + + /// Build the [`WsClient`] with specified data stream, using [`WsTransportClientBuilder::build_with_stream`]. + /// + /// ## Panics + /// + /// Panics if being called outside of `tokio` runtime context. + pub async fn build_with_stream(self, url: impl AsRef, data_stream: T) -> Result + where + T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static, + { + let transport_builder = WsTransportClientBuilder { + certificate_store: self.certificate_store, + connection_timeout: self.connection_timeout, + headers: self.headers.clone(), + max_request_size: self.max_request_size, + max_response_size: self.max_response_size, + max_redirections: self.max_redirections, + }; + + let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?; + let (sender, receiver) = + transport_builder.build_with_stream(uri, data_stream).await.map_err(|e| Error::Transport(e.into()))?; + + let ws_client = self.build_with_transport(sender, receiver).await?; + Ok(ws_client) + } + + /// Build the [`WsClient`] with specified URL to connect to, using the default + /// [`WsTransportClientBuilder::build_with_stream`], therefore with the default TCP as transport layer. + /// + /// ## Panics + /// + /// Panics if being called outside of `tokio` runtime context. + pub async fn build(self, url: impl AsRef) -> Result { + let transport_builder = WsTransportClientBuilder { + certificate_store: self.certificate_store, + connection_timeout: self.connection_timeout, + headers: self.headers.clone(), + max_request_size: self.max_request_size, + max_response_size: self.max_response_size, + max_redirections: self.max_redirections, + }; + + let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?; + let (sender, receiver) = transport_builder.build(uri).await.map_err(|e| Error::Transport(e.into()))?; + + let ws_client = self.build_with_transport(sender, receiver).await?; + Ok(ws_client) + } } From 10e78a045cc07d5b39f8609f13a7ace8284e741f Mon Sep 17 00:00:00 2001 From: Leonardo Lima Date: Sat, 11 Nov 2023 16:20:24 -0300 Subject: [PATCH 3/5] refactor: re-export `EitherStream` & sort current list --- client/transport/src/ws/mod.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index 1d5e3fe83f..dda61029eb 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -30,21 +30,21 @@ use std::io; use std::net::SocketAddr; use std::time::Duration; -use futures_util::io::{BufReader, BufWriter}; pub use futures_util::{AsyncRead, AsyncWrite}; +use futures_util::io::{BufReader, BufWriter}; +use jsonrpsee_core::{async_trait, Cow}; use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; -use jsonrpsee_core::{async_trait, Cow}; +use soketto::{connection, Data, Incoming}; use soketto::connection::Error::Utf8; use soketto::data::ByteSlice125; use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse}; -use soketto::{connection, Data, Incoming}; -use stream::EitherStream; use thiserror::Error; use tokio::net::TcpStream; pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri}; pub use soketto::handshake::client::Header; +pub use stream::EitherStream; pub use url::Url; /// Sending end of WebSocket transport. From 53641d1ec73c0f4a17191cd7f32b52f646226895 Mon Sep 17 00:00:00 2001 From: Leonardo Lima Date: Tue, 14 Nov 2023 21:27:11 -0300 Subject: [PATCH 4/5] test: add integration tests and helper fns - add new helper fns to spawn a socks5 server, using `fast-socks5` - add new helper enum for `DataStream` that acts as a wrapper to `Socks5Stream`, similar to what a client would need to do. - impl AsyncRead + AsyncWrite for the helper `DataStream` enum, to make it compatible between futures::io and tokio::io. - add new tests that connects over a socks5 proxy, and use the new `WsClientBuilder::default()::build_with_stream(...) fn. --- client/transport/src/ws/mod.rs | 6 +- tests/Cargo.toml | 14 ++-- tests/tests/helpers.rs | 116 +++++++++++++++++++++++++++- tests/tests/integration_tests.rs | 126 ++++++++++++++++++++++++++++++- 4 files changed, 249 insertions(+), 13 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index dda61029eb..9f60e901c2 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -30,15 +30,15 @@ use std::io; use std::net::SocketAddr; use std::time::Duration; -pub use futures_util::{AsyncRead, AsyncWrite}; use futures_util::io::{BufReader, BufWriter}; -use jsonrpsee_core::{async_trait, Cow}; +pub use futures_util::{AsyncRead, AsyncWrite}; use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee_core::TEN_MB_SIZE_BYTES; -use soketto::{connection, Data, Incoming}; +use jsonrpsee_core::{async_trait, Cow}; use soketto::connection::Error::Utf8; use soketto::data::ByteSlice125; use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse}; +use soketto::{connection, Data, Incoming}; use thiserror::Error; use tokio::net::TcpStream; diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 05ddca6a75..75d14fac0a 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -10,15 +10,19 @@ publish = false [dev-dependencies] anyhow = "1" beef = { version = "0.5.1", features = ["impl_serde"] } +fast-socks5 = { version = "0.9.1" } futures = { version = "0.3.14", default-features = false, features = ["std"] } +futures-util = { version = "0.3.14", default-features = false, features = ["alloc"]} +hyper = { version = "0.14", features = ["http1", "client"] } jsonrpsee = { path = "../jsonrpsee", features = ["server", "client-core", "http-client", "ws-client", "macros"] } jsonrpsee-test-utils = { path = "../test-utils" } -tokio = { version = "1.16", features = ["full"] } -tracing = "0.1.34" serde = "1" serde_json = "1" -hyper = { version = "0.14", features = ["http1", "client"] } -tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } +tokio = { version = "1.16", features = ["full"] } tokio-stream = "0.1" -tower-http = { version = "0.4.0", features = ["full"] } +tokio-util = { version = "0.7", features = ["compat"]} tower = { version = "0.4.13", features = ["full"] } +tower-http = { version = "0.4.0", features = ["full"] } +tracing = "0.1.34" +tracing-subscriber = { version = "0.3.3", features = ["env-filter"] } +pin-project = { version = "1" } diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 8c754e4a30..1d1e02b9a6 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -29,7 +29,9 @@ use std::net::SocketAddr; use std::time::Duration; -use futures::{SinkExt, Stream, StreamExt}; +use fast_socks5::client::Socks5Stream; +use fast_socks5::server; +use futures::{AsyncRead, AsyncWrite, SinkExt, Stream, StreamExt}; use jsonrpsee::core::Error; use jsonrpsee::server::middleware::http::ProxyGetRequestLayer; use jsonrpsee::server::{ @@ -37,9 +39,12 @@ use jsonrpsee::server::{ }; use jsonrpsee::types::{ErrorObject, ErrorObjectOwned}; use jsonrpsee::SubscriptionCloseResponse; +use pin_project::pin_project; use serde::Serialize; +use tokio::net::TcpStream; use tokio::time::interval; use tokio_stream::wrappers::IntervalStream; +use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tower_http::cors::CorsLayer; #[allow(dead_code)] @@ -249,3 +254,112 @@ pub async fn pipe_from_stream_and_drop( } } } + +#[allow(dead_code)] +pub async fn socks_server_no_auth() -> SocketAddr { + let mut config = server::Config::default(); + config.set_dns_resolve(false); + let config = std::sync::Arc::new(config); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let proxy_addr = listener.local_addr().unwrap(); + + spawn_socks_server(listener, config).await; + + proxy_addr +} + +#[allow(dead_code)] +pub async fn spawn_socks_server(listener: tokio::net::TcpListener, config: std::sync::Arc) { + let addr = listener.local_addr().unwrap(); + tokio::spawn(async move { + loop { + let (stream, _) = listener.accept().await.unwrap(); + let mut socks5_socket = server::Socks5Socket::new(stream, config.clone()); + socks5_socket.set_reply_ip(addr.ip()); + + socks5_socket.upgrade_to_socks5().await.unwrap(); + } + }); +} + +#[allow(dead_code)] +pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream { + let target_addr = server_addr.ip().to_string(); + let target_port = server_addr.port(); + + let socks_server = socks_server_no_auth().await; + + fast_socks5::client::Socks5Stream::connect( + socks_server, + target_addr, + target_port, + fast_socks5::client::Config::default(), + ) + .await + .unwrap() +} + +#[pin_project(project = DataStreamProj)] +#[allow(dead_code)] +pub enum DataStream { + Socks5(#[pin] Socks5Stream), +} + +impl AsyncRead for DataStream { + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> std::task::Poll> { + match self.project() { + DataStreamProj::Socks5(s) => { + let compat = s.compat(); + futures_util::pin_mut!(compat); + AsyncRead::poll_read(compat, cx, buf) + } + } + } +} + +impl AsyncWrite for DataStream { + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + match self.project() { + DataStreamProj::Socks5(s) => { + let compat = s.compat_write(); + futures_util::pin_mut!(compat); + AsyncWrite::poll_write(compat, cx, buf) + } + } + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + DataStreamProj::Socks5(s) => { + let compat = s.compat_write(); + futures_util::pin_mut!(compat); + AsyncWrite::poll_flush(compat, cx) + } + } + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.project() { + DataStreamProj::Socks5(s) => { + let compat = s.compat_write(); + futures_util::pin_mut!(compat); + AsyncWrite::poll_close(compat, cx) + } + } + } +} diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 92f112a9a3..559f7518e9 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -36,8 +36,8 @@ use std::time::Duration; use futures::stream::FuturesUnordered; use futures::{channel::mpsc, StreamExt, TryStreamExt}; use helpers::{ - init_logger, pipe_from_stream_and_drop, server, server_with_cors, server_with_health_api, server_with_subscription, - server_with_subscription_and_handle, + connect_over_socks_stream, init_logger, pipe_from_stream_and_drop, server, server_with_cors, + server_with_health_api, server_with_subscription, server_with_subscription_and_handle, DataStream, }; use hyper::http::HeaderValue; use jsonrpsee::core::client::{ClientT, IdKind, Subscription, SubscriptionClientT}; @@ -75,6 +75,31 @@ async fn ws_subscription_works() { } } +#[tokio::test] +async fn ws_subscription_works_over_proxy_stream() { + init_logger(); + + let server_addr = server_with_subscription().await; + let target_url = format!("ws://{}", server_addr); + + let socks_stream = connect_over_socks_stream(server_addr).await; + let data_stream = DataStream::Socks5(socks_stream); + + let client = WsClientBuilder::default().build_with_stream(target_url, data_stream).await.unwrap(); + + let mut hello_sub: Subscription = + client.subscribe("subscribe_hello", rpc_params![], "unsubscribe_hello").await.unwrap(); + let mut foo_sub: Subscription = + client.subscribe("subscribe_foo", rpc_params![], "unsubscribe_foo").await.unwrap(); + + for _ in 0..10 { + let hello = hello_sub.next().await.unwrap().unwrap(); + let foo = foo_sub.next().await.unwrap().unwrap(); + assert_eq!(hello, "hello from subscription".to_string()); + assert_eq!(foo, 1337); + } +} + #[tokio::test] async fn ws_unsubscription_works() { init_logger(); @@ -108,6 +133,47 @@ async fn ws_unsubscription_works() { assert!(success); } +#[tokio::test] +async fn ws_unsubscription_works_over_proxy_stream() { + init_logger(); + + let server_addr = server_with_subscription().await; + let server_url = format!("ws://{}", server_addr); + + let socks_stream = connect_over_socks_stream(server_addr).await; + let data_stream = DataStream::Socks5(socks_stream); + + let client = WsClientBuilder::default() + .max_concurrent_requests(1) + .build_with_stream(&server_url, data_stream) + .await + .unwrap(); + + let mut sub: Subscription = + client.subscribe("subscribe_foo", rpc_params![], "unsubscribe_foo").await.unwrap(); + + // It's technically possible to have race-conditions between the notifications and the unsubscribe message. + // So let's wait for the first notification and then unsubscribe. + let _item = sub.next().await.unwrap().unwrap(); + + sub.unsubscribe().await.unwrap(); + + let mut success = false; + + // Wait until a slot is available, as only one concurrent call is allowed. + // Then when this finishes we know that unsubscribe call has been finished. + for _ in 0..30 { + let res: Result = client.request("say_hello", rpc_params![]).await; + if res.is_ok() { + success = true; + break; + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + + assert!(success); +} + #[tokio::test] async fn ws_subscription_with_input_works() { init_logger(); @@ -124,6 +190,27 @@ async fn ws_subscription_with_input_works() { } } +#[tokio::test] +async fn ws_subscription_with_input_works_over_proxy_stream() { + init_logger(); + + let server_addr = server_with_subscription().await; + let server_url = format!("ws://{}", server_addr); + + let socks_stream = connect_over_socks_stream(server_addr).await; + let data_stream = DataStream::Socks5(socks_stream); + + let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + + let mut add_one: Subscription = + client.subscribe("subscribe_add_one", rpc_params![1], "unsubscribe_add_one").await.unwrap(); + + for i in 2..4 { + let next = add_one.next().await.unwrap().unwrap(); + assert_eq!(next, i); + } +} + #[tokio::test] async fn ws_method_call_works() { init_logger(); @@ -135,6 +222,21 @@ async fn ws_method_call_works() { assert_eq!(&response, "hello"); } +#[tokio::test] +async fn ws_method_call_works_over_proxy_stream() { + init_logger(); + + let server_addr = server().await; + let server_url = format!("ws://{}", server_addr); + + let socks_stream = connect_over_socks_stream(server_addr).await; + let data_stream = DataStream::Socks5(socks_stream); + + let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); + let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); + assert_eq!(&response, "hello"); +} + #[tokio::test] async fn ws_method_call_str_id_works() { init_logger(); @@ -146,6 +248,22 @@ async fn ws_method_call_str_id_works() { assert_eq!(&response, "hello"); } +#[tokio::test] +async fn ws_method_call_str_id_works_over_proxy_stream() { + init_logger(); + + let server_addr = server().await; + let server_url = format!("ws://{}", server_addr); + + let socks_stream = connect_over_socks_stream(server_addr).await; + let data_stream = DataStream::Socks5(socks_stream); + + let client = + WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, data_stream).await.unwrap(); + let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); + assert_eq!(&response, "hello"); +} + #[tokio::test] async fn http_method_call_works() { init_logger(); @@ -256,7 +374,7 @@ async fn ws_subscription_several_clients_with_drop() { } #[tokio::test] -async fn ws_subscription_without_polling_doesnt_make_client_unuseable() { +async fn ws_subscription_without_polling_does_not_make_client_unusable() { init_logger(); let server_addr = server_with_subscription().await; @@ -273,7 +391,7 @@ async fn ws_subscription_without_polling_doesnt_make_client_unuseable() { assert!(hello_sub.next().await.unwrap().is_ok()); } - // NOTE: this is now unuseable and unregistered. + // NOTE: this is now unusable and unregistered. assert!(hello_sub.next().await.is_none()); // The client should still be useable => make sure it still works. From 6ec9a3e3bddcae19568cba9ce684d21dfb0703ee Mon Sep 17 00:00:00 2001 From: Niklas Adolfsson Date: Tue, 21 Nov 2023 17:10:21 +0100 Subject: [PATCH 5/5] address review suggestions --- client/transport/src/ws/mod.rs | 26 +++++++------- tests/tests/helpers.rs | 62 +++++++++++--------------------- tests/tests/integration_tests.rs | 10 +++--- tests/tests/metrics.rs | 2 ++ tests/tests/proc_macros.rs | 2 ++ tests/tests/rpc_module.rs | 2 ++ 6 files changed, 45 insertions(+), 59 deletions(-) diff --git a/client/transport/src/ws/mod.rs b/client/transport/src/ws/mod.rs index 9f60e901c2..1f3d21fb35 100644 --- a/client/transport/src/ws/mod.rs +++ b/client/transport/src/ws/mod.rs @@ -298,6 +298,19 @@ impl WsTransportClientBuilder { self.try_connect_over_tcp(uri).await } + /// Try to establish the connection over the given data stream. + pub async fn build_with_stream( + self, + uri: Url, + data_stream: T, + ) -> Result<(Sender, Receiver), WsHandshakeError> + where + T: AsyncRead + AsyncWrite + Unpin, + { + let target: Target = uri.try_into()?; + self.try_connect(&target, data_stream).await + } + // Try to establish the connection over TCP. async fn try_connect_over_tcp( &self, @@ -401,19 +414,6 @@ impl WsTransportClientBuilder { err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host))) } - /// Try to establish the connection over the given data stream. - pub async fn build_with_stream( - self, - uri: Url, - data_stream: T, - ) -> Result<(Sender, Receiver), WsHandshakeError> - where - T: AsyncRead + AsyncWrite + Unpin, - { - let target: Target = uri.try_into()?; - self.try_connect(&target, data_stream).await - } - /// Try to establish the handshake over the given data stream. async fn try_connect( &self, diff --git a/tests/tests/helpers.rs b/tests/tests/helpers.rs index 1d1e02b9a6..3cb8687ab7 100644 --- a/tests/tests/helpers.rs +++ b/tests/tests/helpers.rs @@ -25,6 +25,7 @@ // DEALINGS IN THE SOFTWARE. #![cfg(test)] +#![allow(dead_code)] use std::net::SocketAddr; use std::time::Duration; @@ -47,7 +48,6 @@ use tokio_stream::wrappers::IntervalStream; use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt}; use tower_http::cors::CorsLayer; -#[allow(dead_code)] pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) { let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); @@ -129,7 +129,6 @@ pub async fn server_with_subscription_and_handle() -> (SocketAddr, ServerHandle) (addr, server_handle) } -#[allow(dead_code)] pub async fn server_with_subscription() -> SocketAddr { let (addr, handle) = server_with_subscription_and_handle().await; @@ -138,7 +137,6 @@ pub async fn server_with_subscription() -> SocketAddr { addr } -#[allow(dead_code)] pub async fn server() -> SocketAddr { let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); let mut module = RpcModule::new(()); @@ -171,7 +169,6 @@ pub async fn server() -> SocketAddr { } /// Yields one item then sleeps for an hour. -#[allow(dead_code)] pub async fn server_with_sleeping_subscription(tx: futures::channel::mpsc::Sender<()>) -> SocketAddr { let server = ServerBuilder::default().build("127.0.0.1:0").await.unwrap(); let addr = server.local_addr().unwrap(); @@ -198,7 +195,6 @@ pub async fn server_with_sleeping_subscription(tx: futures::channel::mpsc::Sende addr } -#[allow(dead_code)] pub async fn server_with_health_api() -> (SocketAddr, ServerHandle) { server_with_cors(CorsLayer::new()).await } @@ -255,7 +251,6 @@ pub async fn pipe_from_stream_and_drop( } } -#[allow(dead_code)] pub async fn socks_server_no_auth() -> SocketAddr { let mut config = server::Config::default(); config.set_dns_resolve(false); @@ -269,7 +264,6 @@ pub async fn socks_server_no_auth() -> SocketAddr { proxy_addr } -#[allow(dead_code)] pub async fn spawn_socks_server(listener: tokio::net::TcpListener, config: std::sync::Arc) { let addr = listener.local_addr().unwrap(); tokio::spawn(async move { @@ -283,7 +277,6 @@ pub async fn spawn_socks_server(listener: tokio::net::TcpListener, config: std:: }); } -#[allow(dead_code)] pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream { let target_addr = server_addr.ip().to_string(); let target_port = server_addr.port(); @@ -300,25 +293,24 @@ pub async fn connect_over_socks_stream(server_addr: SocketAddr) -> Socks5Stream< .unwrap() } -#[pin_project(project = DataStreamProj)] -#[allow(dead_code)] -pub enum DataStream { - Socks5(#[pin] Socks5Stream), +#[pin_project] +pub struct DataStream(#[pin] Socks5Stream); + +impl DataStream { + pub fn new(t: Socks5Stream) -> Self { + Self(t) + } } -impl AsyncRead for DataStream { +impl AsyncRead for DataStream { fn poll_read( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, buf: &mut [u8], ) -> std::task::Poll> { - match self.project() { - DataStreamProj::Socks5(s) => { - let compat = s.compat(); - futures_util::pin_mut!(compat); - AsyncRead::poll_read(compat, cx, buf) - } - } + let this = self.project().0.compat(); + futures_util::pin_mut!(this); + AsyncRead::poll_read(this, cx, buf) } } @@ -328,38 +320,26 @@ impl Async cx: &mut std::task::Context<'_>, buf: &[u8], ) -> std::task::Poll> { - match self.project() { - DataStreamProj::Socks5(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_write(compat, cx, buf) - } - } + let this = self.project().0.compat_write(); + futures_util::pin_mut!(this); + AsyncWrite::poll_write(this, cx, buf) } fn poll_flush( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - match self.project() { - DataStreamProj::Socks5(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_flush(compat, cx) - } - } + let this = self.project().0.compat_write(); + futures_util::pin_mut!(this); + AsyncWrite::poll_flush(this, cx) } fn poll_close( self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - match self.project() { - DataStreamProj::Socks5(s) => { - let compat = s.compat_write(); - futures_util::pin_mut!(compat); - AsyncWrite::poll_close(compat, cx) - } - } + let this = self.project().0.compat_write(); + futures_util::pin_mut!(this); + AsyncWrite::poll_close(this, cx) } } diff --git a/tests/tests/integration_tests.rs b/tests/tests/integration_tests.rs index 559f7518e9..6f92780a60 100644 --- a/tests/tests/integration_tests.rs +++ b/tests/tests/integration_tests.rs @@ -83,7 +83,7 @@ async fn ws_subscription_works_over_proxy_stream() { let target_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::Socks5(socks_stream); + let data_stream = DataStream::new(socks_stream); let client = WsClientBuilder::default().build_with_stream(target_url, data_stream).await.unwrap(); @@ -141,7 +141,7 @@ async fn ws_unsubscription_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::Socks5(socks_stream); + let data_stream = DataStream::new(socks_stream); let client = WsClientBuilder::default() .max_concurrent_requests(1) @@ -198,7 +198,7 @@ async fn ws_subscription_with_input_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::Socks5(socks_stream); + let data_stream = DataStream::new(socks_stream); let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); @@ -230,7 +230,7 @@ async fn ws_method_call_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::Socks5(socks_stream); + let data_stream = DataStream::new(socks_stream); let client = WsClientBuilder::default().build_with_stream(&server_url, data_stream).await.unwrap(); let response: String = client.request("say_hello", rpc_params![]).await.unwrap(); @@ -256,7 +256,7 @@ async fn ws_method_call_str_id_works_over_proxy_stream() { let server_url = format!("ws://{}", server_addr); let socks_stream = connect_over_socks_stream(server_addr).await; - let data_stream = DataStream::Socks5(socks_stream); + let data_stream = DataStream::new(socks_stream); let client = WsClientBuilder::default().id_format(IdKind::String).build_with_stream(&server_url, data_stream).await.unwrap(); diff --git a/tests/tests/metrics.rs b/tests/tests/metrics.rs index 787496f278..043855fc98 100644 --- a/tests/tests/metrics.rs +++ b/tests/tests/metrics.rs @@ -24,6 +24,8 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +#![cfg(test)] + mod helpers; use std::collections::HashMap; diff --git a/tests/tests/proc_macros.rs b/tests/tests/proc_macros.rs index d394f74893..2cda9d4da2 100644 --- a/tests/tests/proc_macros.rs +++ b/tests/tests/proc_macros.rs @@ -26,6 +26,8 @@ //! Example of using proc macro to generate working client and server. +#![cfg(test)] + mod helpers; use std::net::SocketAddr; diff --git a/tests/tests/rpc_module.rs b/tests/tests/rpc_module.rs index bbd8428801..58b859827a 100644 --- a/tests/tests/rpc_module.rs +++ b/tests/tests/rpc_module.rs @@ -24,6 +24,8 @@ // IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER // DEALINGS IN THE SOFTWARE. +#![cfg(test)] + mod helpers; use std::collections::{HashMap, VecDeque};