Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: update and turn WsTransportClientBuilder generic #1168

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 95 additions & 43 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,35 +31,36 @@ use std::net::SocketAddr;
use std::time::Duration;

use futures_util::io::{BufReader, BufWriter};
use jsonrpsee_core::client::{CertificateStore, ReceivedMessage, TransportReceiverT, TransportSenderT};
pub use futures_util::{AsyncRead, AsyncWrite};
use jsonrpsee_core::client::{CertificateStore, MaybeSend, ReceivedMessage, TransportReceiverT, TransportSenderT};
use jsonrpsee_core::TEN_MB_SIZE_BYTES;
use jsonrpsee_core::{async_trait, Cow};
use soketto::connection::Error::Utf8;
use soketto::data::ByteSlice125;
use soketto::handshake::client::{Client as WsHandshakeClient, ServerResponse};
use soketto::{connection, Data, Incoming};
use stream::EitherStream;
use thiserror::Error;
use tokio::net::TcpStream;

pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
pub use soketto::handshake::client::Header;
pub use stream::EitherStream;
pub use url::Url;

/// Sending end of WebSocket transport.
#[derive(Debug)]
pub struct Sender {
inner: connection::Sender<BufReader<BufWriter<EitherStream>>>,
pub struct Sender<T> {
inner: connection::Sender<BufReader<BufWriter<T>>>,
max_request_size: u32,
}

/// Receiving end of WebSocket transport.
#[derive(Debug)]
pub struct Receiver {
inner: connection::Receiver<BufReader<BufWriter<EitherStream>>>,
pub struct Receiver<T> {
inner: connection::Receiver<BufReader<BufWriter<T>>>,
}

/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
/// Builder for a WebSocket transport [`Sender`] and [`Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder {
/// What certificate store to use
Expand Down Expand Up @@ -190,6 +191,15 @@ pub enum WsHandshakeError {
status_code: u16,
},

/// Server redirected to other location.
#[error("Connection redirected with status code: {status_code} and location: {location}")]
Redirected {
/// HTTP status code that the server returned.
status_code: u16,
/// The location URL redirected to.
location: String,
},

/// Timeout while trying to connect.
#[error("Connection timeout exceeded: {0:?}")]
Timeout(Duration),
Expand All @@ -215,7 +225,10 @@ pub enum WsError {
}

#[async_trait]
impl TransportSenderT for Sender {
impl<T> TransportSenderT for Sender<T>
where
T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;

/// Sends out a request. Returns a `Future` that finishes when the request has been
Expand Down Expand Up @@ -252,7 +265,10 @@ impl TransportSenderT for Sender {
}

#[async_trait]
impl TransportReceiverT for Receiver {
impl<T> TransportReceiverT for Receiver<T>
where
T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static,
{
type Error = WsError;

/// Returns a `Future` resolving when the server sent us something back.
Expand All @@ -276,12 +292,31 @@ impl TransportReceiverT for Receiver {

impl WsTransportClientBuilder {
/// Try to establish the connection.
pub async fn build(self, uri: Url) -> Result<(Sender, Receiver), WsHandshakeError> {
///
/// Uses the default connection over TCP.
pub async fn build(self, uri: Url) -> Result<(Sender<EitherStream>, Receiver<EitherStream>), WsHandshakeError> {
niklasad1 marked this conversation as resolved.
Show resolved Hide resolved
self.try_connect_over_tcp(uri).await
}

/// Try to establish the connection over the given data stream.
pub async fn build_with_stream<T>(
self,
uri: Url,
data_stream: T,
) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let target: Target = uri.try_into()?;
self.try_connect(target).await
self.try_connect(&target, data_stream).await
}

async fn try_connect(self, mut target: Target) -> Result<(Sender, Receiver), WsHandshakeError> {
// Try to establish the connection over TCP.
async fn try_connect_over_tcp(
&self,
uri: Url,
) -> Result<(Sender<EitherStream>, Receiver<EitherStream>), WsHandshakeError> {
let mut target: Target = uri.try_into()?;
let mut err = None;

// Only build TLS connector if `wss` in URL.
Expand Down Expand Up @@ -317,37 +352,10 @@ impl WsTransportClientBuilder {
}
};

let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(tcp_stream)),
&target.host_header,
&target.path_and_query,
);

let headers: Vec<_> = self
.headers
.iter()
.map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() })
.collect();
client.set_headers(&headers);

// Perform the initial handshake.
match client.handshake().await {
Ok(ServerResponse::Accepted { .. }) => {
tracing::debug!("Connection established to target: {:?}", target);
let mut builder = client.into_builder();
builder.set_max_message_size(self.max_response_size as usize);
let (sender, receiver) = builder.finish();
return Ok((
Sender { inner: sender, max_request_size: self.max_request_size },
Receiver { inner: receiver },
));
}
match self.try_connect(&target, tcp_stream).await {
Ok(result) => return Ok(result),

Ok(ServerResponse::Rejected { status_code }) => {
tracing::debug!("Connection rejected: {:?}", status_code);
err = Some(Err(WsHandshakeError::Rejected { status_code }));
}
Ok(ServerResponse::Redirect { status_code, location }) => {
Err(WsHandshakeError::Redirected { status_code, location }) => {
tracing::debug!("Redirection: status_code: {}, location: {}", status_code, location);
match Url::parse(&location) {
// redirection with absolute path => need to lookup.
Expand Down Expand Up @@ -396,14 +404,58 @@ impl WsTransportClientBuilder {
}
};
}

Err(e) => {
err = Some(Err(e.into()));
err = Some(Err(e));
}
};
}
}
err.unwrap_or(Err(WsHandshakeError::NoAddressFound(target.host)))
}

/// Try to establish the handshake over the given data stream.
async fn try_connect<T>(
&self,
target: &Target,
data_stream: T,
) -> Result<(Sender<T>, Receiver<T>), WsHandshakeError>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut client = WsHandshakeClient::new(
BufReader::new(BufWriter::new(data_stream)),
&target.host_header,
&target.path_and_query,
);

let headers: Vec<_> =
self.headers.iter().map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() }).collect();
client.set_headers(&headers);

// Perform the initial handshake.
match client.handshake().await {
Ok(ServerResponse::Accepted { .. }) => {
tracing::debug!("Connection established to target: {:?}", target);
let mut builder = client.into_builder();
builder.set_max_message_size(self.max_response_size as usize);
let (sender, receiver) = builder.finish();
Ok((Sender { inner: sender, max_request_size: self.max_request_size }, Receiver { inner: receiver }))
}

Ok(ServerResponse::Rejected { status_code }) => {
tracing::debug!("Connection rejected: {:?}", status_code);
Err(WsHandshakeError::Rejected { status_code })
}

Ok(ServerResponse::Redirect { status_code, location }) => {
tracing::debug!("Redirection: status_code: {}, location: {}", status_code, location);
Err(WsHandshakeError::Redirected { status_code, location })
}

Err(e) => Err(e.into()),
}
}
}

