Skip to content

Commit

Permalink
Uniform API for custom headers between clients (#814)
Browse files Browse the repository at this point in the history
* ws-client: Replace `httparse::Header` with `http::HeaderMap`

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* ws-client: Make headers optional

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http-client: Expose custom header injection

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http-client: Adjust testing for custom headers

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* Make `http::HeaderMap` non-optional

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http-client: Cache request headers

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* Fix doc tests

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http-client: Use `into_iter` for headers

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* docs: Improve custom headers documentation

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http: Use `hyper::http` instead of `http` directly

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* http-client: Adjust testing

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* Fix doc tests

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>

* client: Expose `http::HeaderMap` and `http::HeaderValue`

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>
  • Loading branch information
lexnv authored Jul 13, 2022
1 parent a26f1fb commit 0ccfbd7
Show file tree
Hide file tree
Showing 6 changed files with 132 additions and 45 deletions.
45 changes: 42 additions & 3 deletions client/http-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::time::Duration;
use crate::transport::HttpTransportClient;
use crate::types::{ErrorResponse, Id, NotificationSer, ParamsSer, RequestSer, Response};
use async_trait::async_trait;
use hyper::http::HeaderMap;
use jsonrpsee_core::client::{CertificateStore, ClientT, IdKind, RequestIdManager, Subscription, SubscriptionClientT};
use jsonrpsee_core::tracing::RpcTracing;
use jsonrpsee_core::{Error, TEN_MB_SIZE_BYTES};
Expand All @@ -39,6 +40,29 @@ use serde::de::DeserializeOwned;
use tracing_futures::Instrument;

/// Http Client Builder.
///
/// # Examples
///
/// ```no_run
///
/// use jsonrpsee_http_client::{HttpClientBuilder, HeaderMap, HeaderValue};
///
/// #[tokio::main]
/// async fn main() {
/// // Build custom headers used for every submitted request.
/// let mut headers = HeaderMap::new();
/// headers.insert("Any-Header-You-Like", HeaderValue::from_static("42"));
///
/// // Build client
/// let client = HttpClientBuilder::default()
/// .set_headers(headers)
/// .build("wss://localhost:443")
/// .unwrap();
///
/// // use client....
/// }
///
/// ```
#[derive(Debug)]
pub struct HttpClientBuilder {
max_request_body_size: u32,
Expand All @@ -47,6 +71,7 @@ pub struct HttpClientBuilder {
certificate_store: CertificateStore,
id_kind: IdKind,
max_log_length: u32,
headers: HeaderMap,
}

impl HttpClientBuilder {
Expand Down Expand Up @@ -88,11 +113,24 @@ impl HttpClientBuilder {
self
}

/// Set a custom header passed to the server with every request (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn set_headers(mut self, headers: HeaderMap) -> Self {
self.headers = headers;
self
}

/// Build the HTTP client with target to connect to.
pub fn build(self, target: impl AsRef<str>) -> Result<HttpClient, Error> {
let transport =
HttpTransportClient::new(target, self.max_request_body_size, self.certificate_store, self.max_log_length)
.map_err(|e| Error::Transport(e.into()))?;
let transport = HttpTransportClient::new(
target,
self.max_request_body_size,
self.certificate_store,
self.max_log_length,
self.headers,
)
.map_err(|e| Error::Transport(e.into()))?;
Ok(HttpClient {
transport,
id_manager: Arc::new(RequestIdManager::new(self.max_concurrent_requests, self.id_kind)),
Expand All @@ -110,6 +148,7 @@ impl Default for HttpClientBuilder {
certificate_store: CertificateStore::Native,
id_kind: IdKind::Number,
max_log_length: 4096,
headers: HeaderMap::new(),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions client/http-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ pub mod transport;
mod tests;

pub use client::{HttpClient, HttpClientBuilder};
pub use hyper::http::{HeaderMap, HeaderValue};
pub use jsonrpsee_types as types;
73 changes: 55 additions & 18 deletions client/http-client/src/transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// the JSON-RPC request id to a value that might have already been used.

use hyper::client::{Client, HttpConnector};
use hyper::http::{HeaderMap, HeaderValue};
use hyper::Uri;
use jsonrpsee_core::client::CertificateStore;
use jsonrpsee_core::error::GenericTransportError;
Expand Down Expand Up @@ -48,6 +49,8 @@ pub struct HttpTransportClient {
///
/// Logs bigger than this limit will be truncated.
max_log_length: u32,
/// Custom headers to pass with every request.
headers: HeaderMap,
}

impl HttpTransportClient {
Expand All @@ -57,6 +60,7 @@ impl HttpTransportClient {
max_request_body_size: u32,
cert_store: CertificateStore,
max_log_length: u32,
headers: HeaderMap,
) -> Result<Self, Error> {
let target: Uri = target.as_ref().parse().map_err(|e| Error::Url(format!("Invalid URL: {}", e)))?;
if target.port_u16().is_none() {
Expand Down Expand Up @@ -90,7 +94,20 @@ impl HttpTransportClient {
return Err(Error::Url(err.into()));
}
};
Ok(Self { target, client, max_request_body_size, max_log_length })

// Cache request headers: 2 default headers, followed by user custom headers.
// Maintain order for headers in case of duplicate keys:
// https://datatracker.ietf.org/doc/html/rfc7230#section-3.2.2
let mut cached_headers = HeaderMap::with_capacity(2 + headers.len());
cached_headers.insert(hyper::header::CONTENT_TYPE, HeaderValue::from_static(CONTENT_TYPE_JSON));
cached_headers.insert(hyper::header::ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON));
for (key, value) in headers.into_iter() {
if let Some(key) = key {
cached_headers.insert(key, value);
}
}

Ok(Self { target, client, max_request_body_size, max_log_length, headers: cached_headers })
}

async fn inner_send(&self, body: String) -> Result<hyper::Response<hyper::Body>, Error> {
Expand All @@ -100,11 +117,9 @@ impl HttpTransportClient {
return Err(Error::RequestTooLarge);
}

let req = hyper::Request::post(&self.target)
.header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.header(hyper::header::ACCEPT, hyper::header::HeaderValue::from_static(CONTENT_TYPE_JSON))
.body(From::from(body))
.expect("URI and request headers are valid; qed");
let mut req = hyper::Request::post(&self.target);
req.headers_mut().map(|headers| *headers = self.headers.clone());
let req = req.body(From::from(body)).expect("URI and request headers are valid; qed");

let response = self.client.request(req).await.map_err(|e| Error::Http(Box::new(e)))?;
if response.status().is_success() {
Expand Down Expand Up @@ -179,7 +194,7 @@ where

#[cfg(test)]
mod tests {
use super::{CertificateStore, Error, HttpTransportClient};
use super::*;

fn assert_target(
client: &HttpTransportClient,
Expand All @@ -198,37 +213,50 @@ mod tests {

#[test]
fn invalid_http_url_rejected() {
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new("ws://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[cfg(feature = "tls")]
#[test]
fn https_works() {
let client = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap();
let client =
HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap();
assert_target(&client, "localhost", "https", "/", 9933, 80);
}

#[cfg(not(feature = "tls"))]
#[test]
fn https_fails_without_tls_feature() {
let err = HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("https://localhost:9933", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn faulty_port() {
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80).unwrap_err();
let err = HttpTransportClient::new("http://localhost:-43", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
let err = HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80).unwrap_err();
let err =
HttpTransportClient::new("http://localhost:-99999", 80, CertificateStore::Native, 80, HeaderMap::new())
.unwrap_err();
assert!(matches!(err, Error::Url(_)));
}

#[test]
fn url_with_path_works() {
let client =
HttpTransportClient::new("http://localhost:9944/my-special-path", 1337, CertificateStore::Native, 80)
.unwrap();
let client = HttpTransportClient::new(
"http://localhost:9944/my-special-path",
1337,
CertificateStore::Native,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "localhost", "http", "/my-special-path", 9944, 1337);
}

Expand All @@ -239,22 +267,31 @@ mod tests {
u32::MAX,
CertificateStore::WebPki,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my?name1=value1&name2=value2", 9999, u32::MAX);
}

#[test]
fn url_with_fragment_is_ignored() {
let client =
HttpTransportClient::new("http://127.0.0.1:9944/my.htm#ignore", 999, CertificateStore::Native, 80).unwrap();
let client = HttpTransportClient::new(
"http://127.0.0.1:9944/my.htm#ignore",
999,
CertificateStore::Native,
80,
HeaderMap::new(),
)
.unwrap();
assert_target(&client, "127.0.0.1", "http", "/my.htm", 9944, 999);
}

#[tokio::test]
async fn request_limit_works() {
let eighty_bytes_limit = 80;
let client = HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99).unwrap();
let client =
HttpTransportClient::new("http://localhost:9933", 80, CertificateStore::WebPki, 99, HeaderMap::new())
.unwrap();
assert_eq!(client.max_request_body_size, eighty_bytes_limit);

let body = "a".repeat(81);
Expand Down
28 changes: 16 additions & 12 deletions client/transport/src/ws/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ use stream::EitherStream;
use thiserror::Error;
use tokio::net::TcpStream;

pub use http::{uri::InvalidUri, Uri};
pub use http::{uri::InvalidUri, HeaderMap, HeaderValue, Uri};
pub use soketto::handshake::client::Header;

/// Sending end of WebSocket transport.
Expand All @@ -59,33 +59,32 @@ pub struct Receiver {

/// Builder for a WebSocket transport [`Sender`] and ['Receiver`] pair.
#[derive(Debug)]
pub struct WsTransportClientBuilder<'a> {
pub struct WsTransportClientBuilder {
/// What certificate store to use
pub certificate_store: CertificateStore,
/// Timeout for the connection.
pub connection_timeout: Duration,
/// Custom headers to pass during the HTTP handshake. If `None`, no
/// custom header is passed.
pub headers: Vec<Header<'a>>,
/// Custom headers to pass during the HTTP handshake.
pub headers: http::HeaderMap,
/// Max payload size
pub max_request_body_size: u32,
/// Max number of redirections.
pub max_redirections: usize,
}

impl<'a> Default for WsTransportClientBuilder<'a> {
impl Default for WsTransportClientBuilder {
fn default() -> Self {
Self {
certificate_store: CertificateStore::Native,
max_request_body_size: TEN_MB_SIZE_BYTES,
connection_timeout: Duration::from_secs(10),
headers: Vec::new(),
headers: http::HeaderMap::new(),
max_redirections: 5,
}
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Set whether to use system certificates (default is native).
pub fn certificate_store(mut self, certificate_store: CertificateStore) -> Self {
self.certificate_store = certificate_store;
Expand All @@ -107,8 +106,8 @@ impl<'a> WsTransportClientBuilder<'a> {
/// Set a custom header passed to the server during the handshake (default is none).
///
/// The caller is responsible for checking that the headers do not conflict or are duplicated.
pub fn add_header(mut self, name: &'a str, value: &'a str) -> Self {
self.headers.push(Header { name, value: value.as_bytes() });
pub fn set_headers(mut self, headers: http::HeaderMap) -> Self {
self.headers = headers;
self
}

Expand Down Expand Up @@ -240,7 +239,7 @@ impl TransportReceiverT for Receiver {
}
}

impl<'a> WsTransportClientBuilder<'a> {
impl WsTransportClientBuilder {
/// Try to establish the connection.
pub async fn build(self, uri: Uri) -> Result<(Sender, Receiver), WsHandshakeError> {
let target: Target = uri.try_into()?;
Expand Down Expand Up @@ -289,7 +288,12 @@ impl<'a> WsTransportClientBuilder<'a> {
&target.path_and_query,
);

client.set_headers(&self.headers);
let headers: Vec<_> = self
.headers
.iter()
.map(|(key, value)| Header { name: key.as_str(), value: value.as_bytes() })
.collect();
client.set_headers(&headers);

// Perform the initial handshake.
match client.handshake().await {
Expand Down
1 change: 1 addition & 0 deletions client/ws-client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ documentation = "https://docs.rs/jsonrpsee-ws-client"
jsonrpsee-types = { path = "../../types", version = "0.14.0" }
jsonrpsee-client-transport = { path = "../transport", version = "0.14.0", features = ["ws"] }
jsonrpsee-core = { path = "../../core", version = "0.14.0", features = ["async-client"] }
http = "0.2.0"

[dev-dependencies]
tracing-subscriber = { version = "0.3.3", features = ["env-filter"] }
Expand Down
Loading

0 comments on commit 0ccfbd7

Please sign in to comment.