#[cfg(feature = "__tls")]
Expand Down
2 changes: 1 addition & 1 deletion client/transport/src/ws/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
#[pin_project(project = EitherStreamProj)]
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
pub(crate) enum EitherStream {
pub enum EitherStream {
/// Unencrypted socket stream.
Plain(#[pin] TcpStream),
/// Encrypted socket stream.
Expand Down
83 changes: 60 additions & 23 deletions client/ws-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ pub use http::{HeaderMap, HeaderValue};
use std::time::Duration;
use url::Url;

use jsonrpsee_client_transport::ws::WsTransportClientBuilder;
use jsonrpsee_core::client::{CertificateStore, ClientBuilder, IdKind};
use jsonrpsee_client_transport::ws::{AsyncRead, AsyncWrite, WsTransportClientBuilder};
use jsonrpsee_core::client::{
CertificateStore, ClientBuilder, IdKind, MaybeSend, TransportReceiverT, TransportSenderT,
};
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};

/// Builder for [`WsClient`].
Expand Down Expand Up @@ -213,40 +215,26 @@ impl WsClientBuilder {
self
}

/// Build the client with specified URL to connect to.
/// You must provide the port number in the URL.
/// Build the [`WsClient`] with specified [`TransportSenderT`] [`TransportReceiverT`] parameters
///
/// ## Panics
///
/// Panics if being called outside of `tokio` runtime context.
pub async fn build(self, url: impl AsRef<str>) -> Result<WsClient, Error> {
pub async fn build_with_transport<S, R>(self, sender: S, receiver: R) -> Result<WsClient, Error>
where
S: TransportSenderT + Send,
R: TransportReceiverT + Send,
{
let Self {
certificate_store,
max_concurrent_requests,
max_request_size,
max_response_size,
request_timeout,
connection_timeout,
ping_interval,
headers,
max_redirections,
max_buffer_capacity_per_subscription,
id_kind,
max_log_length,
..
} = self;

let transport_builder = WsTransportClientBuilder {
certificate_store,
connection_timeout,
headers,
max_request_size,
max_response_size,
max_redirections,
};

let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?;
let (sender, receiver) = transport_builder.build(uri).await.map_err(|e| Error::Transport(e.into()))?;

let mut client = ClientBuilder::default()
.max_buffer_capacity_per_subscription(max_buffer_capacity_per_subscription)
.request_timeout(request_timeout)
Expand All @@ -260,4 +248,53 @@ impl WsClientBuilder {

Ok(client.build_with_tokio(sender, receiver))
}

/// Build the [`WsClient`] with specified data stream, using [`WsTransportClientBuilder::build_with_stream`].
///
/// ## Panics
///
/// Panics if being called outside of `tokio` runtime context.
pub async fn build_with_stream<T>(self, url: impl AsRef<str>, data_stream: T) -> Result<WsClient, Error>
where
T: AsyncRead + AsyncWrite + Unpin + MaybeSend + 'static,
{
let transport_builder = WsTransportClientBuilder {
certificate_store: self.certificate_store,
connection_timeout: self.connection_timeout,
headers: self.headers.clone(),
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
max_redirections: self.max_redirections,
};

let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?;
let (sender, receiver) =
transport_builder.build_with_stream(uri, data_stream).await.map_err(|e| Error::Transport(e.into()))?;

let ws_client = self.build_with_transport(sender, receiver).await?;
Ok(ws_client)
}

/// Build the [`WsClient`] with specified URL to connect to, using the default
/// [`WsTransportClientBuilder::build_with_stream`], therefore with the default TCP as transport layer.
///
/// ## Panics
///
/// Panics if being called outside of `tokio` runtime context.
pub async fn build(self, url: impl AsRef<str>) -> Result<WsClient, Error> {
let transport_builder = WsTransportClientBuilder {
certificate_store: self.certificate_store,
connection_timeout: self.connection_timeout,
headers: self.headers.clone(),
max_request_size: self.max_request_size,
max_response_size: self.max_response_size,
max_redirections: self.max_redirections,
};

let uri = Url::parse(url.as_ref()).map_err(|e| Error::Transport(e.into()))?;
let (sender, receiver) = transport_builder.build(uri).await.map_err(|e| Error::Transport(e.into()))?;

let ws_client = self.build_with_transport(sender, receiver).await?;
Ok(ws_client)
}
}
14 changes: 9 additions & 5 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@ publish = false
[dev-dependencies]
anyhow = "1"
beef = { version = "0.5.1", features = ["impl_serde"] }
fast-socks5 = { version = "0.9.1" }
futures = { version = "0.3.14", default-features = false, features = ["std"] }
futures-util = { version = "0.3.14", default-features = false, features = ["alloc"]}
hyper = { version = "0.14", features = ["http1", "client"] }
jsonrpsee = { path = "../jsonrpsee", features = ["server", "client-core", "http-client", "ws-client", "macros"] }
jsonrpsee-test-utils = { path = "../test-utils" }
tokio = { version = "1.16", features = ["full"] }
tracing = "0.1.34"
serde = "1"
serde_json = "1"
hyper = { version = "0.14", features = ["http1", "client"] }
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
tokio = { version = "1.16", features = ["full"] }
tokio-stream = "0.1"
tower-http = { version = "0.4.0", features = ["full"] }
tokio-util = { version = "0.7", features = ["compat"]}
tower = { version = "0.4.13", features = ["full"] }
tower-http = { version = "0.4.0", features = ["full"] }
tracing = "0.1.34"
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
pin-project = { version = "1" }
Loading
Loading