From 0a632a4cdf2e5f7520e40eae857656e86cc5fed7 Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 23 Oct 2020 10:03:24 +0200 Subject: [PATCH 1/2] chore: add `rustfmt.toml` for formatting --- rustfmt.toml | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 rustfmt.toml diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000000..c699603f59 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,4 @@ +hard_tabs = true +max_width = 120 +use_small_heuristics = "Max" +edition = "2018" From dcf891fa87bf32abe81809d3066073a5920988cc Mon Sep 17 00:00:00 2001 From: Niklas Date: Fri, 23 Oct 2020 10:04:26 +0200 Subject: [PATCH 2/2] style: `cargo fmt --all` with new config --- examples/http.rs | 36 +- examples/subscription.rs | 55 +- examples/ws.rs | 36 +- proc-macros/src/api_def.rs | 277 +++--- proc-macros/src/lib.rs | 819 ++++++++-------- src/client/http/client.rs | 309 +++--- src/client/http/raw.rs | 479 +++++---- src/client/http/transport.rs | 443 ++++----- src/client/mod.rs | 4 +- src/client/ws/client.rs | 611 ++++++------ src/client/ws/raw.rs | 1182 +++++++++++------------ src/client/ws/stream.rs | 112 +-- src/client/ws/transport.rs | 476 +++++---- src/common/error.rs | 278 +++--- src/common/id.rs | 68 +- src/common/mod.rs | 2 +- src/common/params.rs | 161 ++- src/common/request.rs | 507 +++++----- src/common/response.rs | 363 +++---- src/common/version.rs | 58 +- src/http/raw/batch.rs | 562 +++++------ src/http/raw/batches.rs | 675 ++++++------- src/http/raw/core.rs | 780 +++++++-------- src/http/raw/mod.rs | 4 +- src/http/raw/notification.rs | 37 +- src/http/raw/params.rs | 190 ++-- src/http/raw/tests.rs | 93 +- src/http/raw/typed_rp.rs | 52 +- src/http/server.rs | 778 +++++++-------- src/http/server_utils/access_control.rs | 255 +++-- src/http/server_utils/cors.rs | 1034 ++++++++++---------- src/http/server_utils/hosts.rs | 355 ++++--- src/http/server_utils/matcher.rs | 64 +- src/http/server_utils/utils.rs | 2 +- src/http/tests.rs | 176 ++-- src/http/transport/background.rs | 388 ++++---- src/http/transport/mod.rs | 401 ++++---- src/http/transport/response.rs | 89 +- src/ws/mod.rs | 4 +- src/ws/raw/batch.rs | 562 +++++------ src/ws/raw/batches.rs | 675 ++++++------- src/ws/raw/core.rs | 780 +++++++-------- src/ws/raw/mod.rs | 4 +- src/ws/raw/notification.rs | 37 +- src/ws/raw/params.rs | 190 ++-- src/ws/raw/tests.rs | 100 +- src/ws/raw/typed_rp.rs | 52 +- src/ws/server.rs | 837 ++++++++-------- src/ws/tests.rs | 431 ++++----- src/ws/transport.rs | 801 ++++++++------- test-utils/src/helpers.rs | 80 +- test-utils/src/types.rs | 82 +- 52 files changed, 7953 insertions(+), 8893 deletions(-) diff --git a/examples/http.rs b/examples/http.rs index 4b091a7dd3..4a6ab445ed 100644 --- a/examples/http.rs +++ b/examples/http.rs @@ -35,29 +35,29 @@ const SERVER_URI: &str = "http://localhost:9933"; #[async_std::main] async fn main() -> Result<(), Box> { - env_logger::init(); + env_logger::init(); - let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); - let _server = task::spawn(async move { - run_server(server_started_tx, SOCK_ADDR).await; - }); + let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); + let _server = task::spawn(async move { + run_server(server_started_tx, SOCK_ADDR).await; + }); - server_started_rx.await?; + server_started_rx.await?; - let client = HttpClient::new(SERVER_URI); - let response: Result = client.request("say_hello", Params::None).await; - println!("r: {:?}", response); + let client = HttpClient::new(SERVER_URI); + let response: Result = client.request("say_hello", Params::None).await; + println!("r: {:?}", response); - Ok(()) + Ok(()) } async fn run_server(server_started_tx: Sender<()>, url: &str) { - let server = HttpServer::new(url).await.unwrap(); - let mut say_hello = server.register_method("say_hello".to_string()).unwrap(); - - server_started_tx.send(()).unwrap(); - loop { - let r = say_hello.next().await; - r.respond(Ok(JsonValue::String("lo".to_owned()))).await; - } + let server = HttpServer::new(url).await.unwrap(); + let mut say_hello = server.register_method("say_hello".to_string()).unwrap(); + + server_started_tx.send(()).unwrap(); + loop { + let r = say_hello.next().await; + r.respond(Ok(JsonValue::String("lo".to_owned()))).await; + } } diff --git a/examples/subscription.rs b/examples/subscription.rs index 87a1e24e1c..0bb00609ac 100644 --- a/examples/subscription.rs +++ b/examples/subscription.rs @@ -36,43 +36,36 @@ const NUM_SUBSCRIPTION_RESPONSES: usize = 10; #[async_std::main] async fn main() -> Result<(), Box> { - env_logger::init(); + env_logger::init(); - let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); - let _server = task::spawn(async move { - run_server(server_started_tx, SOCK_ADDR).await; - }); + let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); + let _server = task::spawn(async move { + run_server(server_started_tx, SOCK_ADDR).await; + }); - server_started_rx.await?; - let client = WsClient::new(SERVER_URI).await?; - let mut subscribe_hello: WsSubscription = client - .subscribe("subscribe_hello", Params::None, "unsubscribe_hello") - .await?; + server_started_rx.await?; + let client = WsClient::new(SERVER_URI).await?; + let mut subscribe_hello: WsSubscription = + client.subscribe("subscribe_hello", Params::None, "unsubscribe_hello").await?; - let mut i = 0; - while i <= NUM_SUBSCRIPTION_RESPONSES { - let r = subscribe_hello.next().await; - log::debug!("received {:?}", r); - i += 1; - } + let mut i = 0; + while i <= NUM_SUBSCRIPTION_RESPONSES { + let r = subscribe_hello.next().await; + log::debug!("received {:?}", r); + i += 1; + } - Ok(()) + Ok(()) } async fn run_server(server_started_tx: Sender<()>, url: &str) { - let server = WsServer::new(url).await.unwrap(); - let mut subscription = server - .register_subscription( - "subscribe_hello".to_string(), - "unsubscribe_hello".to_string(), - ) - .unwrap(); + let server = WsServer::new(url).await.unwrap(); + let mut subscription = + server.register_subscription("subscribe_hello".to_string(), "unsubscribe_hello".to_string()).unwrap(); - server_started_tx.send(()).unwrap(); - loop { - subscription - .send(JsonValue::String("hello my friend".to_owned())) - .await; - std::thread::sleep(std::time::Duration::from_secs(1)); - } + server_started_tx.send(()).unwrap(); + loop { + subscription.send(JsonValue::String("hello my friend".to_owned())).await; + std::thread::sleep(std::time::Duration::from_secs(1)); + } } diff --git a/examples/ws.rs b/examples/ws.rs index 930431d00e..5899205dca 100644 --- a/examples/ws.rs +++ b/examples/ws.rs @@ -35,28 +35,28 @@ const SERVER_URI: &str = "ws://localhost:9944"; #[async_std::main] async fn main() -> Result<(), Box> { - env_logger::init(); + env_logger::init(); - let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); - let _server = task::spawn(async move { - run_server(server_started_tx, SOCK_ADDR).await; - }); + let (server_started_tx, server_started_rx) = oneshot::channel::<()>(); + let _server = task::spawn(async move { + run_server(server_started_tx, SOCK_ADDR).await; + }); - server_started_rx.await?; - let client = WsClient::new(SERVER_URI).await?; - let response: JsonValue = client.request("say_hello", Params::None).await?; - println!("r: {:?}", response); + server_started_rx.await?; + let client = WsClient::new(SERVER_URI).await?; + let response: JsonValue = client.request("say_hello", Params::None).await?; + println!("r: {:?}", response); - Ok(()) + Ok(()) } async fn run_server(server_started_tx: Sender<()>, url: &str) { - let server = WsServer::new(url).await.unwrap(); - let mut say_hello = server.register_method("say_hello".to_string()).unwrap(); - - server_started_tx.send(()).unwrap(); - loop { - let r = say_hello.next().await; - r.respond(Ok(JsonValue::String("lo".to_owned()))).await; - } + let server = WsServer::new(url).await.unwrap(); + let mut say_hello = server.register_method("say_hello".to_string()).unwrap(); + + server_started_tx.send(()).unwrap(); + loop { + let r = say_hello.next().await; + r.respond(Ok(JsonValue::String("lo".to_owned()))).await; + } } diff --git a/proc-macros/src/api_def.rs b/proc-macros/src/api_def.rs index 5065cdb3cc..577ec7098a 100644 --- a/proc-macros/src/api_def.rs +++ b/proc-macros/src/api_def.rs @@ -34,202 +34,187 @@ use syn::spanned::Spanned as _; /// Represents the entire content of the procedural macro. #[derive(Debug)] pub struct ApiDefinitions { - pub apis: Vec, + pub apis: Vec, } /// A single API defined by the user. #[derive(Debug)] pub struct ApiDefinition { - /// Visibility of the definition (e.g. `pub`, `pub(crate)`, ...). - pub visibility: syn::Visibility, - /// Name of the API. For example `System`. - pub name: syn::Ident, - /// Optional generics for the API name. - pub generics: syn::Generics, - /// List of RPC functions defined for this API. - pub definitions: Vec, + /// Visibility of the definition (e.g. `pub`, `pub(crate)`, ...). + pub visibility: syn::Visibility, + /// Name of the API. For example `System`. + pub name: syn::Ident, + /// Optional generics for the API name. + pub generics: syn::Generics, + /// List of RPC functions defined for this API. + pub definitions: Vec, } /// A single JSON-RPC method definition. #[derive(Debug)] pub struct ApiMethod { - /// Signature of the method. - pub signature: syn::Signature, - /// Attributes on the method. - pub attributes: ApiMethodAttrs, + /// Signature of the method. + pub signature: syn::Signature, + /// Attributes on the method. + pub attributes: ApiMethodAttrs, } /// List of attributes applied to a method. #[derive(Debug, Default)] pub struct ApiMethodAttrs { - /// Name of the RPC method, if specified. - pub method: Option, - /// Whether the params are by-position (ie. a JSON array) or by-name (ie. a JSON object). - pub positional_params: bool, + /// Name of the RPC method, if specified. + pub method: Option, + /// Whether the params are by-position (ie. a JSON array) or by-name (ie. a JSON object). + pub positional_params: bool, } impl ApiMethod { - /// Returns true if this method has a `()` return type. - /// - /// This is used to determine whether this should be a notification or a method call. - pub fn is_void_ret_type(&self) -> bool { - let ret_ty = match &self.signature.output { - syn::ReturnType::Default => return true, - syn::ReturnType::Type(_, ty) => ty, - }; - - let tuple_ret_ty = match &**ret_ty { - syn::Type::Tuple(tuple) => tuple, - _ => return false, - }; - - tuple_ret_ty.elems.is_empty() - } + /// Returns true if this method has a `()` return type. + /// + /// This is used to determine whether this should be a notification or a method call. + pub fn is_void_ret_type(&self) -> bool { + let ret_ty = match &self.signature.output { + syn::ReturnType::Default => return true, + syn::ReturnType::Type(_, ty) => ty, + }; + + let tuple_ret_ty = match &**ret_ty { + syn::Type::Tuple(tuple) => tuple, + _ => return false, + }; + + tuple_ret_ty.elems.is_empty() + } } /// Implementation detail of `ApiDefinition`. /// Parses one single block of function definitions. #[derive(Debug)] struct ApiMethods { - definitions: Vec, + definitions: Vec, } /// Implementation detail of `ApiMethodAttrs`. /// Parses a single attribute. enum ApiMethodAttr { - Method(syn::LitStr), - PositionalParams, + Method(syn::LitStr), + PositionalParams, } impl syn::parse::Parse for ApiDefinitions { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let mut out = ApiDefinitions { apis: Vec::new() }; + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let mut out = ApiDefinitions { apis: Vec::new() }; - while !input.is_empty() { - out.apis.push(input.parse()?); - } + while !input.is_empty() { + out.apis.push(input.parse()?); + } - Ok(out) - } + Ok(out) + } } impl syn::parse::Parse for ApiDefinition { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let visibility = input.parse()?; - let name = input.parse()?; - let generics = input.parse()?; - let group: proc_macro2::Group = input.parse()?; - assert_eq!(group.delimiter(), proc_macro2::Delimiter::Brace); - let defs: ApiMethods = syn::parse2(group.stream())?; - - Ok(ApiDefinition { - visibility, - name, - generics, - definitions: defs.definitions, - }) - } + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let visibility = input.parse()?; + let name = input.parse()?; + let generics = input.parse()?; + let group: proc_macro2::Group = input.parse()?; + assert_eq!(group.delimiter(), proc_macro2::Delimiter::Brace); + let defs: ApiMethods = syn::parse2(group.stream())?; + + Ok(ApiDefinition { visibility, name, generics, definitions: defs.definitions }) + } } impl syn::parse::Parse for ApiMethod { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let item: syn::TraitItemMethod = input.parse()?; - if item.default.is_some() { - return Err(syn::Error::new(item.default.span(), - "It is forbidden to provide a default implementation for methods in the API definition")); - } - - let mut attributes = ApiMethodAttrs::default(); - for attribute in &item.attrs { - if attribute.path.is_ident("rpc") { - let attrs = attribute.parse_args()?; - attributes.try_merge(attrs)?; - } else { - // TODO: do we copy the attributes somewhere in the output? - } - } - - Ok(ApiMethod { - signature: item.sig, - attributes, - }) - } + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let item: syn::TraitItemMethod = input.parse()?; + if item.default.is_some() { + return Err(syn::Error::new( + item.default.span(), + "It is forbidden to provide a default implementation for methods in the API definition", + )); + } + + let mut attributes = ApiMethodAttrs::default(); + for attribute in &item.attrs { + if attribute.path.is_ident("rpc") { + let attrs = attribute.parse_args()?; + attributes.try_merge(attrs)?; + } else { + // TODO: do we copy the attributes somewhere in the output? + } + } + + Ok(ApiMethod { signature: item.sig, attributes }) + } } impl ApiMethodAttrs { - /// Tries to merge another `ApiMethodAttrs` within this one. Returns an error if there is an - /// overlap in the attributes. - // TODO: span - fn try_merge(&mut self, other: ApiMethodAttrs) -> syn::parse::Result<()> { - if let Some(method) = other.method { - if self.method.is_some() { - // TODO: return Err(()) - } - self.method = Some(method); - } - - if other.positional_params { - self.positional_params = true; - } - - Ok(()) - } + /// Tries to merge another `ApiMethodAttrs` within this one. Returns an error if there is an + /// overlap in the attributes. + // TODO: span + fn try_merge(&mut self, other: ApiMethodAttrs) -> syn::parse::Result<()> { + if let Some(method) = other.method { + if self.method.is_some() { + // TODO: return Err(()) + } + self.method = Some(method); + } + + if other.positional_params { + self.positional_params = true; + } + + Ok(()) + } } impl syn::parse::Parse for ApiMethodAttrs { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let mut out = ApiMethodAttrs::default(); - - let list = input - .parse_terminated::<_, syn::token::Comma>(|input| input.parse::())?; - for attr in list { - match attr { - ApiMethodAttr::Method(method) => { - if out.method.is_some() { - return Err(syn::Error::new( - method.span(), - "Duplicate method attribute found", - )); - } - out.method = Some(method.value()); - } - ApiMethodAttr::PositionalParams => out.positional_params = true, - } - } - Ok(out) - } + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let mut out = ApiMethodAttrs::default(); + + let list = input.parse_terminated::<_, syn::token::Comma>(|input| input.parse::())?; + for attr in list { + match attr { + ApiMethodAttr::Method(method) => { + if out.method.is_some() { + return Err(syn::Error::new(method.span(), "Duplicate method attribute found")); + } + out.method = Some(method.value()); + } + ApiMethodAttr::PositionalParams => out.positional_params = true, + } + } + Ok(out) + } } impl syn::parse::Parse for ApiMethodAttr { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let attr: syn::Ident = input.parse()?; - - if attr == "method" { - let _: syn::token::Eq = input.parse()?; - let val = input.parse()?; - Ok(ApiMethodAttr::Method(val)) - } else if attr == "positional_params" { - Ok(ApiMethodAttr::PositionalParams) - } else { - Err(syn::Error::new( - attr.span(), - &format!("Unknown attribute: {}", attr.to_string()), - )) - } - } + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let attr: syn::Ident = input.parse()?; + + if attr == "method" { + let _: syn::token::Eq = input.parse()?; + let val = input.parse()?; + Ok(ApiMethodAttr::Method(val)) + } else if attr == "positional_params" { + Ok(ApiMethodAttr::PositionalParams) + } else { + Err(syn::Error::new(attr.span(), &format!("Unknown attribute: {}", attr.to_string()))) + } + } } impl syn::parse::Parse for ApiMethods { - fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { - let mut out = ApiMethods { - definitions: Vec::new(), - }; - - while !input.is_empty() { - let method: ApiMethod = input.parse()?; - out.definitions.push(method); - } - - Ok(out) - } + fn parse(input: syn::parse::ParseStream) -> syn::parse::Result { + let mut out = ApiMethods { definitions: Vec::new() }; + + while !input.is_empty() { + let method: ApiMethod = input.parse()?; + out.definitions.push(method); + } + + Ok(out) + } } diff --git a/proc-macros/src/lib.rs b/proc-macros/src/lib.rs index 883ed28b40..b6cd001ee2 100644 --- a/proc-macros/src/lib.rs +++ b/proc-macros/src/lib.rs @@ -69,420 +69,387 @@ mod api_def; /// #[proc_macro] pub fn rpc_api(input_token_stream: TokenStream) -> TokenStream { - // Start by parsing the input into what we expect. - let defs: api_def::ApiDefinitions = match syn::parse(input_token_stream) { - Ok(d) => d, - Err(err) => return err.to_compile_error().into(), - }; - - let mut out = Vec::with_capacity(defs.apis.len()); - for api in defs.apis { - match build_api(api) { - Ok(a) => out.push(a), - Err(err) => return err.to_compile_error().into(), - }; - } - - TokenStream::from(quote! { - #(#out)* - }) + // Start by parsing the input into what we expect. + let defs: api_def::ApiDefinitions = match syn::parse(input_token_stream) { + Ok(d) => d, + Err(err) => return err.to_compile_error().into(), + }; + + let mut out = Vec::with_capacity(defs.apis.len()); + for api in defs.apis { + match build_api(api) { + Ok(a) => out.push(a), + Err(err) => return err.to_compile_error().into(), + }; + } + + TokenStream::from(quote! { + #(#out)* + }) } /// Generates the macro output token stream corresponding to a single API. fn build_api(api: api_def::ApiDefinition) -> Result { - let enum_name = &api.name; - - // TODO: make sure there's no conflict here - let mut tweaked_generics = api.generics.clone(); - tweaked_generics.params.insert( - 0, - From::from(syn::LifetimeDef::new( - syn::parse_str::("'a").unwrap(), - )), - ); - tweaked_generics - .params - .push(From::from(syn::TypeParam::from( - syn::parse_str::("R").unwrap(), - ))); - tweaked_generics - .params - .push(From::from(syn::TypeParam::from( - syn::parse_str::("I").unwrap(), - ))); - let (impl_generics, ty_generics, where_clause) = tweaked_generics.split_for_impl(); - let generics = api - .generics - .params - .iter() - .filter_map(|gp| { - if let syn::GenericParam::Type(tp) = gp { - Some(tp.ident.clone()) - } else { - None - } - }) - .collect::>(); - - let visibility = &api.visibility; - - let mut variants = Vec::new(); - let mut tmp_variants = Vec::new(); - for function in &api.definitions { - let function_is_notification = function.is_void_ret_type(); - let variant_name = snake_case_to_camel_case(&function.signature.ident); - let ret = match &function.signature.output { - syn::ReturnType::Default => quote! {()}, - syn::ReturnType::Type(_, ty) => quote_spanned!(ty.span()=> #ty), - }; - - let mut params_list = Vec::new(); - for input in function.signature.inputs.iter() { - let (ty, pat_span, param_variant_name) = match input { - syn::FnArg::Receiver(_) => { - return Err(syn::Error::new( - input.span(), - "Having `self` is not allowed in RPC queries definitions", - )); - } - syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => { - (ty, pat.span(), param_variant_name(&pat)?) - } - }; - - params_list.push(quote_spanned!(pat_span=> #param_variant_name: #ty)); - } - - if !function_is_notification { - if params_list.is_empty() { - tmp_variants.push(quote_spanned!(function.signature.ident.span()=> #variant_name)); - } else { - tmp_variants.push(quote_spanned!(function.signature.ident.span()=> - #variant_name { - #(#params_list,)* - } - )); - } - } - - if function_is_notification { - variants.push(quote_spanned!(function.signature.ident.span()=> - #variant_name { - #(#params_list,)* - } - )); - } else { - variants.push(quote_spanned!(function.signature.ident.span()=> - #variant_name { - respond: jsonrpsee::raw::server::TypedResponder<'a, R, I, #ret>, - #(#params_list,)* - } - )); - } - } - - let next_request = { - let mut notifications_blocks = Vec::new(); - let mut function_blocks = Vec::new(); - let mut tmp_to_rq = Vec::new(); - - struct GenericParams { - generics: HashSet, - types: HashSet, - } - impl<'ast> syn::visit::Visit<'ast> for GenericParams { - fn visit_ident(&mut self, ident: &'ast syn::Ident) { - if self.generics.contains(ident) { - self.types.insert(ident.clone()); - } - } - } - - let mut generic_params = GenericParams { - generics, - types: HashSet::new(), - }; - - for function in &api.definitions { - let function_is_notification = function.is_void_ret_type(); - let variant_name = snake_case_to_camel_case(&function.signature.ident); - let rpc_method_name = function - .attributes - .method - .clone() - .unwrap_or_else(|| function.signature.ident.to_string()); - - let mut params_builders = Vec::new(); - let mut params_names_list = Vec::new(); - - for input in function.signature.inputs.iter() { - let (ty, param_variant_name, rpc_param_name) = match input { - syn::FnArg::Receiver(_) => { - return Err(syn::Error::new( - input.span(), - "Having `self` is not allowed in RPC queries definitions", - )); - } - syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { - (ty, param_variant_name(&pat)?, rpc_param_name(&pat, &attrs)?) - } - }; - - syn::visit::visit_type(&mut generic_params, &ty); - - params_names_list - .push(quote_spanned!(function.signature.span()=> #param_variant_name)); - if !function_is_notification { - params_builders.push(quote_spanned!(function.signature.span()=> - let #param_variant_name: #ty = { - match request.params().get(#rpc_param_name) { - Ok(v) => v, - Err(_) => { - // TODO: message - request.respond(Err(jsonrpsee::common::Error::invalid_params(#rpc_param_name))); - continue; - } - } - }; - )); - } else { - params_builders.push(quote_spanned!(function.signature.span()=> - let #param_variant_name: #ty = { - match request.params().get(#rpc_param_name) { - Ok(v) => v, - Err(_) => { - // TODO: log this? - continue; - } - } - }; - )); - } - } - - if function_is_notification { - notifications_blocks.push(quote_spanned!(function.signature.span()=> - if method == #rpc_method_name { - let request = n; - #(#params_builders)* - return Ok(#enum_name::#variant_name { #(#params_names_list),* }); - } - )); - } else { - function_blocks.push(quote_spanned!(function.signature.span()=> - if request_outcome.is_none() && method == #rpc_method_name { - let request = server.request_by_id(&request_id).unwrap(); - #(#params_builders)* - request_outcome = Some(Tmp::#variant_name { #(#params_names_list),* }); - } - )); - - tmp_to_rq.push(quote_spanned!(function.signature.span()=> - Some(Tmp::#variant_name { #(#params_names_list),* }) => { - let request = server.request_by_id(&request_id).unwrap(); - let respond = jsonrpsee::raw::server::TypedResponder::from(request); - return Ok(#enum_name::#variant_name { respond #(, #params_names_list)* }); - }, - )); - } - } - - let params_tys = generic_params.types.iter(); - - let tmp_generics = if generic_params.types.is_empty() { - quote!() - } else { - quote_spanned!(api.name.span()=> - <#(#params_tys,)*> - ) - }; - - let on_request = quote_spanned!(api.name.span()=> { - #[allow(unused)] // The enum might be empty - enum Tmp #tmp_generics { - #(#tmp_variants,)* - } - - let request_id = r.id(); - let method = r.method().to_owned(); - - let mut request_outcome: Option = None; - - #(#function_blocks)* - - match request_outcome { - #(#tmp_to_rq)* - None => server.request_by_id(&request_id).unwrap().respond(Err(jsonrpsee::common::Error::method_not_found())), - } - }); - - let on_notification = quote_spanned!(api.name.span()=> { - let method = n.method().to_owned(); - #(#notifications_blocks)* - // TODO: we received an unknown notification; log this? - }); - - let params_tys = generic_params.types.iter(); - - quote_spanned!(api.name.span()=> - #visibility async fn next_request(server: &'a mut jsonrpsee::raw::RawServer) -> core::result::Result<#enum_name #ty_generics, std::io::Error> - where - R: jsonrpsee::transport::TransportServer, - I: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync - #(, #params_tys: jsonrpsee::common::DeserializeOwned)* - { - loop { - match server.next_event().await { - jsonrpsee::raw::RawServerEvent::Notification(n) => #on_notification, - jsonrpsee::raw::RawServerEvent::SubscriptionsClosed(_) => unimplemented!(), // TODO: - jsonrpsee::raw::RawServerEvent::SubscriptionsReady(_) => unimplemented!(), // TODO: - jsonrpsee::raw::RawServerEvent::Request(r) => #on_request, - } - } - } - ) - }; - - let client_impl_block = build_client_impl(&api)?; - let debug_variants = build_debug_variants(&api)?; - - Ok(quote_spanned!(api.name.span()=> - #visibility enum #enum_name #tweaked_generics { - #(#variants),* - } - - impl #impl_generics #enum_name #ty_generics #where_clause { - #next_request - } - - #client_impl_block - - impl #impl_generics std::fmt::Debug for #enum_name #ty_generics #where_clause { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - #(#debug_variants,)* - } - } - } - )) + let enum_name = &api.name; + + // TODO: make sure there's no conflict here + let mut tweaked_generics = api.generics.clone(); + tweaked_generics + .params + .insert(0, From::from(syn::LifetimeDef::new(syn::parse_str::("'a").unwrap()))); + tweaked_generics.params.push(From::from(syn::TypeParam::from(syn::parse_str::("R").unwrap()))); + tweaked_generics.params.push(From::from(syn::TypeParam::from(syn::parse_str::("I").unwrap()))); + let (impl_generics, ty_generics, where_clause) = tweaked_generics.split_for_impl(); + let generics = api + .generics + .params + .iter() + .filter_map(|gp| if let syn::GenericParam::Type(tp) = gp { Some(tp.ident.clone()) } else { None }) + .collect::>(); + + let visibility = &api.visibility; + + let mut variants = Vec::new(); + let mut tmp_variants = Vec::new(); + for function in &api.definitions { + let function_is_notification = function.is_void_ret_type(); + let variant_name = snake_case_to_camel_case(&function.signature.ident); + let ret = match &function.signature.output { + syn::ReturnType::Default => quote! {()}, + syn::ReturnType::Type(_, ty) => quote_spanned!(ty.span()=> #ty), + }; + + let mut params_list = Vec::new(); + for input in function.signature.inputs.iter() { + let (ty, pat_span, param_variant_name) = match input { + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new( + input.span(), + "Having `self` is not allowed in RPC queries definitions", + )); + } + syn::FnArg::Typed(syn::PatType { ty, pat, .. }) => (ty, pat.span(), param_variant_name(&pat)?), + }; + + params_list.push(quote_spanned!(pat_span=> #param_variant_name: #ty)); + } + + if !function_is_notification { + if params_list.is_empty() { + tmp_variants.push(quote_spanned!(function.signature.ident.span()=> #variant_name)); + } else { + tmp_variants.push(quote_spanned!(function.signature.ident.span()=> + #variant_name { + #(#params_list,)* + } + )); + } + } + + if function_is_notification { + variants.push(quote_spanned!(function.signature.ident.span()=> + #variant_name { + #(#params_list,)* + } + )); + } else { + variants.push(quote_spanned!(function.signature.ident.span()=> + #variant_name { + respond: jsonrpsee::raw::server::TypedResponder<'a, R, I, #ret>, + #(#params_list,)* + } + )); + } + } + + let next_request = { + let mut notifications_blocks = Vec::new(); + let mut function_blocks = Vec::new(); + let mut tmp_to_rq = Vec::new(); + + struct GenericParams { + generics: HashSet, + types: HashSet, + } + impl<'ast> syn::visit::Visit<'ast> for GenericParams { + fn visit_ident(&mut self, ident: &'ast syn::Ident) { + if self.generics.contains(ident) { + self.types.insert(ident.clone()); + } + } + } + + let mut generic_params = GenericParams { generics, types: HashSet::new() }; + + for function in &api.definitions { + let function_is_notification = function.is_void_ret_type(); + let variant_name = snake_case_to_camel_case(&function.signature.ident); + let rpc_method_name = + function.attributes.method.clone().unwrap_or_else(|| function.signature.ident.to_string()); + + let mut params_builders = Vec::new(); + let mut params_names_list = Vec::new(); + + for input in function.signature.inputs.iter() { + let (ty, param_variant_name, rpc_param_name) = match input { + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new( + input.span(), + "Having `self` is not allowed in RPC queries definitions", + )); + } + syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { + (ty, param_variant_name(&pat)?, rpc_param_name(&pat, &attrs)?) + } + }; + + syn::visit::visit_type(&mut generic_params, &ty); + + params_names_list.push(quote_spanned!(function.signature.span()=> #param_variant_name)); + if !function_is_notification { + params_builders.push(quote_spanned!(function.signature.span()=> + let #param_variant_name: #ty = { + match request.params().get(#rpc_param_name) { + Ok(v) => v, + Err(_) => { + // TODO: message + request.respond(Err(jsonrpsee::common::Error::invalid_params(#rpc_param_name))); + continue; + } + } + }; + )); + } else { + params_builders.push(quote_spanned!(function.signature.span()=> + let #param_variant_name: #ty = { + match request.params().get(#rpc_param_name) { + Ok(v) => v, + Err(_) => { + // TODO: log this? + continue; + } + } + }; + )); + } + } + + if function_is_notification { + notifications_blocks.push(quote_spanned!(function.signature.span()=> + if method == #rpc_method_name { + let request = n; + #(#params_builders)* + return Ok(#enum_name::#variant_name { #(#params_names_list),* }); + } + )); + } else { + function_blocks.push(quote_spanned!(function.signature.span()=> + if request_outcome.is_none() && method == #rpc_method_name { + let request = server.request_by_id(&request_id).unwrap(); + #(#params_builders)* + request_outcome = Some(Tmp::#variant_name { #(#params_names_list),* }); + } + )); + + tmp_to_rq.push(quote_spanned!(function.signature.span()=> + Some(Tmp::#variant_name { #(#params_names_list),* }) => { + let request = server.request_by_id(&request_id).unwrap(); + let respond = jsonrpsee::raw::server::TypedResponder::from(request); + return Ok(#enum_name::#variant_name { respond #(, #params_names_list)* }); + }, + )); + } + } + + let params_tys = generic_params.types.iter(); + + let tmp_generics = if generic_params.types.is_empty() { + quote!() + } else { + quote_spanned!(api.name.span()=> + <#(#params_tys,)*> + ) + }; + + let on_request = quote_spanned!(api.name.span()=> { + #[allow(unused)] // The enum might be empty + enum Tmp #tmp_generics { + #(#tmp_variants,)* + } + + let request_id = r.id(); + let method = r.method().to_owned(); + + let mut request_outcome: Option = None; + + #(#function_blocks)* + + match request_outcome { + #(#tmp_to_rq)* + None => server.request_by_id(&request_id).unwrap().respond(Err(jsonrpsee::common::Error::method_not_found())), + } + }); + + let on_notification = quote_spanned!(api.name.span()=> { + let method = n.method().to_owned(); + #(#notifications_blocks)* + // TODO: we received an unknown notification; log this? + }); + + let params_tys = generic_params.types.iter(); + + quote_spanned!(api.name.span()=> + #visibility async fn next_request(server: &'a mut jsonrpsee::raw::RawServer) -> core::result::Result<#enum_name #ty_generics, std::io::Error> + where + R: jsonrpsee::transport::TransportServer, + I: Clone + PartialEq + Eq + std::hash::Hash + Send + Sync + #(, #params_tys: jsonrpsee::common::DeserializeOwned)* + { + loop { + match server.next_event().await { + jsonrpsee::raw::RawServerEvent::Notification(n) => #on_notification, + jsonrpsee::raw::RawServerEvent::SubscriptionsClosed(_) => unimplemented!(), // TODO: + jsonrpsee::raw::RawServerEvent::SubscriptionsReady(_) => unimplemented!(), // TODO: + jsonrpsee::raw::RawServerEvent::Request(r) => #on_request, + } + } + } + ) + }; + + let client_impl_block = build_client_impl(&api)?; + let debug_variants = build_debug_variants(&api)?; + + Ok(quote_spanned!(api.name.span()=> + #visibility enum #enum_name #tweaked_generics { + #(#variants),* + } + + impl #impl_generics #enum_name #ty_generics #where_clause { + #next_request + } + + #client_impl_block + + impl #impl_generics std::fmt::Debug for #enum_name #ty_generics #where_clause { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + #(#debug_variants,)* + } + } + } + )) } /// Builds the impl block that allow performing outbound JSON-RPC queries. /// /// Generates the `impl { }` block containing functions that perform RPC client calls. fn build_client_impl(api: &api_def::ApiDefinition) -> Result { - let enum_name = &api.name; - - let (impl_generics_org, _, where_clause_org) = api.generics.split_for_impl(); - let lifetimes_org = api.generics.lifetimes(); - let type_params_org = api.generics.type_params(); - let const_params_org = api.generics.const_params(); - - let client_functions = build_client_functions(&api)?; - - Ok(quote_spanned!(api.name.span()=> - // TODO: order between type_params and const_params is undecided - impl #impl_generics_org #enum_name<'static #(, #lifetimes_org)* #(, #type_params_org)* #(, #const_params_org)*, (), ()> - #where_clause_org - { - #(#client_functions)* - } - )) + let enum_name = &api.name; + + let (impl_generics_org, _, where_clause_org) = api.generics.split_for_impl(); + let lifetimes_org = api.generics.lifetimes(); + let type_params_org = api.generics.type_params(); + let const_params_org = api.generics.const_params(); + + let client_functions = build_client_functions(&api)?; + + Ok(quote_spanned!(api.name.span()=> + // TODO: order between type_params and const_params is undecided + impl #impl_generics_org #enum_name<'static #(, #lifetimes_org)* #(, #type_params_org)* #(, #const_params_org)*, (), ()> + #where_clause_org + { + #(#client_functions)* + } + )) } /// Builds the functions that allow performing outbound JSON-RPC queries. /// /// Generates a list of functions that perform RPC client calls. -fn build_client_functions( - api: &api_def::ApiDefinition, -) -> Result, syn::Error> { - let visibility = &api.visibility; - - let mut client_functions = Vec::new(); - for function in &api.definitions { - let f_name = &function.signature.ident; - let ret_ty = match function.signature.output { - syn::ReturnType::Default => quote!(()), - syn::ReturnType::Type(_, ref ty) => quote_spanned!(ty.span()=> #ty), - }; - let rpc_method_name = function - .attributes - .method - .clone() - .unwrap_or_else(|| function.signature.ident.to_string()); - - let mut params_list = Vec::new(); - let mut params_to_json = Vec::new(); - let mut params_to_array = Vec::new(); - let mut params_tys = Vec::new(); - - for (param_index, input) in function.signature.inputs.iter().enumerate() { - let (ty, pat_span, rpc_param_name) = match input { - syn::FnArg::Receiver(_) => { - return Err(syn::Error::new( - input.span(), - "Having `self` is not allowed in RPC queries definitions", - )); - } - syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { - (ty, pat.span(), rpc_param_name(&pat, &attrs)?) - } - }; - - let generated_param_name = syn::Ident::new( - &format!("param{}", param_index), - proc_macro2::Span::call_site(), - ); - - params_tys.push(ty); - params_list.push(quote_spanned!(pat_span=> #generated_param_name: impl Into<#ty>)); - params_to_json.push(quote_spanned!(pat_span=> - map.insert( - #rpc_param_name.to_string(), - jsonrpsee::common::to_value(#generated_param_name.into()).unwrap() // TODO: don't unwrap - ); - )); - params_to_array.push(quote_spanned!(pat_span=> - jsonrpsee::common::to_value(#generated_param_name.into()).unwrap() // TODO: don't unwrap - )); - } - - let params_building = if params_list.is_empty() { - quote! {jsonrpsee::common::Params::None} - } else if function.attributes.positional_params { - quote_spanned!(function.signature.span()=> - jsonrpsee::common::Params::Array(vec![ - #(#params_to_array),* - ]) - ) - } else { - let params_list_len = params_list.len(); - quote_spanned!(function.signature.span()=> - jsonrpsee::common::Params::Map({ - let mut map = jsonrpsee::common::JsonMap::with_capacity(#params_list_len); - #(#params_to_json)* - map - }) - ) - }; - - let is_notification = function.is_void_ret_type(); - let function_body = if is_notification { - quote_spanned!(function.signature.span()=> - client.send_notification(#rpc_method_name, #params_building).await - .map_err(jsonrpsee::raw::client::RawClientError::Inner)?; - Ok(()) - ) - } else { - quote_spanned!(function.signature.span()=> - let rq_id = client.start_request(#rpc_method_name, #params_building).await - .map_err(jsonrpsee::raw::client::RawClientError::Inner)?; - let data = client.request_by_id(rq_id).unwrap().await?; // TODO: don't unwrap? - Ok(jsonrpsee::common::from_value(data).unwrap()) // TODO: don't unwrap - ) - }; - - client_functions.push(quote_spanned!(function.signature.span()=> +fn build_client_functions(api: &api_def::ApiDefinition) -> Result, syn::Error> { + let visibility = &api.visibility; + + let mut client_functions = Vec::new(); + for function in &api.definitions { + let f_name = &function.signature.ident; + let ret_ty = match function.signature.output { + syn::ReturnType::Default => quote!(()), + syn::ReturnType::Type(_, ref ty) => quote_spanned!(ty.span()=> #ty), + }; + let rpc_method_name = + function.attributes.method.clone().unwrap_or_else(|| function.signature.ident.to_string()); + + let mut params_list = Vec::new(); + let mut params_to_json = Vec::new(); + let mut params_to_array = Vec::new(); + let mut params_tys = Vec::new(); + + for (param_index, input) in function.signature.inputs.iter().enumerate() { + let (ty, pat_span, rpc_param_name) = match input { + syn::FnArg::Receiver(_) => { + return Err(syn::Error::new( + input.span(), + "Having `self` is not allowed in RPC queries definitions", + )); + } + syn::FnArg::Typed(syn::PatType { ty, pat, attrs, .. }) => { + (ty, pat.span(), rpc_param_name(&pat, &attrs)?) + } + }; + + let generated_param_name = + syn::Ident::new(&format!("param{}", param_index), proc_macro2::Span::call_site()); + + params_tys.push(ty); + params_list.push(quote_spanned!(pat_span=> #generated_param_name: impl Into<#ty>)); + params_to_json.push(quote_spanned!(pat_span=> + map.insert( + #rpc_param_name.to_string(), + jsonrpsee::common::to_value(#generated_param_name.into()).unwrap() // TODO: don't unwrap + ); + )); + params_to_array.push(quote_spanned!(pat_span=> + jsonrpsee::common::to_value(#generated_param_name.into()).unwrap() // TODO: don't unwrap + )); + } + + let params_building = if params_list.is_empty() { + quote! {jsonrpsee::common::Params::None} + } else if function.attributes.positional_params { + quote_spanned!(function.signature.span()=> + jsonrpsee::common::Params::Array(vec![ + #(#params_to_array),* + ]) + ) + } else { + let params_list_len = params_list.len(); + quote_spanned!(function.signature.span()=> + jsonrpsee::common::Params::Map({ + let mut map = jsonrpsee::common::JsonMap::with_capacity(#params_list_len); + #(#params_to_json)* + map + }) + ) + }; + + let is_notification = function.is_void_ret_type(); + let function_body = if is_notification { + quote_spanned!(function.signature.span()=> + client.send_notification(#rpc_method_name, #params_building).await + .map_err(jsonrpsee::raw::client::RawClientError::Inner)?; + Ok(()) + ) + } else { + quote_spanned!(function.signature.span()=> + let rq_id = client.start_request(#rpc_method_name, #params_building).await + .map_err(jsonrpsee::raw::client::RawClientError::Inner)?; + let data = client.request_by_id(rq_id).unwrap().await?; // TODO: don't unwrap? + Ok(jsonrpsee::common::from_value(data).unwrap()) // TODO: don't unwrap + ) + }; + + client_functions.push(quote_spanned!(function.signature.span()=> // TODO: what if there's a conflict between `client` and a param name? #visibility async fn #f_name(client: &mut jsonrpsee::raw::RawClient #(, #params_list)*) -> core::result::Result<#ret_ty, jsonrpsee::raw::client::RawClientError<::Error>> @@ -493,48 +460,46 @@ fn build_client_functions( #function_body } )); - } + } - Ok(client_functions) + Ok(client_functions) } // TODO: better docs -fn build_debug_variants( - api: &api_def::ApiDefinition, -) -> Result, syn::Error> { - let enum_name = &api.name; - let mut debug_variants = Vec::new(); - for function in &api.definitions { - let variant_name = snake_case_to_camel_case(&function.signature.ident); - debug_variants.push(quote_spanned!(function.signature.ident.span()=> - #enum_name::#variant_name { /* TODO: params */ .. } => { - f.debug_struct(stringify!(#enum_name))/* TODO: params */.finish() - } - )); - } - Ok(debug_variants) +fn build_debug_variants(api: &api_def::ApiDefinition) -> Result, syn::Error> { + let enum_name = &api.name; + let mut debug_variants = Vec::new(); + for function in &api.definitions { + let variant_name = snake_case_to_camel_case(&function.signature.ident); + debug_variants.push(quote_spanned!(function.signature.ident.span()=> + #enum_name::#variant_name { /* TODO: params */ .. } => { + f.debug_struct(stringify!(#enum_name))/* TODO: params */.finish() + } + )); + } + Ok(debug_variants) } /// Turns a snake case function name into an UpperCamelCase name suitable to be an enum variant. fn snake_case_to_camel_case(snake_case: &syn::Ident) -> syn::Ident { - syn::Ident::new(&snake_case.to_string().to_pascal_case(), snake_case.span()) + syn::Ident::new(&snake_case.to_string().to_pascal_case(), snake_case.span()) } /// Determine the name of the variant in the enum based on the pattern of the function parameter. fn param_variant_name(pat: &syn::Pat) -> syn::parse::Result<&syn::Ident> { - match pat { - // TODO: check other fields of the `PatIdent` - syn::Pat::Ident(ident) => Ok(&ident.ident), - _ => unimplemented!(), - } + match pat { + // TODO: check other fields of the `PatIdent` + syn::Pat::Ident(ident) => Ok(&ident.ident), + _ => unimplemented!(), + } } /// Determine the name of the parameter based on the pattern. fn rpc_param_name(pat: &syn::Pat, _attrs: &[syn::Attribute]) -> syn::parse::Result { - // TODO: look in attributes if the user specified a param name - match pat { - // TODO: check other fields of the `PatIdent` - syn::Pat::Ident(ident) => Ok(ident.ident.to_string()), - _ => unimplemented!(), - } + // TODO: look in attributes if the user specified a param name + match pat { + // TODO: check other fields of the `PatIdent` + syn::Pat::Ident(ident) => Ok(ident.ident.to_string()), + _ => unimplemented!(), + } } diff --git a/src/client/http/client.rs b/src/client/http/client.rs index 3ec4166304..fda32bbce4 100644 --- a/src/client/http/client.rs +++ b/src/client/http/client.rs @@ -12,185 +12,164 @@ use futures::{channel::mpsc, channel::oneshot, future::Either, pin_mut, prelude: /// The communication is performed via a `mpsc` channel where the `Client` acts as simple frontend /// and just passes requests along to the backend (worker thread) pub struct Client { - backend: mpsc::Sender, + backend: mpsc::Sender, } /// Message that the [`Client`] can send to the background task. enum FrontToBack { - /// Send a one-shot notification to the server. The server doesn't give back any feedback. - Notification { - /// Method for the notification. - method: String, - /// Parameters to send to the server. - params: common::Params, - }, - - /// Send a request to the server. - StartRequest { - /// Method for the request. - method: String, - /// Parameters of the request. - params: common::Params, - /// One-shot channel where to send back the outcome of that request. - send_back: oneshot::Sender>, - }, + /// Send a one-shot notification to the server. The server doesn't give back any feedback. + Notification { + /// Method for the notification. + method: String, + /// Parameters to send to the server. + params: common::Params, + }, + + /// Send a request to the server. + StartRequest { + /// Method for the request. + method: String, + /// Parameters of the request. + params: common::Params, + /// One-shot channel where to send back the outcome of that request. + send_back: oneshot::Sender>, + }, } /// Error produced by [`Client::request`] and [`Client::subscribe`]. #[derive(Debug, thiserror::Error)] pub enum RequestError { - /// Networking error or error on the low-level protocol layer (e.g. missing field, - /// invalid ID, etc.). - #[error("Networking or low-level protocol error: {0}")] - TransportError(#[source] Box), - /// RawServer responded to our request with an error. - #[error("Server responded to our request with an error: {0:?}")] - Request(#[source] common::Error), - /// Failed to parse the data that the server sent back to us. - #[error("Parse error: {0}")] - ParseError(#[source] common::ParseError), + /// Networking error or error on the low-level protocol layer (e.g. missing field, + /// invalid ID, etc.). + #[error("Networking or low-level protocol error: {0}")] + TransportError(#[source] Box), + /// RawServer responded to our request with an error. + #[error("Server responded to our request with an error: {0:?}")] + Request(#[source] common::Error), + /// Failed to parse the data that the server sent back to us. + #[error("Parse error: {0}")] + ParseError(#[source] common::ParseError), } impl Client { - /// Create a client to connect to the server at address `endpoint` - pub fn new(endpoint: &str) -> Self { - let client = RawClient::new(HttpTransportClient::new(endpoint)); - - let (to_back, from_front) = mpsc::channel(16); - async_std::task::spawn(async move { - background_task(client, from_front).await; - }); - - Self { backend: to_back } - } - - /// Send a notification to the server. - pub async fn notification( - &self, - method: impl Into, - params: impl Into, - ) { - let method = method.into(); - let params = params.into(); - log::debug!(target: "jsonrpsee-http-client", "transmitting notification: method={:?}, params={:?}", method, params); - - // TODO: do we care if the channel is just temporarly full or closed in this context? - let _ = self - .backend - .clone() - .send(FrontToBack::Notification { method, params }) - .await; - } - - /// Perform a request towards the server. - pub async fn request( - &self, - method: impl Into, - params: impl Into, - ) -> Result - where - Ret: common::DeserializeOwned, - { - let method = method.into(); - let params = params.into(); - log::debug!(target: "jsonrpsee-http-client", "transmitting request: method={:?}, params={:?}", method, params); - let (send_back_tx, send_back_rx) = oneshot::channel(); - - // TODO: do we care if the channel is just temporarly full or closed in this context? - if let Err(e) = self - .backend - .clone() - .send(FrontToBack::StartRequest { - method, - params, - send_back: send_back_tx, - }) - .await - { - log::debug!(target: "jsonrpsee-http-client", "failed to send request to background task={:?}", e); - } - - let json_value = match send_back_rx.await { - Ok(Ok(v)) => { - log::debug!(target: "jsonrpsee-http-client", "response={:?}", v); - v - } - Ok(Err(err)) => return Err(err), - Err(_) => { - let err = io::Error::new(io::ErrorKind::Other, "background task closed"); - return Err(RequestError::TransportError(Box::new(err))); - } - }; - - common::from_value(json_value).map_err(RequestError::ParseError) - } + /// Create a client to connect to the server at address `endpoint` + pub fn new(endpoint: &str) -> Self { + let client = RawClient::new(HttpTransportClient::new(endpoint)); + + let (to_back, from_front) = mpsc::channel(16); + async_std::task::spawn(async move { + background_task(client, from_front).await; + }); + + Self { backend: to_back } + } + + /// Send a notification to the server. + pub async fn notification(&self, method: impl Into, params: impl Into) { + let method = method.into(); + let params = params.into(); + log::debug!(target: "jsonrpsee-http-client", "transmitting notification: method={:?}, params={:?}", method, params); + + // TODO: do we care if the channel is just temporarly full or closed in this context? + let _ = self.backend.clone().send(FrontToBack::Notification { method, params }).await; + } + + /// Perform a request towards the server. + pub async fn request( + &self, + method: impl Into, + params: impl Into, + ) -> Result + where + Ret: common::DeserializeOwned, + { + let method = method.into(); + let params = params.into(); + log::debug!(target: "jsonrpsee-http-client", "transmitting request: method={:?}, params={:?}", method, params); + let (send_back_tx, send_back_rx) = oneshot::channel(); + + // TODO: do we care if the channel is just temporarly full or closed in this context? + if let Err(e) = + self.backend.clone().send(FrontToBack::StartRequest { method, params, send_back: send_back_tx }).await + { + log::debug!(target: "jsonrpsee-http-client", "failed to send request to background task={:?}", e); + } + + let json_value = match send_back_rx.await { + Ok(Ok(v)) => { + log::debug!(target: "jsonrpsee-http-client", "response={:?}", v); + v + } + Ok(Err(err)) => return Err(err), + Err(_) => { + let err = io::Error::new(io::ErrorKind::Other, "background task closed"); + return Err(RequestError::TransportError(Box::new(err))); + } + }; + + common::from_value(json_value).map_err(RequestError::ParseError) + } } /// Function being run in the background that processes messages from the frontend. async fn background_task(mut client: RawClient, mut from_front: mpsc::Receiver) { - log::debug!(target: "jsonrpsee-http-client", "background thread started"); - - // List of requests that the server must answer. - let mut ongoing_requests: HashMap>> = - HashMap::new(); - - loop { - // We need to do a little transformation in order to destroy the borrow to `client` - // and `from_front`. - let outcome = { - let next_message = from_front.next(); - let next_event = client.next_event(); - pin_mut!(next_message); - pin_mut!(next_event); - match future::select(next_message, next_event).await { - Either::Left((v, _)) => Either::Left(v), - Either::Right((v, _)) => Either::Right(v), - } - }; - - match outcome { - // If the channel is closed, then the `Client` has been destroyed and we - // stop this task. - Either::Left(None) => { - log::debug!(target: "jsonrpsee-http-client", "background thread terminated"); - if !ongoing_requests.is_empty() { - log::warn!(target: "jsonrpsee-http-client", "client was dropped with {} pending requests", ongoing_requests.len()); - } - return; - } - - // User called `notification` on the front-end. - Either::Left(Some(FrontToBack::Notification { method, params })) => { - let _ = client.send_notification(method, params).await; - } - - // User called `request` on the front-end. - Either::Left(Some(FrontToBack::StartRequest { - method, - params, - send_back, - })) => match client.start_request(method, params).await { - Ok(id) => { - log::debug!(target: "jsonrpsee-http-client", "background thread; inserting ingoing request={:?}", id); - ongoing_requests.insert(id, send_back); - } - Err(err) => { - let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); - } - }, - - // Received a response to a request from the server. - Either::Right(Ok(RawClientEvent::Response { request_id, result })) => { - let _ = ongoing_requests - .remove(&request_id) - .unwrap() - .send(result.map_err(RequestError::Request)); - } - - Either::Right(Err(e)) => { - // TODO: https://github.com/paritytech/jsonrpsee/issues/67 - log::error!("Client Error: {:?}", e); - } - } - } + log::debug!(target: "jsonrpsee-http-client", "background thread started"); + + // List of requests that the server must answer. + let mut ongoing_requests: HashMap>> = HashMap::new(); + + loop { + // We need to do a little transformation in order to destroy the borrow to `client` + // and `from_front`. + let outcome = { + let next_message = from_front.next(); + let next_event = client.next_event(); + pin_mut!(next_message); + pin_mut!(next_event); + match future::select(next_message, next_event).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + }; + + match outcome { + // If the channel is closed, then the `Client` has been destroyed and we + // stop this task. + Either::Left(None) => { + log::debug!(target: "jsonrpsee-http-client", "background thread terminated"); + if !ongoing_requests.is_empty() { + log::warn!(target: "jsonrpsee-http-client", "client was dropped with {} pending requests", ongoing_requests.len()); + } + return; + } + + // User called `notification` on the front-end. + Either::Left(Some(FrontToBack::Notification { method, params })) => { + let _ = client.send_notification(method, params).await; + } + + // User called `request` on the front-end. + Either::Left(Some(FrontToBack::StartRequest { method, params, send_back })) => { + match client.start_request(method, params).await { + Ok(id) => { + log::debug!(target: "jsonrpsee-http-client", "background thread; inserting ingoing request={:?}", id); + ongoing_requests.insert(id, send_back); + } + Err(err) => { + let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); + } + } + } + + // Received a response to a request from the server. + Either::Right(Ok(RawClientEvent::Response { request_id, result })) => { + let _ = ongoing_requests.remove(&request_id).unwrap().send(result.map_err(RequestError::Request)); + } + + Either::Right(Err(e)) => { + // TODO: https://github.com/paritytech/jsonrpsee/issues/67 + log::error!("Client Error: {:?}", e); + } + } + } } diff --git a/src/client/http/raw.rs b/src/client/http/raw.rs index 1c2eec1d5d..f9e89f7e04 100644 --- a/src/client/http/raw.rs +++ b/src/client/http/raw.rs @@ -76,28 +76,28 @@ use hashbrown::HashSet; /// /// See [the module root documentation](crate::client) for more information. pub struct RawClient { - /// Inner raw client. - inner: HttpTransportClient, - - /// Id to assign to the next request. We always assign linearly-increasing numeric keys. - next_request_id: RawClientRequestId, - - /// List of requests and subscription requests that have been sent out and that are waiting - /// for a response. - /// - // NOTE: `fnv - fowler-Noll-Vo hash function`, more efficient for smaller hash keys. - requests: HashSet, - - /// Queue of pending events to return from [`RawClient::next_event`]. - // TODO: call shrink_to from time to time; see https://github.com/rust-lang/rust/issues/56431 - events_queue: VecDeque, - - /// Maximum allowed size of [`RawClient::events_queue`]. - /// - /// If this size is reached, elements can still be pushed to the queue if they are critical, - /// but will be discarded if they are not. - // TODO: make this configurable? note: if this is configurable, it should always be >= 1 - events_queue_max_size: usize, + /// Inner raw client. + inner: HttpTransportClient, + + /// Id to assign to the next request. We always assign linearly-increasing numeric keys. + next_request_id: RawClientRequestId, + + /// List of requests and subscription requests that have been sent out and that are waiting + /// for a response. + /// + // NOTE: `fnv - fowler-Noll-Vo hash function`, more efficient for smaller hash keys. + requests: HashSet, + + /// Queue of pending events to return from [`RawClient::next_event`]. + // TODO: call shrink_to from time to time; see https://github.com/rust-lang/rust/issues/56431 + events_queue: VecDeque, + + /// Maximum allowed size of [`RawClient::events_queue`]. + /// + /// If this size is reached, elements can still be pushed to the queue if they are critical, + /// but will be discarded if they are not. + // TODO: make this configurable? note: if this is configurable, it should always be >= 1 + events_queue_max_size: usize, } /// Unique identifier of a request within a [`RawClient`]. @@ -107,249 +107,234 @@ pub struct RawClientRequestId(u64); /// Event returned by [`RawClient::next_event`]. #[derive(Debug)] pub enum RawClientEvent { - /// A request has received a response. - Response { - /// Identifier of the request. Can be matched with the value that [`RawClient::start_request`] - /// has returned. - request_id: RawClientRequestId, - /// The response itself. - result: Result, - }, + /// A request has received a response. + Response { + /// Identifier of the request. Can be matched with the value that [`RawClient::start_request`] + /// has returned. + request_id: RawClientRequestId, + /// The response itself. + result: Result, + }, } /// Error that can happen during a request. #[derive(Debug)] pub enum RawClientError { - /// Error in the raw client. - Inner(RequestError), - /// RawServer returned an error for our request. - RequestError(common::Error), - /// RawServer has sent back a response containing an unknown request ID. - UnknownRequestId, - /// RawServer has sent back a response containing a null request ID. - NullRequestId, + /// Error in the raw client. + Inner(RequestError), + /// RawServer returned an error for our request. + RequestError(common::Error), + /// RawServer has sent back a response containing an unknown request ID. + UnknownRequestId, + /// RawServer has sent back a response containing a null request ID. + NullRequestId, } impl RawClient { - /// Initializes a new `RawClient` using the given raw client as backend. - pub fn new(inner: HttpTransportClient) -> Self { - RawClient { - inner, - next_request_id: RawClientRequestId(0), - requests: HashSet::default(), - events_queue: VecDeque::with_capacity(16), - events_queue_max_size: 64, - } - } + /// Initializes a new `RawClient` using the given raw client as backend. + pub fn new(inner: HttpTransportClient) -> Self { + RawClient { + inner, + next_request_id: RawClientRequestId(0), + requests: HashSet::default(), + events_queue: VecDeque::with_capacity(16), + events_queue_max_size: 64, + } + } } impl RawClient { - /// Sends a notification to the server. The notification doesn't need any response. - /// - /// This asynchronous function finishes when the notification has finished being sent. - pub async fn send_notification( - &mut self, - method: impl Into, - params: impl Into, - ) -> Result<(), RequestError> { - let request = common::Request::Single(common::Call::Notification(common::Notification { - jsonrpc: common::Version::V2, - method: method.into(), - params: params.into(), - })); - - self.inner.send_request(request).await?; - Ok(()) - } - - /// Starts a request. - /// - /// This asynchronous function finishes when the request has been sent to the server. The - /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) - /// until you get a response. - pub async fn start_request( - &mut self, - method: impl Into, - params: impl Into, - ) -> Result { - loop { - let id = self.next_request_id; - self.next_request_id.0 = self.next_request_id.0.wrapping_add(1); - - if self.requests.contains(&id) { - continue; - } else { - self.requests.insert(id); - } - - let request = common::Request::Single(common::Call::MethodCall(common::MethodCall { - jsonrpc: common::Version::V2, - method: method.into(), - params: params.into(), - id: common::Id::Num(id.0), - })); - - log::debug!(target: "jsonrpsee-http-raw-client", "request={:?}", request); - // Note that in case of an error, we "lose" the request id (as in, it will never be - // used). This isn't a problem, however. - self.inner.send_request(request).await?; - - break Ok(id); - } - } - - /// Waits until the client receives a message from the server. - /// - /// If this function returns an `Err`, it indicates a connectivity issue with the server or a - /// low-level protocol error, and not a request that has failed to be answered. - pub async fn next_event(&mut self) -> Result { - loop { - if let Some(event) = self.events_queue.pop_front() { - return Ok(event); - } - - self.event_step().await?; - } - } - - /// Returns a `Future` that resolves when the server sends back a response for the given - /// request. - /// - /// Returns `None` if the request identifier is invalid, or if the request is a subscription. - /// - /// > **Note**: While this function is waiting, all the other responses and pubsub events - /// > returned by the server will be buffered up to a certain limit. Once this - /// > limit is reached, server notifications will be discarded. If you want to be - /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) - /// > instead. - pub fn request_by_id<'a>( - &'a mut self, - rq_id: RawClientRequestId, - ) -> Option> + 'a> { - // First, let's check whether the request ID is valid. - if !self.requests.contains(&rq_id) { - return None; - } - - Some(async move { - let mut events_queue_loopkup = 0; - - loop { - while events_queue_loopkup < self.events_queue.len() { - match &self.events_queue[events_queue_loopkup] { - RawClientEvent::Response { request_id, .. } if *request_id == rq_id => { - return match self.events_queue.remove(events_queue_loopkup) { - Some(RawClientEvent::Response { result, .. }) => { - result.map_err(RawClientError::RequestError) - } - _ => unreachable!(), - } - } - _ => {} - } - - events_queue_loopkup += 1; - } - - self.event_step().await?; - } - }) - } - - /// Waits for one server message and processes it by updating the state of `self`. - /// - /// If the events queue is full (see [`RawClient::events_queue_max_size`]), then responses to - /// requests will still be pushed to the queue, but notifications will be discarded. - /// - /// Check the content of [`events_queue`](RawClient::events_queue) afterwards for events to - /// dispatch to the user. - async fn event_step(&mut self) -> Result<(), RawClientError> { - let result = self - .inner - .next_response() - .await - .map_err(RawClientError::Inner)?; - - match result { - common::Response::Single(rp) => self.process_response(rp)?, - common::Response::Batch(rps) => { - for rp in rps { - // TODO: if an error happens, we throw away the entire batch - self.process_response(rp)?; - } - } - // Server MUST NOT reply to a Notification. - common::Response::Notif(_notif) => unreachable!(), - } - - Ok(()) - } - - /// Processes the response obtained from the server. Updates the internal state of `self` to - /// account for it. - /// - /// Regards all `response IDs` that is not a number as error because only numbers are used as - /// `id` in this library even though that `JSONRPC 2.0` allows String and Null as valid IDs. - fn process_response(&mut self, response: common::Output) -> Result<(), RawClientError> { - let request_id = match response.id() { - common::Id::Num(n) => RawClientRequestId(*n), - common::Id::Str(s) => { - log::warn!("Server responded with an invalid request id: {:?}", s); - return Err(RawClientError::UnknownRequestId); - } - common::Id::Null => { - log::warn!("Server responded with a null request id"); - return Err(RawClientError::NullRequestId); - } - }; - - // Find the request that this answered. - if self.requests.remove(&request_id) { - self.events_queue.push_back(RawClientEvent::Response { - result: response.into(), - request_id, - }); - } else { - log::warn!( - "Server responsed with an invalid request id: {:?}", - request_id - ); - return Err(RawClientError::UnknownRequestId); - } - - Ok(()) - } + /// Sends a notification to the server. The notification doesn't need any response. + /// + /// This asynchronous function finishes when the notification has finished being sent. + pub async fn send_notification( + &mut self, + method: impl Into, + params: impl Into, + ) -> Result<(), RequestError> { + let request = common::Request::Single(common::Call::Notification(common::Notification { + jsonrpc: common::Version::V2, + method: method.into(), + params: params.into(), + })); + + self.inner.send_request(request).await?; + Ok(()) + } + + /// Starts a request. + /// + /// This asynchronous function finishes when the request has been sent to the server. The + /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) + /// until you get a response. + pub async fn start_request( + &mut self, + method: impl Into, + params: impl Into, + ) -> Result { + loop { + let id = self.next_request_id; + self.next_request_id.0 = self.next_request_id.0.wrapping_add(1); + + if self.requests.contains(&id) { + continue; + } else { + self.requests.insert(id); + } + + let request = common::Request::Single(common::Call::MethodCall(common::MethodCall { + jsonrpc: common::Version::V2, + method: method.into(), + params: params.into(), + id: common::Id::Num(id.0), + })); + + log::debug!(target: "jsonrpsee-http-raw-client", "request={:?}", request); + // Note that in case of an error, we "lose" the request id (as in, it will never be + // used). This isn't a problem, however. + self.inner.send_request(request).await?; + + break Ok(id); + } + } + + /// Waits until the client receives a message from the server. + /// + /// If this function returns an `Err`, it indicates a connectivity issue with the server or a + /// low-level protocol error, and not a request that has failed to be answered. + pub async fn next_event(&mut self) -> Result { + loop { + if let Some(event) = self.events_queue.pop_front() { + return Ok(event); + } + + self.event_step().await?; + } + } + + /// Returns a `Future` that resolves when the server sends back a response for the given + /// request. + /// + /// Returns `None` if the request identifier is invalid, or if the request is a subscription. + /// + /// > **Note**: While this function is waiting, all the other responses and pubsub events + /// > returned by the server will be buffered up to a certain limit. Once this + /// > limit is reached, server notifications will be discarded. If you want to be + /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) + /// > instead. + pub fn request_by_id<'a>( + &'a mut self, + rq_id: RawClientRequestId, + ) -> Option> + 'a> { + // First, let's check whether the request ID is valid. + if !self.requests.contains(&rq_id) { + return None; + } + + Some(async move { + let mut events_queue_loopkup = 0; + + loop { + while events_queue_loopkup < self.events_queue.len() { + match &self.events_queue[events_queue_loopkup] { + RawClientEvent::Response { request_id, .. } if *request_id == rq_id => { + return match self.events_queue.remove(events_queue_loopkup) { + Some(RawClientEvent::Response { result, .. }) => { + result.map_err(RawClientError::RequestError) + } + _ => unreachable!(), + } + } + _ => {} + } + + events_queue_loopkup += 1; + } + + self.event_step().await?; + } + }) + } + + /// Waits for one server message and processes it by updating the state of `self`. + /// + /// If the events queue is full (see [`RawClient::events_queue_max_size`]), then responses to + /// requests will still be pushed to the queue, but notifications will be discarded. + /// + /// Check the content of [`events_queue`](RawClient::events_queue) afterwards for events to + /// dispatch to the user. + async fn event_step(&mut self) -> Result<(), RawClientError> { + let result = self.inner.next_response().await.map_err(RawClientError::Inner)?; + + match result { + common::Response::Single(rp) => self.process_response(rp)?, + common::Response::Batch(rps) => { + for rp in rps { + // TODO: if an error happens, we throw away the entire batch + self.process_response(rp)?; + } + } + // Server MUST NOT reply to a Notification. + common::Response::Notif(_notif) => unreachable!(), + } + + Ok(()) + } + + /// Processes the response obtained from the server. Updates the internal state of `self` to + /// account for it. + /// + /// Regards all `response IDs` that is not a number as error because only numbers are used as + /// `id` in this library even though that `JSONRPC 2.0` allows String and Null as valid IDs. + fn process_response(&mut self, response: common::Output) -> Result<(), RawClientError> { + let request_id = match response.id() { + common::Id::Num(n) => RawClientRequestId(*n), + common::Id::Str(s) => { + log::warn!("Server responded with an invalid request id: {:?}", s); + return Err(RawClientError::UnknownRequestId); + } + common::Id::Null => { + log::warn!("Server responded with a null request id"); + return Err(RawClientError::NullRequestId); + } + }; + + // Find the request that this answered. + if self.requests.remove(&request_id) { + self.events_queue.push_back(RawClientEvent::Response { result: response.into(), request_id }); + } else { + log::warn!("Server responsed with an invalid request id: {:?}", request_id); + return Err(RawClientError::UnknownRequestId); + } + + Ok(()) + } } impl fmt::Debug for RawClient { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("RawClient") - .field("inner", &self.inner) - .field("pending_requests", &self.requests) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("RawClient").field("inner", &self.inner).field("pending_requests", &self.requests).finish() + } } impl std::error::Error for RawClientError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - RawClientError::Inner(err) => Some(err), - RawClientError::RequestError(ref err) => Some(err), - RawClientError::UnknownRequestId => None, - RawClientError::NullRequestId => None, - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RawClientError::Inner(err) => Some(err), + RawClientError::RequestError(ref err) => Some(err), + RawClientError::UnknownRequestId => None, + RawClientError::NullRequestId => None, + } + } } impl fmt::Display for RawClientError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - RawClientError::Inner(err) => write!(f, "Error in the raw client: {}", err), - RawClientError::RequestError(ref err) => write!(f, "Server returned error: {}", err), - RawClientError::UnknownRequestId => { - write!(f, "Server responded with an unknown request ID") - } - RawClientError::NullRequestId => write!(f, "Server responded with a null request ID"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + RawClientError::Inner(err) => write!(f, "Error in the raw client: {}", err), + RawClientError::RequestError(ref err) => write!(f, "Server returned error: {}", err), + RawClientError::UnknownRequestId => write!(f, "Server responded with an unknown request ID"), + RawClientError::NullRequestId => write!(f, "Server responded with a null request ID"), + } + } } diff --git a/src/client/http/transport.rs b/src/client/http/transport.rs index ecccf11f7a..b431150240 100644 --- a/src/client/http/transport.rs +++ b/src/client/http/transport.rs @@ -23,261 +23,238 @@ use thiserror::Error; /// Implementation of a raw client for HTTP requests. pub struct HttpTransportClient { - /// Sender that sends requests to the background task. - requests_tx: mpsc::Sender, - /// URL of the server to connect to. - url: String, - /// Responses receiver. - responses_rx: mpsc::UnboundedReceiver, hyper::Error>>, - /// Responses transmitter - responses_tx: mpsc::UnboundedSender, hyper::Error>>, + /// Sender that sends requests to the background task. + requests_tx: mpsc::Sender, + /// URL of the server to connect to. + url: String, + /// Responses receiver. + responses_rx: mpsc::UnboundedReceiver, hyper::Error>>, + /// Responses transmitter + responses_tx: mpsc::UnboundedSender, hyper::Error>>, } /// Message transmitted from the foreground task to the background. struct FrontToBack { - /// Request that the background task should perform. - request: hyper::Request, - /// Channel to send back to the response. - send_back: mpsc::UnboundedSender, hyper::Error>>, + /// Request that the background task should perform. + request: hyper::Request, + /// Channel to send back to the response. + send_back: mpsc::UnboundedSender, hyper::Error>>, } impl HttpTransportClient { - /// Initializes a new HTTP client. - // TODO: better type for target - pub fn new(target: &str) -> Self { - let (requests_tx, requests_rx) = mpsc::channel::(4); - - // Because hyper can only be polled through tokio, we spawn it in a background thread. - thread::Builder::new() - .name("jsonrpsee-hyper-client".to_string()) - .spawn(move || { - let client = hyper::Client::new(); - background_thread(requests_rx, move |rq| { - // cloning Hyper client = cloning references - let client = client.clone(); - async move { - let _ = rq - .send_back - .unbounded_send(client.request(rq.request).await); - } - }) - }) - .unwrap(); - - let (responses_tx, responses_rx) = futures::channel::mpsc::unbounded(); - HttpTransportClient { - requests_tx, - url: target.to_owned(), - responses_tx, - responses_rx, - } - } - - /// Send request to the target - pub fn send_request<'s>( - &'s mut self, - request: common::Request, - ) -> Pin> + Send + 's>> { - let mut requests_tx = self.requests_tx.clone(); - - let request = common::to_vec(&request).map(|body| { - hyper::Request::post(&self.url) - .header( - hyper::header::CONTENT_TYPE, - hyper::header::HeaderValue::from_static("application/json"), - ) - .body(From::from(body)) - .expect("Uri and request headers are valid; qed") // TODO: not necessarily true for URL here - }); - - Box::pin(async move { - let message = FrontToBack { - request: request.map_err(RequestError::Serialization)?, - send_back: self.responses_tx.clone(), - }; - - if requests_tx.send(message).await.is_err() { - log::error!("JSONRPC http client background thread has shut down"); - return Err(RequestError::Http(Box::new(io::Error::new( - io::ErrorKind::Other, - "background thread is down".to_string(), - )))); - } - - Ok(()) - }) - } - - /// Waits for the next response. - pub fn next_response<'s>( - &'s mut self, - ) -> Pin> + Send + 's>> { - Box::pin(async move { - let hyper_response = match self.responses_rx.next().await { - Some(Ok(r)) => r, - Some(Err(err)) => return Err(RequestError::Http(Box::new(err))), - None => { - log::error!("JSONRPC http client background thread has shut down"); - return Err(RequestError::Http(Box::new(io::Error::new( - io::ErrorKind::Other, - "background thread is down".to_string(), - )))); - } - }; - - if !hyper_response.status().is_success() { - return Err(RequestError::RequestFailure { - status_code: hyper_response.status().into(), - }); - } - - // Note that we don't check the Content-Type of the request. This is deemed - // unnecessary, as a parsing error while happen anyway. - - // TODO: enforce a maximum size here - let body = hyper::body::to_bytes(hyper_response.into_body()) - .await - .map_err(|err| RequestError::Http(Box::new(err)))?; - - // TODO: use Response::from_json - let as_json: common::Response = - common::from_slice(&body).map_err(RequestError::ParseError)?; - Ok(as_json) - }) - } + /// Initializes a new HTTP client. + // TODO: better type for target + pub fn new(target: &str) -> Self { + let (requests_tx, requests_rx) = mpsc::channel::(4); + + // Because hyper can only be polled through tokio, we spawn it in a background thread. + thread::Builder::new() + .name("jsonrpsee-hyper-client".to_string()) + .spawn(move || { + let client = hyper::Client::new(); + background_thread(requests_rx, move |rq| { + // cloning Hyper client = cloning references + let client = client.clone(); + async move { + let _ = rq.send_back.unbounded_send(client.request(rq.request).await); + } + }) + }) + .unwrap(); + + let (responses_tx, responses_rx) = futures::channel::mpsc::unbounded(); + HttpTransportClient { requests_tx, url: target.to_owned(), responses_tx, responses_rx } + } + + /// Send request to the target + pub fn send_request<'s>( + &'s mut self, + request: common::Request, + ) -> Pin> + Send + 's>> { + let mut requests_tx = self.requests_tx.clone(); + + let request = common::to_vec(&request).map(|body| { + hyper::Request::post(&self.url) + .header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static("application/json")) + .body(From::from(body)) + .expect("Uri and request headers are valid; qed") // TODO: not necessarily true for URL here + }); + + Box::pin(async move { + let message = FrontToBack { + request: request.map_err(RequestError::Serialization)?, + send_back: self.responses_tx.clone(), + }; + + if requests_tx.send(message).await.is_err() { + log::error!("JSONRPC http client background thread has shut down"); + return Err(RequestError::Http(Box::new(io::Error::new( + io::ErrorKind::Other, + "background thread is down".to_string(), + )))); + } + + Ok(()) + }) + } + + /// Waits for the next response. + pub fn next_response<'s>( + &'s mut self, + ) -> Pin> + Send + 's>> { + Box::pin(async move { + let hyper_response = match self.responses_rx.next().await { + Some(Ok(r)) => r, + Some(Err(err)) => return Err(RequestError::Http(Box::new(err))), + None => { + log::error!("JSONRPC http client background thread has shut down"); + return Err(RequestError::Http(Box::new(io::Error::new( + io::ErrorKind::Other, + "background thread is down".to_string(), + )))); + } + }; + + if !hyper_response.status().is_success() { + return Err(RequestError::RequestFailure { status_code: hyper_response.status().into() }); + } + + // Note that we don't check the Content-Type of the request. This is deemed + // unnecessary, as a parsing error while happen anyway. + + // TODO: enforce a maximum size here + let body = hyper::body::to_bytes(hyper_response.into_body()) + .await + .map_err(|err| RequestError::Http(Box::new(err)))?; + + // TODO: use Response::from_json + let as_json: common::Response = common::from_slice(&body).map_err(RequestError::ParseError)?; + Ok(as_json) + }) + } } impl fmt::Debug for HttpTransportClient { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_tuple("HttpTransportClient").finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("HttpTransportClient").finish() + } } /// Error that can happen during a request. #[derive(Debug, Error)] pub enum RequestError { - /// Error while serializing the request. - // TODO: can that happen? - #[error("error while serializing the request")] - Serialization(#[source] serde_json::error::Error), - - /// Response given by the server failed to decode as UTF-8. - #[error("response body is not UTF-8")] - Utf8(#[source] std::string::FromUtf8Error), - - /// Error during the HTTP request, including networking errors and HTTP protocol errors. - #[error("error while performing the HTTP request")] - Http(Box), - - /// Server returned a non-success status code. - #[error("server returned an error status code: {:?}", status_code)] - RequestFailure { - /// Status code returned by the server. - status_code: u16, - }, - - /// Failed to parse the JSON returned by the server into a JSON-RPC response. - #[error("error while parsing the response body")] - ParseError(#[source] serde_json::error::Error), + /// Error while serializing the request. + // TODO: can that happen? + #[error("error while serializing the request")] + Serialization(#[source] serde_json::error::Error), + + /// Response given by the server failed to decode as UTF-8. + #[error("response body is not UTF-8")] + Utf8(#[source] std::string::FromUtf8Error), + + /// Error during the HTTP request, including networking errors and HTTP protocol errors. + #[error("error while performing the HTTP request")] + Http(Box), + + /// Server returned a non-success status code. + #[error("server returned an error status code: {:?}", status_code)] + RequestFailure { + /// Status code returned by the server. + status_code: u16, + }, + + /// Failed to parse the JSON returned by the server into a JSON-RPC response. + #[error("error while parsing the response body")] + ParseError(#[source] serde_json::error::Error), } /// Function that runs in a background thread. fn background_thread>( - mut requests_rx: mpsc::Receiver, - process_request: impl Fn(T) -> ProcessRequest, + mut requests_rx: mpsc::Receiver, + process_request: impl Fn(T) -> ProcessRequest, ) { - let mut runtime = match tokio::runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - { - Ok(r) => r, - Err(err) => { - // Ideally, we would try to initialize the tokio runtime in the main thread then move - // it here. That however isn't possible. If we fail to initialize the runtime, the only - // thing we can do is print an error and shut down the background thread. - // Initialization failures should be almost non-existant anyway, so this isn't a big - // deal. - log::error!("Failed to initialize tokio runtime: {:?}", err); - return; - } - }; - - // Running until the channel has been closed, and all requests have been completed. - runtime.block_on(async move { - // Collection of futures that process ongoing requests. - let mut pending_requests = stream::FuturesUnordered::new(); - - loop { - let request = loop { - if !pending_requests.is_empty() { - let event = future::select(requests_rx.next(), pending_requests.next()).await; - if let future::Either::Left((rq, _)) = event { - break rq; - } - // else: one of the elements of `pending_requests` is finished, but we don't care - } else { - break requests_rx.next().await; - } - }; - - match request { - // We received a request from the foreground. - Some(rq) => pending_requests.push(process_request(rq)), - // The channel with the foreground has closed. - None => break, - } - } - - // Before returning, complete all pending requests. - while pending_requests.next().await.is_some() {} - }); + let mut runtime = match tokio::runtime::Builder::new().basic_scheduler().enable_all().build() { + Ok(r) => r, + Err(err) => { + // Ideally, we would try to initialize the tokio runtime in the main thread then move + // it here. That however isn't possible. If we fail to initialize the runtime, the only + // thing we can do is print an error and shut down the background thread. + // Initialization failures should be almost non-existant anyway, so this isn't a big + // deal. + log::error!("Failed to initialize tokio runtime: {:?}", err); + return; + } + }; + + // Running until the channel has been closed, and all requests have been completed. + runtime.block_on(async move { + // Collection of futures that process ongoing requests. + let mut pending_requests = stream::FuturesUnordered::new(); + + loop { + let request = loop { + if !pending_requests.is_empty() { + let event = future::select(requests_rx.next(), pending_requests.next()).await; + if let future::Either::Left((rq, _)) = event { + break rq; + } + // else: one of the elements of `pending_requests` is finished, but we don't care + } else { + break requests_rx.next().await; + } + }; + + match request { + // We received a request from the foreground. + Some(rq) => pending_requests.push(process_request(rq)), + // The channel with the foreground has closed. + None => break, + } + } + + // Before returning, complete all pending requests. + while pending_requests.next().await.is_some() {} + }); } #[cfg(test)] mod tests { - use super::*; - use futures::channel::oneshot; - - #[test] - fn background_thread_is_able_to_complete_requests() { - // start background thread that returns square(passed_value) after signal - // from 'main' thread is received - let (mut requests_tx, requests_rx) = mpsc::channel(4); - let background_thread = thread::spawn(move || { - background_thread( - requests_rx, - move |(send_when, send_back, value): ( - oneshot::Receiver<()>, - oneshot::Sender, - u32, - )| async move { - send_when.await.unwrap(); - send_back.send(value * value).unwrap(); - }, - ) - }); - - // send two requests - there'll be two simultaneous active requests, waiting for - // main thread' signals - let mut pool = futures::executor::LocalPool::new(); - let (send_when_tx1, send_when_rx1) = oneshot::channel(); - let (send_when_tx2, send_when_rx2) = oneshot::channel(); - let (send_back_tx1, send_back_rx1) = oneshot::channel(); - let (send_back_tx2, send_back_rx2) = oneshot::channel(); - pool.run_until(requests_tx.send((send_when_rx1, send_back_tx1, 32))) - .unwrap(); - pool.run_until(requests_tx.send((send_when_rx2, send_back_tx2, 1024))) - .unwrap(); - - // send both signals and wait for responses - send_when_tx1.send(()).unwrap(); - send_when_tx2.send(()).unwrap(); - assert_eq!(pool.run_until(send_back_rx1), Ok(32 * 32)); - assert_eq!(pool.run_until(send_back_rx2), Ok(1024 * 1024)); - - // drop requests sender, asking background thread to exit gently - drop(requests_tx); - background_thread.join().unwrap(); - } + use super::*; + use futures::channel::oneshot; + + #[test] + fn background_thread_is_able_to_complete_requests() { + // start background thread that returns square(passed_value) after signal + // from 'main' thread is received + let (mut requests_tx, requests_rx) = mpsc::channel(4); + let background_thread = thread::spawn(move || { + background_thread( + requests_rx, + move |(send_when, send_back, value): (oneshot::Receiver<()>, oneshot::Sender, u32)| async move { + send_when.await.unwrap(); + send_back.send(value * value).unwrap(); + }, + ) + }); + + // send two requests - there'll be two simultaneous active requests, waiting for + // main thread' signals + let mut pool = futures::executor::LocalPool::new(); + let (send_when_tx1, send_when_rx1) = oneshot::channel(); + let (send_when_tx2, send_when_rx2) = oneshot::channel(); + let (send_back_tx1, send_back_rx1) = oneshot::channel(); + let (send_back_tx2, send_back_rx2) = oneshot::channel(); + pool.run_until(requests_tx.send((send_when_rx1, send_back_tx1, 32))).unwrap(); + pool.run_until(requests_tx.send((send_when_rx2, send_back_tx2, 1024))).unwrap(); + + // send both signals and wait for responses + send_when_tx1.send(()).unwrap(); + send_when_tx2.send(()).unwrap(); + assert_eq!(pool.run_until(send_back_rx1), Ok(32 * 32)); + assert_eq!(pool.run_until(send_back_rx2), Ok(1024 * 1024)); + + // drop requests sender, asking background thread to exit gently + drop(requests_tx); + background_thread.join().unwrap(); + } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 481bea6c29..8d6a2c95b8 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -9,6 +9,4 @@ mod ws; #[cfg(feature = "http")] pub use http::{Client as HttpClient, HttpTransportClient, RawClient as HttpRawClient}; #[cfg(feature = "ws")] -pub use ws::{ - Client as WsClient, RawClient as RawWsClient, Subscription as WsSubscription, WsTransportClient, -}; +pub use ws::{Client as WsClient, RawClient as RawWsClient, Subscription as WsSubscription, WsTransportClient}; diff --git a/src/client/ws/client.rs b/src/client/ws/client.rs index ad8164f8e3..ca9b4daabe 100644 --- a/src/client/ws/client.rs +++ b/src/client/ws/client.rs @@ -28,10 +28,10 @@ use crate::client::ws::{RawClient, RawClientEvent, RawClientRequestId, WsTranspo use crate::common::{self, JsonValue}; use futures::{ - channel::{mpsc, oneshot}, - future::Either, - pin_mut, - prelude::*, + channel::{mpsc, oneshot}, + future::Either, + pin_mut, + prelude::*, }; use std::{collections::HashMap, error, io, marker::PhantomData}; @@ -42,383 +42,324 @@ use std::{collections::HashMap, error, io, marker::PhantomData}; /// > [`RawClient`] struct instead. #[derive(Clone)] pub struct Client { - /// Channel to send requests to the background task. - to_back: mpsc::Sender, + /// Channel to send requests to the background task. + to_back: mpsc::Sender, } /// Active subscription on a [`Client`]. pub struct Subscription { - /// Channel to send requests to the background task. - to_back: mpsc::Sender, - /// Channel from which we receive notifications from the server, as undecoded `JsonValue`s. - notifs_rx: mpsc::Receiver, - /// Marker in order to pin the `Notif` parameter. - marker: PhantomData>, + /// Channel to send requests to the background task. + to_back: mpsc::Sender, + /// Channel from which we receive notifications from the server, as undecoded `JsonValue`s. + notifs_rx: mpsc::Receiver, + /// Marker in order to pin the `Notif` parameter. + marker: PhantomData>, } /// Error produced by [`Client::request`] and [`Client::subscribe`]. #[derive(Debug, thiserror::Error)] pub enum RequestError { - /// Networking error or error on the low-level protocol layer (e.g. missing field, - /// invalid ID, etc.). - #[error("Networking or low-level protocol error: {0}")] - TransportError(#[source] Box), - /// RawServer responded to our request with an error. - #[error("Server responded to our request with an error: {0:?}")] - Request(#[source] common::Error), - /// Failed to parse the data that the server sent back to us. - #[error("Parse error: {0}")] - ParseError(#[source] common::ParseError), + /// Networking error or error on the low-level protocol layer (e.g. missing field, + /// invalid ID, etc.). + #[error("Networking or low-level protocol error: {0}")] + TransportError(#[source] Box), + /// RawServer responded to our request with an error. + #[error("Server responded to our request with an error: {0:?}")] + Request(#[source] common::Error), + /// Failed to parse the data that the server sent back to us. + #[error("Parse error: {0}")] + ParseError(#[source] common::ParseError), } /// Message that the [`Client`] can send to the background task. enum FrontToBack { - /// Send a one-shot notification to the server. The server doesn't give back any feedback. - Notification { - /// Method for the notification. - method: String, - /// Parameters to send to the server. - params: common::Params, - }, + /// Send a one-shot notification to the server. The server doesn't give back any feedback. + Notification { + /// Method for the notification. + method: String, + /// Parameters to send to the server. + params: common::Params, + }, - /// Send a request to the server. - StartRequest { - /// Method for the request. - method: String, - /// Parameters of the request. - params: common::Params, - /// One-shot channel where to send back the outcome of that request. - send_back: oneshot::Sender>, - }, + /// Send a request to the server. + StartRequest { + /// Method for the request. + method: String, + /// Parameters of the request. + params: common::Params, + /// One-shot channel where to send back the outcome of that request. + send_back: oneshot::Sender>, + }, - /// Send a subscription request to the server. - Subscribe { - /// Method for the subscription request. - subscribe_method: String, - /// Parameters to send for the subscription. - params: common::Params, - /// Method to use to later unsubscription. Used if the channel unexpectedly closes. - unsubscribe_method: String, - /// When we get a response from the server about that subscription, we send the result on - /// this channel. If the subscription succeeds, we return a `Receiver` that will receive - /// notifications. - send_back: oneshot::Sender, RequestError>>, - }, + /// Send a subscription request to the server. + Subscribe { + /// Method for the subscription request. + subscribe_method: String, + /// Parameters to send for the subscription. + params: common::Params, + /// Method to use to later unsubscription. Used if the channel unexpectedly closes. + unsubscribe_method: String, + /// When we get a response from the server about that subscription, we send the result on + /// this channel. If the subscription succeeds, we return a `Receiver` that will receive + /// notifications. + send_back: oneshot::Sender, RequestError>>, + }, - /// When a request or subscription channel is closed, we send this message to the background - /// task in order for it to garbage collect closed requests and subscriptions. - /// - /// While this means that closing a request or a subscription is a `O(n)` operation, it is - /// expected that the volume of requests and subscriptions is low enough that this isn't - /// a problem in practice. - ChannelClosed, + /// When a request or subscription channel is closed, we send this message to the background + /// task in order for it to garbage collect closed requests and subscriptions. + /// + /// While this means that closing a request or a subscription is a `O(n)` operation, it is + /// expected that the volume of requests and subscriptions is low enough that this isn't + /// a problem in practice. + ChannelClosed, } impl Client { - /// Initializes a new WebSocket client - /// - /// Failes when the URI is invalid i.e, doesn't start with `ws://` or `wss://` - pub async fn new(target: &str) -> Result> { - let client = RawClient::new(WsTransportClient::new(target).await?); - let (to_back, from_front) = mpsc::channel(16); - async_std::task::spawn(async move { - background_task(client, from_front).await; - }); - Ok(Client { to_back }) - } + /// Initializes a new WebSocket client + /// + /// Failes when the URI is invalid i.e, doesn't start with `ws://` or `wss://` + pub async fn new(target: &str) -> Result> { + let client = RawClient::new(WsTransportClient::new(target).await?); + let (to_back, from_front) = mpsc::channel(16); + async_std::task::spawn(async move { + background_task(client, from_front).await; + }); + Ok(Client { to_back }) + } - /// Send a notification to the server. - pub async fn notification( - &self, - method: impl Into, - params: impl Into, - ) { - let method = method.into(); - let params = params.into(); - log::debug!( - "[frontend]: client send notification: method={:?}, params={:?}", - method, - params - ); - let _ = self - .to_back - .clone() - .send(FrontToBack::Notification { method, params }) - .await; - } + /// Send a notification to the server. + pub async fn notification(&self, method: impl Into, params: impl Into) { + let method = method.into(); + let params = params.into(); + log::debug!("[frontend]: client send notification: method={:?}, params={:?}", method, params); + let _ = self.to_back.clone().send(FrontToBack::Notification { method, params }).await; + } - /// Perform a request towards the server. - pub async fn request( - &self, - method: impl Into, - params: impl Into, - ) -> Result - where - Ret: common::DeserializeOwned, - { - let method = method.into(); - let params = params.into(); - log::debug!( - "[frontend]: send request: method={:?}, params={:?}", - method, - params - ); - let (send_back_tx, send_back_rx) = oneshot::channel(); - let _ = self - .to_back - .clone() - .send(FrontToBack::StartRequest { - method, - params, - send_back: send_back_tx, - }) - .await; + /// Perform a request towards the server. + pub async fn request( + &self, + method: impl Into, + params: impl Into, + ) -> Result + where + Ret: common::DeserializeOwned, + { + let method = method.into(); + let params = params.into(); + log::debug!("[frontend]: send request: method={:?}, params={:?}", method, params); + let (send_back_tx, send_back_rx) = oneshot::channel(); + let _ = self.to_back.clone().send(FrontToBack::StartRequest { method, params, send_back: send_back_tx }).await; - // TODO: send a `ChannelClosed` message if we close the channel unexpectedly + // TODO: send a `ChannelClosed` message if we close the channel unexpectedly - let json_value = match send_back_rx.await { - Ok(Ok(v)) => v, - Ok(Err(err)) => return Err(err), - Err(_) => { - let err = io::Error::new(io::ErrorKind::Other, "background task closed"); - return Err(RequestError::TransportError(Box::new(err))); - } - }; + let json_value = match send_back_rx.await { + Ok(Ok(v)) => v, + Ok(Err(err)) => return Err(err), + Err(_) => { + let err = io::Error::new(io::ErrorKind::Other, "background task closed"); + return Err(RequestError::TransportError(Box::new(err))); + } + }; - common::from_value(json_value).map_err(RequestError::ParseError) - } + common::from_value(json_value).map_err(RequestError::ParseError) + } - /// Send a subscription request to the server. - /// - /// The `subscribe_method` and `params` are used to ask for the subscription towards the - /// server. The `unsubscribe_method` is used to close the subscription. - pub async fn subscribe( - &self, - subscribe_method: impl Into, - params: impl Into, - unsubscribe_method: impl Into, - ) -> Result, RequestError> { - let (send_back_tx, send_back_rx) = oneshot::channel(); - let _ = self - .to_back - .clone() - .send(FrontToBack::Subscribe { - subscribe_method: subscribe_method.into(), - unsubscribe_method: unsubscribe_method.into(), - params: params.into(), - send_back: send_back_tx, - }) - .await; + /// Send a subscription request to the server. + /// + /// The `subscribe_method` and `params` are used to ask for the subscription towards the + /// server. The `unsubscribe_method` is used to close the subscription. + pub async fn subscribe( + &self, + subscribe_method: impl Into, + params: impl Into, + unsubscribe_method: impl Into, + ) -> Result, RequestError> { + let (send_back_tx, send_back_rx) = oneshot::channel(); + let _ = self + .to_back + .clone() + .send(FrontToBack::Subscribe { + subscribe_method: subscribe_method.into(), + unsubscribe_method: unsubscribe_method.into(), + params: params.into(), + send_back: send_back_tx, + }) + .await; - let notifs_rx = match send_back_rx.await { - Ok(Ok(v)) => v, - Ok(Err(err)) => return Err(err), - Err(_) => { - let err = io::Error::new(io::ErrorKind::Other, "background task closed"); - return Err(RequestError::TransportError(Box::new(err))); - } - }; + let notifs_rx = match send_back_rx.await { + Ok(Ok(v)) => v, + Ok(Err(err)) => return Err(err), + Err(_) => { + let err = io::Error::new(io::ErrorKind::Other, "background task closed"); + return Err(RequestError::TransportError(Box::new(err))); + } + }; - Ok(Subscription { - to_back: self.to_back.clone(), - notifs_rx, - marker: PhantomData, - }) - } + Ok(Subscription { to_back: self.to_back.clone(), notifs_rx, marker: PhantomData }) + } } impl Subscription where - Notif: common::DeserializeOwned, + Notif: common::DeserializeOwned, { - /// Returns the next notification sent from the server. - /// - /// Ignores any malformed packet. - pub async fn next(&mut self) -> Notif { - loop { - match self.notifs_rx.next().await { - Some(n) => { - if let Ok(parsed) = common::from_value(n) { - return parsed; - } - } - None => futures::pending!(), - } - } - } + /// Returns the next notification sent from the server. + /// + /// Ignores any malformed packet. + pub async fn next(&mut self) -> Notif { + loop { + match self.notifs_rx.next().await { + Some(n) => { + if let Ok(parsed) = common::from_value(n) { + return parsed; + } + } + None => futures::pending!(), + } + } + } } impl Drop for Subscription { - fn drop(&mut self) { - // We can't actually guarantee that this goes through. If the background task is busy, then - // the channel's buffer will be full, and our unsubscription request will never make it. - // However, when a notification arrives, the background task will realize that the channel - // to the `Subscription` has been closed, and will perform the unsubscribe. - let _ = self.to_back.send(FrontToBack::ChannelClosed).now_or_never(); - } + fn drop(&mut self) { + // We can't actually guarantee that this goes through. If the background task is busy, then + // the channel's buffer will be full, and our unsubscription request will never make it. + // However, when a notification arrives, the background task will realize that the channel + // to the `Subscription` has been closed, and will perform the unsubscribe. + let _ = self.to_back.send(FrontToBack::ChannelClosed).now_or_never(); + } } /// Function being run in the background that processes messages from the frontend. async fn background_task(mut client: RawClient, mut from_front: mpsc::Receiver) { - // List of subscription requests that have been sent to the server, with the method name to - // unsubscribe. - let mut pending_subscriptions: HashMap, _)> = - HashMap::new(); - // List of subscription that are active on the server, with the method name to unsubscribe. - let mut active_subscriptions: HashMap< - RawClientRequestId, - (mpsc::Sender, _), - > = HashMap::new(); - // List of requests that the server must answer. - let mut ongoing_requests: HashMap>> = - HashMap::new(); + // List of subscription requests that have been sent to the server, with the method name to + // unsubscribe. + let mut pending_subscriptions: HashMap, _)> = HashMap::new(); + // List of subscription that are active on the server, with the method name to unsubscribe. + let mut active_subscriptions: HashMap, _)> = HashMap::new(); + // List of requests that the server must answer. + let mut ongoing_requests: HashMap>> = HashMap::new(); - loop { - // We need to do a little transformation in order to destroy the borrow to `client` - // and `from_front`. - let outcome = { - let next_message = from_front.next(); - let next_event = client.next_event(); - pin_mut!(next_message); - pin_mut!(next_event); - match future::select(next_message, next_event).await { - Either::Left((v, _)) => Either::Left(v), - Either::Right((v, _)) => Either::Right(v), - } - }; + loop { + // We need to do a little transformation in order to destroy the borrow to `client` + // and `from_front`. + let outcome = { + let next_message = from_front.next(); + let next_event = client.next_event(); + pin_mut!(next_message); + pin_mut!(next_event); + match future::select(next_message, next_event).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + }; - match outcome { - // If the channel is closed, then the `Client` has been destroyed and we - // stop this task. - Either::Left(None) => { - log::trace!("[backend]: client terminated"); - return; - } + match outcome { + // If the channel is closed, then the `Client` has been destroyed and we + // stop this task. + Either::Left(None) => { + log::trace!("[backend]: client terminated"); + return; + } - // User called `notification` on the front-end. - Either::Left(Some(FrontToBack::Notification { method, params })) => { - log::trace!("[backend]: client send notification"); - let _ = client.send_notification(method, params).await; - } + // User called `notification` on the front-end. + Either::Left(Some(FrontToBack::Notification { method, params })) => { + log::trace!("[backend]: client send notification"); + let _ = client.send_notification(method, params).await; + } - // User called `request` on the front-end. - Either::Left(Some(FrontToBack::StartRequest { - method, - params, - send_back, - })) => { - log::trace!("[backend]: client prepare to send request={:?}", method); - match client.start_request(method, params).await { - Ok(id) => { - ongoing_requests.insert(id, send_back); - } - Err(err) => { - log::warn!("[backend]: client send request failed: {:?}", err); - let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); - } - } - } - // User called `subscribe` on the front-end. - Either::Left(Some(FrontToBack::Subscribe { - subscribe_method, - unsubscribe_method, - params, - send_back, - })) => { - log::trace!("[backend]: client prepare to start subscription, subscribe_method={:?} unsubscribe_method:{:?}", subscribe_method, unsubscribe_method); - match client.start_subscription(subscribe_method, params).await { - Ok(id) => { - pending_subscriptions.insert(id, (send_back, unsubscribe_method)); - } - Err(err) => { - log::warn!("[backend]: client start subscription failed: {:?}", err); - let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); - } - } - } - Either::Left(Some(FrontToBack::ChannelClosed)) => { - // TODO: there's no way to cancel pending subscriptions and requests, otherwise - // we should clean them up as well - while let Some(rq_id) = active_subscriptions - .iter() - .find(|(_, (v, _))| v.is_closed()) - .map(|(k, _)| *k) - { - let (_, unsubscribe) = active_subscriptions.remove(&rq_id).unwrap(); - client - .subscription_by_id(rq_id) - .unwrap() - .into_active() - .unwrap() - .close(unsubscribe) - .await - .unwrap(); - } - } + // User called `request` on the front-end. + Either::Left(Some(FrontToBack::StartRequest { method, params, send_back })) => { + log::trace!("[backend]: client prepare to send request={:?}", method); + match client.start_request(method, params).await { + Ok(id) => { + ongoing_requests.insert(id, send_back); + } + Err(err) => { + log::warn!("[backend]: client send request failed: {:?}", err); + let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); + } + } + } + // User called `subscribe` on the front-end. + Either::Left(Some(FrontToBack::Subscribe { subscribe_method, unsubscribe_method, params, send_back })) => { + log::trace!( + "[backend]: client prepare to start subscription, subscribe_method={:?} unsubscribe_method:{:?}", + subscribe_method, + unsubscribe_method + ); + match client.start_subscription(subscribe_method, params).await { + Ok(id) => { + pending_subscriptions.insert(id, (send_back, unsubscribe_method)); + } + Err(err) => { + log::warn!("[backend]: client start subscription failed: {:?}", err); + let _ = send_back.send(Err(RequestError::TransportError(Box::new(err)))); + } + } + } + Either::Left(Some(FrontToBack::ChannelClosed)) => { + // TODO: there's no way to cancel pending subscriptions and requests, otherwise + // we should clean them up as well + while let Some(rq_id) = active_subscriptions.iter().find(|(_, (v, _))| v.is_closed()).map(|(k, _)| *k) { + let (_, unsubscribe) = active_subscriptions.remove(&rq_id).unwrap(); + client.subscription_by_id(rq_id).unwrap().into_active().unwrap().close(unsubscribe).await.unwrap(); + } + } - // Received a response to a request from the server. - Either::Right(Ok(RawClientEvent::Response { request_id, result })) => { - log::trace!( - "[backend] client received response to req={:?}, result={:?}", - request_id, - result - ); - let _ = ongoing_requests - .remove(&request_id) - .unwrap() - .send(result.map_err(RequestError::Request)); - } + // Received a response to a request from the server. + Either::Right(Ok(RawClientEvent::Response { request_id, result })) => { + log::trace!("[backend] client received response to req={:?}, result={:?}", request_id, result); + let _ = ongoing_requests.remove(&request_id).unwrap().send(result.map_err(RequestError::Request)); + } - // Receive a response from the server about a subscription. - Either::Right(Ok(RawClientEvent::SubscriptionResponse { request_id, result })) => { - log::trace!( - "[backend]: client received response to subscription: {:?}", - result - ); - let (send_back, unsubscribe) = pending_subscriptions.remove(&request_id).unwrap(); - if let Err(err) = result { - let _ = send_back.send(Err(RequestError::Request(err))); - } else { - // TODO: what's a good limit here? way more tricky than it looks - let (notifs_tx, notifs_rx) = mpsc::channel(4); - if send_back.send(Ok(notifs_rx)).is_ok() { - active_subscriptions.insert(request_id, (notifs_tx, unsubscribe)); - } else { - client - .subscription_by_id(request_id) - .unwrap() - .into_active() - .unwrap() - .close(unsubscribe) - .await - .unwrap(); - } - } - } + // Receive a response from the server about a subscription. + Either::Right(Ok(RawClientEvent::SubscriptionResponse { request_id, result })) => { + log::trace!("[backend]: client received response to subscription: {:?}", result); + let (send_back, unsubscribe) = pending_subscriptions.remove(&request_id).unwrap(); + if let Err(err) = result { + let _ = send_back.send(Err(RequestError::Request(err))); + } else { + // TODO: what's a good limit here? way more tricky than it looks + let (notifs_tx, notifs_rx) = mpsc::channel(4); + if send_back.send(Ok(notifs_rx)).is_ok() { + active_subscriptions.insert(request_id, (notifs_tx, unsubscribe)); + } else { + client + .subscription_by_id(request_id) + .unwrap() + .into_active() + .unwrap() + .close(unsubscribe) + .await + .unwrap(); + } + } + } - Either::Right(Ok(RawClientEvent::SubscriptionNotif { request_id, result })) => { - // TODO: unsubscribe if channel is closed - let (notifs_tx, _) = active_subscriptions.get_mut(&request_id).unwrap(); - if notifs_tx.send(result).await.is_err() { - let (_, unsubscribe) = active_subscriptions.remove(&request_id).unwrap(); - client - .subscription_by_id(request_id) - .unwrap() - .into_active() - .unwrap() - .close(unsubscribe) - .await - .unwrap(); - } - } + Either::Right(Ok(RawClientEvent::SubscriptionNotif { request_id, result })) => { + // TODO: unsubscribe if channel is closed + let (notifs_tx, _) = active_subscriptions.get_mut(&request_id).unwrap(); + if notifs_tx.send(result).await.is_err() { + let (_, unsubscribe) = active_subscriptions.remove(&request_id).unwrap(); + client + .subscription_by_id(request_id) + .unwrap() + .into_active() + .unwrap() + .close(unsubscribe) + .await + .unwrap(); + } + } - // Request for the server to unsubscribe us has succeeded. - Either::Right(Ok(RawClientEvent::Unsubscribed { request_id: _ })) => {} + // Request for the server to unsubscribe us has succeeded. + Either::Right(Ok(RawClientEvent::Unsubscribed { request_id: _ })) => {} - Either::Right(Err(e)) => { - // TODO: https://github.com/paritytech/jsonrpsee/issues/67 - log::error!("Client Error: {:?}", e); - } - } - } + Either::Right(Err(e)) => { + // TODO: https://github.com/paritytech/jsonrpsee/issues/67 + log::error!("Client Error: {:?}", e); + } + } + } } diff --git a/src/client/ws/raw.rs b/src/client/ws/raw.rs index 1a0c86154a..7d38a84252 100644 --- a/src/client/ws/raw.rs +++ b/src/client/ws/raw.rs @@ -76,56 +76,56 @@ use hashbrown::{hash_map::Entry, HashMap}; /// /// See [the module root documentation](crate::client) for more information. pub struct RawClient { - /// Inner raw client. - inner: WsTransportClient, - - /// Id to assign to the next request. We always assign linearly-increasing numeric keys. - next_request_id: RawClientRequestId, - - /// List of requests and subscription requests that have been sent out and that are waiting - /// for a response. - requests: HashMap, - - /// List of active subscriptions by ID (ID is chosen by the server). Note that this doesn't - /// cover subscription requests that have been sent out but not answered yet, as these are in - /// the [`requests`](RawClient::requests) field. - /// - /// The value of this hash map is only ever used for external API purposes and not for - /// communication with the server. - /// - /// Since the keys are decided by the server, we use a regular HashMap and its - /// hash-collision-resistant algorithm. - subscriptions: HashMap, - - /// Queue of pending events to return from [`RawClient::next_event`]. - // TODO: call shrink_to from time to time; see https://github.com/rust-lang/rust/issues/56431 - events_queue: VecDeque, - - /// Maximum allowed size of [`RawClient::events_queue`]. - /// - /// If this size is reached, elements can still be pushed to the queue if they are critical, - /// but will be discarded if they are not. - // TODO: make this configurable? note: if this is configurable, it should always be >= 1 - events_queue_max_size: usize, + /// Inner raw client. + inner: WsTransportClient, + + /// Id to assign to the next request. We always assign linearly-increasing numeric keys. + next_request_id: RawClientRequestId, + + /// List of requests and subscription requests that have been sent out and that are waiting + /// for a response. + requests: HashMap, + + /// List of active subscriptions by ID (ID is chosen by the server). Note that this doesn't + /// cover subscription requests that have been sent out but not answered yet, as these are in + /// the [`requests`](RawClient::requests) field. + /// + /// The value of this hash map is only ever used for external API purposes and not for + /// communication with the server. + /// + /// Since the keys are decided by the server, we use a regular HashMap and its + /// hash-collision-resistant algorithm. + subscriptions: HashMap, + + /// Queue of pending events to return from [`RawClient::next_event`]. + // TODO: call shrink_to from time to time; see https://github.com/rust-lang/rust/issues/56431 + events_queue: VecDeque, + + /// Maximum allowed size of [`RawClient::events_queue`]. + /// + /// If this size is reached, elements can still be pushed to the queue if they are critical, + /// but will be discarded if they are not. + // TODO: make this configurable? note: if this is configurable, it should always be >= 1 + events_queue_max_size: usize, } /// Type of request that has been sent out and that is waiting for a response. #[derive(Debug, PartialEq, Eq)] enum Request { - /// A single request expecting a response. - Request, - /// A potential subscription. As a response, we expect a single subscription id. - PendingSubscription, - /// The request is stale and was originally used to open a subscription. The subscription ID - /// decided by the server is contained as parameter. - ActiveSubscription { - sub_id: String, - /// We sent a subscription closing message to the server. - closing: bool, - }, - /// Unsubscribing from an active subscription. The request corresponding to the active - /// subscription to unsubscribe from is contained as parameter. - Unsubscribe(RawClientRequestId), + /// A single request expecting a response. + Request, + /// A potential subscription. As a response, we expect a single subscription id. + PendingSubscription, + /// The request is stale and was originally used to open a subscription. The subscription ID + /// decided by the server is contained as parameter. + ActiveSubscription { + sub_id: String, + /// We sent a subscription closing message to the server. + closing: bool, + }, + /// Unsubscribing from an active subscription. The request corresponding to the active + /// subscription to unsubscribe from is contained as parameter. + Unsubscribe(RawClientRequestId), } /// Unique identifier of a request within a [`RawClient`]. @@ -135,635 +135,577 @@ pub struct RawClientRequestId(u64); /// Event returned by [`RawClient::next_event`]. #[derive(Debug)] pub enum RawClientEvent { - /// A request has received a response. - Response { - /// Identifier of the request. Can be matched with the value that [`RawClient::start_request`] - /// has returned. - request_id: RawClientRequestId, - /// The response itself. - result: Result, - }, - - /// A subscription request has received a response. - SubscriptionResponse { - /// Identifier of the request. Can be matched with the value that - /// [`RawClient::start_subscription`] has returned. - request_id: RawClientRequestId, - /// On success, we are now actively subscribed. - /// [`SubscriptionNotif`](RawClientEvent::SubscriptionNotif) events will now be generated. - result: Result<(), common::Error>, - }, - - /// Notification about something we are subscribed to. - SubscriptionNotif { - /// Identifier of the request. Can be matched with the value that - /// [`RawClient::start_subscription`] has returned. - request_id: RawClientRequestId, - /// Opaque data that the server wants to communicate to us. - result: common::JsonValue, - }, - - /// Finished closing a subscription. - Unsubscribed { - /// Subscription that has been closed. - request_id: RawClientRequestId, - }, + /// A request has received a response. + Response { + /// Identifier of the request. Can be matched with the value that [`RawClient::start_request`] + /// has returned. + request_id: RawClientRequestId, + /// The response itself. + result: Result, + }, + + /// A subscription request has received a response. + SubscriptionResponse { + /// Identifier of the request. Can be matched with the value that + /// [`RawClient::start_subscription`] has returned. + request_id: RawClientRequestId, + /// On success, we are now actively subscribed. + /// [`SubscriptionNotif`](RawClientEvent::SubscriptionNotif) events will now be generated. + result: Result<(), common::Error>, + }, + + /// Notification about something we are subscribed to. + SubscriptionNotif { + /// Identifier of the request. Can be matched with the value that + /// [`RawClient::start_subscription`] has returned. + request_id: RawClientRequestId, + /// Opaque data that the server wants to communicate to us. + result: common::JsonValue, + }, + + /// Finished closing a subscription. + Unsubscribed { + /// Subscription that has been closed. + request_id: RawClientRequestId, + }, } /// Access to a subscription within a [`RawClient`]. #[derive(Debug)] pub enum RawClientSubscription<'a> { - /// The server hasn't accepted our subscription request yet. - Pending(RawClientPendingSubscription<'a>), - /// The server has accepted our subscription request. We might receive notifications for it. - Active(RawClientActiveSubscription<'a>), + /// The server hasn't accepted our subscription request yet. + Pending(RawClientPendingSubscription<'a>), + /// The server has accepted our subscription request. We might receive notifications for it. + Active(RawClientActiveSubscription<'a>), } /// Access to a subscription within a [`RawClient`]. #[derive(Debug)] pub struct RawClientPendingSubscription<'a> { - /// Reference to the [`RawClient`]. - client: &'a mut RawClient, - /// Identifier of the subscription within the [`RawClient`]. - id: RawClientRequestId, + /// Reference to the [`RawClient`]. + client: &'a mut RawClient, + /// Identifier of the subscription within the [`RawClient`]. + id: RawClientRequestId, } /// Access to a subscription within a [`RawClient`]. #[derive(Debug)] pub struct RawClientActiveSubscription<'a> { - /// Reference to the [`RawClient`]. - client: &'a mut RawClient, - /// Identifier of the subscription within the [`RawClient`]. - id: RawClientRequestId, + /// Reference to the [`RawClient`]. + client: &'a mut RawClient, + /// Identifier of the subscription within the [`RawClient`]. + id: RawClientRequestId, } /// Error that can happen during a request. #[derive(Debug)] pub enum RawClientError { - /// Error in the raw client. - Inner(WsConnectError), - /// RawServer returned an error for our request. - RequestError(common::Error), - /// RawServer has sent back a subscription ID that has already been used by an earlier - /// subscription. - DuplicateSubscriptionId, - /// Failed to parse subscription ID send by server. - /// - /// On a successful subscription, the server is expected to send back a single number or - /// string representing the ID of the subscription. This error happens if the server returns - /// something else than a number or string. - SubscriptionIdParseError, - /// RawServer has sent back a response containing an unknown request ID. - UnknownRequestId, - /// RawServer has sent back a response containing a null request ID. - NullRequestId, - /// RawServer has sent back a notification using an unknown subscription ID. - UnknownSubscriptionId, + /// Error in the raw client. + Inner(WsConnectError), + /// RawServer returned an error for our request. + RequestError(common::Error), + /// RawServer has sent back a subscription ID that has already been used by an earlier + /// subscription. + DuplicateSubscriptionId, + /// Failed to parse subscription ID send by server. + /// + /// On a successful subscription, the server is expected to send back a single number or + /// string representing the ID of the subscription. This error happens if the server returns + /// something else than a number or string. + SubscriptionIdParseError, + /// RawServer has sent back a response containing an unknown request ID. + UnknownRequestId, + /// RawServer has sent back a response containing a null request ID. + NullRequestId, + /// RawServer has sent back a notification using an unknown subscription ID. + UnknownSubscriptionId, } /// Error that can happen when attempting to close a subscription. #[derive(Debug)] pub enum CloseError { - /// Error in the raw client. - TransportClient(WsConnectError), + /// Error in the raw client. + TransportClient(WsConnectError), - /// We are already trying to close this subscription. - AlreadyClosing, + /// We are already trying to close this subscription. + AlreadyClosing, } impl RawClient { - /// Initializes a new `RawClient` using the given raw client as backend. - pub fn new(inner: WsTransportClient) -> Self { - RawClient { - inner, - next_request_id: RawClientRequestId(0), - requests: HashMap::default(), - subscriptions: HashMap::default(), - events_queue: VecDeque::with_capacity(16), - events_queue_max_size: 64, - } - } + /// Initializes a new `RawClient` using the given raw client as backend. + pub fn new(inner: WsTransportClient) -> Self { + RawClient { + inner, + next_request_id: RawClientRequestId(0), + requests: HashMap::default(), + subscriptions: HashMap::default(), + events_queue: VecDeque::with_capacity(16), + events_queue_max_size: 64, + } + } } impl RawClient { - /// Sends a notification to the server. The notification doesn't need any response. - /// - /// This asynchronous function finishes when the notification has finished being sent. - pub async fn send_notification( - &mut self, - method: impl Into, - params: impl Into, - ) -> Result<(), WsConnectError> { - let request = common::Request::Single(common::Call::Notification(common::Notification { - jsonrpc: common::Version::V2, - method: method.into(), - params: params.into(), - })); - - self.inner.send_request(request).await?; - Ok(()) - } - - /// Starts a request. - /// - /// This asynchronous function finishes when the request has been sent to the server. The - /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) - /// until you get a response. - pub async fn start_request( - &mut self, - method: impl Into, - params: impl Into, - ) -> Result { - self.start_impl(method, params, Request::Request).await - } - - /// Starts a request. - /// - /// This asynchronous function finishes when the request has been sent to the server. The - /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) - /// until you get a response. - pub async fn start_subscription( - &mut self, - method: impl Into, - params: impl Into, - ) -> Result { - self.start_impl(method, params, Request::PendingSubscription) - .await - } - - /// Inner implementation for starting either a request or a subscription. - async fn start_impl( - &mut self, - method: impl Into, - params: impl Into, - ty: Request, - ) -> Result { - loop { - let id = self.next_request_id; - self.next_request_id.0 = self.next_request_id.0.wrapping_add(1); - - let entry = match self.requests.entry(id) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => e, - }; - - let request = common::Request::Single(common::Call::MethodCall(common::MethodCall { - jsonrpc: common::Version::V2, - method: method.into(), - params: params.into(), - id: common::Id::Num(id.0), - })); - - // Note that in case of an error, we "lose" the request id (as in, it will never be - // used). This isn't a problem, however. - self.inner.send_request(request).await?; - - entry.insert(ty); - break Ok(id); - } - } - - /// Waits until the client receives a message from the server. - /// - /// If this function returns an `Err`, it indicates a connectivity issue with the server or a - /// low-level protocol error, and not a request that has failed to be answered. - pub async fn next_event(&mut self) -> Result { - loop { - if let Some(event) = self.events_queue.pop_front() { - return Ok(event); - } - - self.event_step().await?; - } - } - - /// Returns a `Future` that resolves when the server sends back a response for the given - /// request. - /// - /// Returns `None` if the request identifier is invalid, or if the request is a subscription. - /// - /// > **Note**: While this function is waiting, all the other responses and pubsub events - /// > returned by the server will be buffered up to a certain limit. Once this - /// > limit is reached, server notifications will be discarded. If you want to be - /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) - /// > instead. - pub fn request_by_id<'a>( - &'a mut self, - rq_id: RawClientRequestId, - ) -> Option> + 'a> { - // First, let's check whether the request ID is valid. - if let Some(rq) = self.requests.get(&rq_id) { - if *rq != Request::Request { - return None; - } - } else { - return None; - } - - Some(async move { - let mut events_queue_loopkup = 0; - - loop { - while events_queue_loopkup < self.events_queue.len() { - match &self.events_queue[events_queue_loopkup] { - RawClientEvent::Response { request_id, .. } if *request_id == rq_id => { - return match self.events_queue.remove(events_queue_loopkup) { - Some(RawClientEvent::Response { result, .. }) => { - result.map_err(RawClientError::RequestError) - } - _ => unreachable!(), - } - } - _ => {} - } - - events_queue_loopkup += 1; - } - - self.event_step().await?; - } - }) - } - - /// Returns a [`RawClientSubscription`] object representing a certain active or pending - /// subscription. - /// - /// Returns `None` if the identifier is invalid, or if it is not a subscription. - pub fn subscription_by_id( - &mut self, - rq_id: RawClientRequestId, - ) -> Option { - match self.requests.get(&rq_id)? { - Request::PendingSubscription => { - debug_assert!(!self.subscriptions.values().any(|i| *i == rq_id)); - Some(RawClientSubscription::Pending( - RawClientPendingSubscription { - client: self, - id: rq_id, - }, - )) - } - - Request::ActiveSubscription { sub_id, .. } => { - debug_assert_eq!(self.subscriptions.get(sub_id), Some(&rq_id)); - Some(RawClientSubscription::Active(RawClientActiveSubscription { - client: self, - id: rq_id, - })) - } - - _ => None, - } - } - - /// Waits for one server message and processes it by updating the state of `self`. - /// - /// If the events queue is full (see [`RawClient::events_queue_max_size`]), then responses to - /// requests will still be pushed to the queue, but notifications will be discarded. - /// - /// Check the content of [`events_queue`](RawClient::events_queue) afterwards for events to - /// dispatch to the user. - async fn event_step(&mut self) -> Result<(), RawClientError> { - let result = self - .inner - .next_response() - .await - .map_err(RawClientError::Inner)?; - - match result { - common::Response::Single(rp) => self.process_response(rp)?, - common::Response::Batch(rps) => { - for rp in rps { - // TODO: if an error happens, we throw away the entire batch - self.process_response(rp)?; - } - } - common::Response::Notif(notif) => { - let sub_id = notif.params.subscription.into_string(); - if let Some(request_id) = self.subscriptions.get(&sub_id) { - if self.events_queue.len() < self.events_queue_max_size { - self.events_queue - .push_back(RawClientEvent::SubscriptionNotif { - request_id: *request_id, - result: notif.params.result, - }); - } - } else { - log::warn!( - "Server sent subscription notif with an invalid id: {:?}", - sub_id - ); - return Err(RawClientError::UnknownSubscriptionId); - } - } - } - - Ok(()) - } - - /// Processes the response obtained from the server. Updates the internal state of `self` to - /// account for it. - fn process_response(&mut self, response: common::Output) -> Result<(), RawClientError> { - log::debug!(target: "ws-client-raw", "received response: {:?}", response); - let request_id = match response.id() { - common::Id::Num(n) => RawClientRequestId(*n), - common::Id::Str(s) => { - log::warn!("Server responded with an invalid request id: {:?}", s); - return Err(RawClientError::UnknownRequestId); - } - common::Id::Null => { - log::warn!("Server responded with a null request id"); - return Err(RawClientError::NullRequestId); - } - }; - - // Find the request that this answered. - match self.requests.remove(&request_id) { - Some(Request::Request) => { - self.events_queue.push_back(RawClientEvent::Response { - result: response.into(), - request_id, - }); - } - - Some(Request::PendingSubscription) => { - let response = match Result::from(response) { - Ok(r) => r, - Err(err) => { - self.events_queue - .push_back(RawClientEvent::SubscriptionResponse { - result: Err(err), - request_id, - }); - return Ok(()); - } - }; - - let sub_id = match common::from_value::(response) { - Ok(id) => id.into_string(), - Err(err) => { - log::warn!("Failed to parse string subscription id: {:?}", err); - return Err(RawClientError::SubscriptionIdParseError); - } - }; - - match self.subscriptions.entry(sub_id.clone()) { - Entry::Vacant(e) => e.insert(request_id), - Entry::Occupied(e) => { - log::warn!("Duplicate subscription id sent by server: {:?}", e.key()); - return Err(RawClientError::DuplicateSubscriptionId); - } - }; - - self.requests.insert( - request_id, - Request::ActiveSubscription { - sub_id, - closing: false, - }, - ); - self.events_queue - .push_back(RawClientEvent::SubscriptionResponse { - result: Ok(()), - request_id, - }); - } - - Some(Request::Unsubscribe(active_sub_rq_id)) => { - match self.requests.remove(&active_sub_rq_id) { - Some(Request::ActiveSubscription { sub_id, .. }) => { - if self.subscriptions.remove(&sub_id).is_some() { - self.events_queue.push_back(RawClientEvent::Unsubscribed { - request_id: active_sub_rq_id, - }); - } else { - debug_assert!(false); - } - } - _ => debug_assert!(false), - } - } - - Some(v @ Request::ActiveSubscription { .. }) => { - self.requests.insert(request_id, v); - log::warn!( - "Server responsed with an invalid request id: {:?}", - request_id - ); - return Err(RawClientError::UnknownRequestId); - } - - None => { - log::warn!( - "Server responsed with an invalid request id: {:?}", - request_id - ); - return Err(RawClientError::UnknownRequestId); - } - }; - - Ok(()) - } + /// Sends a notification to the server. The notification doesn't need any response. + /// + /// This asynchronous function finishes when the notification has finished being sent. + pub async fn send_notification( + &mut self, + method: impl Into, + params: impl Into, + ) -> Result<(), WsConnectError> { + let request = common::Request::Single(common::Call::Notification(common::Notification { + jsonrpc: common::Version::V2, + method: method.into(), + params: params.into(), + })); + + self.inner.send_request(request).await?; + Ok(()) + } + + /// Starts a request. + /// + /// This asynchronous function finishes when the request has been sent to the server. The + /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) + /// until you get a response. + pub async fn start_request( + &mut self, + method: impl Into, + params: impl Into, + ) -> Result { + self.start_impl(method, params, Request::Request).await + } + + /// Starts a request. + /// + /// This asynchronous function finishes when the request has been sent to the server. The + /// request is added to the [`RawClient`]. You must then call [`next_event`](RawClient::next_event) + /// until you get a response. + pub async fn start_subscription( + &mut self, + method: impl Into, + params: impl Into, + ) -> Result { + self.start_impl(method, params, Request::PendingSubscription).await + } + + /// Inner implementation for starting either a request or a subscription. + async fn start_impl( + &mut self, + method: impl Into, + params: impl Into, + ty: Request, + ) -> Result { + loop { + let id = self.next_request_id; + self.next_request_id.0 = self.next_request_id.0.wrapping_add(1); + + let entry = match self.requests.entry(id) { + Entry::Occupied(_) => continue, + Entry::Vacant(e) => e, + }; + + let request = common::Request::Single(common::Call::MethodCall(common::MethodCall { + jsonrpc: common::Version::V2, + method: method.into(), + params: params.into(), + id: common::Id::Num(id.0), + })); + + // Note that in case of an error, we "lose" the request id (as in, it will never be + // used). This isn't a problem, however. + self.inner.send_request(request).await?; + + entry.insert(ty); + break Ok(id); + } + } + + /// Waits until the client receives a message from the server. + /// + /// If this function returns an `Err`, it indicates a connectivity issue with the server or a + /// low-level protocol error, and not a request that has failed to be answered. + pub async fn next_event(&mut self) -> Result { + loop { + if let Some(event) = self.events_queue.pop_front() { + return Ok(event); + } + + self.event_step().await?; + } + } + + /// Returns a `Future` that resolves when the server sends back a response for the given + /// request. + /// + /// Returns `None` if the request identifier is invalid, or if the request is a subscription. + /// + /// > **Note**: While this function is waiting, all the other responses and pubsub events + /// > returned by the server will be buffered up to a certain limit. Once this + /// > limit is reached, server notifications will be discarded. If you want to be + /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) + /// > instead. + pub fn request_by_id<'a>( + &'a mut self, + rq_id: RawClientRequestId, + ) -> Option> + 'a> { + // First, let's check whether the request ID is valid. + if let Some(rq) = self.requests.get(&rq_id) { + if *rq != Request::Request { + return None; + } + } else { + return None; + } + + Some(async move { + let mut events_queue_loopkup = 0; + + loop { + while events_queue_loopkup < self.events_queue.len() { + match &self.events_queue[events_queue_loopkup] { + RawClientEvent::Response { request_id, .. } if *request_id == rq_id => { + return match self.events_queue.remove(events_queue_loopkup) { + Some(RawClientEvent::Response { result, .. }) => { + result.map_err(RawClientError::RequestError) + } + _ => unreachable!(), + } + } + _ => {} + } + + events_queue_loopkup += 1; + } + + self.event_step().await?; + } + }) + } + + /// Returns a [`RawClientSubscription`] object representing a certain active or pending + /// subscription. + /// + /// Returns `None` if the identifier is invalid, or if it is not a subscription. + pub fn subscription_by_id(&mut self, rq_id: RawClientRequestId) -> Option { + match self.requests.get(&rq_id)? { + Request::PendingSubscription => { + debug_assert!(!self.subscriptions.values().any(|i| *i == rq_id)); + Some(RawClientSubscription::Pending(RawClientPendingSubscription { client: self, id: rq_id })) + } + + Request::ActiveSubscription { sub_id, .. } => { + debug_assert_eq!(self.subscriptions.get(sub_id), Some(&rq_id)); + Some(RawClientSubscription::Active(RawClientActiveSubscription { client: self, id: rq_id })) + } + + _ => None, + } + } + + /// Waits for one server message and processes it by updating the state of `self`. + /// + /// If the events queue is full (see [`RawClient::events_queue_max_size`]), then responses to + /// requests will still be pushed to the queue, but notifications will be discarded. + /// + /// Check the content of [`events_queue`](RawClient::events_queue) afterwards for events to + /// dispatch to the user. + async fn event_step(&mut self) -> Result<(), RawClientError> { + let result = self.inner.next_response().await.map_err(RawClientError::Inner)?; + + match result { + common::Response::Single(rp) => self.process_response(rp)?, + common::Response::Batch(rps) => { + for rp in rps { + // TODO: if an error happens, we throw away the entire batch + self.process_response(rp)?; + } + } + common::Response::Notif(notif) => { + let sub_id = notif.params.subscription.into_string(); + if let Some(request_id) = self.subscriptions.get(&sub_id) { + if self.events_queue.len() < self.events_queue_max_size { + self.events_queue.push_back(RawClientEvent::SubscriptionNotif { + request_id: *request_id, + result: notif.params.result, + }); + } + } else { + log::warn!("Server sent subscription notif with an invalid id: {:?}", sub_id); + return Err(RawClientError::UnknownSubscriptionId); + } + } + } + + Ok(()) + } + + /// Processes the response obtained from the server. Updates the internal state of `self` to + /// account for it. + fn process_response(&mut self, response: common::Output) -> Result<(), RawClientError> { + log::debug!(target: "ws-client-raw", "received response: {:?}", response); + let request_id = match response.id() { + common::Id::Num(n) => RawClientRequestId(*n), + common::Id::Str(s) => { + log::warn!("Server responded with an invalid request id: {:?}", s); + return Err(RawClientError::UnknownRequestId); + } + common::Id::Null => { + log::warn!("Server responded with a null request id"); + return Err(RawClientError::NullRequestId); + } + }; + + // Find the request that this answered. + match self.requests.remove(&request_id) { + Some(Request::Request) => { + self.events_queue.push_back(RawClientEvent::Response { result: response.into(), request_id }); + } + + Some(Request::PendingSubscription) => { + let response = match Result::from(response) { + Ok(r) => r, + Err(err) => { + self.events_queue + .push_back(RawClientEvent::SubscriptionResponse { result: Err(err), request_id }); + return Ok(()); + } + }; + + let sub_id = match common::from_value::(response) { + Ok(id) => id.into_string(), + Err(err) => { + log::warn!("Failed to parse string subscription id: {:?}", err); + return Err(RawClientError::SubscriptionIdParseError); + } + }; + + match self.subscriptions.entry(sub_id.clone()) { + Entry::Vacant(e) => e.insert(request_id), + Entry::Occupied(e) => { + log::warn!("Duplicate subscription id sent by server: {:?}", e.key()); + return Err(RawClientError::DuplicateSubscriptionId); + } + }; + + self.requests.insert(request_id, Request::ActiveSubscription { sub_id, closing: false }); + self.events_queue.push_back(RawClientEvent::SubscriptionResponse { result: Ok(()), request_id }); + } + + Some(Request::Unsubscribe(active_sub_rq_id)) => match self.requests.remove(&active_sub_rq_id) { + Some(Request::ActiveSubscription { sub_id, .. }) => { + if self.subscriptions.remove(&sub_id).is_some() { + self.events_queue.push_back(RawClientEvent::Unsubscribed { request_id: active_sub_rq_id }); + } else { + debug_assert!(false); + } + } + _ => debug_assert!(false), + }, + + Some(v @ Request::ActiveSubscription { .. }) => { + self.requests.insert(request_id, v); + log::warn!("Server responsed with an invalid request id: {:?}", request_id); + return Err(RawClientError::UnknownRequestId); + } + + None => { + log::warn!("Server responsed with an invalid request id: {:?}", request_id); + return Err(RawClientError::UnknownRequestId); + } + }; + + Ok(()) + } } impl fmt::Debug for RawClient { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("RawClient") - .field("inner", &self.inner) - .field("pending_requests", &self.requests.keys()) - .field("active_subscriptions", &self.subscriptions.keys()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("RawClient") + .field("inner", &self.inner) + .field("pending_requests", &self.requests.keys()) + .field("active_subscriptions", &self.subscriptions.keys()) + .finish() + } } impl<'a> RawClientSubscription<'a> { - /// Returns true if the subscription is active. That is, if the server has accepted our - /// subscription request and might generate events. - pub fn is_active(&self) -> bool { - match self { - RawClientSubscription::Pending(_) => false, - RawClientSubscription::Active(_) => true, - } - } - - /// If this subscription is active, returns the [`RawClientActiveSubscription`]. - pub fn into_active(self) -> Option> { - match self { - RawClientSubscription::Pending(_) => None, - RawClientSubscription::Active(s) => Some(s), - } - } + /// Returns true if the subscription is active. That is, if the server has accepted our + /// subscription request and might generate events. + pub fn is_active(&self) -> bool { + match self { + RawClientSubscription::Pending(_) => false, + RawClientSubscription::Active(_) => true, + } + } + + /// If this subscription is active, returns the [`RawClientActiveSubscription`]. + pub fn into_active(self) -> Option> { + match self { + RawClientSubscription::Pending(_) => None, + RawClientSubscription::Active(s) => Some(s), + } + } } impl<'a> RawClientPendingSubscription<'a> { - // TODO: since this is the only method, maybe we could replace `RawClientPendingSubscription` - // with an `impl Future` once the `impl Trait` feature is stabilized - /// Wait until the server sends back an answer to this subscription request. - /// - /// > **Note**: While this function is waiting, all the other responses and pubsub events - /// > returned by the server will be buffered up to a certain limit. Once this - /// > limit is reached, server notifications will be discarded. If you want to be - /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) - /// > instead. - pub async fn wait(self) -> Result, RawClientError> { - let mut events_queue_loopkup = 0; - - loop { - while events_queue_loopkup < self.client.events_queue.len() { - match &self.client.events_queue[events_queue_loopkup] { - RawClientEvent::SubscriptionResponse { request_id, .. } - if *request_id == self.id => - { - return match self.client.events_queue.remove(events_queue_loopkup) { - Some(RawClientEvent::SubscriptionResponse { - result: Ok(()), .. - }) => Ok(RawClientActiveSubscription { - client: self.client, - id: self.id, - }), - Some(RawClientEvent::SubscriptionResponse { - result: Err(err), .. - }) => Err(RawClientError::RequestError(err)), - _ => unreachable!(), - } - } - _ => {} - } - - events_queue_loopkup += 1; - } - - self.client.event_step().await?; - } - } + // TODO: since this is the only method, maybe we could replace `RawClientPendingSubscription` + // with an `impl Future` once the `impl Trait` feature is stabilized + /// Wait until the server sends back an answer to this subscription request. + /// + /// > **Note**: While this function is waiting, all the other responses and pubsub events + /// > returned by the server will be buffered up to a certain limit. Once this + /// > limit is reached, server notifications will be discarded. If you want to be + /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) + /// > instead. + pub async fn wait(self) -> Result, RawClientError> { + let mut events_queue_loopkup = 0; + + loop { + while events_queue_loopkup < self.client.events_queue.len() { + match &self.client.events_queue[events_queue_loopkup] { + RawClientEvent::SubscriptionResponse { request_id, .. } if *request_id == self.id => { + return match self.client.events_queue.remove(events_queue_loopkup) { + Some(RawClientEvent::SubscriptionResponse { result: Ok(()), .. }) => { + Ok(RawClientActiveSubscription { client: self.client, id: self.id }) + } + Some(RawClientEvent::SubscriptionResponse { result: Err(err), .. }) => { + Err(RawClientError::RequestError(err)) + } + _ => unreachable!(), + } + } + _ => {} + } + + events_queue_loopkup += 1; + } + + self.client.event_step().await?; + } + } } impl<'a> RawClientActiveSubscription<'a> { - /// Returns a `Future` that resolves when the server sends back a notification for this - /// subscription. - /// - /// > **Note**: While this function is waiting, all the other responses and pubsub events - /// > returned by the server will be buffered up to a certain limit. Once this - /// > limit is reached, server notifications will be discarded. If you want to be - /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) - /// > instead. - pub async fn next_notification(&mut self) -> Result { - let mut events_queue_loopkup = 0; - - loop { - while events_queue_loopkup < self.client.events_queue.len() { - match &self.client.events_queue[events_queue_loopkup] { - RawClientEvent::SubscriptionNotif { request_id, .. } - if *request_id == self.id => - { - return match self.client.events_queue.remove(events_queue_loopkup) { - Some(RawClientEvent::SubscriptionNotif { result, .. }) => Ok(result), - _ => unreachable!(), - } - } - _ => {} - } - - events_queue_loopkup += 1; - } - - self.client.event_step().await?; - } - } - - /// Returns `true` if we called [`close`](RawClientActiveSubscription::close) earlier on this - /// subscription and we are waiting for the server to respond to our close request. - pub fn is_closing(&self) -> bool { - match self.client.requests.get(&self.id) { - Some(Request::ActiveSubscription { closing, .. }) => *closing, - _ => panic!(), - } - } - - /// Starts closing an open subscription by performing an RPC call with the given method name. - /// - /// Calling this method multiple times with the same subscription will yield an error. - /// - /// Note that, for convenience, we will consider the subscription closed even the server - /// returns an error to the unsubscription request. - pub async fn close(&mut self, method_name: impl Into) -> Result<(), CloseError> { - let sub_id = match self.client.requests.get(&self.id) { - Some(Request::ActiveSubscription { sub_id, closing }) => { - if *closing { - return Err(CloseError::AlreadyClosing); - } - sub_id.clone() - } - _ => panic!(), - }; - - let params = common::Params::Array(vec![sub_id.clone().into()]); - self.client - .start_impl(method_name, params, Request::Unsubscribe(self.id)) - .await - .map_err(CloseError::TransportClient)?; - - match self.client.requests.get_mut(&self.id) { - Some(Request::ActiveSubscription { closing, .. }) => { - debug_assert!(!*closing); - *closing = true; - } - _ => panic!(), - }; - - Ok(()) - } + /// Returns a `Future` that resolves when the server sends back a notification for this + /// subscription. + /// + /// > **Note**: While this function is waiting, all the other responses and pubsub events + /// > returned by the server will be buffered up to a certain limit. Once this + /// > limit is reached, server notifications will be discarded. If you want to be + /// > sure to catch all notifications, use [`next_event`](RawClient::next_event) + /// > instead. + pub async fn next_notification(&mut self) -> Result { + let mut events_queue_loopkup = 0; + + loop { + while events_queue_loopkup < self.client.events_queue.len() { + match &self.client.events_queue[events_queue_loopkup] { + RawClientEvent::SubscriptionNotif { request_id, .. } if *request_id == self.id => { + return match self.client.events_queue.remove(events_queue_loopkup) { + Some(RawClientEvent::SubscriptionNotif { result, .. }) => Ok(result), + _ => unreachable!(), + } + } + _ => {} + } + + events_queue_loopkup += 1; + } + + self.client.event_step().await?; + } + } + + /// Returns `true` if we called [`close`](RawClientActiveSubscription::close) earlier on this + /// subscription and we are waiting for the server to respond to our close request. + pub fn is_closing(&self) -> bool { + match self.client.requests.get(&self.id) { + Some(Request::ActiveSubscription { closing, .. }) => *closing, + _ => panic!(), + } + } + + /// Starts closing an open subscription by performing an RPC call with the given method name. + /// + /// Calling this method multiple times with the same subscription will yield an error. + /// + /// Note that, for convenience, we will consider the subscription closed even the server + /// returns an error to the unsubscription request. + pub async fn close(&mut self, method_name: impl Into) -> Result<(), CloseError> { + let sub_id = match self.client.requests.get(&self.id) { + Some(Request::ActiveSubscription { sub_id, closing }) => { + if *closing { + return Err(CloseError::AlreadyClosing); + } + sub_id.clone() + } + _ => panic!(), + }; + + let params = common::Params::Array(vec![sub_id.clone().into()]); + self.client + .start_impl(method_name, params, Request::Unsubscribe(self.id)) + .await + .map_err(CloseError::TransportClient)?; + + match self.client.requests.get_mut(&self.id) { + Some(Request::ActiveSubscription { closing, .. }) => { + debug_assert!(!*closing); + *closing = true; + } + _ => panic!(), + }; + + Ok(()) + } } impl std::error::Error for RawClientError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - RawClientError::Inner(ref err) => Some(err), - RawClientError::RequestError(ref err) => Some(err), - RawClientError::DuplicateSubscriptionId => None, - RawClientError::SubscriptionIdParseError => None, - RawClientError::UnknownRequestId => None, - RawClientError::NullRequestId => None, - RawClientError::UnknownSubscriptionId => None, - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + RawClientError::Inner(ref err) => Some(err), + RawClientError::RequestError(ref err) => Some(err), + RawClientError::DuplicateSubscriptionId => None, + RawClientError::SubscriptionIdParseError => None, + RawClientError::UnknownRequestId => None, + RawClientError::NullRequestId => None, + RawClientError::UnknownSubscriptionId => None, + } + } } impl fmt::Display for RawClientError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - RawClientError::Inner(ref err) => write!(f, "Error in the raw client: {}", err), - RawClientError::RequestError(ref err) => write!(f, "Server returned error: {}", err), - RawClientError::DuplicateSubscriptionId => write!( - f, - "Server has responded with a subscription ID that's already in use" - ), - RawClientError::SubscriptionIdParseError => write!(f, "Subscription ID parse error"), - RawClientError::UnknownRequestId => { - write!(f, "Server responded with an unknown request ID") - } - RawClientError::NullRequestId => write!(f, "Server responded with a null request ID"), - RawClientError::UnknownSubscriptionId => { - write!(f, "Server responded with an unknown subscription ID") - } - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + RawClientError::Inner(ref err) => write!(f, "Error in the raw client: {}", err), + RawClientError::RequestError(ref err) => write!(f, "Server returned error: {}", err), + RawClientError::DuplicateSubscriptionId => { + write!(f, "Server has responded with a subscription ID that's already in use") + } + RawClientError::SubscriptionIdParseError => write!(f, "Subscription ID parse error"), + RawClientError::UnknownRequestId => write!(f, "Server responded with an unknown request ID"), + RawClientError::NullRequestId => write!(f, "Server responded with a null request ID"), + RawClientError::UnknownSubscriptionId => write!(f, "Server responded with an unknown subscription ID"), + } + } } impl std::error::Error for CloseError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - match self { - CloseError::TransportClient(err) => Some(err), - CloseError::AlreadyClosing => None, - } - } + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + CloseError::TransportClient(err) => Some(err), + CloseError::AlreadyClosing => None, + } + } } impl fmt::Display for CloseError { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - CloseError::TransportClient(err) => fmt::Display::fmt(err, f), - CloseError::AlreadyClosing => write!(f, "Subscription already being closed"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + CloseError::TransportClient(err) => fmt::Display::fmt(err, f), + CloseError::AlreadyClosing => write!(f, "Subscription already being closed"), + } + } } diff --git a/src/client/ws/stream.rs b/src/client/ws/stream.rs index 342937fe9c..b004267b29 100644 --- a/src/client/ws/stream.rs +++ b/src/client/ws/stream.rs @@ -27,8 +27,8 @@ //! Convenience wrapper for a stream (AsyncRead + AsyncWrite) which can either be plain TCP or TLS. use futures::{ - io::{IoSlice, IoSliceMut}, - prelude::*, + io::{IoSlice, IoSliceMut}, + prelude::*, }; use pin_project::pin_project; use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll}; @@ -36,78 +36,66 @@ use std::{io::Error as IoError, pin::Pin, task::Context, task::Poll}; #[pin_project(project = EitherStreamProj)] #[derive(Debug, Copy, Clone)] pub enum EitherStream { - /// Unencrypted socket stream. - Plain(#[pin] S), - /// Encrypted socket stream. - Tls(#[pin] T), + /// Unencrypted socket stream. + Plain(#[pin] S), + /// Encrypted socket stream. + Tls(#[pin] T), } impl AsyncRead for EitherStream where - S: AsyncRead, - T: AsyncRead, + S: AsyncRead, + T: AsyncRead, { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut [u8], - ) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncRead::poll_read(s, cx, buf), - EitherStreamProj::Tls(t) => AsyncRead::poll_read(t, cx, buf), - } - } + fn poll_read(self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8]) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncRead::poll_read(s, cx, buf), + EitherStreamProj::Tls(t) => AsyncRead::poll_read(t, cx, buf), + } + } - fn poll_read_vectored( - self: Pin<&mut Self>, - cx: &mut Context, - bufs: &mut [IoSliceMut], - ) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncRead::poll_read_vectored(s, cx, bufs), - EitherStreamProj::Tls(t) => AsyncRead::poll_read_vectored(t, cx, bufs), - } - } + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context, + bufs: &mut [IoSliceMut], + ) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncRead::poll_read_vectored(s, cx, bufs), + EitherStreamProj::Tls(t) => AsyncRead::poll_read_vectored(t, cx, bufs), + } + } } impl AsyncWrite for EitherStream where - S: AsyncWrite, - T: AsyncWrite, + S: AsyncWrite, + T: AsyncWrite, { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncWrite::poll_write(s, cx, buf), - EitherStreamProj::Tls(t) => AsyncWrite::poll_write(t, cx, buf), - } - } + fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncWrite::poll_write(s, cx, buf), + EitherStreamProj::Tls(t) => AsyncWrite::poll_write(t, cx, buf), + } + } - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context, - bufs: &[IoSlice], - ) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncWrite::poll_write_vectored(s, cx, bufs), - EitherStreamProj::Tls(t) => AsyncWrite::poll_write_vectored(t, cx, bufs), - } - } + fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice]) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncWrite::poll_write_vectored(s, cx, bufs), + EitherStreamProj::Tls(t) => AsyncWrite::poll_write_vectored(t, cx, bufs), + } + } - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncWrite::poll_flush(s, cx), - EitherStreamProj::Tls(t) => AsyncWrite::poll_flush(t, cx), - } - } + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncWrite::poll_flush(s, cx), + EitherStreamProj::Tls(t) => AsyncWrite::poll_flush(t, cx), + } + } - fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - match self.project() { - EitherStreamProj::Plain(s) => AsyncWrite::poll_close(s, cx), - EitherStreamProj::Tls(t) => AsyncWrite::poll_close(t, cx), - } - } + fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.project() { + EitherStreamProj::Plain(s) => AsyncWrite::poll_close(s, cx), + EitherStreamProj::Tls(t) => AsyncWrite::poll_close(t, cx), + } + } } diff --git a/src/client/ws/transport.rs b/src/client/ws/transport.rs index 32df33c485..918f4df5d4 100644 --- a/src/client/ws/transport.rs +++ b/src/client/ws/transport.rs @@ -38,300 +38,288 @@ type TlsOrPlain = crate::client::ws::stream::EitherStream, - /// Receiving half of a TCP/IP connection wrapped around a WebSocket decoder. - receiver: connection::Receiver, + /// Sending half of a TCP/IP connection wrapped around a WebSocket encoder. + sender: connection::Sender, + /// Receiving half of a TCP/IP connection wrapped around a WebSocket decoder. + receiver: connection::Receiver, } /// Builder for a [`WsTransportClient`]. pub struct WsTransportClientBuilder<'a> { - /// IP address to try to connect to. - target: SocketAddr, - /// Host to send during the WS handshake. - host: Cow<'a, str>, - /// DNS host name. - dns_name: Cow<'a, str>, - /// Stream mode, either plain TCP or TLS. - mode: Mode, - /// Url to send during the HTTP handshake. - url: Cow<'a, str>, - /// Timeout for the connection. - timeout: Duration, - /// `Origin` header to pass during the HTTP handshake. If `None`, no - /// `Origin` header is passed. - origin: Option>, + /// IP address to try to connect to. + target: SocketAddr, + /// Host to send during the WS handshake. + host: Cow<'a, str>, + /// DNS host name. + dns_name: Cow<'a, str>, + /// Stream mode, either plain TCP or TLS. + mode: Mode, + /// Url to send during the HTTP handshake. + url: Cow<'a, str>, + /// Timeout for the connection. + timeout: Duration, + /// `Origin` header to pass during the HTTP handshake. If `None`, no + /// `Origin` header is passed. + origin: Option>, } /// Stream mode, either plain TCP or TLS. #[derive(Clone, Copy, Debug)] pub enum Mode { - /// Plain mode (`ws://` URL). - Plain, - /// TLS mode (`wss://` URL). - Tls, + /// Plain mode (`ws://` URL). + Plain, + /// TLS mode (`wss://` URL). + Tls, } /// Error that can happen during the initial handshake. #[derive(Debug, Error)] pub enum WsNewError { - /// Error when opening the TCP socket. - #[error("Error when opening the TCP socket: {}", 0)] - Io(io::Error), - - /// Error in the WebSocket handshake. - #[error("Error in the WebSocket handshake: {}", 0)] - Handshake(#[source] soketto::handshake::Error), - - /// Invalid DNS name error for TLS - #[error("Invalid DNS name: {}", 0)] - InvalidDNSName(#[source] webpki::InvalidDNSNameError), - - /// RawServer rejected our handshake. - #[error("Server returned an error status code: {}", status_code)] - Rejected { - /// HTTP status code that the server returned. - status_code: u16, - }, - - /// Timeout while trying to connect. - #[error("Timeout when trying to connect")] - Timeout, + /// Error when opening the TCP socket. + #[error("Error when opening the TCP socket: {}", 0)] + Io(io::Error), + + /// Error in the WebSocket handshake. + #[error("Error in the WebSocket handshake: {}", 0)] + Handshake(#[source] soketto::handshake::Error), + + /// Invalid DNS name error for TLS + #[error("Invalid DNS name: {}", 0)] + InvalidDNSName(#[source] webpki::InvalidDNSNameError), + + /// RawServer rejected our handshake. + #[error("Server returned an error status code: {}", status_code)] + Rejected { + /// HTTP status code that the server returned. + status_code: u16, + }, + + /// Timeout while trying to connect. + #[error("Timeout when trying to connect")] + Timeout, } /// Error that can happen during the initial handshake. #[derive(Debug, Error)] pub enum WsNewDnsError { - /// Invalid URL. - #[error("Invalid url: {}", 0)] - Url(Cow<'static, str>), - - /// Error when trying to connect. - /// - /// If multiple IP addresses are attempted, only the last error is returned, similar to how - /// [`std::net::TcpStream::connect`] behaves. - #[error("Error when trying to connect: {}", 0)] - Connect(WsNewError), - - /// Failed to resolve IP addresses for this hostname. - #[error("Failed to resolve IP addresses for this hostname: {}", 0)] - ResolutionFailed(io::Error), - - /// Couldn't find any IP address for this hostname. - #[error("Couldn't find any IP address for this hostname")] - NoAddressFound, + /// Invalid URL. + #[error("Invalid url: {}", 0)] + Url(Cow<'static, str>), + + /// Error when trying to connect. + /// + /// If multiple IP addresses are attempted, only the last error is returned, similar to how + /// [`std::net::TcpStream::connect`] behaves. + #[error("Error when trying to connect: {}", 0)] + Connect(WsNewError), + + /// Failed to resolve IP addresses for this hostname. + #[error("Failed to resolve IP addresses for this hostname: {}", 0)] + ResolutionFailed(io::Error), + + /// Couldn't find any IP address for this hostname. + #[error("Couldn't find any IP address for this hostname")] + NoAddressFound, } /// Error that can happen during a request. #[derive(Debug, Error)] pub enum WsConnectError { - /// Error while serializing the request. - // TODO: can that happen? - #[error("error while serializing the request")] - Serialization(#[source] serde_json::error::Error), - - /// Error in the WebSocket connection. - #[error("error in the WebSocket connection")] - Ws(#[source] soketto::connection::Error), - - /// Failed to parse the JSON returned by the server into a JSON-RPC response. - #[error("error while parsing the response body")] - ParseError(#[source] serde_json::error::Error), + /// Error while serializing the request. + // TODO: can that happen? + #[error("error while serializing the request")] + Serialization(#[source] serde_json::error::Error), + + /// Error in the WebSocket connection. + #[error("error in the WebSocket connection")] + Ws(#[source] soketto::connection::Error), + + /// Failed to parse the JSON returned by the server into a JSON-RPC response. + #[error("error while parsing the response body")] + ParseError(#[source] serde_json::error::Error), } impl WsTransportClient { - /// Creates a new [`WsTransportClientBuilder`] containing the given address and hostname. - pub fn builder<'a>( - target: SocketAddr, - host: impl Into>, - dns_name: impl Into>, - mode: Mode, - ) -> WsTransportClientBuilder<'a> { - WsTransportClientBuilder { - target, - host: host.into(), - dns_name: dns_name.into(), - mode, - url: From::from("/"), - timeout: Duration::from_secs(10), - origin: None, - } - } - - /// Initializes a new WS client from a URL. - pub async fn new(target: &str) -> Result { - let url = url::Url::parse(target) - .map_err(|e| WsNewDnsError::Url(format!("Invalid URL: {}", e).into()))?; - let mode = match url.scheme() { - "ws" => Mode::Plain, - "wss" => Mode::Tls, - _ => { - return Err(WsNewDnsError::Url( - "URL scheme not supported, expects 'ws' or 'wss'".into(), - )) - } - }; - let host = url - .host_str() - .ok_or(WsNewDnsError::Url("No host in URL".into()))?; - let target = match url.port_or_known_default() { - Some(port) => format!("{}:{}", host, port), - None => host.to_string(), - }; - - let mut error = None; - - for url in target - .to_socket_addrs() - .await - .map_err(WsNewDnsError::ResolutionFailed)? - { - match Self::builder(url, &target, host, mode).build().await { - Ok(ws_raw_client) => return Ok(ws_raw_client), - Err(err) => error = Some(err), - } - } - - if let Some(error) = error { - Err(WsNewDnsError::Connect(error)) - } else { - Err(WsNewDnsError::NoAddressFound) - } - } + /// Creates a new [`WsTransportClientBuilder`] containing the given address and hostname. + pub fn builder<'a>( + target: SocketAddr, + host: impl Into>, + dns_name: impl Into>, + mode: Mode, + ) -> WsTransportClientBuilder<'a> { + WsTransportClientBuilder { + target, + host: host.into(), + dns_name: dns_name.into(), + mode, + url: From::from("/"), + timeout: Duration::from_secs(10), + origin: None, + } + } + + /// Initializes a new WS client from a URL. + pub async fn new(target: &str) -> Result { + let url = url::Url::parse(target).map_err(|e| WsNewDnsError::Url(format!("Invalid URL: {}", e).into()))?; + let mode = match url.scheme() { + "ws" => Mode::Plain, + "wss" => Mode::Tls, + _ => return Err(WsNewDnsError::Url("URL scheme not supported, expects 'ws' or 'wss'".into())), + }; + let host = url.host_str().ok_or(WsNewDnsError::Url("No host in URL".into()))?; + let target = match url.port_or_known_default() { + Some(port) => format!("{}:{}", host, port), + None => host.to_string(), + }; + + let mut error = None; + + for url in target.to_socket_addrs().await.map_err(WsNewDnsError::ResolutionFailed)? { + match Self::builder(url, &target, host, mode).build().await { + Ok(ws_raw_client) => return Ok(ws_raw_client), + Err(err) => error = Some(err), + } + } + + if let Some(error) = error { + Err(WsNewDnsError::Connect(error)) + } else { + Err(WsNewDnsError::NoAddressFound) + } + } } // former transport client impl, impl WsTransportClient { - /// Sends out out a request. Returns a `Future` that finishes when the request has been - /// successfully sent. - pub fn send_request<'a>( - &'a mut self, - request: common::Request, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - log::debug!("send request: {:?}", request); - let request = common::to_string(&request).map_err(WsConnectError::Serialization)?; - self.sender.send_text(request).await?; - self.sender.flush().await?; - Ok(()) - }) - } - - /// Returns a `Future` resolving when the server sent us something back. - pub fn next_response<'a>( - &'a mut self, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - let mut message = Vec::new(); - self.receiver.receive_data(&mut message).await?; - let response = common::from_slice(&message).map_err(WsConnectError::ParseError)?; - log::debug!("received response: {:?}", response); - Ok(response) - }) - } + /// Sends out out a request. Returns a `Future` that finishes when the request has been + /// successfully sent. + pub fn send_request<'a>( + &'a mut self, + request: common::Request, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + log::debug!("send request: {:?}", request); + let request = common::to_string(&request).map_err(WsConnectError::Serialization)?; + self.sender.send_text(request).await?; + self.sender.flush().await?; + Ok(()) + }) + } + + /// Returns a `Future` resolving when the server sent us something back. + pub fn next_response<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let mut message = Vec::new(); + self.receiver.receive_data(&mut message).await?; + let response = common::from_slice(&message).map_err(WsConnectError::ParseError)?; + log::debug!("received response: {:?}", response); + Ok(response) + }) + } } impl fmt::Debug for WsTransportClient { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_tuple("WsTransportClient").finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("WsTransportClient").finish() + } } impl<'a> WsTransportClientBuilder<'a> { - /// Sets the URL to pass during the HTTP handshake. - /// - /// The default URL is `/`. - pub fn with_url(mut self, url: impl Into>) -> Self { - self.url = url.into(); - self - } - - /// Sets the `Origin` header to pass during the HTTP handshake. - /// - /// By default, no `Origin` header is sent. - pub fn with_origin_header(mut self, origin: impl Into>) -> Self { - self.origin = Some(origin.into()); - self - } - - /// Sets the timeout to use when establishing the TCP connection. - /// - /// The default timeout is 10 seconds. - // TODO: design decision: should the timeout not be handled by the user? - pub fn with_timeout(mut self, timeout: Duration) -> Self { - self.timeout = timeout; - self - } - - /// Try establish the connection. - pub async fn build(self) -> Result { - // Try establish the TCP connection. - let tcp_stream = { - let socket = TcpStream::connect(self.target); - let timeout = async_std::task::sleep(self.timeout); - futures::pin_mut!(socket, timeout); - match future::select(socket, timeout).await { - future::Either::Left((socket, _)) => match self.mode { - Mode::Plain => TlsOrPlain::Plain(socket?), - Mode::Tls => { - let connector = async_tls::TlsConnector::default(); - let dns_name = webpki::DNSNameRef::try_from_ascii_str(&self.dns_name)?; - let tls_stream = connector.connect(&dns_name.to_owned(), socket?).await?; - TlsOrPlain::Tls(tls_stream) - } - }, - future::Either::Right((_, _)) => return Err(WsNewError::Timeout), - } - }; - - // Configure a WebSockets client on top. - let mut client = WsRawClient::new(tcp_stream, &self.host, &self.url); - if let Some(origin) = self.origin.as_ref() { - client.set_origin(origin); - } - - // Perform the initial handshake. - match client.handshake().await? { - ServerResponse::Accepted { .. } => {} - ServerResponse::Rejected { status_code } - | ServerResponse::Redirect { status_code, .. } => { - // TODO: HTTP redirects also lead here - return Err(WsNewError::Rejected { status_code }); - } - } - - // If the handshake succeeded, return. - let (sender, receiver) = client.into_builder().finish(); - Ok(WsTransportClient { sender, receiver }) - } + /// Sets the URL to pass during the HTTP handshake. + /// + /// The default URL is `/`. + pub fn with_url(mut self, url: impl Into>) -> Self { + self.url = url.into(); + self + } + + /// Sets the `Origin` header to pass during the HTTP handshake. + /// + /// By default, no `Origin` header is sent. + pub fn with_origin_header(mut self, origin: impl Into>) -> Self { + self.origin = Some(origin.into()); + self + } + + /// Sets the timeout to use when establishing the TCP connection. + /// + /// The default timeout is 10 seconds. + // TODO: design decision: should the timeout not be handled by the user? + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + /// Try establish the connection. + pub async fn build(self) -> Result { + // Try establish the TCP connection. + let tcp_stream = { + let socket = TcpStream::connect(self.target); + let timeout = async_std::task::sleep(self.timeout); + futures::pin_mut!(socket, timeout); + match future::select(socket, timeout).await { + future::Either::Left((socket, _)) => match self.mode { + Mode::Plain => TlsOrPlain::Plain(socket?), + Mode::Tls => { + let connector = async_tls::TlsConnector::default(); + let dns_name = webpki::DNSNameRef::try_from_ascii_str(&self.dns_name)?; + let tls_stream = connector.connect(&dns_name.to_owned(), socket?).await?; + TlsOrPlain::Tls(tls_stream) + } + }, + future::Either::Right((_, _)) => return Err(WsNewError::Timeout), + } + }; + + // Configure a WebSockets client on top. + let mut client = WsRawClient::new(tcp_stream, &self.host, &self.url); + if let Some(origin) = self.origin.as_ref() { + client.set_origin(origin); + } + + // Perform the initial handshake. + match client.handshake().await? { + ServerResponse::Accepted { .. } => {} + ServerResponse::Rejected { status_code } | ServerResponse::Redirect { status_code, .. } => { + // TODO: HTTP redirects also lead here + return Err(WsNewError::Rejected { status_code }); + } + } + + // If the handshake succeeded, return. + let (sender, receiver) = client.into_builder().finish(); + Ok(WsTransportClient { sender, receiver }) + } } impl From for WsNewError { - fn from(err: io::Error) -> WsNewError { - WsNewError::Io(err) - } + fn from(err: io::Error) -> WsNewError { + WsNewError::Io(err) + } } impl From for WsNewError { - fn from(err: webpki::InvalidDNSNameError) -> WsNewError { - WsNewError::InvalidDNSName(err) - } + fn from(err: webpki::InvalidDNSNameError) -> WsNewError { + WsNewError::InvalidDNSName(err) + } } impl From for WsNewError { - fn from(err: soketto::handshake::Error) -> WsNewError { - WsNewError::Handshake(err) - } + fn from(err: soketto::handshake::Error) -> WsNewError { + WsNewError::Handshake(err) + } } impl From for WsNewDnsError { - fn from(err: WsNewError) -> WsNewDnsError { - WsNewDnsError::Connect(err) - } + fn from(err: WsNewError) -> WsNewDnsError { + WsNewDnsError::Connect(err) + } } impl From for WsConnectError { - fn from(err: soketto::connection::Error) -> Self { - WsConnectError::Ws(err) - } + fn from(err: soketto::connection::Error) -> Self { + WsConnectError::Ws(err) + } } diff --git a/src/common/error.rs b/src/common/error.rs index 07ea72ff7a..184435a325 100644 --- a/src/common/error.rs +++ b/src/common/error.rs @@ -27,9 +27,9 @@ use super::JsonValue; use alloc::{ - borrow::ToOwned as _, - format, - string::{String, ToString as _}, + borrow::ToOwned as _, + format, + string::{String, ToString as _}, }; use core::fmt; use serde::de::Deserializer; @@ -39,173 +39,165 @@ use serde::{Deserialize, Serialize}; /// JSONRPC error code #[derive(Debug, PartialEq, Clone)] pub enum ErrorCode { - /// Invalid JSON was received by the server. - /// An error occurred on the server while parsing the JSON text. - ParseError, - /// The JSON sent is not a valid Request object. - InvalidRequest, - /// The method does not exist / is not available. - MethodNotFound, - /// Invalid method parameter(s). - InvalidParams, - /// Internal JSON-RPC error. - InternalError, - /// Reserved for implementation-defined server-errors. - ServerError(i64), - /// Error returned by called method - MethodError(i64), + /// Invalid JSON was received by the server. + /// An error occurred on the server while parsing the JSON text. + ParseError, + /// The JSON sent is not a valid Request object. + InvalidRequest, + /// The method does not exist / is not available. + MethodNotFound, + /// Invalid method parameter(s). + InvalidParams, + /// Internal JSON-RPC error. + InternalError, + /// Reserved for implementation-defined server-errors. + ServerError(i64), + /// Error returned by called method + MethodError(i64), } impl ErrorCode { - /// Returns integer code value - pub fn code(&self) -> i64 { - match *self { - ErrorCode::ParseError => -32700, - ErrorCode::InvalidRequest => -32600, - ErrorCode::MethodNotFound => -32601, - ErrorCode::InvalidParams => -32602, - ErrorCode::InternalError => -32603, - ErrorCode::ServerError(code) => code, - ErrorCode::MethodError(code) => code, - } - } - - /// Returns human-readable description - pub fn description(&self) -> String { - let desc = match *self { - ErrorCode::ParseError => "Parse error", - ErrorCode::InvalidRequest => "Invalid request", - ErrorCode::MethodNotFound => "Method not found", - ErrorCode::InvalidParams => "Invalid params", - ErrorCode::InternalError => "Internal error", - ErrorCode::ServerError(_) => "Server error", - ErrorCode::MethodError(_) => "Method error", - }; - desc.to_string() - } + /// Returns integer code value + pub fn code(&self) -> i64 { + match *self { + ErrorCode::ParseError => -32700, + ErrorCode::InvalidRequest => -32600, + ErrorCode::MethodNotFound => -32601, + ErrorCode::InvalidParams => -32602, + ErrorCode::InternalError => -32603, + ErrorCode::ServerError(code) => code, + ErrorCode::MethodError(code) => code, + } + } + + /// Returns human-readable description + pub fn description(&self) -> String { + let desc = match *self { + ErrorCode::ParseError => "Parse error", + ErrorCode::InvalidRequest => "Invalid request", + ErrorCode::MethodNotFound => "Method not found", + ErrorCode::InvalidParams => "Invalid params", + ErrorCode::InternalError => "Internal error", + ErrorCode::ServerError(_) => "Server error", + ErrorCode::MethodError(_) => "Method error", + }; + desc.to_string() + } } impl From for ErrorCode { - fn from(code: i64) -> Self { - match code { - -32700 => ErrorCode::ParseError, - -32600 => ErrorCode::InvalidRequest, - -32601 => ErrorCode::MethodNotFound, - -32602 => ErrorCode::InvalidParams, - -32603 => ErrorCode::InternalError, - -32099..=-32000 => ErrorCode::ServerError(code), - code => ErrorCode::MethodError(code), - } - } + fn from(code: i64) -> Self { + match code { + -32700 => ErrorCode::ParseError, + -32600 => ErrorCode::InvalidRequest, + -32601 => ErrorCode::MethodNotFound, + -32602 => ErrorCode::InvalidParams, + -32603 => ErrorCode::InternalError, + -32099..=-32000 => ErrorCode::ServerError(code), + code => ErrorCode::MethodError(code), + } + } } impl<'a> serde::Deserialize<'a> for ErrorCode { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'a>, - { - let code: i64 = serde::Deserialize::deserialize(deserializer)?; - Ok(ErrorCode::from(code)) - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'a>, + { + let code: i64 = serde::Deserialize::deserialize(deserializer)?; + Ok(ErrorCode::from(code)) + } } impl serde::Serialize for ErrorCode { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_i64(self.code()) - } + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_i64(self.code()) + } } /// Error object as defined in Spec #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Error { - /// Code - pub code: ErrorCode, - /// Message - pub message: String, - /// Optional data - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, + /// Code + pub code: ErrorCode, + /// Message + pub message: String, + /// Optional data + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, } impl Error { - /// Wraps given `ErrorCode` - pub fn new(code: ErrorCode) -> Self { - Error { - message: code.description(), - code, - data: None, - } - } - - /// Creates new `ParseError` - pub fn parse_error() -> Self { - Self::new(ErrorCode::ParseError) - } - - /// Creates new `InvalidRequest` - pub fn invalid_request() -> Self { - Self::new(ErrorCode::InvalidRequest) - } - - /// Creates new `MethodNotFound` - pub fn method_not_found() -> Self { - Self::new(ErrorCode::MethodNotFound) - } - - /// Creates new `InvalidParams` - pub fn invalid_params(message: M) -> Self - where - M: Into, - { - Error { - code: ErrorCode::InvalidParams, - message: message.into(), - data: None, - } - } - - /// Creates `InvalidParams` for given parameter, with details. - pub fn invalid_params_with_details(message: M, details: T) -> Error - where - M: Into, - T: fmt::Debug, - { - Error { - code: ErrorCode::InvalidParams, - message: format!("Invalid parameters: {}", message.into()), - data: Some(JsonValue::String(format!("{:?}", details))), - } - } - - /// Creates new `InternalError` - pub fn internal_error() -> Self { - Self::new(ErrorCode::InternalError) - } - - /// Creates new `InvalidRequest` with invalid version description - pub fn invalid_version() -> Self { - Error { - code: ErrorCode::InvalidRequest, - message: "Unsupported JSON-RPC protocol version".to_owned(), - data: None, - } - } + /// Wraps given `ErrorCode` + pub fn new(code: ErrorCode) -> Self { + Error { message: code.description(), code, data: None } + } + + /// Creates new `ParseError` + pub fn parse_error() -> Self { + Self::new(ErrorCode::ParseError) + } + + /// Creates new `InvalidRequest` + pub fn invalid_request() -> Self { + Self::new(ErrorCode::InvalidRequest) + } + + /// Creates new `MethodNotFound` + pub fn method_not_found() -> Self { + Self::new(ErrorCode::MethodNotFound) + } + + /// Creates new `InvalidParams` + pub fn invalid_params(message: M) -> Self + where + M: Into, + { + Error { code: ErrorCode::InvalidParams, message: message.into(), data: None } + } + + /// Creates `InvalidParams` for given parameter, with details. + pub fn invalid_params_with_details(message: M, details: T) -> Error + where + M: Into, + T: fmt::Debug, + { + Error { + code: ErrorCode::InvalidParams, + message: format!("Invalid parameters: {}", message.into()), + data: Some(JsonValue::String(format!("{:?}", details))), + } + } + + /// Creates new `InternalError` + pub fn internal_error() -> Self { + Self::new(ErrorCode::InternalError) + } + + /// Creates new `InvalidRequest` with invalid version description + pub fn invalid_version() -> Self { + Error { + code: ErrorCode::InvalidRequest, + message: "Unsupported JSON-RPC protocol version".to_owned(), + data: None, + } + } } impl From for Error { - fn from(code: ErrorCode) -> Error { - Error::new(code) - } + fn from(code: ErrorCode) -> Error { + Error::new(code) + } } impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}: {}", self.code.description(), self.message) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}: {}", self.code.description(), self.message) + } } impl std::error::Error for Error {} diff --git a/src/common/id.rs b/src/common/id.rs index bcbd61622f..4ead40d488 100644 --- a/src/common/id.rs +++ b/src/common/id.rs @@ -32,52 +32,42 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Id { - /// No id (notification) - Null, - /// Numeric id - Num(u64), - /// String id - Str(String), + /// No id (notification) + Null, + /// Numeric id + Num(u64), + /// String id + Str(String), } #[cfg(test)] mod tests { - use super::*; - use serde_json; + use super::*; + use serde_json; - #[test] - fn id_deserialization() { - let s = r#""2""#; - let deserialized: Id = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Id::Str("2".into())); + #[test] + fn id_deserialization() { + let s = r#""2""#; + let deserialized: Id = serde_json::from_str(s).unwrap(); + assert_eq!(deserialized, Id::Str("2".into())); - let s = r#"2"#; - let deserialized: Id = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Id::Num(2)); + let s = r#"2"#; + let deserialized: Id = serde_json::from_str(s).unwrap(); + assert_eq!(deserialized, Id::Num(2)); - let s = r#""2x""#; - let deserialized: Id = serde_json::from_str(s).unwrap(); - assert_eq!(deserialized, Id::Str("2x".to_owned())); + let s = r#""2x""#; + let deserialized: Id = serde_json::from_str(s).unwrap(); + assert_eq!(deserialized, Id::Str("2x".to_owned())); - let s = r#"[null, 0, 2, "3"]"#; - let deserialized: Vec = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - vec![Id::Null, Id::Num(0), Id::Num(2), Id::Str("3".into())] - ); - } + let s = r#"[null, 0, 2, "3"]"#; + let deserialized: Vec = serde_json::from_str(s).unwrap(); + assert_eq!(deserialized, vec![Id::Null, Id::Num(0), Id::Num(2), Id::Str("3".into())]); + } - #[test] - fn id_serialization() { - let d = vec![ - Id::Null, - Id::Num(0), - Id::Num(2), - Id::Num(3), - Id::Str("3".to_owned()), - Id::Str("test".to_owned()), - ]; - let serialized = serde_json::to_string(&d).unwrap(); - assert_eq!(serialized, r#"[null,0,2,3,"3","test"]"#); - } + #[test] + fn id_serialization() { + let d = vec![Id::Null, Id::Num(0), Id::Num(2), Id::Num(3), Id::Str("3".to_owned()), Id::Str("test".to_owned())]; + let serialized = serde_json::to_string(&d).unwrap(); + assert_eq!(serialized, r#"[null,0,2,3,"3","test"]"#); + } } diff --git a/src/common/mod.rs b/src/common/mod.rs index ba74b6dea4..0fb9b03600 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -48,6 +48,6 @@ pub use self::id::Id; pub use self::params::Params; pub use self::request::{Call, MethodCall, Notification, Request}; pub use self::response::{ - Failure, Output, Response, SubscriptionId, SubscriptionNotif, SubscriptionNotifParams, Success, + Failure, Output, Response, SubscriptionId, SubscriptionNotif, SubscriptionNotifParams, Success, }; pub use self::version::Version; diff --git a/src/common/params.rs b/src/common/params.rs index 10a0c56394..354b2f0c85 100644 --- a/src/common/params.rs +++ b/src/common/params.rs @@ -37,107 +37,98 @@ use super::{Error, JsonValue}; #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Params { - /// No parameters - None, - /// Array of values - Array(Vec), - /// Map of values - Map(serde_json::Map), + /// No parameters + None, + /// Array of values + Array(Vec), + /// Map of values + Map(serde_json::Map), } impl Params { - /// Parse incoming `Params` into expected common. - pub fn parse(self) -> Result - where - D: DeserializeOwned, - { - let value: JsonValue = self.into(); - from_value(value).map_err(|e| Error::invalid_params(format!("Invalid params: {}.", e))) - } + /// Parse incoming `Params` into expected common. + pub fn parse(self) -> Result + where + D: DeserializeOwned, + { + let value: JsonValue = self.into(); + from_value(value).map_err(|e| Error::invalid_params(format!("Invalid params: {}.", e))) + } - /// Check for no params, returns Err if any params - pub fn expect_no_params(self) -> Result<(), Error> { - match self { - Params::None => Ok(()), - Params::Array(ref v) if v.is_empty() => Ok(()), - p => Err(Error::invalid_params_with_details( - "No parameters were expected", - p, - )), - } - } + /// Check for no params, returns Err if any params + pub fn expect_no_params(self) -> Result<(), Error> { + match self { + Params::None => Ok(()), + Params::Array(ref v) if v.is_empty() => Ok(()), + p => Err(Error::invalid_params_with_details("No parameters were expected", p)), + } + } } impl From for JsonValue { - fn from(params: Params) -> JsonValue { - match params { - Params::Array(vec) => JsonValue::Array(vec), - Params::Map(map) => JsonValue::Object(map), - Params::None => JsonValue::Null, - } - } + fn from(params: Params) -> JsonValue { + match params { + Params::Array(vec) => JsonValue::Array(vec), + Params::Map(map) => JsonValue::Object(map), + Params::None => JsonValue::Null, + } + } } #[cfg(test)] mod tests { - use super::Params; - use crate::common::{Error, ErrorCode, JsonValue}; - use serde_json; + use super::Params; + use crate::common::{Error, ErrorCode, JsonValue}; + use serde_json; - #[test] - fn params_deserialization() { - let s = r#"[null, true, -1, 4, 2.3, "hello", [0], {"key": "value"}, []]"#; - let deserialized: Params = serde_json::from_str(s).unwrap(); + #[test] + fn params_deserialization() { + let s = r#"[null, true, -1, 4, 2.3, "hello", [0], {"key": "value"}, []]"#; + let deserialized: Params = serde_json::from_str(s).unwrap(); - let mut map = serde_json::Map::new(); - map.insert("key".to_string(), JsonValue::String("value".to_string())); + let mut map = serde_json::Map::new(); + map.insert("key".to_string(), JsonValue::String("value".to_string())); - assert_eq!( - Params::Array(vec![ - JsonValue::Null, - JsonValue::Bool(true), - JsonValue::from(-1), - JsonValue::from(4), - JsonValue::from(2.3), - JsonValue::String("hello".to_string()), - JsonValue::Array(vec![JsonValue::from(0)]), - JsonValue::Object(map), - JsonValue::Array(vec![]), - ]), - deserialized - ); - } + assert_eq!( + Params::Array(vec![ + JsonValue::Null, + JsonValue::Bool(true), + JsonValue::from(-1), + JsonValue::from(4), + JsonValue::from(2.3), + JsonValue::String("hello".to_string()), + JsonValue::Array(vec![JsonValue::from(0)]), + JsonValue::Object(map), + JsonValue::Array(vec![]), + ]), + deserialized + ); + } - #[test] - fn should_return_meaningful_error_when_deserialization_fails() { - // given - let s = r#"[1, true]"#; - let params = || serde_json::from_str::(s).unwrap(); + #[test] + fn should_return_meaningful_error_when_deserialization_fails() { + // given + let s = r#"[1, true]"#; + let params = || serde_json::from_str::(s).unwrap(); - // when - let v1: Result<(Option, String), Error> = params().parse(); - let v2: Result<(u8, bool, String), Error> = params().parse(); - let err1 = v1.unwrap_err(); - let err2 = v2.unwrap_err(); + // when + let v1: Result<(Option, String), Error> = params().parse(); + let v2: Result<(u8, bool, String), Error> = params().parse(); + let err1 = v1.unwrap_err(); + let err2 = v2.unwrap_err(); - // then - assert_eq!(err1.code, ErrorCode::InvalidParams); - assert_eq!( - err1.message, - "Invalid params: invalid type: boolean `true`, expected a string." - ); - assert_eq!(err1.data, None); - assert_eq!(err2.code, ErrorCode::InvalidParams); - assert_eq!( - err2.message, - "Invalid params: invalid length 2, expected a tuple of size 3." - ); - assert_eq!(err2.data, None); - } + // then + assert_eq!(err1.code, ErrorCode::InvalidParams); + assert_eq!(err1.message, "Invalid params: invalid type: boolean `true`, expected a string."); + assert_eq!(err1.data, None); + assert_eq!(err2.code, ErrorCode::InvalidParams); + assert_eq!(err2.message, "Invalid params: invalid length 2, expected a tuple of size 3."); + assert_eq!(err2.data, None); + } - #[test] - fn single_param_parsed_as_tuple() { - let params: (u64,) = Params::Array(vec![JsonValue::from(1)]).parse().unwrap(); - assert_eq!(params, (1,)); - } + #[test] + fn single_param_parsed_as_tuple() { + let params: (u64,) = Params::Array(vec![JsonValue::from(1)]).parse().unwrap(); + assert_eq!(params, (1,)); + } } diff --git a/src/common/request.rs b/src/common/request.rs index 2e0237389f..5595fa4a63 100644 --- a/src/common/request.rs +++ b/src/common/request.rs @@ -33,68 +33,68 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct MethodCall { - /// A String specifying the version of the JSON-RPC protocol. - pub jsonrpc: Version, - /// A String containing the name of the method to be invoked. - pub method: String, - /// A Structured value that holds the parameter values to be used - /// during the invocation of the method. This member MAY be omitted. - #[serde(default = "default_params")] - pub params: Params, - /// An identifier established by the Client that MUST contain a String, - /// Number, or NULL value if included. If it is not included it is assumed - /// to be a notification. - pub id: Id, + /// A String specifying the version of the JSON-RPC protocol. + pub jsonrpc: Version, + /// A String containing the name of the method to be invoked. + pub method: String, + /// A Structured value that holds the parameter values to be used + /// during the invocation of the method. This member MAY be omitted. + #[serde(default = "default_params")] + pub params: Params, + /// An identifier established by the Client that MUST contain a String, + /// Number, or NULL value if included. If it is not included it is assumed + /// to be a notification. + pub id: Id, } /// Represents jsonrpc request which is a notification. #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(deny_unknown_fields)] pub struct Notification { - /// A String specifying the version of the JSON-RPC protocol. - pub jsonrpc: Version, - /// A String containing the name of the method to be invoked. - pub method: String, - /// A Structured value that holds the parameter values to be used - /// during the invocation of the method. This member MAY be omitted. - #[serde(default = "default_params")] - pub params: Params, + /// A String specifying the version of the JSON-RPC protocol. + pub jsonrpc: Version, + /// A String containing the name of the method to be invoked. + pub method: String, + /// A Structured value that holds the parameter values to be used + /// during the invocation of the method. This member MAY be omitted. + #[serde(default = "default_params")] + pub params: Params, } /// Represents single jsonrpc call. #[derive(Clone, Debug, PartialEq, Deserialize, Serialize)] #[serde(untagged)] pub enum Call { - /// Call method - MethodCall(MethodCall), - /// Fire notification - Notification(Notification), - /// Invalid call - Invalid { - /// Call id (if known) - #[serde(default = "default_id")] - id: Id, - }, + /// Call method + MethodCall(MethodCall), + /// Fire notification + Notification(Notification), + /// Invalid call + Invalid { + /// Call id (if known) + #[serde(default = "default_id")] + id: Id, + }, } fn default_params() -> Params { - Params::None + Params::None } fn default_id() -> Id { - Id::Null + Id::Null } impl From for Call { - fn from(mc: MethodCall) -> Self { - Call::MethodCall(mc) - } + fn from(mc: MethodCall) -> Self { + Call::MethodCall(mc) + } } impl From for Call { - fn from(n: Notification) -> Self { - Call::Notification(n) - } + fn from(n: Notification) -> Self { + Call::Notification(n) + } } /// Represents jsonrpc request. @@ -102,230 +102,217 @@ impl From for Call { #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Request { - /// Single request (call) - Single(Call), - /// Batch of requests (calls) - Batch(Vec), + /// Single request (call) + Single(Call), + /// Batch of requests (calls) + Batch(Vec), } #[cfg(test)] mod tests { - use super::*; - use serde_json::Value; - - #[test] - fn method_call_serialize() { - use serde_json; - use serde_json::Value; - - let m = MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1), - }; - - let serialized = serde_json::to_string(&m).unwrap(); - assert_eq!( - serialized, - r#"{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1}"# - ); - } - - #[test] - fn notification_serialize() { - use serde_json; - use serde_json::Value; - - let n = Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - }; - - let serialized = serde_json::to_string(&n).unwrap(); - assert_eq!( - serialized, - r#"{"jsonrpc":"2.0","method":"update","params":[1,2]}"# - ); - } - - #[test] - fn call_serialize() { - use serde_json; - use serde_json::Value; - - let n = Call::Notification(Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]), - }); - - let serialized = serde_json::to_string(&n).unwrap(); - assert_eq!( - serialized, - r#"{"jsonrpc":"2.0","method":"update","params":[1]}"# - ); - } - - #[test] - fn request_serialize_batch() { - use serde_json; - - let batch = Request::Batch(vec![ - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1), - }), - Call::Notification(Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]), - }), - ]); - - let serialized = serde_json::to_string(&batch).unwrap(); - assert_eq!( - serialized, - r#"[{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1},{"jsonrpc":"2.0","method":"update","params":[1]}]"# - ); - } - - #[test] - fn notification_deserialize() { - use serde_json; - use serde_json::Value; - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2]}"#; - let deserialized: Notification = serde_json::from_str(s).unwrap(); - - assert_eq!( - deserialized, - Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]) - } - ); - - let s = r#"{"jsonrpc": "2.0", "method": "foobar"}"#; - let deserialized: Notification = serde_json::from_str(s).unwrap(); - - assert_eq!( - deserialized, - Notification { - jsonrpc: Version::V2, - method: "foobar".to_owned(), - params: Params::None, - } - ); - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1}"#; - let deserialized: Result = serde_json::from_str(s); - assert!(deserialized.is_err()); - } - - #[test] - fn call_deserialize() { - use serde_json; - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1]}"#; - let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Call::Notification(Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) - }) - ); - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1], "id": 1}"#; - let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]), - id: Id::Num(1) - }) - ); - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [], "id": 1}"#; - let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![]), - id: Id::Num(1) - }) - ); - - let s = r#"{"jsonrpc": "2.0", "method": "update", "params": null, "id": 1}"#; - let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::None, - id: Id::Num(1) - }) - ); - - let s = r#"{"jsonrpc": "2.0", "method": "update", "id": 1}"#; - let deserialized: Call = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::None, - id: Id::Num(1) - }) - ); - } - - #[test] - fn request_deserialize_batch() { - use serde_json; - - let s = r#"[{}, {"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1},{"jsonrpc": "2.0", "method": "update", "params": [1]}]"#; - let deserialized: Request = serde_json::from_str(s).unwrap(); - assert_eq!( - deserialized, - Request::Batch(vec![ - Call::Invalid { id: Id::Null }, - Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: Id::Num(1) - }), - Call::Notification(Notification { - jsonrpc: Version::V2, - method: "update".to_owned(), - params: Params::Array(vec![Value::from(1)]) - }) - ]) - ) - } - - #[test] - fn request_invalid_returns_id() { - use serde_json; - - let s = r#"{"id":120,"method":"my_method","params":["foo", "bar"],"extra_field":[]}"#; - let deserialized: Request = serde_json::from_str(s).unwrap(); - - match deserialized { - Request::Single(Call::Invalid { id: Id::Num(120) }) => {} - _ => panic!("Request wrongly deserialized: {:?}", deserialized), - } - } + use super::*; + use serde_json::Value; + + #[test] + fn method_call_serialize() { + use serde_json; + use serde_json::Value; + + let m = MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: Id::Num(1), + }; + + let serialized = serde_json::to_string(&m).unwrap(); + assert_eq!(serialized, r#"{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1}"#); + } + + #[test] + fn notification_serialize() { + use serde_json; + use serde_json::Value; + + let n = Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + }; + + let serialized = serde_json::to_string(&n).unwrap(); + assert_eq!(serialized, r#"{"jsonrpc":"2.0","method":"update","params":[1,2]}"#); + } + + #[test] + fn call_serialize() { + use serde_json; + use serde_json::Value; + + let n = Call::Notification(Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]), + }); + + let serialized = serde_json::to_string(&n).unwrap(); + assert_eq!(serialized, r#"{"jsonrpc":"2.0","method":"update","params":[1]}"#); + } + + #[test] + fn request_serialize_batch() { + use serde_json; + + let batch = Request::Batch(vec![ + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: Id::Num(1), + }), + Call::Notification(Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]), + }), + ]); + + let serialized = serde_json::to_string(&batch).unwrap(); + assert_eq!( + serialized, + r#"[{"jsonrpc":"2.0","method":"update","params":[1,2],"id":1},{"jsonrpc":"2.0","method":"update","params":[1]}]"# + ); + } + + #[test] + fn notification_deserialize() { + use serde_json; + use serde_json::Value; + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2]}"#; + let deserialized: Notification = serde_json::from_str(s).unwrap(); + + assert_eq!( + deserialized, + Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]) + } + ); + + let s = r#"{"jsonrpc": "2.0", "method": "foobar"}"#; + let deserialized: Notification = serde_json::from_str(s).unwrap(); + + assert_eq!( + deserialized, + Notification { jsonrpc: Version::V2, method: "foobar".to_owned(), params: Params::None } + ); + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1}"#; + let deserialized: Result = serde_json::from_str(s); + assert!(deserialized.is_err()); + } + + #[test] + fn call_deserialize() { + use serde_json; + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1]}"#; + let deserialized: Call = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Call::Notification(Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]) + }) + ); + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [1], "id": 1}"#; + let deserialized: Call = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]), + id: Id::Num(1) + }) + ); + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": [], "id": 1}"#; + let deserialized: Call = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![]), + id: Id::Num(1) + }) + ); + + let s = r#"{"jsonrpc": "2.0", "method": "update", "params": null, "id": 1}"#; + let deserialized: Call = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::None, + id: Id::Num(1) + }) + ); + + let s = r#"{"jsonrpc": "2.0", "method": "update", "id": 1}"#; + let deserialized: Call = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::None, + id: Id::Num(1) + }) + ); + } + + #[test] + fn request_deserialize_batch() { + use serde_json; + + let s = r#"[{}, {"jsonrpc": "2.0", "method": "update", "params": [1,2], "id": 1},{"jsonrpc": "2.0", "method": "update", "params": [1]}]"#; + let deserialized: Request = serde_json::from_str(s).unwrap(); + assert_eq!( + deserialized, + Request::Batch(vec![ + Call::Invalid { id: Id::Null }, + Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: Id::Num(1) + }), + Call::Notification(Notification { + jsonrpc: Version::V2, + method: "update".to_owned(), + params: Params::Array(vec![Value::from(1)]) + }) + ]) + ) + } + + #[test] + fn request_invalid_returns_id() { + use serde_json; + + let s = r#"{"id":120,"method":"my_method","params":["foo", "bar"],"extra_field":[]}"#; + let deserialized: Request = serde_json::from_str(s).unwrap(); + + match deserialized { + Request::Single(Call::Invalid { id: Id::Num(120) }) => {} + _ => panic!("Request wrongly deserialized: {:?}", deserialized), + } + } } diff --git a/src/common/response.rs b/src/common/response.rs index e2faa77b54..5c35e685f0 100644 --- a/src/common/response.rs +++ b/src/common/response.rs @@ -27,9 +27,9 @@ use super::{Error, Id, JsonValue, Version}; use alloc::{ - string::{String, ToString as _}, - vec, - vec::Vec, + string::{String, ToString as _}, + vec, + vec::Vec, }; use serde::{Deserialize, Serialize}; @@ -38,36 +38,36 @@ use serde::{Deserialize, Serialize}; #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Response { - /// Single response - Single(Output), - /// Response to batch request (batch of responses) - Batch(Vec), - /// Notification to an active subscription. - Notif(SubscriptionNotif), + /// Single response + Single(Output), + /// Response to batch request (batch of responses) + Batch(Vec), + /// Notification to an active subscription. + Notif(SubscriptionNotif), } /// Successful response #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Success { - /// Protocol version - pub jsonrpc: Version, - /// Result - pub result: JsonValue, - /// Correlation id - pub id: Id, + /// Protocol version + pub jsonrpc: Version, + /// Result + pub result: JsonValue, + /// Correlation id + pub id: Id, } /// Unsuccessful response #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct Failure { - /// Protocol version - pub jsonrpc: Version, - /// Error - pub error: Error, - /// Correlation id - pub id: Id, + /// Protocol version + pub jsonrpc: Version, + /// Error + pub error: Error, + /// Correlation id + pub id: Id, } /// Represents output - failure or success @@ -75,32 +75,32 @@ pub struct Failure { #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Output { - /// Success - Success(Success), - /// Failure - Failure(Failure), + /// Success + Success(Success), + /// Failure + Failure(Failure), } /// Server notification about something the client is subscribed to. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct SubscriptionNotif { - /// Protocol version - pub jsonrpc: Version, - /// A String containing the name of the method that was used for the subscription. - pub method: String, - /// Parameters of the notification. - pub params: SubscriptionNotifParams, + /// Protocol version + pub jsonrpc: Version, + /// A String containing the name of the method that was used for the subscription. + pub method: String, + /// Parameters of the notification. + pub params: SubscriptionNotifParams, } /// Field of a [`SubscriptionNotif`]. #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct SubscriptionNotifParams { - /// Subscription id, as communicated during the subscription. - pub subscription: SubscriptionId, - /// Actual data that the server wants to communicate to us. - pub result: JsonValue, + /// Subscription id, as communicated during the subscription. + pub subscription: SubscriptionId, + /// Actual data that the server wants to communicate to us. + pub result: JsonValue, } /// Id of a subscription, communicated by the server. @@ -108,212 +108,169 @@ pub struct SubscriptionNotifParams { #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum SubscriptionId { - /// Numeric id - Num(u64), - /// String id - Str(String), + /// Numeric id + Num(u64), + /// String id + Str(String), } impl Output { - /// Creates new output given `Result`, `Id` and `Version`. - pub fn from(result: Result, id: Id, jsonrpc: Version) -> Self { - match result { - Ok(result) => Output::Success(Success { - id, - jsonrpc, - result, - }), - Err(error) => Output::Failure(Failure { id, jsonrpc, error }), - } - } - - /// Get the jsonrpc protocol version. - pub fn version(&self) -> Version { - match *self { - Output::Success(ref s) => s.jsonrpc, - Output::Failure(ref f) => f.jsonrpc, - } - } - - /// Get the correlation id. - pub fn id(&self) -> &Id { - match *self { - Output::Success(ref s) => &s.id, - Output::Failure(ref f) => &f.id, - } - } + /// Creates new output given `Result`, `Id` and `Version`. + pub fn from(result: Result, id: Id, jsonrpc: Version) -> Self { + match result { + Ok(result) => Output::Success(Success { id, jsonrpc, result }), + Err(error) => Output::Failure(Failure { id, jsonrpc, error }), + } + } + + /// Get the jsonrpc protocol version. + pub fn version(&self) -> Version { + match *self { + Output::Success(ref s) => s.jsonrpc, + Output::Failure(ref f) => f.jsonrpc, + } + } + + /// Get the correlation id. + pub fn id(&self) -> &Id { + match *self { + Output::Success(ref s) => &s.id, + Output::Failure(ref f) => &f.id, + } + } } impl From for Result { - /// Convert into a result. Will be `Ok` if it is a `Success` and `Err` if `Failure`. - fn from(output: Output) -> Result { - match output { - Output::Success(s) => Ok(s.result), - Output::Failure(f) => Err(f.error), - } - } + /// Convert into a result. Will be `Ok` if it is a `Success` and `Err` if `Failure`. + fn from(output: Output) -> Result { + match output { + Output::Success(s) => Ok(s.result), + Output::Failure(f) => Err(f.error), + } + } } impl Response { - /// Creates new `Response` with given error and `Version` - pub fn from(error: impl Into, jsonrpc: Version) -> Self { - Failure { - id: Id::Null, - jsonrpc, - error: error.into(), - } - .into() - } - - /// Deserialize `Response` from given JSON string. - /// - /// This method will handle an empty string as empty batch response. - pub fn from_json(s: &str) -> Result { - if s.is_empty() { - Ok(Response::Batch(vec![])) - } else { - serde_json::from_str(s) - } - } + /// Creates new `Response` with given error and `Version` + pub fn from(error: impl Into, jsonrpc: Version) -> Self { + Failure { id: Id::Null, jsonrpc, error: error.into() }.into() + } + + /// Deserialize `Response` from given JSON string. + /// + /// This method will handle an empty string as empty batch response. + pub fn from_json(s: &str) -> Result { + if s.is_empty() { + Ok(Response::Batch(vec![])) + } else { + serde_json::from_str(s) + } + } } impl From for Response { - fn from(failure: Failure) -> Self { - Response::Single(Output::Failure(failure)) - } + fn from(failure: Failure) -> Self { + Response::Single(Output::Failure(failure)) + } } impl From for Response { - fn from(success: Success) -> Self { - Response::Single(Output::Success(success)) - } + fn from(success: Success) -> Self { + Response::Single(Output::Success(success)) + } } impl SubscriptionId { - /// Turns the subcription ID into a string. - pub fn into_string(self) -> String { - match self { - SubscriptionId::Num(n) => n.to_string(), - SubscriptionId::Str(s) => s, - } - } + /// Turns the subcription ID into a string. + pub fn into_string(self) -> String { + match self { + SubscriptionId::Num(n) => n.to_string(), + SubscriptionId::Str(s) => s, + } + } } #[test] fn success_output_serialize() { - use serde_json; - use serde_json::Value; + use serde_json; + use serde_json::Value; - let so = Output::Success(Success { - jsonrpc: Version::V2, - result: Value::from(1), - id: Id::Num(1), - }); + let so = Output::Success(Success { jsonrpc: Version::V2, result: Value::from(1), id: Id::Num(1) }); - let serialized = serde_json::to_string(&so).unwrap(); - assert_eq!(serialized, r#"{"jsonrpc":"2.0","result":1,"id":1}"#); + let serialized = serde_json::to_string(&so).unwrap(); + assert_eq!(serialized, r#"{"jsonrpc":"2.0","result":1,"id":1}"#); } #[test] fn success_output_deserialize() { - use serde_json; - use serde_json::Value; - - let dso = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; - - let deserialized: Output = serde_json::from_str(dso).unwrap(); - assert_eq!( - deserialized, - Output::Success(Success { - jsonrpc: Version::V2, - result: Value::from(1), - id: Id::Num(1) - }) - ); + use serde_json; + use serde_json::Value; + + let dso = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; + + let deserialized: Output = serde_json::from_str(dso).unwrap(); + assert_eq!(deserialized, Output::Success(Success { jsonrpc: Version::V2, result: Value::from(1), id: Id::Num(1) })); } #[test] fn failure_output_serialize() { - use serde_json; - - let fo = Output::Failure(Failure { - jsonrpc: Version::V2, - error: Error::parse_error(), - id: Id::Num(1), - }); - - let serialized = serde_json::to_string(&fo).unwrap(); - assert_eq!( - serialized, - r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"# - ); + use serde_json; + + let fo = Output::Failure(Failure { jsonrpc: Version::V2, error: Error::parse_error(), id: Id::Num(1) }); + + let serialized = serde_json::to_string(&fo).unwrap(); + assert_eq!(serialized, r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#); } #[test] fn failure_output_deserialize() { - use serde_json; - - let dfo = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#; - - let deserialized: Output = serde_json::from_str(dfo).unwrap(); - assert_eq!( - deserialized, - Output::Failure(Failure { - jsonrpc: Version::V2, - error: Error::parse_error(), - id: Id::Num(1) - }) - ); + use serde_json; + + let dfo = r#"{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}"#; + + let deserialized: Output = serde_json::from_str(dfo).unwrap(); + assert_eq!( + deserialized, + Output::Failure(Failure { jsonrpc: Version::V2, error: Error::parse_error(), id: Id::Num(1) }) + ); } #[test] fn single_response_deserialize() { - use serde_json; - use serde_json::Value; - - let dsr = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; - - let deserialized: Response = serde_json::from_str(dsr).unwrap(); - assert_eq!( - deserialized, - Response::Single(Output::Success(Success { - jsonrpc: Version::V2, - result: Value::from(1), - id: Id::Num(1) - })) - ); + use serde_json; + use serde_json::Value; + + let dsr = r#"{"jsonrpc":"2.0","result":1,"id":1}"#; + + let deserialized: Response = serde_json::from_str(dsr).unwrap(); + assert_eq!( + deserialized, + Response::Single(Output::Success(Success { jsonrpc: Version::V2, result: Value::from(1), id: Id::Num(1) })) + ); } #[test] fn batch_response_deserialize() { - use serde_json; - use serde_json::Value; - - let dbr = r#"[{"jsonrpc":"2.0","result":1,"id":1},{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}]"#; - - let deserialized: Response = serde_json::from_str(dbr).unwrap(); - assert_eq!( - deserialized, - Response::Batch(vec![ - Output::Success(Success { - jsonrpc: Version::V2, - result: Value::from(1), - id: Id::Num(1) - }), - Output::Failure(Failure { - jsonrpc: Version::V2, - error: Error::parse_error(), - id: Id::Num(1) - }) - ]) - ); + use serde_json; + use serde_json::Value; + + let dbr = r#"[{"jsonrpc":"2.0","result":1,"id":1},{"jsonrpc":"2.0","error":{"code":-32700,"message":"Parse error"},"id":1}]"#; + + let deserialized: Response = serde_json::from_str(dbr).unwrap(); + assert_eq!( + deserialized, + Response::Batch(vec![ + Output::Success(Success { jsonrpc: Version::V2, result: Value::from(1), id: Id::Num(1) }), + Output::Failure(Failure { jsonrpc: Version::V2, error: Error::parse_error(), id: Id::Num(1) }) + ]) + ); } #[test] fn handle_incorrect_responses() { - use serde_json; + use serde_json; - let dsr = r#" + let dsr = r#" { "id": 2, "jsonrpc": "2.0", @@ -325,24 +282,18 @@ fn handle_incorrect_responses() { } }"#; - let deserialized: Result = serde_json::from_str(dsr); - assert!( - deserialized.is_err(), - "Expected error when deserializing invalid payload." - ); + let deserialized: Result = serde_json::from_str(dsr); + assert!(deserialized.is_err(), "Expected error when deserializing invalid payload."); } #[test] fn should_parse_empty_response_as_batch() { - use serde_json; + use serde_json; - let dsr = r#""#; + let dsr = r#""#; - let deserialized1: Result = serde_json::from_str(dsr); - let deserialized2: Result = Response::from_json(dsr); - assert!( - deserialized1.is_err(), - "Empty string is not valid JSON, so we should get an error." - ); - assert_eq!(deserialized2.unwrap(), Response::Batch(vec![])); + let deserialized1: Result = serde_json::from_str(dsr); + let deserialized2: Result = Response::from_json(dsr); + assert!(deserialized1.is_err(), "Empty string is not valid JSON, so we should get an error."); + assert_eq!(deserialized2.unwrap(), Response::Batch(vec![])); } diff --git a/src/common/version.rs b/src/common/version.rs index e815c8f937..6a95e2f6d9 100644 --- a/src/common/version.rs +++ b/src/common/version.rs @@ -31,46 +31,46 @@ use serde::{Deserialize, Deserializer, Serialize, Serializer}; /// Protocol version. #[derive(Debug, PartialEq, Clone, Copy, Hash, Eq)] pub enum Version { - /// JSONRPC 2.0 - V2, + /// JSONRPC 2.0 + V2, } impl Serialize for Version { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match *self { - Version::V2 => serializer.serialize_str("2.0"), - } - } + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + match *self { + Version::V2 => serializer.serialize_str("2.0"), + } + } } impl<'a> Deserialize<'a> for Version { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'a>, - { - deserializer.deserialize_identifier(VersionVisitor) - } + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'a>, + { + deserializer.deserialize_identifier(VersionVisitor) + } } struct VersionVisitor; impl<'a> Visitor<'a> for VersionVisitor { - type Value = Version; + type Value = Version; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string") - } + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string") + } - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match value { - "2.0" => Ok(Version::V2), - _ => Err(de::Error::custom("invalid version")), - } - } + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + match value { + "2.0" => Ok(Version::V2), + _ => Err(de::Error::custom("invalid version")), + } + } } diff --git a/src/http/raw/batch.rs b/src/http/raw/batch.rs index 8c31beb92f..2ac5eb6a98 100644 --- a/src/http/raw/batch.rs +++ b/src/http/raw/batch.rs @@ -52,333 +52,303 @@ use smallvec::SmallVec; /// destroy the [`BatchState`]. /// pub struct BatchState { - /// List of elements to present to the user. - to_yield: SmallVec<[ToYield; 1]>, + /// List of elements to present to the user. + to_yield: SmallVec<[ToYield; 1]>, - /// List of requests to be answered. When a request is answered, we replace it with `None` so - /// that indices don't change. - requests: SmallVec<[Option; 1]>, + /// List of requests to be answered. When a request is answered, we replace it with `None` so + /// that indices don't change. + requests: SmallVec<[Option; 1]>, - /// List of pending responses. - responses: SmallVec<[common::Output; 1]>, + /// List of pending responses. + responses: SmallVec<[common::Output; 1]>, - /// True if the original request was a batch. We need to keep track of this because we need to - /// respond differently depending on whether we have a single request or a batch with one - /// request. - is_batch: bool, + /// True if the original request was a batch. We need to keep track of this because we need to + /// respond differently depending on whether we have a single request or a batch with one + /// request. + is_batch: bool, } /// Element remaining to be yielded to the user. #[derive(Debug)] enum ToYield { - Notification(common::Notification), - Request(common::MethodCall), + Notification(common::Notification), + Request(common::MethodCall), } /// Event generated by the [`next`](BatchState::next) function. #[derive(Debug)] pub enum BatchInc<'a> { - /// Request is a notification. - Notification(Notification), - /// Request is a method call. - Request(BatchElem<'a>), + /// Request is a notification. + Notification(Notification), + /// Request is a method call. + Request(BatchElem<'a>), } /// References to a request within the batch that must be answered. pub struct BatchElem<'a> { - /// Index within the `BatchState::requests` list. - index: usize, - /// Reference to the actual element. Must always be `Some` for the lifetime of this object. - /// We hold a `&mut Option` rather than a `&mut common::MethodCall` so - /// that we can put `None` in it. - elem: &'a mut Option, - /// Reference to the `BatchState::responses` list so that we can push a response. - responses: &'a mut SmallVec<[common::Output; 1]>, + /// Index within the `BatchState::requests` list. + index: usize, + /// Reference to the actual element. Must always be `Some` for the lifetime of this object. + /// We hold a `&mut Option` rather than a `&mut common::MethodCall` so + /// that we can put `None` in it. + elem: &'a mut Option, + /// Reference to the `BatchState::responses` list so that we can push a response. + responses: &'a mut SmallVec<[common::Output; 1]>, } impl BatchState { - /// Creates a `BatchState` that will manage the given request. - pub fn from_request(raw_request_body: common::Request) -> BatchState { - match raw_request_body { - common::Request::Single(rq) => BatchState::from_iter(iter::once(rq), false), - common::Request::Batch(requests) => BatchState::from_iter(requests.into_iter(), true), - } - } - - /// Internal implementation of [`from_request`](BatchState::from_request). Generic over the - /// iterator. - fn from_iter( - calls_list: impl ExactSizeIterator, - is_batch: bool, - ) -> BatchState { - debug_assert!(!(!is_batch && calls_list.len() >= 2)); - - let mut to_yield = SmallVec::with_capacity(calls_list.len()); - let mut responses = SmallVec::with_capacity(calls_list.len()); - let mut num_requests = 0; - - for call in calls_list { - match call { - common::Call::MethodCall(call) => { - to_yield.push(ToYield::Request(call)); - num_requests += 1; - } - common::Call::Notification(n) => { - to_yield.push(ToYield::Notification(n)); - } - common::Call::Invalid { id } => { - let err = Err(common::Error::invalid_request()); - let out = common::Output::from(err, id, common::Version::V2); - responses.push(out); - } - } - } - - BatchState { - to_yield, - requests: SmallVec::with_capacity(num_requests), - responses, - is_batch, - } - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id(&mut self, id: usize) -> Option { - if let Some(elem) = self.requests.get_mut(id) { - if elem.is_none() { - return None; - } - Some(BatchElem { - elem, - index: id, - responses: &mut self.responses, - }) - } else { - None - } - } - - /// Extracts the next request from the batch. Returns `None` if the batch is empty. - pub fn next(&mut self) -> Option { - if self.to_yield.is_empty() { - return None; - } - - match self.to_yield.remove(0) { - ToYield::Notification(n) => Some(BatchInc::Notification(From::from(n))), - ToYield::Request(n) => { - let request_id = self.requests.len(); - self.requests.push(Some(n)); - Some(BatchInc::Request(BatchElem { - index: request_id, - elem: &mut self.requests[request_id], - responses: &mut self.responses, - })) - } - } - } - - /// Returns true if this batch is ready to send out its response. - pub fn is_ready_to_respond(&self) -> bool { - self.to_yield.is_empty() && self.requests.iter().all(|r| r.is_none()) - } - - /// Turns this batch into a response to send out to the client. - /// - /// Returns `Ok(None)` if there is actually nothing to send to the client, such as when the - /// client has only sent notifications. - pub fn into_response(mut self) -> Result, Self> { - if !self.is_ready_to_respond() { - return Err(self); - } - - let raw_response = if self.is_batch { - let list: Vec<_> = self.responses.drain(..).collect(); - if list.is_empty() { - None - } else { - Some(common::Response::Batch(list)) - } - } else { - debug_assert!(self.responses.len() <= 1); - if self.responses.is_empty() { - None - } else { - Some(common::Response::Single(self.responses.remove(0))) - } - }; - - Ok(raw_response) - } + /// Creates a `BatchState` that will manage the given request. + pub fn from_request(raw_request_body: common::Request) -> BatchState { + match raw_request_body { + common::Request::Single(rq) => BatchState::from_iter(iter::once(rq), false), + common::Request::Batch(requests) => BatchState::from_iter(requests.into_iter(), true), + } + } + + /// Internal implementation of [`from_request`](BatchState::from_request). Generic over the + /// iterator. + fn from_iter(calls_list: impl ExactSizeIterator, is_batch: bool) -> BatchState { + debug_assert!(!(!is_batch && calls_list.len() >= 2)); + + let mut to_yield = SmallVec::with_capacity(calls_list.len()); + let mut responses = SmallVec::with_capacity(calls_list.len()); + let mut num_requests = 0; + + for call in calls_list { + match call { + common::Call::MethodCall(call) => { + to_yield.push(ToYield::Request(call)); + num_requests += 1; + } + common::Call::Notification(n) => { + to_yield.push(ToYield::Notification(n)); + } + common::Call::Invalid { id } => { + let err = Err(common::Error::invalid_request()); + let out = common::Output::from(err, id, common::Version::V2); + responses.push(out); + } + } + } + + BatchState { to_yield, requests: SmallVec::with_capacity(num_requests), responses, is_batch } + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id(&mut self, id: usize) -> Option { + if let Some(elem) = self.requests.get_mut(id) { + if elem.is_none() { + return None; + } + Some(BatchElem { elem, index: id, responses: &mut self.responses }) + } else { + None + } + } + + /// Extracts the next request from the batch. Returns `None` if the batch is empty. + pub fn next(&mut self) -> Option { + if self.to_yield.is_empty() { + return None; + } + + match self.to_yield.remove(0) { + ToYield::Notification(n) => Some(BatchInc::Notification(From::from(n))), + ToYield::Request(n) => { + let request_id = self.requests.len(); + self.requests.push(Some(n)); + Some(BatchInc::Request(BatchElem { + index: request_id, + elem: &mut self.requests[request_id], + responses: &mut self.responses, + })) + } + } + } + + /// Returns true if this batch is ready to send out its response. + pub fn is_ready_to_respond(&self) -> bool { + self.to_yield.is_empty() && self.requests.iter().all(|r| r.is_none()) + } + + /// Turns this batch into a response to send out to the client. + /// + /// Returns `Ok(None)` if there is actually nothing to send to the client, such as when the + /// client has only sent notifications. + pub fn into_response(mut self) -> Result, Self> { + if !self.is_ready_to_respond() { + return Err(self); + } + + let raw_response = if self.is_batch { + let list: Vec<_> = self.responses.drain(..).collect(); + if list.is_empty() { + None + } else { + Some(common::Response::Batch(list)) + } + } else { + debug_assert!(self.responses.len() <= 1); + if self.responses.is_empty() { + None + } else { + Some(common::Response::Single(self.responses.remove(0))) + } + }; + + Ok(raw_response) + } } impl fmt::Debug for BatchState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_list() - .entries(self.to_yield.iter()) - .entries(self.requests.iter().filter(|r| r.is_some())) - .entries(self.responses.iter()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_list() + .entries(self.to_yield.iter()) + .entries(self.requests.iter().filter(|r| r.is_some())) + .entries(self.responses.iter()) + .finish() + } } impl<'a> BatchElem<'a> { - /// Returns the id of the request within the [`BatchState`]. - /// - /// > **Note**: This is NOT the request id that the client passed. - pub fn id(&self) -> usize { - self.index - } - - /// Returns the id that the client sent out. - pub fn request_id(&self) -> &common::Id { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - &request.id - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - &request.method - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - Params::from(&request.params) - } - - /// Responds to the request. This destroys the request object, meaning you can no longer - /// retrieve it with [`request_by_id`](BatchState::request_by_id) later anymore. - pub fn set_response(self, response: Result) { - let request = self - .elem - .take() - .expect("elem must be Some for the lifetime of the object; qed"); - let response = common::Output::from(response, request.id, common::Version::V2); - self.responses.push(response); - } + /// Returns the id of the request within the [`BatchState`]. + /// + /// > **Note**: This is NOT the request id that the client passed. + pub fn id(&self) -> usize { + self.index + } + + /// Returns the id that the client sent out. + pub fn request_id(&self) -> &common::Id { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + &request.id + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + &request.method + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + Params::from(&request.params) + } + + /// Responds to the request. This destroys the request object, meaning you can no longer + /// retrieve it with [`request_by_id`](BatchState::request_by_id) later anymore. + pub fn set_response(self, response: Result) { + let request = self.elem.take().expect("elem must be Some for the lifetime of the object; qed"); + let response = common::Output::from(response, request.id, common::Version::V2); + self.responses.push(response); + } } impl<'a> fmt::Debug for BatchElem<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BatchElem") - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BatchElem").field("method", &self.method()).field("params", &self.params()).finish() + } } #[cfg(test)] mod tests { - use super::{BatchInc, BatchState}; - use crate::{common, http::HttpRawNotification}; - - #[test] - fn basic_notification() { - let notif = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let mut state = { - let rq = common::Request::Single(common::Call::Notification(notif.clone())); - BatchState::from_request(rq) - }; - - assert!(!state.is_ready_to_respond()); - match state.next() { - Some(BatchInc::Notification(ref n)) if n == &HttpRawNotification::from(notif) => {} - _ => panic!(), - } - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - match state.into_response() { - Ok(None) => {} - _ => panic!(), - } - } - - #[test] - fn basic_request() { - let call = common::MethodCall { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), - id: common::Id::Num(123), - }; - - let mut state = { - let rq = common::Request::Single(common::Call::MethodCall(call.clone())); - BatchState::from_request(rq) - }; - - assert!(!state.is_ready_to_respond()); - let rq_id = match state.next() { - Some(BatchInc::Request(rq)) => { - assert_eq!(rq.method(), "foo"); - assert_eq!( - { - let v: String = rq.params().get("test").unwrap(); - v - }, - "foo" - ); - assert_eq!(rq.request_id(), &common::Id::Num(123)); - rq.id() - } - _ => panic!(), - }; - - assert!(state.next().is_none()); - assert!(!state.is_ready_to_respond()); - assert!(state.next().is_none()); - - assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); - state - .request_by_id(rq_id) - .unwrap() - .set_response(Err(common::Error::method_not_found())); - - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - - match state.into_response() { - Ok(Some(common::Response::Single(common::Output::Failure(f)))) => { - assert_eq!(f.id, common::Id::Num(123)); - } - _ => panic!(), - } - } - - #[test] - fn empty_batch() { - let mut state = { - let rq = common::Request::Batch(Vec::new()); - BatchState::from_request(rq) - }; - - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - match state.into_response() { - Ok(None) => {} - _ => panic!(), - } - } + use super::{BatchInc, BatchState}; + use crate::{common, http::HttpRawNotification}; + + #[test] + fn basic_notification() { + let notif = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let mut state = { + let rq = common::Request::Single(common::Call::Notification(notif.clone())); + BatchState::from_request(rq) + }; + + assert!(!state.is_ready_to_respond()); + match state.next() { + Some(BatchInc::Notification(ref n)) if n == &HttpRawNotification::from(notif) => {} + _ => panic!(), + } + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + match state.into_response() { + Ok(None) => {} + _ => panic!(), + } + } + + #[test] + fn basic_request() { + let call = common::MethodCall { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), + id: common::Id::Num(123), + }; + + let mut state = { + let rq = common::Request::Single(common::Call::MethodCall(call.clone())); + BatchState::from_request(rq) + }; + + assert!(!state.is_ready_to_respond()); + let rq_id = match state.next() { + Some(BatchInc::Request(rq)) => { + assert_eq!(rq.method(), "foo"); + assert_eq!( + { + let v: String = rq.params().get("test").unwrap(); + v + }, + "foo" + ); + assert_eq!(rq.request_id(), &common::Id::Num(123)); + rq.id() + } + _ => panic!(), + }; + + assert!(state.next().is_none()); + assert!(!state.is_ready_to_respond()); + assert!(state.next().is_none()); + + assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); + state.request_by_id(rq_id).unwrap().set_response(Err(common::Error::method_not_found())); + + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + + match state.into_response() { + Ok(Some(common::Response::Single(common::Output::Failure(f)))) => { + assert_eq!(f.id, common::Id::Num(123)); + } + _ => panic!(), + } + } + + #[test] + fn empty_batch() { + let mut state = { + let rq = common::Request::Batch(Vec::new()); + BatchState::from_request(rq) + }; + + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + match state.into_response() { + Ok(None) => {} + _ => panic!(), + } + } } diff --git a/src/http/raw/batches.rs b/src/http/raw/batches.rs index 875b20b516..f6fac03153 100644 --- a/src/http/raw/batches.rs +++ b/src/http/raw/batches.rs @@ -48,398 +48,365 @@ use hashbrown::{hash_map::Entry, HashMap}; /// [`request_by_id`](BatchesState::request_by_id). /// pub struct BatchesState { - /// Identifier of the next batch to add to `batches`. - next_batch_id: u64, - - /// For each batch, the individual batch's state and the user parameter. - /// - /// The identifier is lineraly increasing and is never leaked on the wire or outside of this - /// module. Therefore there is no risk of hash collision. - batches: HashMap, + /// Identifier of the next batch to add to `batches`. + next_batch_id: u64, + + /// For each batch, the individual batch's state and the user parameter. + /// + /// The identifier is lineraly increasing and is never leaked on the wire or outside of this + /// module. Therefore there is no risk of hash collision. + batches: HashMap, } /// Event generated by [`next_event`](BatchesState::next_event). #[derive(Debug)] pub enum BatchesEvent<'a, T> { - /// A notification has been extracted from a batch. - Notification { - /// Notification in question. - notification: Notification, - /// User parameter passed when calling [`inject`](BatchesState::inject). - user_param: &'a mut T, - }, - - /// A request has been extracted from a batch. - Request(BatchesElem<'a, T>), - - /// A batch has gotten all its requests answered and a response is ready to be sent out. - ReadyToSend { - /// Response to send out to the JSON-RPC client. - response: common::Response, - /// User parameter passed when calling [`inject`](BatchesState::inject). - user_param: T, - }, + /// A notification has been extracted from a batch. + Notification { + /// Notification in question. + notification: Notification, + /// User parameter passed when calling [`inject`](BatchesState::inject). + user_param: &'a mut T, + }, + + /// A request has been extracted from a batch. + Request(BatchesElem<'a, T>), + + /// A batch has gotten all its requests answered and a response is ready to be sent out. + ReadyToSend { + /// Response to send out to the JSON-RPC client. + response: common::Response, + /// User parameter passed when calling [`inject`](BatchesState::inject). + user_param: T, + }, } /// Request within the batches. pub struct BatchesElem<'a, T> { - /// Id of the batch that contains this element. - batch_id: u64, - /// Inner reference to a request within a batch. - inner: batch::BatchElem<'a>, - /// User parameter passed when calling `inject`. - user_param: &'a mut T, + /// Id of the batch that contains this element. + batch_id: u64, + /// Inner reference to a request within a batch. + inner: batch::BatchElem<'a>, + /// User parameter passed when calling `inject`. + user_param: &'a mut T, } /// Identifier of a request within a [`BatchesState`]. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct BatchesElemId { - /// Id of the batch within `BatchesState::batches`. - outer: u64, - /// Id of the request within the batch. - inner: usize, + /// Id of the batch within `BatchesState::batches`. + outer: u64, + /// Id of the request within the batch. + inner: usize, } /// Minimal capacity for the `batches` container. const BATCHES_MIN_CAPACITY: usize = 256; impl BatchesState { - /// Creates a new empty `BatchesState`. - pub fn new() -> BatchesState { - BatchesState { - next_batch_id: 0, - batches: HashMap::with_capacity_and_hasher(BATCHES_MIN_CAPACITY, Default::default()), - } - } - - /// Processes one step from a batch and returns an event. Returns `None` if there is nothing - /// to do. After you call `inject`, then this will return `Some` at least once. - pub fn next_event(&mut self) -> Option> { - // Note that this function has a complexity of `O(n)`, as we iterate over every single - // batch every single time. This is however the most straight-forward way to implement it, - // and while better strategies might yield better complexities, it might not actually yield - // better performances in real-world situations. More brainstorming and benchmarking could - // get helpful here. - - // Because of long-standing Rust lifetime issues - // (https://github.com/rust-lang/rust/issues/51526), we can't do this in an elegant way. - // If you're reading this code, know that it took several iterations and that I hated my - // life while trying to figure out how to make the compiler happy. - for batch_id in self.batches.keys().cloned().collect::>() { - enum WhatCanWeDo { - Nothing, - ReadyToRespond, - Notification(Notification), - Request(usize), - } - - let what_can_we_do = { - let (batch, _) = self - .batches - .get_mut(&batch_id) - .expect("all keys are valid; qed"); - let is_ready_to_respond = batch.is_ready_to_respond(); - match batch.next() { - None if is_ready_to_respond => WhatCanWeDo::ReadyToRespond, - None => WhatCanWeDo::Nothing, - Some(batch::BatchInc::Notification(n)) => WhatCanWeDo::Notification(n), - Some(batch::BatchInc::Request(inner)) => WhatCanWeDo::Request(inner.id()), - } - }; - - match what_can_we_do { - WhatCanWeDo::Nothing => {} - WhatCanWeDo::ReadyToRespond => { - let (batch, user_param) = self - .batches - .remove(&batch_id) - .expect("key was grabbed from self.batches; qed"); - let response = batch - .into_response() - .unwrap_or_else(|_| panic!("is_ready_to_respond returned true; qed")); - if let Some(response) = response { - return Some(BatchesEvent::ReadyToSend { - response, - user_param, - }); - } - } - WhatCanWeDo::Notification(notification) => { - return Some(BatchesEvent::Notification { - notification, - user_param: &mut self.batches.get_mut(&batch_id).unwrap().1, - }); - } - WhatCanWeDo::Request(id) => { - let (batch, user_param) = self.batches.get_mut(&batch_id).unwrap(); - return Some(BatchesEvent::Request(BatchesElem { - batch_id, - inner: batch.request_by_id(id).unwrap(), - user_param, - })); - } - } - } - - None - } - - /// Injects a newly-received batch into the list. You must then call - /// [`next_event`](BatchesState::next_event) in order to process it. - pub fn inject(&mut self, request: common::Request, user_param: T) { - let batch = batch::BatchState::from_request(request); - - loop { - let id = self.next_batch_id; - self.next_batch_id = self.next_batch_id.wrapping_add(1); - - // We shrink `self.batches` from time to time so that it doesn't grow too much. - if id % 256 == 0 { - self.batches.shrink_to_fit(); - // TODO: self.batches.shrink_to(BATCHES_MIN_CAPACITY); - // ^ see https://github.com/rust-lang/rust/issues/56431 - } - - match self.batches.entry(id) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => { - e.insert((batch, user_param)); - break; - } - } - } - } - - /// Returns a list of all user data associated to active batches. - pub fn batches<'a>(&'a mut self) -> impl Iterator + 'a { - self.batches.values_mut().map(|(_, user_data)| user_data) - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id(&mut self, id: BatchesElemId) -> Option> { - if let Some((batch, user_param)) = self.batches.get_mut(&id.outer) { - Some(BatchesElem { - batch_id: id.outer, - inner: batch.request_by_id(id.inner)?, - user_param, - }) - } else { - None - } - } + /// Creates a new empty `BatchesState`. + pub fn new() -> BatchesState { + BatchesState { + next_batch_id: 0, + batches: HashMap::with_capacity_and_hasher(BATCHES_MIN_CAPACITY, Default::default()), + } + } + + /// Processes one step from a batch and returns an event. Returns `None` if there is nothing + /// to do. After you call `inject`, then this will return `Some` at least once. + pub fn next_event(&mut self) -> Option> { + // Note that this function has a complexity of `O(n)`, as we iterate over every single + // batch every single time. This is however the most straight-forward way to implement it, + // and while better strategies might yield better complexities, it might not actually yield + // better performances in real-world situations. More brainstorming and benchmarking could + // get helpful here. + + // Because of long-standing Rust lifetime issues + // (https://github.com/rust-lang/rust/issues/51526), we can't do this in an elegant way. + // If you're reading this code, know that it took several iterations and that I hated my + // life while trying to figure out how to make the compiler happy. + for batch_id in self.batches.keys().cloned().collect::>() { + enum WhatCanWeDo { + Nothing, + ReadyToRespond, + Notification(Notification), + Request(usize), + } + + let what_can_we_do = { + let (batch, _) = self.batches.get_mut(&batch_id).expect("all keys are valid; qed"); + let is_ready_to_respond = batch.is_ready_to_respond(); + match batch.next() { + None if is_ready_to_respond => WhatCanWeDo::ReadyToRespond, + None => WhatCanWeDo::Nothing, + Some(batch::BatchInc::Notification(n)) => WhatCanWeDo::Notification(n), + Some(batch::BatchInc::Request(inner)) => WhatCanWeDo::Request(inner.id()), + } + }; + + match what_can_we_do { + WhatCanWeDo::Nothing => {} + WhatCanWeDo::ReadyToRespond => { + let (batch, user_param) = + self.batches.remove(&batch_id).expect("key was grabbed from self.batches; qed"); + let response = + batch.into_response().unwrap_or_else(|_| panic!("is_ready_to_respond returned true; qed")); + if let Some(response) = response { + return Some(BatchesEvent::ReadyToSend { response, user_param }); + } + } + WhatCanWeDo::Notification(notification) => { + return Some(BatchesEvent::Notification { + notification, + user_param: &mut self.batches.get_mut(&batch_id).unwrap().1, + }); + } + WhatCanWeDo::Request(id) => { + let (batch, user_param) = self.batches.get_mut(&batch_id).unwrap(); + return Some(BatchesEvent::Request(BatchesElem { + batch_id, + inner: batch.request_by_id(id).unwrap(), + user_param, + })); + } + } + } + + None + } + + /// Injects a newly-received batch into the list. You must then call + /// [`next_event`](BatchesState::next_event) in order to process it. + pub fn inject(&mut self, request: common::Request, user_param: T) { + let batch = batch::BatchState::from_request(request); + + loop { + let id = self.next_batch_id; + self.next_batch_id = self.next_batch_id.wrapping_add(1); + + // We shrink `self.batches` from time to time so that it doesn't grow too much. + if id % 256 == 0 { + self.batches.shrink_to_fit(); + // TODO: self.batches.shrink_to(BATCHES_MIN_CAPACITY); + // ^ see https://github.com/rust-lang/rust/issues/56431 + } + + match self.batches.entry(id) { + Entry::Occupied(_) => continue, + Entry::Vacant(e) => { + e.insert((batch, user_param)); + break; + } + } + } + } + + /// Returns a list of all user data associated to active batches. + pub fn batches<'a>(&'a mut self) -> impl Iterator + 'a { + self.batches.values_mut().map(|(_, user_data)| user_data) + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id(&mut self, id: BatchesElemId) -> Option> { + if let Some((batch, user_param)) = self.batches.get_mut(&id.outer) { + Some(BatchesElem { batch_id: id.outer, inner: batch.request_by_id(id.inner)?, user_param }) + } else { + None + } + } } impl Default for BatchesState { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl fmt::Debug for BatchesState where - T: fmt::Debug, + T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_list().entries(self.batches.values()).finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_list().entries(self.batches.values()).finish() + } } impl<'a, T> BatchesElem<'a, T> { - /// Returns the id of the request within the [`BatchesState`]. - /// - /// > **Note**: This is NOT the request id that the client passed. - pub fn id(&self) -> BatchesElemId { - BatchesElemId { - outer: self.batch_id, - inner: self.inner.id(), - } - } - - /// Returns the user parameter passed when calling [`inject`](BatchesState::inject). - pub fn user_param(&mut self) -> &mut T { - &mut self.user_param - } - - /// Returns the id that the client sent out. - pub fn request_id(&self) -> &common::Id { - self.inner.request_id() - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - self.inner.method() - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - self.inner.params() - } - - /// Responds to the request. This destroys the request object, meaning you can no longer - /// retrieve it with [`request_by_id`](BatchesState::request_by_id) later anymore. - /// - /// A [`ReadyToSend`](BatchesEvent::ReadyToSend) event containing this response might be - /// generated the next time you call [`next_event`](BatchesState::next_event). - pub fn set_response(self, response: Result) { - self.inner.set_response(response) - } + /// Returns the id of the request within the [`BatchesState`]. + /// + /// > **Note**: This is NOT the request id that the client passed. + pub fn id(&self) -> BatchesElemId { + BatchesElemId { outer: self.batch_id, inner: self.inner.id() } + } + + /// Returns the user parameter passed when calling [`inject`](BatchesState::inject). + pub fn user_param(&mut self) -> &mut T { + &mut self.user_param + } + + /// Returns the id that the client sent out. + pub fn request_id(&self) -> &common::Id { + self.inner.request_id() + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + self.inner.method() + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + self.inner.params() + } + + /// Responds to the request. This destroys the request object, meaning you can no longer + /// retrieve it with [`request_by_id`](BatchesState::request_by_id) later anymore. + /// + /// A [`ReadyToSend`](BatchesEvent::ReadyToSend) event containing this response might be + /// generated the next time you call [`next_event`](BatchesState::next_event). + pub fn set_response(self, response: Result) { + self.inner.set_response(response) + } } impl<'a, T> fmt::Debug for BatchesElem<'a, T> where - T: fmt::Debug, + T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("BatchesElem") - .field("id", &self.id()) - .field("user_param", &self.user_param) - .field("request_id", &self.request_id()) - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BatchesElem") + .field("id", &self.id()) + .field("user_param", &self.user_param) + .field("request_id", &self.request_id()) + .field("method", &self.method()) + .field("params", &self.params()) + .finish() + } } #[cfg(test)] mod tests { - use super::{BatchesEvent, BatchesState}; - use crate::{common, http::HttpRawNotification}; - - #[test] - fn basic_notification() { - let notif = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Single(common::Call::Notification(notif.clone())), - (), - ); - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, .. - }) if *notification == HttpRawNotification::from(notif) => {} - _ => panic!(), - } - assert!(state.next_event().is_none()); - } - - #[test] - fn basic_request() { - let call = common::MethodCall { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), - id: common::Id::Num(123), - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Single(common::Call::MethodCall(call)), - 8889, - ); - - let rq_id = match state.next_event() { - Some(BatchesEvent::Request(rq)) => { - assert_eq!(rq.method(), "foo"); - assert_eq!( - { - let v: String = rq.params().get("test").unwrap(); - v - }, - "foo" - ); - assert_eq!(rq.request_id(), &common::Id::Num(123)); - rq.id() - } - _ => panic!(), - }; - - assert!(state.next_event().is_none()); - - assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); - state - .request_by_id(rq_id) - .unwrap() - .set_response(Err(common::Error::method_not_found())); - assert!(state.request_by_id(rq_id).is_none()); - - match state.next_event() { - Some(BatchesEvent::ReadyToSend { - response, - user_param, - }) => { - assert_eq!(user_param, 8889); - match response { - common::Response::Single(common::Output::Failure(f)) => { - assert_eq!(f.id, common::Id::Num(123)); - } - _ => panic!(), - } - } - _ => panic!(), - }; - } - - #[test] - fn empty_batch() { - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject(common::Request::Batch(Vec::new()), ()); - assert!(state.next_event().is_none()); - } - - #[test] - fn batch_of_notifs() { - let notif1 = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let notif2 = common::Notification { - jsonrpc: common::Version::V2, - method: "bar".to_string(), - params: common::Params::None, - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Batch(vec![ - common::Call::Notification(notif1.clone()), - common::Call::Notification(notif2.clone()), - ]), - 2, - ); - - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, - ref user_param, - }) if *notification == HttpRawNotification::from(notif1) && **user_param == 2 => {} - _ => panic!(), - } - - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, - ref user_param, - }) if *notification == HttpRawNotification::from(notif2) && **user_param == 2 => {} - _ => panic!(), - } - - assert!(state.next_event().is_none()); - } + use super::{BatchesEvent, BatchesState}; + use crate::{common, http::HttpRawNotification}; + + #[test] + fn basic_notification() { + let notif = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Single(common::Call::Notification(notif.clone())), ()); + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, .. }) + if *notification == HttpRawNotification::from(notif) => {} + _ => panic!(), + } + assert!(state.next_event().is_none()); + } + + #[test] + fn basic_request() { + let call = common::MethodCall { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), + id: common::Id::Num(123), + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Single(common::Call::MethodCall(call)), 8889); + + let rq_id = match state.next_event() { + Some(BatchesEvent::Request(rq)) => { + assert_eq!(rq.method(), "foo"); + assert_eq!( + { + let v: String = rq.params().get("test").unwrap(); + v + }, + "foo" + ); + assert_eq!(rq.request_id(), &common::Id::Num(123)); + rq.id() + } + _ => panic!(), + }; + + assert!(state.next_event().is_none()); + + assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); + state.request_by_id(rq_id).unwrap().set_response(Err(common::Error::method_not_found())); + assert!(state.request_by_id(rq_id).is_none()); + + match state.next_event() { + Some(BatchesEvent::ReadyToSend { response, user_param }) => { + assert_eq!(user_param, 8889); + match response { + common::Response::Single(common::Output::Failure(f)) => { + assert_eq!(f.id, common::Id::Num(123)); + } + _ => panic!(), + } + } + _ => panic!(), + }; + } + + #[test] + fn empty_batch() { + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Batch(Vec::new()), ()); + assert!(state.next_event().is_none()); + } + + #[test] + fn batch_of_notifs() { + let notif1 = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let notif2 = common::Notification { + jsonrpc: common::Version::V2, + method: "bar".to_string(), + params: common::Params::None, + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject( + common::Request::Batch(vec![ + common::Call::Notification(notif1.clone()), + common::Call::Notification(notif2.clone()), + ]), + 2, + ); + + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, ref user_param }) + if *notification == HttpRawNotification::from(notif1) && **user_param == 2 => {} + _ => panic!(), + } + + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, ref user_param }) + if *notification == HttpRawNotification::from(notif2) && **user_param == 2 => {} + _ => panic!(), + } + + assert!(state.next_event().is_none()); + } } diff --git a/src/http/raw/core.rs b/src/http/raw/core.rs index 8f129d7789..bd8ba4a010 100644 --- a/src/http/raw/core.rs +++ b/src/http/raw/core.rs @@ -38,35 +38,35 @@ pub type RequestId = u64; /// /// See the module-level documentation for more information. pub struct RawServer { - /// Internal "raw" server. - raw: HttpTransportServer, - - /// List of requests that are in the progress of being answered. Each batch is associated with - /// the raw request ID, or with `None` if this raw request has been closed. - /// - /// See the documentation of [`BatchesState`][batches::BatchesState] for more information. - batches: batches::BatchesState>, - - /// List of active subscriptions. - /// The identifier is chosen randomly and uniformy distributed. It is never decided by the - /// client. There is therefore no risk of hash collision attack. - subscriptions: HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, - - /// For each raw request ID (i.e. client connection), the number of active subscriptions - /// that are using it. - /// - /// If this reaches 0, we can tell the raw server to close the request. - /// - /// Because we don't have any information about `I`, we have to use a collision-resistant - /// hashing algorithm. This incurs a performance cost that is theoretically avoidable (if `I` - /// is always local), but that should be negligible in practice. - num_subscriptions: HashMap, + /// Internal "raw" server. + raw: HttpTransportServer, + + /// List of requests that are in the progress of being answered. Each batch is associated with + /// the raw request ID, or with `None` if this raw request has been closed. + /// + /// See the documentation of [`BatchesState`][batches::BatchesState] for more information. + batches: batches::BatchesState>, + + /// List of active subscriptions. + /// The identifier is chosen randomly and uniformy distributed. It is never decided by the + /// client. There is therefore no risk of hash collision attack. + subscriptions: HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, + + /// For each raw request ID (i.e. client connection), the number of active subscriptions + /// that are using it. + /// + /// If this reaches 0, we can tell the raw server to close the request. + /// + /// Because we don't have any information about `I`, we have to use a collision-resistant + /// hashing algorithm. This incurs a performance cost that is theoretically avoidable (if `I` + /// is always local), but that should be negligible in practice. + num_subscriptions: HashMap, } /// Identifier of a request within a `RawServer`. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct RawServerRequestId { - inner: batches::BatchesElemId, + inner: batches::BatchesElemId, } /// Identifier of a subscription within a [`RawServer`](crate::server::RawServer). @@ -79,32 +79,32 @@ pub struct RawServerSubscriptionId([u8; 32]); /// > be dropped. #[derive(Debug)] pub enum RawServerEvent<'a> { - /// Request is a notification. - Notification(Notification), + /// Request is a notification. + Notification(Notification), - /// Request is a method call. - Request(RawServerRequest<'a>), + /// Request is a method call. + Request(RawServerRequest<'a>), - /// Subscriptions are now ready. - SubscriptionsReady(SubscriptionsReadyIter), + /// Subscriptions are now ready. + SubscriptionsReady(SubscriptionsReadyIter), - /// Subscriptions have been closed because the client closed the connection. - SubscriptionsClosed(SubscriptionsClosedIter), + /// Subscriptions have been closed because the client closed the connection. + SubscriptionsClosed(SubscriptionsClosedIter), } /// Request received by a [`RawServer`](crate::raw::RawServer). pub struct RawServerRequest<'a> { - /// Reference to the request within `self.batches`. - inner: batches::BatchesElem<'a, Option>, + /// Reference to the request within `self.batches`. + inner: batches::BatchesElem<'a, Option>, - /// Reference to the corresponding field in `RawServer`. - raw: &'a mut HttpTransportServer, + /// Reference to the corresponding field in `RawServer`. + raw: &'a mut HttpTransportServer, - /// Pending subscriptions. - subscriptions: &'a mut HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, + /// Pending subscriptions. + subscriptions: &'a mut HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, - /// Reference to the corresponding field in `RawServer`. - num_subscriptions: &'a mut HashMap, + /// Reference to the corresponding field in `RawServer`. + num_subscriptions: &'a mut HashMap, } /// Active subscription of a client towards a server. @@ -112,17 +112,17 @@ pub struct RawServerRequest<'a> { /// > **Note**: Holds a borrow of the `RawServer`. Therefore, must be dropped before the `RawServer` can /// > be dropped. pub struct ServerSubscription<'a> { - server: &'a mut RawServer, - id: [u8; 32], + server: &'a mut RawServer, + id: [u8; 32], } /// Error that can happen when calling `into_subscription`. #[derive(Debug)] pub enum IntoSubscriptionErr { - /// Underlying server doesn't support subscriptions. - NotSupported, - /// Request has already been closed by the client. - Closed, + /// Underlying server doesn't support subscriptions. + NotSupported, + /// Request has already been closed by the client. + Closed, } /// Iterator for the list of subscriptions that are now ready. @@ -136,410 +136,372 @@ pub struct SubscriptionsClosedIter(vec::IntoIter); /// Internal structure. Information about a subscription. #[derive(Debug)] struct SubscriptionState { - /// Identifier of the connection in the raw server. - raw_id: I, - /// Method that triggered the subscription. Must be sent to the client at each notification. - method: String, - /// If true, the subscription shouldn't accept any notification push because the confirmation - /// hasn't been sent to the client yet. Once this has switched to `false`, it can never be - /// switched to `true` ever again. - pending: bool, + /// Identifier of the connection in the raw server. + raw_id: I, + /// Method that triggered the subscription. Must be sent to the client at each notification. + method: String, + /// If true, the subscription shouldn't accept any notification push because the confirmation + /// hasn't been sent to the client yet. Once this has switched to `false`, it can never be + /// switched to `true` ever again. + pending: bool, } impl RawServer { - /// Starts a [`RawServer`](crate::raw::RawServer) using the given raw server internally. - pub fn new(raw: HttpTransportServer) -> RawServer { - RawServer { - raw, - batches: batches::BatchesState::new(), - subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), - num_subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), - } - } + /// Starts a [`RawServer`](crate::raw::RawServer) using the given raw server internally. + pub fn new(raw: HttpTransportServer) -> RawServer { + RawServer { + raw, + batches: batches::BatchesState::new(), + subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), + num_subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), + } + } } impl RawServer { - /// Returns a `Future` resolving to the next event that this server generates. - pub async fn next_event<'a>(&'a mut self) -> RawServerEvent<'a> { - let request_id = loop { - match self.batches.next_event() { - None => {} - Some(batches::BatchesEvent::Notification { notification, .. }) => { - return RawServerEvent::Notification(notification) - } - Some(batches::BatchesEvent::Request(inner)) => { - break RawServerRequestId { inner: inner.id() }; - } - Some(batches::BatchesEvent::ReadyToSend { - response, - user_param: Some(raw_request_id), - }) => { - // If we have any active subscription, we only use `send` to not close the - // client request. - if self.num_subscriptions.contains_key(&raw_request_id) { - debug_assert!(self.raw.supports_resuming(&raw_request_id).unwrap_or(false)); - let _ = self.raw.send(&raw_request_id, &response).await; - // TODO: that's O(n) - let mut ready = Vec::new(); // TODO: with_capacity - for (sub_id, sub) in self.subscriptions.iter_mut() { - if sub.raw_id == raw_request_id { - ready.push(RawServerSubscriptionId(sub_id.clone())); - sub.pending = false; - } - } - debug_assert!(!ready.is_empty()); // TODO: assert that capacity == len - return RawServerEvent::SubscriptionsReady(SubscriptionsReadyIter( - ready.into_iter(), - )); - } else { - let _ = self.raw.finish(&raw_request_id, Some(&response)).await; - } - continue; - } - Some(batches::BatchesEvent::ReadyToSend { - response: _, - user_param: None, - }) => { - // This situation happens if the connection has been closed by the client. - // Clients who close their connection. - continue; - } - }; - - match self.raw.next_request().await { - TransportServerEvent::Request { id, request } => { - self.batches.inject(request, Some(id)) - } - TransportServerEvent::Closed(raw_id) => { - // The client has a closed their connection. We eliminate all traces of the - // raw request ID from our state. - // TODO: this has an O(n) complexity; make sure that this is not attackable - for ud in self.batches.batches() { - if ud.as_ref() == Some(&raw_id) { - *ud = None; - } - } - - // Additionally, active subscriptions that were using this connection are - // closed. - if let Some(_) = self.num_subscriptions.remove(&raw_id) { - let ids = self - .subscriptions - .iter() - .filter(|(_, v)| v.raw_id == raw_id) - .map(|(k, _)| RawServerSubscriptionId(*k)) - .collect::>(); - for id in &ids { - let _ = self.subscriptions.remove(&id.0); - } - return RawServerEvent::SubscriptionsClosed(SubscriptionsClosedIter( - ids.into_iter(), - )); - } - } - }; - }; - - RawServerEvent::Request(self.request_by_id(&request_id).unwrap()) - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id<'a>( - &'a mut self, - id: &RawServerRequestId, - ) -> Option> { - Some(RawServerRequest { - inner: self.batches.request_by_id(id.inner)?, - raw: &mut self.raw, - subscriptions: &mut self.subscriptions, - num_subscriptions: &mut self.num_subscriptions, - }) - } - - /// Returns a subscription previously returned by - /// [`into_subscription`](crate::raw::server::RawServerRequest::into_subscription). - pub fn subscription_by_id( - &mut self, - id: RawServerSubscriptionId, - ) -> Option { - if self.subscriptions.contains_key(&id.0) { - Some(ServerSubscription { - server: self, - id: id.0, - }) - } else { - None - } - } + /// Returns a `Future` resolving to the next event that this server generates. + pub async fn next_event<'a>(&'a mut self) -> RawServerEvent<'a> { + let request_id = loop { + match self.batches.next_event() { + None => {} + Some(batches::BatchesEvent::Notification { notification, .. }) => { + return RawServerEvent::Notification(notification) + } + Some(batches::BatchesEvent::Request(inner)) => { + break RawServerRequestId { inner: inner.id() }; + } + Some(batches::BatchesEvent::ReadyToSend { response, user_param: Some(raw_request_id) }) => { + // If we have any active subscription, we only use `send` to not close the + // client request. + if self.num_subscriptions.contains_key(&raw_request_id) { + debug_assert!(self.raw.supports_resuming(&raw_request_id).unwrap_or(false)); + let _ = self.raw.send(&raw_request_id, &response).await; + // TODO: that's O(n) + let mut ready = Vec::new(); // TODO: with_capacity + for (sub_id, sub) in self.subscriptions.iter_mut() { + if sub.raw_id == raw_request_id { + ready.push(RawServerSubscriptionId(sub_id.clone())); + sub.pending = false; + } + } + debug_assert!(!ready.is_empty()); // TODO: assert that capacity == len + return RawServerEvent::SubscriptionsReady(SubscriptionsReadyIter(ready.into_iter())); + } else { + let _ = self.raw.finish(&raw_request_id, Some(&response)).await; + } + continue; + } + Some(batches::BatchesEvent::ReadyToSend { response: _, user_param: None }) => { + // This situation happens if the connection has been closed by the client. + // Clients who close their connection. + continue; + } + }; + + match self.raw.next_request().await { + TransportServerEvent::Request { id, request } => self.batches.inject(request, Some(id)), + TransportServerEvent::Closed(raw_id) => { + // The client has a closed their connection. We eliminate all traces of the + // raw request ID from our state. + // TODO: this has an O(n) complexity; make sure that this is not attackable + for ud in self.batches.batches() { + if ud.as_ref() == Some(&raw_id) { + *ud = None; + } + } + + // Additionally, active subscriptions that were using this connection are + // closed. + if let Some(_) = self.num_subscriptions.remove(&raw_id) { + let ids = self + .subscriptions + .iter() + .filter(|(_, v)| v.raw_id == raw_id) + .map(|(k, _)| RawServerSubscriptionId(*k)) + .collect::>(); + for id in &ids { + let _ = self.subscriptions.remove(&id.0); + } + return RawServerEvent::SubscriptionsClosed(SubscriptionsClosedIter(ids.into_iter())); + } + } + }; + }; + + RawServerEvent::Request(self.request_by_id(&request_id).unwrap()) + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id<'a>(&'a mut self, id: &RawServerRequestId) -> Option> { + Some(RawServerRequest { + inner: self.batches.request_by_id(id.inner)?, + raw: &mut self.raw, + subscriptions: &mut self.subscriptions, + num_subscriptions: &mut self.num_subscriptions, + }) + } + + /// Returns a subscription previously returned by + /// [`into_subscription`](crate::raw::server::RawServerRequest::into_subscription). + pub fn subscription_by_id(&mut self, id: RawServerSubscriptionId) -> Option { + if self.subscriptions.contains_key(&id.0) { + Some(ServerSubscription { server: self, id: id.0 }) + } else { + None + } + } } impl From for RawServer { - fn from(inner: HttpTransportServer) -> Self { - RawServer::new(inner) - } + fn from(inner: HttpTransportServer) -> Self { + RawServer::new(inner) + } } impl<'a> RawServerRequest<'a> { - /// Returns the id of the request. - /// - /// If this request object is dropped, you can retreive it again later by calling - /// [`request_by_id`](crate::raw::RawServer::request_by_id). - pub fn id(&self) -> RawServerRequestId { - RawServerRequestId { - inner: self.inner.id(), - } - } - - /// Returns the id that the client sent out. - // TODO: can return None, which is wrong - pub fn request_id(&self) -> &common::Id { - self.inner.request_id() - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - self.inner.method() - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - self.inner.params() - } + /// Returns the id of the request. + /// + /// If this request object is dropped, you can retreive it again later by calling + /// [`request_by_id`](crate::raw::RawServer::request_by_id). + pub fn id(&self) -> RawServerRequestId { + RawServerRequestId { inner: self.inner.id() } + } + + /// Returns the id that the client sent out. + // TODO: can return None, which is wrong + pub fn request_id(&self) -> &common::Id { + self.inner.request_id() + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + self.inner.method() + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + self.inner.params() + } } impl<'a> RawServerRequest<'a> { - /// Send back a response. - /// - /// If this request is part of a batch: - /// - /// - If all requests of the batch have been responded to, then the response is actively - /// sent out. - /// - Otherwise, this response is buffered. - /// - /// > **Note**: This method is implemented in a way that doesn't wait for long to send the - /// > response. While calling this method will block your entire server, it - /// > should only block it for a short amount of time. See also [the equivalent - /// > method](crate::transport::TransportServer::finish) on the - /// > [`TransportServer`](crate::transport::TransportServer) trait. - /// - pub fn respond(self, response: Result) { - self.inner.set_response(response); - //unimplemented!(); - // TODO: actually send out response? - } - - /// Sends back a response similar to `respond`, then returns a [`RawServerSubscriptionId`] object - /// that allows you to push more data on the corresponding connection. - /// - /// The [`RawServerSubscriptionId`] corresponds to the identifier that has been sent back to the - /// client. If the client refers to this subscription id, you can turn it into a - /// [`RawServerSubscriptionId`] using - /// [`from_wire_message`](RawServerSubscriptionId::from_wire_message). - /// - /// After the request has been turned into a subscription, the subscription might be in - /// "pending mode". Pushing notifications on that subscription will return an error. This - /// mechanism is necessary because the subscription request might be part of a batch, and all - /// the requests of that batch have to be processed before informing the client of the start - /// of the subscription. - /// - /// Returns an error and doesn't do anything if the underlying server doesn't support - /// subscriptions, or if the connection has already been closed by the client. - /// - /// > **Note**: Because of borrowing issues, we return a [`RawServerSubscriptionId`] rather than - /// > a [`ServerSubscription`]. You will have to call - /// > [`subscription_by_id`](RawServer::subscription_by_id) in order to manipulate the - /// > subscription. - // TODO: solve the note - pub fn into_subscription(mut self) -> Result { - let raw_request_id = match self.inner.user_param().clone() { - Some(id) => id, - None => return Err(IntoSubscriptionErr::Closed), - }; - - if !self.raw.supports_resuming(&raw_request_id).unwrap_or(false) { - return Err(IntoSubscriptionErr::NotSupported); - } - - loop { - let new_subscr_id: [u8; 32] = rand::random(); - - match self.subscriptions.entry(new_subscr_id) { - Entry::Vacant(e) => e.insert(SubscriptionState { - raw_id: raw_request_id.clone(), - method: self.inner.method().to_owned(), - pending: true, - }), - // Continue looping if we accidentally chose an existing ID. - Entry::Occupied(_) => continue, - }; - - self.num_subscriptions - .entry(raw_request_id) - .and_modify(|e| { - *e = NonZeroUsize::new(e.get() + 1) - .expect("we add 1 to an existing non-zero value; qed"); - }) - .or_insert_with(|| NonZeroUsize::new(1).expect("1 != 0")); - - let subscr_id_string = bs58::encode(&new_subscr_id).into_string(); - self.inner.set_response(Ok(subscr_id_string.into())); - break Ok(RawServerSubscriptionId(new_subscr_id)); - } - } + /// Send back a response. + /// + /// If this request is part of a batch: + /// + /// - If all requests of the batch have been responded to, then the response is actively + /// sent out. + /// - Otherwise, this response is buffered. + /// + /// > **Note**: This method is implemented in a way that doesn't wait for long to send the + /// > response. While calling this method will block your entire server, it + /// > should only block it for a short amount of time. See also [the equivalent + /// > method](crate::transport::TransportServer::finish) on the + /// > [`TransportServer`](crate::transport::TransportServer) trait. + /// + pub fn respond(self, response: Result) { + self.inner.set_response(response); + //unimplemented!(); + // TODO: actually send out response? + } + + /// Sends back a response similar to `respond`, then returns a [`RawServerSubscriptionId`] object + /// that allows you to push more data on the corresponding connection. + /// + /// The [`RawServerSubscriptionId`] corresponds to the identifier that has been sent back to the + /// client. If the client refers to this subscription id, you can turn it into a + /// [`RawServerSubscriptionId`] using + /// [`from_wire_message`](RawServerSubscriptionId::from_wire_message). + /// + /// After the request has been turned into a subscription, the subscription might be in + /// "pending mode". Pushing notifications on that subscription will return an error. This + /// mechanism is necessary because the subscription request might be part of a batch, and all + /// the requests of that batch have to be processed before informing the client of the start + /// of the subscription. + /// + /// Returns an error and doesn't do anything if the underlying server doesn't support + /// subscriptions, or if the connection has already been closed by the client. + /// + /// > **Note**: Because of borrowing issues, we return a [`RawServerSubscriptionId`] rather than + /// > a [`ServerSubscription`]. You will have to call + /// > [`subscription_by_id`](RawServer::subscription_by_id) in order to manipulate the + /// > subscription. + // TODO: solve the note + pub fn into_subscription(mut self) -> Result { + let raw_request_id = match self.inner.user_param().clone() { + Some(id) => id, + None => return Err(IntoSubscriptionErr::Closed), + }; + + if !self.raw.supports_resuming(&raw_request_id).unwrap_or(false) { + return Err(IntoSubscriptionErr::NotSupported); + } + + loop { + let new_subscr_id: [u8; 32] = rand::random(); + + match self.subscriptions.entry(new_subscr_id) { + Entry::Vacant(e) => e.insert(SubscriptionState { + raw_id: raw_request_id.clone(), + method: self.inner.method().to_owned(), + pending: true, + }), + // Continue looping if we accidentally chose an existing ID. + Entry::Occupied(_) => continue, + }; + + self.num_subscriptions + .entry(raw_request_id) + .and_modify(|e| { + *e = NonZeroUsize::new(e.get() + 1).expect("we add 1 to an existing non-zero value; qed"); + }) + .or_insert_with(|| NonZeroUsize::new(1).expect("1 != 0")); + + let subscr_id_string = bs58::encode(&new_subscr_id).into_string(); + self.inner.set_response(Ok(subscr_id_string.into())); + break Ok(RawServerSubscriptionId(new_subscr_id)); + } + } } impl<'a> fmt::Debug for RawServerRequest<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RawServerRequest") - .field("request_id", &self.request_id()) - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RawServerRequest") + .field("request_id", &self.request_id()) + .field("method", &self.method()) + .field("params", &self.params()) + .finish() + } } impl RawServerSubscriptionId { - /// When the client sends a unsubscribe message containing a subscription ID, this function can - /// be used to parse it into a [`RawServerSubscriptionId`]. - pub fn from_wire_message(params: &JsonValue) -> Result { - let string = match params { - JsonValue::String(s) => s, - _ => return Err(()), - }; - - let decoded = bs58::decode(&string).into_vec().map_err(|_| ())?; - if decoded.len() > 32 { - return Err(()); - } - - let mut out = [0; 32]; - out[(32 - decoded.len())..].copy_from_slice(&decoded); - // TODO: write a test to check that encoding/decoding match - Ok(RawServerSubscriptionId(out)) - } + /// When the client sends a unsubscribe message containing a subscription ID, this function can + /// be used to parse it into a [`RawServerSubscriptionId`]. + pub fn from_wire_message(params: &JsonValue) -> Result { + let string = match params { + JsonValue::String(s) => s, + _ => return Err(()), + }; + + let decoded = bs58::decode(&string).into_vec().map_err(|_| ())?; + if decoded.len() > 32 { + return Err(()); + } + + let mut out = [0; 32]; + out[(32 - decoded.len())..].copy_from_slice(&decoded); + // TODO: write a test to check that encoding/decoding match + Ok(RawServerSubscriptionId(out)) + } } impl<'a> ServerSubscription<'a> { - /// Returns the id of the subscription. - /// - /// If this subscription object is dropped, you can retreive it again later by calling - /// [`subscription_by_id`](crate::raw::RawServer::subscription_by_id). - pub fn id(&self) -> RawServerSubscriptionId { - RawServerSubscriptionId(self.id) - } - - /// Pushes a notification. - /// - // TODO: refactor to progate the error. - pub async fn push(self, message: impl Into) { - let subscription_state = self.server.subscriptions.get(&self.id).unwrap(); - if subscription_state.pending { - return; // TODO: notify user with error - } - - let output = common::SubscriptionNotif { - jsonrpc: common::Version::V2, - method: subscription_state.method.clone(), - params: common::SubscriptionNotifParams { - subscription: common::SubscriptionId::Str(bs58::encode(&self.id).into_string()), - result: message.into(), - }, - }; - let response = common::Response::Notif(output); - - let _ = self - .server - .raw - .send(&subscription_state.raw_id, &response) - .await; // TODO: error handling? - } - - /// Destroys the subscription object. - /// - /// This does not send any message back to the client. Instead, this function is supposed to - /// be used in reaction to the client requesting to be unsubscribed. - /// - /// If this was the last active subscription, also closes the connection ("raw request") with - /// the client. - pub async fn close(self) { - let subscription_state = self.server.subscriptions.remove(&self.id).unwrap(); - - // Check if we're the last subscription on this connection. - // Remove entry from `num_subscriptions` if so. - let is_last_sub = match self - .server - .num_subscriptions - .entry(subscription_state.raw_id.clone()) - { - Entry::Vacant(_) => unreachable!(), - Entry::Occupied(ref mut e) if e.get().get() >= 2 => { - let e = e.get_mut(); - *e = NonZeroUsize::new(e.get() - 1).expect("e is >= 2; qed"); - false - } - Entry::Occupied(e) => { - e.remove(); - true - } - }; - - // If the subscription is pending, we have yet to send something back on that connection - // and thus shouldn't close it. - // When the response is sent back later, the code will realize that `num_subscriptions` - // is zero/empty and call `finish`. - if is_last_sub && !subscription_state.pending { - let _ = self - .server - .raw - .finish(&subscription_state.raw_id, None) - .await; - } - } + /// Returns the id of the subscription. + /// + /// If this subscription object is dropped, you can retreive it again later by calling + /// [`subscription_by_id`](crate::raw::RawServer::subscription_by_id). + pub fn id(&self) -> RawServerSubscriptionId { + RawServerSubscriptionId(self.id) + } + + /// Pushes a notification. + /// + // TODO: refactor to progate the error. + pub async fn push(self, message: impl Into) { + let subscription_state = self.server.subscriptions.get(&self.id).unwrap(); + if subscription_state.pending { + return; // TODO: notify user with error + } + + let output = common::SubscriptionNotif { + jsonrpc: common::Version::V2, + method: subscription_state.method.clone(), + params: common::SubscriptionNotifParams { + subscription: common::SubscriptionId::Str(bs58::encode(&self.id).into_string()), + result: message.into(), + }, + }; + let response = common::Response::Notif(output); + + let _ = self.server.raw.send(&subscription_state.raw_id, &response).await; // TODO: error handling? + } + + /// Destroys the subscription object. + /// + /// This does not send any message back to the client. Instead, this function is supposed to + /// be used in reaction to the client requesting to be unsubscribed. + /// + /// If this was the last active subscription, also closes the connection ("raw request") with + /// the client. + pub async fn close(self) { + let subscription_state = self.server.subscriptions.remove(&self.id).unwrap(); + + // Check if we're the last subscription on this connection. + // Remove entry from `num_subscriptions` if so. + let is_last_sub = match self.server.num_subscriptions.entry(subscription_state.raw_id.clone()) { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(ref mut e) if e.get().get() >= 2 => { + let e = e.get_mut(); + *e = NonZeroUsize::new(e.get() - 1).expect("e is >= 2; qed"); + false + } + Entry::Occupied(e) => { + e.remove(); + true + } + }; + + // If the subscription is pending, we have yet to send something back on that connection + // and thus shouldn't close it. + // When the response is sent back later, the code will realize that `num_subscriptions` + // is zero/empty and call `finish`. + if is_last_sub && !subscription_state.pending { + let _ = self.server.raw.finish(&subscription_state.raw_id, None).await; + } + } } impl fmt::Display for IntoSubscriptionErr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - IntoSubscriptionErr::NotSupported => { - write!(f, "Underlying server doesn't support subscriptions") - } - IntoSubscriptionErr::Closed => write!(f, "Request is already closed"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + IntoSubscriptionErr::NotSupported => write!(f, "Underlying server doesn't support subscriptions"), + IntoSubscriptionErr::Closed => write!(f, "Request is already closed"), + } + } } impl std::error::Error for IntoSubscriptionErr {} impl Iterator for SubscriptionsReadyIter { - type Item = RawServerSubscriptionId; + type Item = RawServerSubscriptionId; - fn next(&mut self) -> Option { - self.0.next() - } + fn next(&mut self) -> Option { + self.0.next() + } - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } impl ExactSizeIterator for SubscriptionsReadyIter {} impl Iterator for SubscriptionsClosedIter { - type Item = RawServerSubscriptionId; + type Item = RawServerSubscriptionId; - fn next(&mut self) -> Option { - self.0.next() - } + fn next(&mut self) -> Option { + self.0.next() + } - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } impl ExactSizeIterator for SubscriptionsClosedIter {} diff --git a/src/http/raw/mod.rs b/src/http/raw/mod.rs index 52239bd03b..03109f8c86 100644 --- a/src/http/raw/mod.rs +++ b/src/http/raw/mod.rs @@ -8,9 +8,7 @@ mod typed_rp; #[cfg(test)] mod tests; -pub use self::core::{ - RawServer, RawServerEvent, RawServerRequest, RawServerRequestId, RawServerSubscriptionId, -}; +pub use self::core::{RawServer, RawServerEvent, RawServerRequest, RawServerRequestId, RawServerSubscriptionId}; pub use self::notification::Notification; pub use self::params::{Iter as ParamsIter, ParamKey as ParamsKey, Params}; pub use self::typed_rp::TypedResponder; diff --git a/src/http/raw/notification.rs b/src/http/raw/notification.rs index 1172608ac5..c8b39ac975 100644 --- a/src/http/raw/notification.rs +++ b/src/http/raw/notification.rs @@ -35,34 +35,31 @@ use core::fmt; pub struct Notification(common::Notification); impl From for Notification { - fn from(notif: common::Notification) -> Notification { - Notification(notif) - } + fn from(notif: common::Notification) -> Notification { + Notification(notif) + } } impl From for common::Notification { - fn from(notif: Notification) -> common::Notification { - notif.0 - } + fn from(notif: Notification) -> common::Notification { + notif.0 + } } impl Notification { - /// Returns the method of this notification. - pub fn method(&self) -> &str { - &self.0.method - } + /// Returns the method of this notification. + pub fn method(&self) -> &str { + &self.0.method + } - /// Returns the parameters of the notification. - pub fn params(&self) -> Params { - Params::from(&self.0.params) - } + /// Returns the parameters of the notification. + pub fn params(&self) -> Params { + Params::from(&self.0.params) + } } impl fmt::Debug for Notification { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Notification") - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Notification").field("method", &self.method()).field("params", &self.params()).finish() + } } diff --git a/src/http/raw/params.rs b/src/http/raw/params.rs index fa91b460c4..90a3a7e782 100644 --- a/src/http/raw/params.rs +++ b/src/http/raw/params.rs @@ -32,144 +32,144 @@ use core::fmt; /// Access to the parameters of a request. #[derive(Copy, Clone)] pub struct Params<'a> { - /// Raw parameters of the request. - params: &'a common::Params, + /// Raw parameters of the request. + params: &'a common::Params, } /// Key referring to a potential parameter of a request. pub enum ParamKey<'a> { - /// String key. Only valid when the parameters list is a map. - String(&'a str), - /// Integer key. Only valid when the parameters list is an array. - Index(usize), + /// String key. Only valid when the parameters list is a map. + String(&'a str), + /// Integer key. Only valid when the parameters list is an array. + Index(usize), } impl<'a> Params<'a> { - /// Wraps around a `&common::Params` and provides utility functions for the user. - pub(crate) fn from(params: &'a common::Params) -> Params<'a> { - Params { params } - } - - /// Returns a parameter of the request by name and decodes it. - /// - /// Returns an error if the parameter doesn't exist or is of the wrong type. - pub fn get<'k, T>(self, param: impl Into>) -> Result - where - T: serde::de::DeserializeOwned, - { - let val = self.get_raw(param).ok_or(())?; - serde_json::from_value(val.clone()).map_err(|_| ()) - } - - /// Returns a parameter of the request by name. - pub fn get_raw<'k>(self, param: impl Into>) -> Option<&'a common::JsonValue> { - match (self.params, param.into()) { - (common::Params::None, _) => None, - (common::Params::Map(map), ParamKey::String(key)) => map.get(key), - (common::Params::Map(_), ParamKey::Index(_)) => None, - (common::Params::Array(_), ParamKey::String(_)) => None, - (common::Params::Array(array), ParamKey::Index(index)) => { - if index < array.len() { - Some(&array[index]) - } else { - None - } - } - } - } + /// Wraps around a `&common::Params` and provides utility functions for the user. + pub(crate) fn from(params: &'a common::Params) -> Params<'a> { + Params { params } + } + + /// Returns a parameter of the request by name and decodes it. + /// + /// Returns an error if the parameter doesn't exist or is of the wrong type. + pub fn get<'k, T>(self, param: impl Into>) -> Result + where + T: serde::de::DeserializeOwned, + { + let val = self.get_raw(param).ok_or(())?; + serde_json::from_value(val.clone()).map_err(|_| ()) + } + + /// Returns a parameter of the request by name. + pub fn get_raw<'k>(self, param: impl Into>) -> Option<&'a common::JsonValue> { + match (self.params, param.into()) { + (common::Params::None, _) => None, + (common::Params::Map(map), ParamKey::String(key)) => map.get(key), + (common::Params::Map(_), ParamKey::Index(_)) => None, + (common::Params::Array(_), ParamKey::String(_)) => None, + (common::Params::Array(array), ParamKey::Index(index)) => { + if index < array.len() { + Some(&array[index]) + } else { + None + } + } + } + } } impl<'a> IntoIterator for Params<'a> { - type Item = (ParamKey<'a>, &'a common::JsonValue); - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Self::IntoIter { - Iter(match self.params { - common::Params::None => IterInner::Empty, - common::Params::Array(arr) => IterInner::Array(arr.iter()), - common::Params::Map(map) => IterInner::Map(map.iter()), - }) - } + type Item = (ParamKey<'a>, &'a common::JsonValue); + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Self::IntoIter { + Iter(match self.params { + common::Params::None => IterInner::Empty, + common::Params::Array(arr) => IterInner::Array(arr.iter()), + common::Params::Map(map) => IterInner::Map(map.iter()), + }) + } } impl<'a> fmt::Debug for Params<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_map().entries(self.into_iter()).finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map().entries(self.into_iter()).finish() + } } impl<'a> AsRef for Params<'a> { - fn as_ref(&self) -> &common::Params { - self.params - } + fn as_ref(&self) -> &common::Params { + self.params + } } impl<'a> From> for &'a common::Params { - fn from(params: Params<'a>) -> &'a common::Params { - params.params - } + fn from(params: Params<'a>) -> &'a common::Params { + params.params + } } impl<'a> From<&'a str> for ParamKey<'a> { - fn from(s: &'a str) -> Self { - ParamKey::String(s) - } + fn from(s: &'a str) -> Self { + ParamKey::String(s) + } } impl<'a> From<&'a String> for ParamKey<'a> { - fn from(s: &'a String) -> Self { - ParamKey::String(&s[..]) - } + fn from(s: &'a String) -> Self { + ParamKey::String(&s[..]) + } } impl<'a> From for ParamKey<'a> { - fn from(i: usize) -> Self { - ParamKey::Index(i) - } + fn from(i: usize) -> Self { + ParamKey::Index(i) + } } impl<'a> fmt::Debug for ParamKey<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ParamKey::String(s) => fmt::Debug::fmt(s, f), - ParamKey::Index(s) => fmt::Debug::fmt(s, f), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ParamKey::String(s) => fmt::Debug::fmt(s, f), + ParamKey::Index(s) => fmt::Debug::fmt(s, f), + } + } } /// Iterator to all the parameters of a request. pub struct Iter<'a>(IterInner<'a>); enum IterInner<'a> { - Empty, - Map(serde_json::map::Iter<'a>), - Array(std::slice::Iter<'a, serde_json::Value>), + Empty, + Map(serde_json::map::Iter<'a>), + Array(std::slice::Iter<'a, serde_json::Value>), } impl<'a> Iterator for Iter<'a> { - type Item = (ParamKey<'a>, &'a common::JsonValue); - - fn next(&mut self) -> Option { - match &mut self.0 { - IterInner::Empty => None, - IterInner::Map(iter) => iter.next().map(|(k, v)| (ParamKey::String(&k[..]), v)), - IterInner::Array(iter) => iter.next().map(|v| (ParamKey::String(""), v)), - } - } - - fn size_hint(&self) -> (usize, Option) { - match &self.0 { - IterInner::Empty => (0, Some(0)), - IterInner::Map(iter) => iter.size_hint(), - IterInner::Array(iter) => iter.size_hint(), - } - } + type Item = (ParamKey<'a>, &'a common::JsonValue); + + fn next(&mut self) -> Option { + match &mut self.0 { + IterInner::Empty => None, + IterInner::Map(iter) => iter.next().map(|(k, v)| (ParamKey::String(&k[..]), v)), + IterInner::Array(iter) => iter.next().map(|v| (ParamKey::String(""), v)), + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.0 { + IterInner::Empty => (0, Some(0)), + IterInner::Map(iter) => iter.size_hint(), + IterInner::Array(iter) => iter.size_hint(), + } + } } impl<'a> ExactSizeIterator for Iter<'a> {} impl<'a> fmt::Debug for Iter<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("ParamsIter").finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ParamsIter").finish() + } } diff --git a/src/http/raw/tests.rs b/src/http/raw/tests.rs index d84d7fe9ed..c366b58913 100644 --- a/src/http/raw/tests.rs +++ b/src/http/raw/tests.rs @@ -32,63 +32,58 @@ use crate::http::{HttpRawServer, HttpRawServerEvent, HttpTransportServer}; use serde_json::Value; async fn connection_context() -> (HttpTransportClient, HttpRawServer) { - let server = HttpTransportServer::new(&"127.0.0.1:0".parse().unwrap()) - .await - .unwrap(); - let uri = format!("http://{}", server.local_addr()); - let client = HttpTransportClient::new(&uri); - (client, server.into()) + let server = HttpTransportServer::new(&"127.0.0.1:0".parse().unwrap()).await.unwrap(); + let uri = format!("http://{}", server.local_addr()); + let client = HttpTransportClient::new(&uri); + (client, server.into()) } #[tokio::test] async fn request_work() { - let (mut client, mut server) = connection_context().await; - tokio::spawn(async move { - let call = Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "hello_world".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: common::Id::Num(3), - }); - client.send_request(Request::Single(call)).await.unwrap(); - }); + let (mut client, mut server) = connection_context().await; + tokio::spawn(async move { + let call = Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "hello_world".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: common::Id::Num(3), + }); + client.send_request(Request::Single(call)).await.unwrap(); + }); - match server.next_event().await { - HttpRawServerEvent::Request(r) => { - assert_eq!(r.method(), "hello_world"); - let p1: i32 = r.params().get(0).unwrap(); - let p2: i32 = r.params().get(1).unwrap(); - assert_eq!(p1, 1); - assert_eq!(p2, 2); - assert_eq!(r.request_id(), &common::Id::Num(3)); - } - e @ _ => panic!("Invalid server event: {:?} expected Request", e), - } + match server.next_event().await { + HttpRawServerEvent::Request(r) => { + assert_eq!(r.method(), "hello_world"); + let p1: i32 = r.params().get(0).unwrap(); + let p2: i32 = r.params().get(1).unwrap(); + assert_eq!(p1, 1); + assert_eq!(p2, 2); + assert_eq!(r.request_id(), &common::Id::Num(3)); + } + e @ _ => panic!("Invalid server event: {:?} expected Request", e), + } } #[tokio::test] async fn notification_work() { - let (mut client, mut server) = connection_context().await; - tokio::spawn(async move { - let n = Notification { - jsonrpc: Version::V2, - method: "hello_world".to_owned(), - params: Params::Array(vec![Value::from("lo"), Value::from(2)]), - }; - client - .send_request(Request::Single(Call::Notification(n))) - .await - .unwrap(); - }); + let (mut client, mut server) = connection_context().await; + tokio::spawn(async move { + let n = Notification { + jsonrpc: Version::V2, + method: "hello_world".to_owned(), + params: Params::Array(vec![Value::from("lo"), Value::from(2)]), + }; + client.send_request(Request::Single(Call::Notification(n))).await.unwrap(); + }); - match server.next_event().await { - HttpRawServerEvent::Notification(r) => { - assert_eq!(r.method(), "hello_world"); - let p1: String = r.params().get(0).unwrap(); - let p2: i32 = r.params().get(1).unwrap(); - assert_eq!(p1, "lo"); - assert_eq!(p2, 2); - } - e @ _ => panic!("Invalid server event: {:?} expected Notification", e), - } + match server.next_event().await { + HttpRawServerEvent::Notification(r) => { + assert_eq!(r.method(), "hello_world"); + let p1: String = r.params().get(0).unwrap(); + let p2: i32 = r.params().get(1).unwrap(); + assert_eq!(p1, "lo"); + assert_eq!(p2, 2); + } + e @ _ => panic!("Invalid server event: {:?} expected Notification", e), + } } diff --git a/src/http/raw/typed_rp.rs b/src/http/raw/typed_rp.rs index e877258a11..b7b8b1f2d7 100644 --- a/src/http/raw/typed_rp.rs +++ b/src/http/raw/typed_rp.rs @@ -29,43 +29,39 @@ use core::marker::PhantomData; /// Allows responding to a server request in a more elegant and strongly-typed fashion. pub struct TypedResponder<'a, T> { - /// The request to answer. - rq: RawServerRequest<'a>, - /// Marker that pins the type of the response. - response_ty: PhantomData, + /// The request to answer. + rq: RawServerRequest<'a>, + /// Marker that pins the type of the response. + response_ty: PhantomData, } impl<'a, T> From> for TypedResponder<'a, T> { - fn from(rq: RawServerRequest<'a>) -> TypedResponder<'a, T> { - TypedResponder { - rq, - response_ty: PhantomData, - } - } + fn from(rq: RawServerRequest<'a>) -> TypedResponder<'a, T> { + TypedResponder { rq, response_ty: PhantomData } + } } impl<'a, T> TypedResponder<'a, T> where - T: serde::Serialize, + T: serde::Serialize, { - /// Returns a successful response. - pub fn ok(self, response: impl Into) { - self.respond(Ok(response)) - } + /// Returns a successful response. + pub fn ok(self, response: impl Into) { + self.respond(Ok(response)) + } - /// Returns an erroneous response. - pub fn err(self, err: crate::common::Error) { - self.respond(Err::(err)) - } + /// Returns an erroneous response. + pub fn err(self, err: crate::common::Error) { + self.respond(Err::(err)) + } - /// Returns a response. - pub fn respond(self, response: Result, crate::common::Error>) { - let response = match response { - Ok(v) => crate::common::to_value(v.into()) - .map_err(|_| crate::common::Error::internal_error()), - Err(err) => Err(err), - }; + /// Returns a response. + pub fn respond(self, response: Result, crate::common::Error>) { + let response = match response { + Ok(v) => crate::common::to_value(v.into()).map_err(|_| crate::common::Error::internal_error()), + Err(err) => Err(err), + }; - self.rq.respond(response) - } + self.rq.respond(response) + } } diff --git a/src/http/server.rs b/src/http/server.rs index 8ada1d2a2e..82e9132d7f 100644 --- a/src/http/server.rs +++ b/src/http/server.rs @@ -31,10 +31,10 @@ use crate::http::transport::HttpTransportServer; use futures::{channel::mpsc, future::Either, pin_mut, prelude::*}; use parking_lot::Mutex; use std::{ - collections::{HashMap, HashSet}, - error, - net::SocketAddr, - sync::{atomic, Arc}, + collections::{HashMap, HashSet}, + error, + net::SocketAddr, + sync::{atomic, Arc}, }; /// Server that can be cloned. @@ -44,449 +44,407 @@ use std::{ /// > [`RawServer`] struct instead. #[derive(Clone)] pub struct Server { - /// Local socket address of the transport server. - local_addr: SocketAddr, - /// Channel to send requests to the background task. - to_back: mpsc::UnboundedSender, - /// List of methods (for RPC queries, subscriptions, and unsubscriptions) that have been - /// registered. Serves no purpose except to check for duplicates. - registered_methods: Arc>>, - /// Next unique ID used when registering a subscription. - next_subscription_unique_id: Arc, + /// Local socket address of the transport server. + local_addr: SocketAddr, + /// Channel to send requests to the background task. + to_back: mpsc::UnboundedSender, + /// List of methods (for RPC queries, subscriptions, and unsubscriptions) that have been + /// registered. Serves no purpose except to check for duplicates. + registered_methods: Arc>>, + /// Next unique ID used when registering a subscription. + next_subscription_unique_id: Arc, } /// Notification method that's been registered. pub struct RegisteredNotification { - /// Receives notifications that the client sent to us. - queries_rx: mpsc::Receiver, + /// Receives notifications that the client sent to us. + queries_rx: mpsc::Receiver, } /// Method that's been registered. pub struct RegisteredMethod { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Receives requests that the client sent to us. - queries_rx: mpsc::Receiver<(RawServerRequestId, common::Params)>, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Receives requests that the client sent to us. + queries_rx: mpsc::Receiver<(RawServerRequestId, common::Params)>, } /// Pub-sub subscription that's been registered. // TODO: unregister on drop pub struct RegisteredSubscription { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Value passed to [`FrontToBack::RegisterSubscription::unique_id`]. - unique_id: usize, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Value passed to [`FrontToBack::RegisterSubscription::unique_id`]. + unique_id: usize, } /// Active request that needs to be answered. pub struct IncomingRequest { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Identifier of the request towards the server. - request_id: RawServerRequestId, - /// Parameters of the request. - params: common::Params, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Identifier of the request towards the server. + request_id: RawServerRequestId, + /// Parameters of the request. + params: common::Params, } /// Message that the [`Server`] can send to the background task. enum FrontToBack { - /// Registers a notifications endpoint. - RegisterNotifications { - /// Name of the method. - name: String, - /// Where to send incoming notifications. - handler: mpsc::Sender, - /// See the documentation of [`Server::register_notifications`]. - allow_losses: bool, - }, - - /// Registers a method. The server will then handle requests using this method. - RegisterMethod { - /// Name of the method. - name: String, - /// Where to send requests. - handler: mpsc::Sender<(RawServerRequestId, common::Params)>, - }, - - /// Send a response to a request that a client made. - AnswerRequest { - /// Request to answer. - request_id: RawServerRequestId, - /// Response to send back. - answer: Result, - }, - - /// Registers a subscription. The server will then handle subscription requests of that - /// method. - RegisterSubscription { - /// Unique identifier decided by the front-end in order to identify this registered - /// subscription. - unique_id: usize, - /// Name of the method that registers the subscription. - subscribe_method: String, - /// Name of the method that unregisters the subscription. - unsubscribe_method: String, - }, - - /// Send out a notification to all the clients registered to a subscription. - SendOutNotif { - /// The value that was passed in [`FrontToBack::RegisterSubscription::unique_id`] earlier. - unique_id: usize, - /// Notification to send to the subscribed clients. - notification: JsonValue, - }, + /// Registers a notifications endpoint. + RegisterNotifications { + /// Name of the method. + name: String, + /// Where to send incoming notifications. + handler: mpsc::Sender, + /// See the documentation of [`Server::register_notifications`]. + allow_losses: bool, + }, + + /// Registers a method. The server will then handle requests using this method. + RegisterMethod { + /// Name of the method. + name: String, + /// Where to send requests. + handler: mpsc::Sender<(RawServerRequestId, common::Params)>, + }, + + /// Send a response to a request that a client made. + AnswerRequest { + /// Request to answer. + request_id: RawServerRequestId, + /// Response to send back. + answer: Result, + }, + + /// Registers a subscription. The server will then handle subscription requests of that + /// method. + RegisterSubscription { + /// Unique identifier decided by the front-end in order to identify this registered + /// subscription. + unique_id: usize, + /// Name of the method that registers the subscription. + subscribe_method: String, + /// Name of the method that unregisters the subscription. + unsubscribe_method: String, + }, + + /// Send out a notification to all the clients registered to a subscription. + SendOutNotif { + /// The value that was passed in [`FrontToBack::RegisterSubscription::unique_id`] earlier. + unique_id: usize, + /// Notification to send to the subscribed clients. + notification: JsonValue, + }, } impl Server { - /// Initializes a new server based upon this raw server. - pub async fn new(url: &str) -> Result> { - let sockaddr = url.parse()?; - let transport_server = HttpTransportServer::new(&sockaddr).await?; - let local_addr = *transport_server.local_addr(); - - // We use an unbounded channel because the only exchanged messages concern registering - // methods. The volume of messages is therefore very low and it doesn't make sense to have - // a backpressure mechanism. - // TODO: that's not true anymore ^ - let (to_back, from_front) = mpsc::unbounded(); - - async_std::task::spawn(async move { - background_task(transport_server.into(), from_front).await; - }); - - Ok(Server { - local_addr, - to_back, - registered_methods: Arc::new(Mutex::new(Default::default())), - next_subscription_unique_id: Arc::new(atomic::AtomicUsize::new(0)), - }) - } - - /// Local socket address of the transport server. - pub fn local_addr(&self) -> &SocketAddr { - &self.local_addr - } - - /// Registers a notification method name towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to process incoming notifications. - /// - /// If `allow_losses` is true, then the server is allowed to drop notifications if the - /// notifications handler (i.e. the code that uses [`RegisteredNotifications`]) is too slow - /// to process notifications. - /// - /// Returns an error if the method name was already registered. - pub fn register_notification( - &self, - method_name: String, - allow_losses: bool, - ) -> Result { - if !self.registered_methods.lock().insert(method_name.clone()) { - return Err(()); - } - - let (tx, rx) = mpsc::channel(32); - - let _ = self - .to_back - .unbounded_send(FrontToBack::RegisterNotifications { - name: method_name, - handler: tx, - allow_losses, - }); - - Ok(RegisteredNotification { queries_rx: rx }) - } - - /// Registers a method towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to handle incoming requests. - /// - /// Contrary to [`register_notifications`](Server::register_notifications), there is no - /// `allow_losses` parameter here. If the handler is too slow to process requests, then the - /// server automatically returns an "internal error" to the client. - /// - /// Returns an error if the method name was already registered. - pub fn register_method(&self, method_name: String) -> Result { - if !self.registered_methods.lock().insert(method_name.clone()) { - return Err(()); - } - - let (tx, rx) = mpsc::channel(32); - - let _ = self.to_back.unbounded_send(FrontToBack::RegisterMethod { - name: method_name, - handler: tx, - }); - - Ok(RegisteredMethod { - to_back: self.to_back.clone(), - queries_rx: rx, - }) - } - - /// Registers a subscription towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to send out notifications. - /// - /// Returns an error if one of the method names was already registered. - pub fn register_subscription( - &self, - subscribe_method_name: String, - unsubscribe_method_name: String, - ) -> Result { - { - let mut registered_methods = self.registered_methods.lock(); - if !registered_methods.insert(subscribe_method_name.clone()) { - return Err(()); - } - if !registered_methods.insert(unsubscribe_method_name.clone()) { - registered_methods.remove(&subscribe_method_name); - return Err(()); - } - } - - let unique_id = self - .next_subscription_unique_id - .fetch_add(1, atomic::Ordering::Relaxed); - - let _ = self - .to_back - .unbounded_send(FrontToBack::RegisterSubscription { - unique_id, - subscribe_method: subscribe_method_name, - unsubscribe_method: unsubscribe_method_name, - }); - - Ok(RegisteredSubscription { - to_back: self.to_back.clone(), - unique_id, - }) - } + /// Initializes a new server based upon this raw server. + pub async fn new(url: &str) -> Result> { + let sockaddr = url.parse()?; + let transport_server = HttpTransportServer::new(&sockaddr).await?; + let local_addr = *transport_server.local_addr(); + + // We use an unbounded channel because the only exchanged messages concern registering + // methods. The volume of messages is therefore very low and it doesn't make sense to have + // a backpressure mechanism. + // TODO: that's not true anymore ^ + let (to_back, from_front) = mpsc::unbounded(); + + async_std::task::spawn(async move { + background_task(transport_server.into(), from_front).await; + }); + + Ok(Server { + local_addr, + to_back, + registered_methods: Arc::new(Mutex::new(Default::default())), + next_subscription_unique_id: Arc::new(atomic::AtomicUsize::new(0)), + }) + } + + /// Local socket address of the transport server. + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } + + /// Registers a notification method name towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to process incoming notifications. + /// + /// If `allow_losses` is true, then the server is allowed to drop notifications if the + /// notifications handler (i.e. the code that uses [`RegisteredNotifications`]) is too slow + /// to process notifications. + /// + /// Returns an error if the method name was already registered. + pub fn register_notification(&self, method_name: String, allow_losses: bool) -> Result { + if !self.registered_methods.lock().insert(method_name.clone()) { + return Err(()); + } + + let (tx, rx) = mpsc::channel(32); + + let _ = self.to_back.unbounded_send(FrontToBack::RegisterNotifications { + name: method_name, + handler: tx, + allow_losses, + }); + + Ok(RegisteredNotification { queries_rx: rx }) + } + + /// Registers a method towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to handle incoming requests. + /// + /// Contrary to [`register_notifications`](Server::register_notifications), there is no + /// `allow_losses` parameter here. If the handler is too slow to process requests, then the + /// server automatically returns an "internal error" to the client. + /// + /// Returns an error if the method name was already registered. + pub fn register_method(&self, method_name: String) -> Result { + if !self.registered_methods.lock().insert(method_name.clone()) { + return Err(()); + } + + let (tx, rx) = mpsc::channel(32); + + let _ = self.to_back.unbounded_send(FrontToBack::RegisterMethod { name: method_name, handler: tx }); + + Ok(RegisteredMethod { to_back: self.to_back.clone(), queries_rx: rx }) + } + + /// Registers a subscription towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to send out notifications. + /// + /// Returns an error if one of the method names was already registered. + pub fn register_subscription( + &self, + subscribe_method_name: String, + unsubscribe_method_name: String, + ) -> Result { + { + let mut registered_methods = self.registered_methods.lock(); + if !registered_methods.insert(subscribe_method_name.clone()) { + return Err(()); + } + if !registered_methods.insert(unsubscribe_method_name.clone()) { + registered_methods.remove(&subscribe_method_name); + return Err(()); + } + } + + let unique_id = self.next_subscription_unique_id.fetch_add(1, atomic::Ordering::Relaxed); + + let _ = self.to_back.unbounded_send(FrontToBack::RegisterSubscription { + unique_id, + subscribe_method: subscribe_method_name, + unsubscribe_method: unsubscribe_method_name, + }); + + Ok(RegisteredSubscription { to_back: self.to_back.clone(), unique_id }) + } } impl RegisteredNotification { - /// Returns the next notification. - pub async fn next(&mut self) -> common::Params { - loop { - match self.queries_rx.next().await { - Some(v) => break v, - None => futures::pending!(), - } - } - } + /// Returns the next notification. + pub async fn next(&mut self) -> common::Params { + loop { + match self.queries_rx.next().await { + Some(v) => break v, + None => futures::pending!(), + } + } + } } impl RegisteredMethod { - /// Returns the next request. - pub async fn next(&mut self) -> IncomingRequest { - let (request_id, params) = loop { - match self.queries_rx.next().await { - Some(v) => break v, - None => futures::pending!(), - } - }; - IncomingRequest { - to_back: self.to_back.clone(), - request_id, - params, - } - } + /// Returns the next request. + pub async fn next(&mut self) -> IncomingRequest { + let (request_id, params) = loop { + match self.queries_rx.next().await { + Some(v) => break v, + None => futures::pending!(), + } + }; + IncomingRequest { to_back: self.to_back.clone(), request_id, params } + } } impl RegisteredSubscription { - /// Sends out a value to all the registered clients. - pub async fn send(&mut self, value: JsonValue) { - let _ = self.to_back.send(FrontToBack::SendOutNotif { - unique_id: self.unique_id, - notification: value, - }); - } + /// Sends out a value to all the registered clients. + pub async fn send(&mut self, value: JsonValue) { + let _ = self.to_back.send(FrontToBack::SendOutNotif { unique_id: self.unique_id, notification: value }); + } } impl IncomingRequest { - /// Returns the parameters of the request. - pub fn params(&self) -> &common::Params { - &self.params - } - - /// Respond to the request. - pub async fn respond(mut self, response: impl Into>) { - let _ = self - .to_back - .send(FrontToBack::AnswerRequest { - request_id: self.request_id, - answer: response.into(), - }) - .await; - } + /// Returns the parameters of the request. + pub fn params(&self) -> &common::Params { + &self.params + } + + /// Respond to the request. + pub async fn respond(mut self, response: impl Into>) { + let _ = self + .to_back + .send(FrontToBack::AnswerRequest { request_id: self.request_id, answer: response.into() }) + .await; + } } /// Function being run in the background that processes messages from the frontend. -async fn background_task( - mut server: RawServer, - mut from_front: mpsc::UnboundedReceiver, -) { - // List of notifications methods that the user has registered, and the channels to dispatch - // incoming notifications. - let mut registered_notifications: HashMap, bool)> = HashMap::new(); - // List of methods that the user has registered, and the channels to dispatch incoming - // requests. - let mut registered_methods: HashMap> = HashMap::new(); - // For each registered subscription, a subscribe method linked to a unique identifier for - // that subscription. - let mut subscribe_methods: HashMap = HashMap::new(); - // For each registered subscription, an unsubscribe method linked to a unique identifier for - // that subscription. - let mut unsubscribe_methods: HashMap = HashMap::new(); - // For each registered subscription, a list of clients that are registered towards us. - let mut subscribed_clients: HashMap> = HashMap::new(); - // Reversed mapping of `subscribed_clients`. Must always be in sync. - let mut active_subscriptions: HashMap = HashMap::new(); - - loop { - // We need to do a little transformation in order to destroy the borrow to `client` - // and `from_front`. - let outcome = { - let next_message = from_front.next(); - let next_event = server.next_event(); - pin_mut!(next_message); - pin_mut!(next_event); - match future::select(next_message, next_event).await { - Either::Left((v, _)) => Either::Left(v), - Either::Right((v, _)) => Either::Right(v), - } - }; - - match outcome { - Either::Left(None) => return, - Either::Left(Some(FrontToBack::AnswerRequest { request_id, answer })) => { - server.request_by_id(&request_id).unwrap().respond(answer); - } - Either::Left(Some(FrontToBack::RegisterNotifications { - name, - handler, - allow_losses, - })) => { - registered_notifications.insert(name, (handler, allow_losses)); - } - Either::Left(Some(FrontToBack::RegisterMethod { name, handler })) => { - registered_methods.insert(name, handler); - } - Either::Left(Some(FrontToBack::RegisterSubscription { - unique_id, - subscribe_method, - unsubscribe_method, - })) => { - debug_assert_ne!(subscribe_method, unsubscribe_method); - debug_assert!(!subscribe_methods.contains_key(&subscribe_method)); - debug_assert!(!subscribe_methods.contains_key(&unsubscribe_method)); - debug_assert!(!unsubscribe_methods.contains_key(&subscribe_method)); - debug_assert!(!unsubscribe_methods.contains_key(&unsubscribe_method)); - debug_assert!(!registered_methods.contains_key(&subscribe_method)); - debug_assert!(!registered_methods.contains_key(&unsubscribe_method)); - debug_assert!(!registered_notifications.contains_key(&subscribe_method)); - debug_assert!(!registered_notifications.contains_key(&unsubscribe_method)); - debug_assert!(!subscribed_clients.contains_key(&unique_id)); - subscribe_methods.insert(subscribe_method, unique_id); - unsubscribe_methods.insert(unsubscribe_method, unique_id); - subscribed_clients.insert(unique_id, Vec::new()); - } - Either::Left(Some(FrontToBack::SendOutNotif { - unique_id, - notification, - })) => { - debug_assert!(subscribed_clients.contains_key(&unique_id)); - if let Some(clients) = subscribed_clients.get(&unique_id) { - for client in clients { - debug_assert_eq!(active_subscriptions.get(client), Some(&unique_id)); - debug_assert!(server.subscription_by_id(*client).is_some()); - if let Some(sub) = server.subscription_by_id(*client) { - sub.push(notification.clone()).await; - } - } - } - } - Either::Right(RawServerEvent::Notification(notification)) => { - log::debug!("server received notification: {:?}", notification); - if let Some((handler, allow_losses)) = - registered_notifications.get_mut(notification.method()) - { - let params: &common::Params = notification.params().into(); - // Note: we just ignore errors. It doesn't make sense logically speaking to - // unregister the notification here. - if *allow_losses { - let _ = handler.send(params.clone()).now_or_never(); - } else { - let _ = handler.send(params.clone()).await; - } - } - } - Either::Right(RawServerEvent::Request(request)) => { - log::debug!("server received request: {:?}", request); - if let Some(handler) = registered_methods.get_mut(request.method()) { - let params: &common::Params = request.params().into(); - match handler.send((request.id(), params.clone())).now_or_never() { - Some(Ok(())) => {} - Some(Err(_)) | None => { - request.respond(Err(From::from(common::ErrorCode::ServerError(0)))); - } - } - } else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) { - if let Ok(sub_id) = request.into_subscription() { - debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { - debug_assert!(clients.iter().all(|c| *c != sub_id)); - clients.push(sub_id); - } - - debug_assert!(!active_subscriptions.contains_key(&sub_id)); - active_subscriptions.insert(sub_id, *sub_unique_id); - } - } else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) { - if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) - { - // FIXME: from request params - debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { - // TODO: we don't actually check whether the unsubscribe comes from the right - // clients, but since this the ID is randomly-generated, it should be - // fine - if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { - clients.remove(client_pos); - } - - if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { - debug_assert_eq!(s_u_id, *sub_unique_id); - } - } - } - } else { - // TODO: we assert that the request is valid because the parsing succeeded but - // not registered. - request.respond(Err(From::from(common::ErrorCode::MethodNotFound))); - } - } - Either::Right(RawServerEvent::SubscriptionsReady(_)) => { - // We don't really care whether subscriptions are now ready. - } - Either::Right(RawServerEvent::SubscriptionsClosed(iter)) => { - // Remove all the subscriptions from `active_subscriptions` and - // `subscribed_clients`. - for sub_id in iter { - debug_assert!(active_subscriptions.contains_key(&sub_id)); - if let Some(unique_id) = active_subscriptions.remove(&sub_id) { - debug_assert!(subscribed_clients.contains_key(&unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&unique_id) { - assert_eq!(clients.iter().filter(|c| **c == sub_id).count(), 1); - clients.retain(|c| *c != sub_id); - } - } - } - } - } - } +async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedReceiver) { + // List of notifications methods that the user has registered, and the channels to dispatch + // incoming notifications. + let mut registered_notifications: HashMap, bool)> = HashMap::new(); + // List of methods that the user has registered, and the channels to dispatch incoming + // requests. + let mut registered_methods: HashMap> = HashMap::new(); + // For each registered subscription, a subscribe method linked to a unique identifier for + // that subscription. + let mut subscribe_methods: HashMap = HashMap::new(); + // For each registered subscription, an unsubscribe method linked to a unique identifier for + // that subscription. + let mut unsubscribe_methods: HashMap = HashMap::new(); + // For each registered subscription, a list of clients that are registered towards us. + let mut subscribed_clients: HashMap> = HashMap::new(); + // Reversed mapping of `subscribed_clients`. Must always be in sync. + let mut active_subscriptions: HashMap = HashMap::new(); + + loop { + // We need to do a little transformation in order to destroy the borrow to `client` + // and `from_front`. + let outcome = { + let next_message = from_front.next(); + let next_event = server.next_event(); + pin_mut!(next_message); + pin_mut!(next_event); + match future::select(next_message, next_event).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + }; + + match outcome { + Either::Left(None) => return, + Either::Left(Some(FrontToBack::AnswerRequest { request_id, answer })) => { + server.request_by_id(&request_id).unwrap().respond(answer); + } + Either::Left(Some(FrontToBack::RegisterNotifications { name, handler, allow_losses })) => { + registered_notifications.insert(name, (handler, allow_losses)); + } + Either::Left(Some(FrontToBack::RegisterMethod { name, handler })) => { + registered_methods.insert(name, handler); + } + Either::Left(Some(FrontToBack::RegisterSubscription { + unique_id, + subscribe_method, + unsubscribe_method, + })) => { + debug_assert_ne!(subscribe_method, unsubscribe_method); + debug_assert!(!subscribe_methods.contains_key(&subscribe_method)); + debug_assert!(!subscribe_methods.contains_key(&unsubscribe_method)); + debug_assert!(!unsubscribe_methods.contains_key(&subscribe_method)); + debug_assert!(!unsubscribe_methods.contains_key(&unsubscribe_method)); + debug_assert!(!registered_methods.contains_key(&subscribe_method)); + debug_assert!(!registered_methods.contains_key(&unsubscribe_method)); + debug_assert!(!registered_notifications.contains_key(&subscribe_method)); + debug_assert!(!registered_notifications.contains_key(&unsubscribe_method)); + debug_assert!(!subscribed_clients.contains_key(&unique_id)); + subscribe_methods.insert(subscribe_method, unique_id); + unsubscribe_methods.insert(unsubscribe_method, unique_id); + subscribed_clients.insert(unique_id, Vec::new()); + } + Either::Left(Some(FrontToBack::SendOutNotif { unique_id, notification })) => { + debug_assert!(subscribed_clients.contains_key(&unique_id)); + if let Some(clients) = subscribed_clients.get(&unique_id) { + for client in clients { + debug_assert_eq!(active_subscriptions.get(client), Some(&unique_id)); + debug_assert!(server.subscription_by_id(*client).is_some()); + if let Some(sub) = server.subscription_by_id(*client) { + sub.push(notification.clone()).await; + } + } + } + } + Either::Right(RawServerEvent::Notification(notification)) => { + log::debug!("server received notification: {:?}", notification); + if let Some((handler, allow_losses)) = registered_notifications.get_mut(notification.method()) { + let params: &common::Params = notification.params().into(); + // Note: we just ignore errors. It doesn't make sense logically speaking to + // unregister the notification here. + if *allow_losses { + let _ = handler.send(params.clone()).now_or_never(); + } else { + let _ = handler.send(params.clone()).await; + } + } + } + Either::Right(RawServerEvent::Request(request)) => { + log::debug!("server received request: {:?}", request); + if let Some(handler) = registered_methods.get_mut(request.method()) { + let params: &common::Params = request.params().into(); + match handler.send((request.id(), params.clone())).now_or_never() { + Some(Ok(())) => {} + Some(Err(_)) | None => { + request.respond(Err(From::from(common::ErrorCode::ServerError(0)))); + } + } + } else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) { + if let Ok(sub_id) = request.into_subscription() { + debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { + debug_assert!(clients.iter().all(|c| *c != sub_id)); + clients.push(sub_id); + } + + debug_assert!(!active_subscriptions.contains_key(&sub_id)); + active_subscriptions.insert(sub_id, *sub_unique_id); + } + } else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) { + if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) { + // FIXME: from request params + debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { + // TODO: we don't actually check whether the unsubscribe comes from the right + // clients, but since this the ID is randomly-generated, it should be + // fine + if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { + clients.remove(client_pos); + } + + if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { + debug_assert_eq!(s_u_id, *sub_unique_id); + } + } + } + } else { + // TODO: we assert that the request is valid because the parsing succeeded but + // not registered. + request.respond(Err(From::from(common::ErrorCode::MethodNotFound))); + } + } + Either::Right(RawServerEvent::SubscriptionsReady(_)) => { + // We don't really care whether subscriptions are now ready. + } + Either::Right(RawServerEvent::SubscriptionsClosed(iter)) => { + // Remove all the subscriptions from `active_subscriptions` and + // `subscribed_clients`. + for sub_id in iter { + debug_assert!(active_subscriptions.contains_key(&sub_id)); + if let Some(unique_id) = active_subscriptions.remove(&sub_id) { + debug_assert!(subscribed_clients.contains_key(&unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&unique_id) { + assert_eq!(clients.iter().filter(|c| **c == sub_id).count(), 1); + clients.retain(|c| *c != sub_id); + } + } + } + } + } + } } diff --git a/src/http/server_utils/access_control.rs b/src/http/server_utils/access_control.rs index 0b360b955d..780f93c163 100644 --- a/src/http/server_utils/access_control.rs +++ b/src/http/server_utils/access_control.rs @@ -34,147 +34,142 @@ use hyper::{self, header}; /// Define access on control on http layer #[derive(Clone)] pub struct AccessControl { - allow_hosts: AllowHosts, - cors_allow_origin: Option>, - cors_max_age: Option, - cors_allow_headers: AccessControlAllowHeaders, - continue_on_invalid_cors: bool, + allow_hosts: AllowHosts, + cors_allow_origin: Option>, + cors_max_age: Option, + cors_allow_headers: AccessControlAllowHeaders, + continue_on_invalid_cors: bool, } impl AccessControl { - /// Validate incoming request by http HOST - pub fn deny_host(&self, request: &hyper::Request) -> bool { - !hosts::is_host_valid(utils::read_header(request, "host"), &self.allow_hosts) - } - - /// Validate incoming request by CORS origin - pub fn deny_cors_origin(&self, request: &hyper::Request) -> bool { - let header = cors::get_cors_allow_origin( - utils::read_header(request, "origin"), - utils::read_header(request, "host"), - &self.cors_allow_origin, - ) - .map(|origin| { - use self::cors::AccessControlAllowOrigin::*; - match origin { - Value(ref val) => header::HeaderValue::from_str(val) - .unwrap_or_else(|_| header::HeaderValue::from_static("null")), - Null => header::HeaderValue::from_static("null"), - Any => header::HeaderValue::from_static("*"), - } - }); - header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors - } - - /// Validate incoming request by CORS header - pub fn deny_cors_header(&self, request: &hyper::Request) -> bool { - let headers = request.headers().keys().map(|name| name.as_str()); - let requested_headers = request - .headers() - .get_all("access-control-request-headers") - .iter() - .filter_map(|val| val.to_str().ok()) - .flat_map(|val| val.split(", ")) - .flat_map(|val| val.split(',')); - - let header = cors::get_cors_allow_headers( - headers, - requested_headers, - &self.cors_allow_headers, - |name| { - header::HeaderValue::from_str(name) - .unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) - }, - ); - header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors - } + /// Validate incoming request by http HOST + pub fn deny_host(&self, request: &hyper::Request) -> bool { + !hosts::is_host_valid(utils::read_header(request, "host"), &self.allow_hosts) + } + + /// Validate incoming request by CORS origin + pub fn deny_cors_origin(&self, request: &hyper::Request) -> bool { + let header = cors::get_cors_allow_origin( + utils::read_header(request, "origin"), + utils::read_header(request, "host"), + &self.cors_allow_origin, + ) + .map(|origin| { + use self::cors::AccessControlAllowOrigin::*; + match origin { + Value(ref val) => { + header::HeaderValue::from_str(val).unwrap_or_else(|_| header::HeaderValue::from_static("null")) + } + Null => header::HeaderValue::from_static("null"), + Any => header::HeaderValue::from_static("*"), + } + }); + header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors + } + + /// Validate incoming request by CORS header + pub fn deny_cors_header(&self, request: &hyper::Request) -> bool { + let headers = request.headers().keys().map(|name| name.as_str()); + let requested_headers = request + .headers() + .get_all("access-control-request-headers") + .iter() + .filter_map(|val| val.to_str().ok()) + .flat_map(|val| val.split(", ")) + .flat_map(|val| val.split(',')); + + let header = cors::get_cors_allow_headers(headers, requested_headers, &self.cors_allow_headers, |name| { + header::HeaderValue::from_str(name).unwrap_or_else(|_| header::HeaderValue::from_static("unknown")) + }); + header == cors::AllowCors::Invalid && !self.continue_on_invalid_cors + } } impl Default for AccessControl { - fn default() -> Self { - Self { - allow_hosts: AllowHosts::Any, - cors_allow_origin: None, - cors_max_age: None, - cors_allow_headers: AccessControlAllowHeaders::Any, - continue_on_invalid_cors: false, - } - } + fn default() -> Self { + Self { + allow_hosts: AllowHosts::Any, + cors_allow_origin: None, + cors_max_age: None, + cors_allow_headers: AccessControlAllowHeaders::Any, + continue_on_invalid_cors: false, + } + } } /// Convenience builder pattern pub struct AccessControlBuilder { - allow_hosts: AllowHosts, - cors_allow_origin: Option>, - cors_max_age: Option, - cors_allow_headers: AccessControlAllowHeaders, - continue_on_invalid_cors: bool, + allow_hosts: AllowHosts, + cors_allow_origin: Option>, + cors_max_age: Option, + cors_allow_headers: AccessControlAllowHeaders, + continue_on_invalid_cors: bool, } impl AccessControlBuilder { - pub fn new() -> Self { - AccessControlBuilder { - allow_hosts: AllowHosts::Any, - cors_allow_origin: None, - cors_max_age: None, - cors_allow_headers: AccessControlAllowHeaders::Any, - continue_on_invalid_cors: false, - } - } - - pub fn allow_host(mut self, host: Host) -> Self { - let allow_hosts = match self.allow_hosts { - AllowHosts::Any => vec![host], - AllowHosts::Only(mut allow_hosts) => { - allow_hosts.push(host); - allow_hosts - } - }; - self.allow_hosts = AllowHosts::Only(allow_hosts); - self - } - - pub fn cors_allow_origin(mut self, allow_origin: AccessControlAllowOrigin) -> Self { - let cors_allow_origin = match self.cors_allow_origin { - Some(mut cors_allow_origin) => { - cors_allow_origin.push(allow_origin); - cors_allow_origin - } - None => vec![allow_origin], - }; - self.cors_allow_origin = Some(cors_allow_origin); - self - } - - pub fn cors_max_age(mut self, max_age: u32) -> Self { - self.cors_max_age = Some(max_age); - self - } - - pub fn cors_allow_header(mut self, header: String) -> Self { - let allow_headers = match self.cors_allow_headers { - AccessControlAllowHeaders::Any => vec![header], - AccessControlAllowHeaders::Only(mut allow_headers) => { - allow_headers.push(header); - allow_headers - } - }; - self.cors_allow_headers = AccessControlAllowHeaders::Only(allow_headers); - self - } - - pub fn continue_on_invalid_cors(mut self, continue_on_invalid_cors: bool) -> Self { - self.continue_on_invalid_cors = continue_on_invalid_cors; - self - } - - pub fn build(self) -> AccessControl { - AccessControl { - allow_hosts: self.allow_hosts, - cors_allow_origin: self.cors_allow_origin, - cors_max_age: self.cors_max_age, - cors_allow_headers: self.cors_allow_headers, - continue_on_invalid_cors: self.continue_on_invalid_cors, - } - } + pub fn new() -> Self { + AccessControlBuilder { + allow_hosts: AllowHosts::Any, + cors_allow_origin: None, + cors_max_age: None, + cors_allow_headers: AccessControlAllowHeaders::Any, + continue_on_invalid_cors: false, + } + } + + pub fn allow_host(mut self, host: Host) -> Self { + let allow_hosts = match self.allow_hosts { + AllowHosts::Any => vec![host], + AllowHosts::Only(mut allow_hosts) => { + allow_hosts.push(host); + allow_hosts + } + }; + self.allow_hosts = AllowHosts::Only(allow_hosts); + self + } + + pub fn cors_allow_origin(mut self, allow_origin: AccessControlAllowOrigin) -> Self { + let cors_allow_origin = match self.cors_allow_origin { + Some(mut cors_allow_origin) => { + cors_allow_origin.push(allow_origin); + cors_allow_origin + } + None => vec![allow_origin], + }; + self.cors_allow_origin = Some(cors_allow_origin); + self + } + + pub fn cors_max_age(mut self, max_age: u32) -> Self { + self.cors_max_age = Some(max_age); + self + } + + pub fn cors_allow_header(mut self, header: String) -> Self { + let allow_headers = match self.cors_allow_headers { + AccessControlAllowHeaders::Any => vec![header], + AccessControlAllowHeaders::Only(mut allow_headers) => { + allow_headers.push(header); + allow_headers + } + }; + self.cors_allow_headers = AccessControlAllowHeaders::Only(allow_headers); + self + } + + pub fn continue_on_invalid_cors(mut self, continue_on_invalid_cors: bool) -> Self { + self.continue_on_invalid_cors = continue_on_invalid_cors; + self + } + + pub fn build(self) -> AccessControl { + AccessControl { + allow_hosts: self.allow_hosts, + cors_allow_origin: self.cors_allow_origin, + cors_max_age: self.cors_max_age, + cors_allow_headers: self.cors_allow_headers, + continue_on_invalid_cors: self.continue_on_invalid_cors, + } + } } diff --git a/src/http/server_utils/cors.rs b/src/http/server_utils/cors.rs index 72522112b0..28255a6fb5 100644 --- a/src/http/server_utils/cors.rs +++ b/src/http/server_utils/cors.rs @@ -36,607 +36,557 @@ use unicase::Ascii; /// Origin Protocol #[derive(Clone, Hash, Debug, PartialEq, Eq)] pub enum OriginProtocol { - /// Http protocol - Http, - /// Https protocol - Https, - /// Custom protocol - Custom(String), + /// Http protocol + Http, + /// Https protocol + Https, + /// Custom protocol + Custom(String), } /// Request Origin #[derive(Clone, PartialEq, Eq, Debug, Hash)] pub struct Origin { - protocol: OriginProtocol, - host: Host, - as_string: String, - matcher: Matcher, + protocol: OriginProtocol, + host: Host, + as_string: String, + matcher: Matcher, } impl> From for Origin { - fn from(string: T) -> Self { - Origin::parse(string.as_ref()) - } + fn from(string: T) -> Self { + Origin::parse(string.as_ref()) + } } impl Origin { - fn with_host(protocol: OriginProtocol, host: Host) -> Self { - let string = Self::to_string(&protocol, &host); - let matcher = Matcher::new(&string); - - Origin { - protocol, - host, - as_string: string, - matcher, - } - } - - /// Creates new origin given protocol, hostname and port parts. - /// Pre-processes input data if necessary. - pub fn new>(protocol: OriginProtocol, host: &str, port: T) -> Self { - Self::with_host(protocol, Host::new(host, port)) - } - - /// Attempts to parse given string as a `Origin`. - /// NOTE: This method always succeeds and falls back to sensible defaults. - pub fn parse(data: &str) -> Self { - let mut it = data.split("://"); - let proto = it.next().expect("split always returns non-empty iterator."); - let hostname = it.next(); - - let (proto, hostname) = match hostname { - None => (None, proto), - Some(hostname) => (Some(proto), hostname), - }; - - let proto = proto.map(str::to_lowercase); - let hostname = Host::parse(hostname); - - let protocol = match proto { - None => OriginProtocol::Http, - Some(ref p) if p == "http" => OriginProtocol::Http, - Some(ref p) if p == "https" => OriginProtocol::Https, - Some(other) => OriginProtocol::Custom(other), - }; - - Origin::with_host(protocol, hostname) - } - - fn to_string(protocol: &OriginProtocol, host: &Host) -> String { - format!( - "{}://{}", - match *protocol { - OriginProtocol::Http => "http", - OriginProtocol::Https => "https", - OriginProtocol::Custom(ref protocol) => protocol, - }, - &**host, - ) - } + fn with_host(protocol: OriginProtocol, host: Host) -> Self { + let string = Self::to_string(&protocol, &host); + let matcher = Matcher::new(&string); + + Origin { protocol, host, as_string: string, matcher } + } + + /// Creates new origin given protocol, hostname and port parts. + /// Pre-processes input data if necessary. + pub fn new>(protocol: OriginProtocol, host: &str, port: T) -> Self { + Self::with_host(protocol, Host::new(host, port)) + } + + /// Attempts to parse given string as a `Origin`. + /// NOTE: This method always succeeds and falls back to sensible defaults. + pub fn parse(data: &str) -> Self { + let mut it = data.split("://"); + let proto = it.next().expect("split always returns non-empty iterator."); + let hostname = it.next(); + + let (proto, hostname) = match hostname { + None => (None, proto), + Some(hostname) => (Some(proto), hostname), + }; + + let proto = proto.map(str::to_lowercase); + let hostname = Host::parse(hostname); + + let protocol = match proto { + None => OriginProtocol::Http, + Some(ref p) if p == "http" => OriginProtocol::Http, + Some(ref p) if p == "https" => OriginProtocol::Https, + Some(other) => OriginProtocol::Custom(other), + }; + + Origin::with_host(protocol, hostname) + } + + fn to_string(protocol: &OriginProtocol, host: &Host) -> String { + format!( + "{}://{}", + match *protocol { + OriginProtocol::Http => "http", + OriginProtocol::Https => "https", + OriginProtocol::Custom(ref protocol) => protocol, + }, + &**host, + ) + } } impl Pattern for Origin { - fn matches>(&self, other: T) -> bool { - self.matcher.matches(other) - } + fn matches>(&self, other: T) -> bool { + self.matcher.matches(other) + } } impl ops::Deref for Origin { - type Target = str; - fn deref(&self) -> &Self::Target { - &self.as_string - } + type Target = str; + fn deref(&self) -> &Self::Target { + &self.as_string + } } /// Origins allowed to access #[derive(Debug, Clone, PartialEq, Eq)] pub enum AccessControlAllowOrigin { - /// Specific hostname - Value(Origin), - /// null-origin (file:///, sandboxed iframe) - Null, - /// Any non-null origin - Any, + /// Specific hostname + Value(Origin), + /// null-origin (file:///, sandboxed iframe) + Null, + /// Any non-null origin + Any, } impl fmt::Display for AccessControlAllowOrigin { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}", - match *self { - AccessControlAllowOrigin::Any => "*", - AccessControlAllowOrigin::Null => "null", - AccessControlAllowOrigin::Value(ref val) => val, - } - ) - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}", + match *self { + AccessControlAllowOrigin::Any => "*", + AccessControlAllowOrigin::Null => "null", + AccessControlAllowOrigin::Value(ref val) => val, + } + ) + } } impl> From for AccessControlAllowOrigin { - fn from(s: T) -> AccessControlAllowOrigin { - match s.into().as_str() { - "all" | "*" | "any" => AccessControlAllowOrigin::Any, - "null" => AccessControlAllowOrigin::Null, - origin => AccessControlAllowOrigin::Value(origin.into()), - } - } + fn from(s: T) -> AccessControlAllowOrigin { + match s.into().as_str() { + "all" | "*" | "any" => AccessControlAllowOrigin::Any, + "null" => AccessControlAllowOrigin::Null, + origin => AccessControlAllowOrigin::Value(origin.into()), + } + } } /// Headers allowed to access #[derive(Debug, Clone, PartialEq)] pub enum AccessControlAllowHeaders { - /// Specific headers - Only(Vec), - /// Any header - Any, + /// Specific headers + Only(Vec), + /// Any header + Any, } /// CORS response headers #[derive(Debug, Clone, PartialEq, Eq)] pub enum AllowCors { - /// CORS header was not required. Origin is not present in the request. - NotRequired, - /// CORS header is not returned, Origin is not allowed to access the resource. - Invalid, - /// CORS header to include in the response. Origin is allowed to access the resource. - Ok(T), + /// CORS header was not required. Origin is not present in the request. + NotRequired, + /// CORS header is not returned, Origin is not allowed to access the resource. + Invalid, + /// CORS header to include in the response. Origin is allowed to access the resource. + Ok(T), } impl AllowCors { - /// Maps `Ok` variant of `AllowCors`. - pub fn map(self, f: F) -> AllowCors - where - F: FnOnce(T) -> O, - { - use self::AllowCors::*; - - match self { - NotRequired => NotRequired, - Invalid => Invalid, - Ok(val) => Ok(f(val)), - } - } + /// Maps `Ok` variant of `AllowCors`. + pub fn map(self, f: F) -> AllowCors + where + F: FnOnce(T) -> O, + { + use self::AllowCors::*; + + match self { + NotRequired => NotRequired, + Invalid => Invalid, + Ok(val) => Ok(f(val)), + } + } } impl Into> for AllowCors { - fn into(self) -> Option { - use self::AllowCors::*; - - match self { - NotRequired | Invalid => None, - Ok(header) => Some(header), - } - } + fn into(self) -> Option { + use self::AllowCors::*; + + match self { + NotRequired | Invalid => None, + Ok(header) => Some(header), + } + } } /// Returns correct CORS header (if any) given list of allowed origins and current origin. pub fn get_cors_allow_origin( - origin: Option<&str>, - host: Option<&str>, - allowed: &Option>, + origin: Option<&str>, + host: Option<&str>, + allowed: &Option>, ) -> AllowCors { - match origin { - None => AllowCors::NotRequired, - Some(ref origin) => { - if let Some(host) = host { - // Request initiated from the same server. - if origin.ends_with(host) { - // Additional check - let origin = Origin::parse(origin); - if &*origin.host == host { - return AllowCors::NotRequired; - } - } - } - - match allowed.as_ref() { - None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null), - None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))), - Some(ref allowed) if *origin == "null" => allowed - .iter() - .find(|cors| **cors == AccessControlAllowOrigin::Null) - .cloned() - .map(AllowCors::Ok) - .unwrap_or(AllowCors::Invalid), - Some(ref allowed) => allowed - .iter() - .find(|cors| match **cors { - AccessControlAllowOrigin::Any => true, - AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true, - _ => false, - }) - .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin))) - .map(AllowCors::Ok) - .unwrap_or(AllowCors::Invalid), - } - } - } + match origin { + None => AllowCors::NotRequired, + Some(ref origin) => { + if let Some(host) = host { + // Request initiated from the same server. + if origin.ends_with(host) { + // Additional check + let origin = Origin::parse(origin); + if &*origin.host == host { + return AllowCors::NotRequired; + } + } + } + + match allowed.as_ref() { + None if *origin == "null" => AllowCors::Ok(AccessControlAllowOrigin::Null), + None => AllowCors::Ok(AccessControlAllowOrigin::Value(Origin::parse(origin))), + Some(ref allowed) if *origin == "null" => allowed + .iter() + .find(|cors| **cors == AccessControlAllowOrigin::Null) + .cloned() + .map(AllowCors::Ok) + .unwrap_or(AllowCors::Invalid), + Some(ref allowed) => allowed + .iter() + .find(|cors| match **cors { + AccessControlAllowOrigin::Any => true, + AccessControlAllowOrigin::Value(ref val) if val.matches(origin) => true, + _ => false, + }) + .map(|_| AccessControlAllowOrigin::Value(Origin::parse(origin))) + .map(AllowCors::Ok) + .unwrap_or(AllowCors::Invalid), + } + } + } } /// Validates if the `AccessControlAllowedHeaders` in the request are allowed. pub fn get_cors_allow_headers, O, F: Fn(T) -> O>( - mut headers: impl Iterator, - requested_headers: impl Iterator, - cors_allow_headers: &AccessControlAllowHeaders, - to_result: F, + mut headers: impl Iterator, + requested_headers: impl Iterator, + cors_allow_headers: &AccessControlAllowHeaders, + to_result: F, ) -> AllowCors> { - // Check if the header fields which were sent in the request are allowed - if let AccessControlAllowHeaders::Only(only) = cors_allow_headers { - let are_all_allowed = headers.all(|header| { - let name = &Ascii::new(header.as_ref()); - only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) - }); - - if !are_all_allowed { - return AllowCors::Invalid; - } - } - - // Check if `AccessControlRequestHeaders` contains fields which were allowed - let (filtered, headers) = match cors_allow_headers { - AccessControlAllowHeaders::Any => { - let headers = requested_headers.map(to_result).collect(); - (false, headers) - } - AccessControlAllowHeaders::Only(only) => { - let mut filtered = false; - let headers: Vec<_> = requested_headers - .filter(|header| { - let name = &Ascii::new(header.as_ref()); - filtered = true; - only.iter().any(|h| Ascii::new(&*h) == name) - || ALWAYS_ALLOWED_HEADERS.contains(name) - }) - .map(to_result) - .collect(); - - (filtered, headers) - } - }; - - if headers.is_empty() { - if filtered { - AllowCors::Invalid - } else { - AllowCors::NotRequired - } - } else { - AllowCors::Ok(headers) - } + // Check if the header fields which were sent in the request are allowed + if let AccessControlAllowHeaders::Only(only) = cors_allow_headers { + let are_all_allowed = headers.all(|header| { + let name = &Ascii::new(header.as_ref()); + only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) + }); + + if !are_all_allowed { + return AllowCors::Invalid; + } + } + + // Check if `AccessControlRequestHeaders` contains fields which were allowed + let (filtered, headers) = match cors_allow_headers { + AccessControlAllowHeaders::Any => { + let headers = requested_headers.map(to_result).collect(); + (false, headers) + } + AccessControlAllowHeaders::Only(only) => { + let mut filtered = false; + let headers: Vec<_> = requested_headers + .filter(|header| { + let name = &Ascii::new(header.as_ref()); + filtered = true; + only.iter().any(|h| Ascii::new(&*h) == name) || ALWAYS_ALLOWED_HEADERS.contains(name) + }) + .map(to_result) + .collect(); + + (filtered, headers) + } + }; + + if headers.is_empty() { + if filtered { + AllowCors::Invalid + } else { + AllowCors::NotRequired + } + } else { + AllowCors::Ok(headers) + } } lazy_static! { - /// Returns headers which are always allowed. - static ref ALWAYS_ALLOWED_HEADERS: HashSet> = { - let mut hs = HashSet::new(); - hs.insert(Ascii::new("Accept")); - hs.insert(Ascii::new("Accept-Language")); - hs.insert(Ascii::new("Access-Control-Allow-Origin")); - hs.insert(Ascii::new("Access-Control-Request-Headers")); - hs.insert(Ascii::new("Content-Language")); - hs.insert(Ascii::new("Content-Type")); - hs.insert(Ascii::new("Host")); - hs.insert(Ascii::new("Origin")); - hs.insert(Ascii::new("Content-Length")); - hs.insert(Ascii::new("Connection")); - hs.insert(Ascii::new("User-Agent")); - hs - }; + /// Returns headers which are always allowed. + static ref ALWAYS_ALLOWED_HEADERS: HashSet> = { + let mut hs = HashSet::new(); + hs.insert(Ascii::new("Accept")); + hs.insert(Ascii::new("Accept-Language")); + hs.insert(Ascii::new("Access-Control-Allow-Origin")); + hs.insert(Ascii::new("Access-Control-Request-Headers")); + hs.insert(Ascii::new("Content-Language")); + hs.insert(Ascii::new("Content-Type")); + hs.insert(Ascii::new("Host")); + hs.insert(Ascii::new("Origin")); + hs.insert(Ascii::new("Content-Length")); + hs.insert(Ascii::new("Connection")); + hs.insert(Ascii::new("User-Agent")); + hs + }; } #[cfg(test)] mod tests { - use std::iter; - - use super::*; - use crate::http::server_utils::hosts::Host; - - #[test] - fn should_parse_origin() { - use self::OriginProtocol::*; - - assert_eq!( - Origin::parse("http://parity.io"), - Origin::new(Http, "parity.io", None) - ); - assert_eq!( - Origin::parse("https://parity.io:8443"), - Origin::new(Https, "parity.io", Some(8443)) - ); - assert_eq!( - Origin::parse("chrome-extension://124.0.0.1"), - Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None) - ); - assert_eq!( - Origin::parse("parity.io/somepath"), - Origin::new(Http, "parity.io", None) - ); - assert_eq!( - Origin::parse("127.0.0.1:8545/somepath"), - Origin::new(Http, "127.0.0.1", Some(8545)) - ); - } - - #[test] - fn should_not_allow_partially_matching_origin() { - // given - let origin1 = Origin::parse("http://subdomain.somedomain.io"); - let origin2 = Origin::parse("http://somedomain.io:8080"); - let host = Host::parse("http://somedomain.io"); - - let origin1 = Some(&*origin1); - let origin2 = Some(&*origin2); - let host = Some(&*host); - - // when - let res1 = get_cors_allow_origin(origin1, host, &Some(vec![])); - let res2 = get_cors_allow_origin(origin2, host, &Some(vec![])); - - // then - assert_eq!(res1, AllowCors::Invalid); - assert_eq!(res2, AllowCors::Invalid); - } - - #[test] - fn should_allow_origins_that_matches_hosts() { - // given - let origin = Origin::parse("http://127.0.0.1:8080"); - let host = Host::parse("http://127.0.0.1:8080"); - - let origin = Some(&*origin); - let host = Some(&*host); - - // when - let res = get_cors_allow_origin(origin, host, &None); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_none_when_there_are_no_cors_domains_and_no_origin() { - // given - let origin = None; - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &None); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_domain_when_all_are_allowed() { - // given - let origin = Some("parity.io"); - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &None); - - // then - assert_eq!(res, AllowCors::Ok("parity.io".into())); - } - - #[test] - fn should_return_none_for_empty_origin() { - // given - let origin = None; - let host = None; - - // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Value( - "http://ethereum.org".into(), - )]), - ); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_none_for_empty_list() { - // given - let origin = None; - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &Some(Vec::new())); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_none_for_not_matching_origin() { - // given - let origin = Some("http://parity.io".into()); - let host = None; - - // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![AccessControlAllowOrigin::Value( - "http://ethereum.org".into(), - )]), - ); - - // then - assert_eq!(res, AllowCors::Invalid); - } - - #[test] - fn should_return_specific_origin_if_we_allow_any() { - // given - let origin = Some("http://parity.io".into()); - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any])); - - // then - assert_eq!( - res, - AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) - ); - } - - #[test] - fn should_return_none_if_origin_is_not_defined() { - // given - let origin = None; - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_null_if_origin_is_null() { - // given - let origin = Some("null".into()); - let host = None; - - // when - let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); - - // then - assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null)); - } - - #[test] - fn should_return_specific_origin_if_there_is_a_match() { - // given - let origin = Some("http://parity.io".into()); - let host = None; - - // when - let res = get_cors_allow_origin( - origin, - host, - &Some(vec![ - AccessControlAllowOrigin::Value("http://ethereum.org".into()), - AccessControlAllowOrigin::Value("http://parity.io".into()), - ]), - ); - - // then - assert_eq!( - res, - AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) - ); - } - - #[test] - fn should_support_wildcards() { - // given - let origin1 = Some("http://parity.io".into()); - let origin2 = Some("http://parity.iot".into()); - let origin3 = Some("chrome-extension://test".into()); - let host = None; - let allowed = Some(vec![ - AccessControlAllowOrigin::Value("http://*.io".into()), - AccessControlAllowOrigin::Value("chrome-extension://*".into()), - ]); - - // when - let res1 = get_cors_allow_origin(origin1, host, &allowed); - let res2 = get_cors_allow_origin(origin2, host, &allowed); - let res3 = get_cors_allow_origin(origin3, host, &allowed); - - // then - assert_eq!( - res1, - AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into())) - ); - assert_eq!(res2, AllowCors::Invalid); - assert_eq!( - res3, - AllowCors::Ok(AccessControlAllowOrigin::Value( - "chrome-extension://test".into() - )) - ); - } - - #[test] - fn should_return_invalid_if_header_not_allowed() { - // given - let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]); - let headers = vec!["Access-Control-Request-Headers"]; - let requested = vec!["x-not-allowed"]; - - // when - let res = get_cors_allow_headers( - headers.iter(), - requested.iter(), - &cors_allow_headers.into(), - |x| x, - ); - - // then - assert_eq!(res, AllowCors::Invalid); - } - - #[test] - fn should_return_valid_if_header_allowed() { - // given - let allowed = vec!["x-allowed".to_owned()]; - let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); - let headers = vec!["Access-Control-Request-Headers"]; - let requested = vec!["x-allowed"]; - - // when - let res = get_cors_allow_headers( - headers.iter(), - requested.iter(), - &cors_allow_headers.into(), - |x| (*x).to_owned(), - ); - - // then - let allowed = vec!["x-allowed".to_owned()]; - assert_eq!(res, AllowCors::Ok(allowed)); - } - - #[test] - fn should_return_no_allowed_headers_if_none_in_request() { - // given - let allowed = vec!["x-allowed".to_owned()]; - let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); - let headers: Vec = vec![]; - - // when - let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x); - - // then - assert_eq!(res, AllowCors::NotRequired); - } - - #[test] - fn should_return_not_required_if_any_header_allowed() { - // given - let cors_allow_headers = AccessControlAllowHeaders::Any; - let headers: Vec = vec![]; - - // when - let res = get_cors_allow_headers( - headers.iter(), - iter::empty(), - &cors_allow_headers.into(), - |x| x, - ); - - // then - assert_eq!(res, AllowCors::NotRequired); - } + use std::iter; + + use super::*; + use crate::http::server_utils::hosts::Host; + + #[test] + fn should_parse_origin() { + use self::OriginProtocol::*; + + assert_eq!(Origin::parse("http://parity.io"), Origin::new(Http, "parity.io", None)); + assert_eq!(Origin::parse("https://parity.io:8443"), Origin::new(Https, "parity.io", Some(8443))); + assert_eq!( + Origin::parse("chrome-extension://124.0.0.1"), + Origin::new(Custom("chrome-extension".into()), "124.0.0.1", None) + ); + assert_eq!(Origin::parse("parity.io/somepath"), Origin::new(Http, "parity.io", None)); + assert_eq!(Origin::parse("127.0.0.1:8545/somepath"), Origin::new(Http, "127.0.0.1", Some(8545))); + } + + #[test] + fn should_not_allow_partially_matching_origin() { + // given + let origin1 = Origin::parse("http://subdomain.somedomain.io"); + let origin2 = Origin::parse("http://somedomain.io:8080"); + let host = Host::parse("http://somedomain.io"); + + let origin1 = Some(&*origin1); + let origin2 = Some(&*origin2); + let host = Some(&*host); + + // when + let res1 = get_cors_allow_origin(origin1, host, &Some(vec![])); + let res2 = get_cors_allow_origin(origin2, host, &Some(vec![])); + + // then + assert_eq!(res1, AllowCors::Invalid); + assert_eq!(res2, AllowCors::Invalid); + } + + #[test] + fn should_allow_origins_that_matches_hosts() { + // given + let origin = Origin::parse("http://127.0.0.1:8080"); + let host = Host::parse("http://127.0.0.1:8080"); + + let origin = Some(&*origin); + let host = Some(&*host); + + // when + let res = get_cors_allow_origin(origin, host, &None); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_none_when_there_are_no_cors_domains_and_no_origin() { + // given + let origin = None; + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &None); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_domain_when_all_are_allowed() { + // given + let origin = Some("parity.io"); + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &None); + + // then + assert_eq!(res, AllowCors::Ok("parity.io".into())); + } + + #[test] + fn should_return_none_for_empty_origin() { + // given + let origin = None; + let host = None; + + // when + let res = get_cors_allow_origin( + origin, + host, + &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]), + ); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_none_for_empty_list() { + // given + let origin = None; + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &Some(Vec::new())); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_none_for_not_matching_origin() { + // given + let origin = Some("http://parity.io".into()); + let host = None; + + // when + let res = get_cors_allow_origin( + origin, + host, + &Some(vec![AccessControlAllowOrigin::Value("http://ethereum.org".into())]), + ); + + // then + assert_eq!(res, AllowCors::Invalid); + } + + #[test] + fn should_return_specific_origin_if_we_allow_any() { + // given + let origin = Some("http://parity.io".into()); + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Any])); + + // then + assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + } + + #[test] + fn should_return_none_if_origin_is_not_defined() { + // given + let origin = None; + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_null_if_origin_is_null() { + // given + let origin = Some("null".into()); + let host = None; + + // when + let res = get_cors_allow_origin(origin, host, &Some(vec![AccessControlAllowOrigin::Null])); + + // then + assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Null)); + } + + #[test] + fn should_return_specific_origin_if_there_is_a_match() { + // given + let origin = Some("http://parity.io".into()); + let host = None; + + // when + let res = get_cors_allow_origin( + origin, + host, + &Some(vec![ + AccessControlAllowOrigin::Value("http://ethereum.org".into()), + AccessControlAllowOrigin::Value("http://parity.io".into()), + ]), + ); + + // then + assert_eq!(res, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + } + + #[test] + fn should_support_wildcards() { + // given + let origin1 = Some("http://parity.io".into()); + let origin2 = Some("http://parity.iot".into()); + let origin3 = Some("chrome-extension://test".into()); + let host = None; + let allowed = Some(vec![ + AccessControlAllowOrigin::Value("http://*.io".into()), + AccessControlAllowOrigin::Value("chrome-extension://*".into()), + ]); + + // when + let res1 = get_cors_allow_origin(origin1, host, &allowed); + let res2 = get_cors_allow_origin(origin2, host, &allowed); + let res3 = get_cors_allow_origin(origin3, host, &allowed); + + // then + assert_eq!(res1, AllowCors::Ok(AccessControlAllowOrigin::Value("http://parity.io".into()))); + assert_eq!(res2, AllowCors::Invalid); + assert_eq!(res3, AllowCors::Ok(AccessControlAllowOrigin::Value("chrome-extension://test".into()))); + } + + #[test] + fn should_return_invalid_if_header_not_allowed() { + // given + let cors_allow_headers = AccessControlAllowHeaders::Only(vec!["x-allowed".to_owned()]); + let headers = vec!["Access-Control-Request-Headers"]; + let requested = vec!["x-not-allowed"]; + + // when + let res = get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| x); + + // then + assert_eq!(res, AllowCors::Invalid); + } + + #[test] + fn should_return_valid_if_header_allowed() { + // given + let allowed = vec!["x-allowed".to_owned()]; + let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); + let headers = vec!["Access-Control-Request-Headers"]; + let requested = vec!["x-allowed"]; + + // when + let res = + get_cors_allow_headers(headers.iter(), requested.iter(), &cors_allow_headers.into(), |x| (*x).to_owned()); + + // then + let allowed = vec!["x-allowed".to_owned()]; + assert_eq!(res, AllowCors::Ok(allowed)); + } + + #[test] + fn should_return_no_allowed_headers_if_none_in_request() { + // given + let allowed = vec!["x-allowed".to_owned()]; + let cors_allow_headers = AccessControlAllowHeaders::Only(allowed.clone()); + let headers: Vec = vec![]; + + // when + let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers, |x| x); + + // then + assert_eq!(res, AllowCors::NotRequired); + } + + #[test] + fn should_return_not_required_if_any_header_allowed() { + // given + let cors_allow_headers = AccessControlAllowHeaders::Any; + let headers: Vec = vec![]; + + // when + let res = get_cors_allow_headers(headers.iter(), iter::empty(), &cors_allow_headers.into(), |x| x); + + // then + assert_eq!(res, AllowCors::NotRequired); + } } diff --git a/src/http/server_utils/hosts.rs b/src/http/server_utils/hosts.rs index 9f34bb1be5..c91a9fd881 100644 --- a/src/http/server_utils/hosts.rs +++ b/src/http/server_utils/hosts.rs @@ -35,247 +35,218 @@ const SPLIT_PROOF: &str = "split always returns non-empty iterator."; /// Port pattern #[derive(Clone, Hash, PartialEq, Eq, Debug)] pub enum Port { - /// No port specified (default port) - None, - /// Port specified as a wildcard pattern - Pattern(String), - /// Fixed numeric port - Fixed(u16), + /// No port specified (default port) + None, + /// Port specified as a wildcard pattern + Pattern(String), + /// Fixed numeric port + Fixed(u16), } impl From> for Port { - fn from(opt: Option) -> Self { - match opt { - Some(port) => Port::Fixed(port), - None => Port::None, - } - } + fn from(opt: Option) -> Self { + match opt { + Some(port) => Port::Fixed(port), + None => Port::None, + } + } } impl From for Port { - fn from(port: u16) -> Port { - Port::Fixed(port) - } + fn from(port: u16) -> Port { + Port::Fixed(port) + } } /// Host type #[derive(Clone, Hash, PartialEq, Eq, Debug)] pub struct Host { - hostname: String, - port: Port, - as_string: String, - matcher: Matcher, + hostname: String, + port: Port, + as_string: String, + matcher: Matcher, } impl> From for Host { - fn from(string: T) -> Self { - Host::parse(string.as_ref()) - } + fn from(string: T) -> Self { + Host::parse(string.as_ref()) + } } impl Host { - /// Creates a new `Host` given hostname and port number. - pub fn new>(hostname: &str, port: T) -> Self { - let port = port.into(); - let hostname = Self::pre_process(hostname); - let string = Self::to_string(&hostname, &port); - let matcher = Matcher::new(&string); - - Host { - hostname, - port, - as_string: string, - matcher, - } - } - - /// Attempts to parse given string as a `Host`. - /// NOTE: This method always succeeds and falls back to sensible defaults. - pub fn parse(hostname: &str) -> Self { - let hostname = Self::pre_process(hostname); - let mut hostname = hostname.split(':'); - let host = hostname.next().expect(SPLIT_PROOF); - let port = match hostname.next() { - None => Port::None, - Some(port) => match port.parse::().ok() { - Some(num) => Port::Fixed(num), - None => Port::Pattern(port.into()), - }, - }; - - Host::new(host, port) - } - - fn pre_process(host: &str) -> String { - // Remove possible protocol definition - let mut it = host.split("://"); - let protocol = it.next().expect(SPLIT_PROOF); - let host = match it.next() { - Some(data) => data, - None => protocol, - }; - - let mut it = host.split('/'); - it.next().expect(SPLIT_PROOF).to_lowercase() - } - - fn to_string(hostname: &str, port: &Port) -> String { - format!( - "{}{}", - hostname, - match *port { - Port::Fixed(port) => format!(":{}", port), - Port::Pattern(ref port) => format!(":{}", port), - Port::None => "".into(), - }, - ) - } + /// Creates a new `Host` given hostname and port number. + pub fn new>(hostname: &str, port: T) -> Self { + let port = port.into(); + let hostname = Self::pre_process(hostname); + let string = Self::to_string(&hostname, &port); + let matcher = Matcher::new(&string); + + Host { hostname, port, as_string: string, matcher } + } + + /// Attempts to parse given string as a `Host`. + /// NOTE: This method always succeeds and falls back to sensible defaults. + pub fn parse(hostname: &str) -> Self { + let hostname = Self::pre_process(hostname); + let mut hostname = hostname.split(':'); + let host = hostname.next().expect(SPLIT_PROOF); + let port = match hostname.next() { + None => Port::None, + Some(port) => match port.parse::().ok() { + Some(num) => Port::Fixed(num), + None => Port::Pattern(port.into()), + }, + }; + + Host::new(host, port) + } + + fn pre_process(host: &str) -> String { + // Remove possible protocol definition + let mut it = host.split("://"); + let protocol = it.next().expect(SPLIT_PROOF); + let host = match it.next() { + Some(data) => data, + None => protocol, + }; + + let mut it = host.split('/'); + it.next().expect(SPLIT_PROOF).to_lowercase() + } + + fn to_string(hostname: &str, port: &Port) -> String { + format!( + "{}{}", + hostname, + match *port { + Port::Fixed(port) => format!(":{}", port), + Port::Pattern(ref port) => format!(":{}", port), + Port::None => "".into(), + }, + ) + } } impl Pattern for Host { - fn matches>(&self, other: T) -> bool { - self.matcher.matches(other) - } + fn matches>(&self, other: T) -> bool { + self.matcher.matches(other) + } } impl ::std::ops::Deref for Host { - type Target = str; - fn deref(&self) -> &Self::Target { - &self.as_string - } + type Target = str; + fn deref(&self) -> &Self::Target { + &self.as_string + } } /// Specifies if domains should be validated. #[derive(Clone, Debug, PartialEq, Eq)] pub enum DomainsValidation { - /// Allow only domains on the list. - AllowOnly(Vec), - /// Disable domains validation completely. - Disabled, + /// Allow only domains on the list. + AllowOnly(Vec), + /// Disable domains validation completely. + Disabled, } impl Into>> for DomainsValidation { - fn into(self) -> Option> { - use self::DomainsValidation::*; - match self { - AllowOnly(list) => Some(list), - Disabled => None, - } - } + fn into(self) -> Option> { + use self::DomainsValidation::*; + match self { + AllowOnly(list) => Some(list), + Disabled => None, + } + } } impl From>> for DomainsValidation { - fn from(other: Option>) -> Self { - match other { - Some(list) => DomainsValidation::AllowOnly(list), - None => DomainsValidation::Disabled, - } - } + fn from(other: Option>) -> Self { + match other { + Some(list) => DomainsValidation::AllowOnly(list), + None => DomainsValidation::Disabled, + } + } } /// Returns `true` when `Host` header is whitelisted in `allow_hosts`. pub fn is_host_valid(host: Option<&str>, allow_hosts: &AllowHosts) -> bool { - match host { - None => false, - Some(ref host) => match allow_hosts { - AllowHosts::Any => true, - AllowHosts::Only(allow_hosts) => allow_hosts.iter().any(|h| h.matches(host)), - }, - } + match host { + None => false, + Some(ref host) => match allow_hosts { + AllowHosts::Any => true, + AllowHosts::Only(allow_hosts) => allow_hosts.iter().any(|h| h.matches(host)), + }, + } } /// Updates given list of hosts with the address. pub fn update(hosts: Option>, address: &SocketAddr) -> Option> { - hosts.map(|current_hosts| { - let mut new_hosts = current_hosts.into_iter().collect::>(); - let address = address.to_string(); - new_hosts.insert(address.clone().into()); - new_hosts.insert(address.replace("127.0.0.1", "localhost").into()); - new_hosts.into_iter().collect() - }) + hosts.map(|current_hosts| { + let mut new_hosts = current_hosts.into_iter().collect::>(); + let address = address.to_string(); + new_hosts.insert(address.clone().into()); + new_hosts.insert(address.replace("127.0.0.1", "localhost").into()); + new_hosts.into_iter().collect() + }) } /// Allowed hosts for http header 'host' #[derive(Clone)] pub enum AllowHosts { - /// Allow requests from any host - Any, - /// Allow only a selection of specific hosts - Only(Vec), + /// Allow requests from any host + Any, + /// Allow only a selection of specific hosts + Only(Vec), } #[cfg(test)] mod tests { - use super::{is_host_valid, AllowHosts, Host}; - - #[test] - fn should_parse_host() { - assert_eq!( - Host::parse("http://parity.io"), - Host::new("parity.io", None) - ); - assert_eq!( - Host::parse("https://parity.io:8443"), - Host::new("parity.io", Some(8443)) - ); - assert_eq!( - Host::parse("chrome-extension://124.0.0.1"), - Host::new("124.0.0.1", None) - ); - assert_eq!( - Host::parse("parity.io/somepath"), - Host::new("parity.io", None) - ); - assert_eq!( - Host::parse("127.0.0.1:8545/somepath"), - Host::new("127.0.0.1", Some(8545)) - ); - } - - #[test] - fn should_reject_when_there_is_no_header() { - let valid = is_host_valid(None, &AllowHosts::Any); - assert_eq!(valid, false); - let valid = is_host_valid(None, &AllowHosts::Only(vec![])); - assert_eq!(valid, false); - } - - #[test] - fn should_reject_when_validation_is_disabled() { - let valid = is_host_valid(Some("any"), &AllowHosts::Any); - assert_eq!(valid, true); - } - - #[test] - fn should_reject_if_header_not_on_the_list() { - let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec![])); - assert_eq!(valid, false); - } - - #[test] - fn should_accept_if_on_the_list() { - let valid = is_host_valid( - Some("parity.io"), - &AllowHosts::Only(vec!["parity.io".into()]), - ); - assert_eq!(valid, true); - } - - #[test] - fn should_accept_if_on_the_list_with_port() { - let valid = is_host_valid( - Some("parity.io:443"), - &AllowHosts::Only(vec!["parity.io:443".into()]), - ); - assert_eq!(valid, true); - } - - #[test] - fn should_support_wildcards() { - let valid = is_host_valid( - Some("parity.web3.site:8180"), - &AllowHosts::Only(vec!["*.web3.site:*".into()]), - ); - assert_eq!(valid, true); - } + use super::{is_host_valid, AllowHosts, Host}; + + #[test] + fn should_parse_host() { + assert_eq!(Host::parse("http://parity.io"), Host::new("parity.io", None)); + assert_eq!(Host::parse("https://parity.io:8443"), Host::new("parity.io", Some(8443))); + assert_eq!(Host::parse("chrome-extension://124.0.0.1"), Host::new("124.0.0.1", None)); + assert_eq!(Host::parse("parity.io/somepath"), Host::new("parity.io", None)); + assert_eq!(Host::parse("127.0.0.1:8545/somepath"), Host::new("127.0.0.1", Some(8545))); + } + + #[test] + fn should_reject_when_there_is_no_header() { + let valid = is_host_valid(None, &AllowHosts::Any); + assert_eq!(valid, false); + let valid = is_host_valid(None, &AllowHosts::Only(vec![])); + assert_eq!(valid, false); + } + + #[test] + fn should_reject_when_validation_is_disabled() { + let valid = is_host_valid(Some("any"), &AllowHosts::Any); + assert_eq!(valid, true); + } + + #[test] + fn should_reject_if_header_not_on_the_list() { + let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec![])); + assert_eq!(valid, false); + } + + #[test] + fn should_accept_if_on_the_list() { + let valid = is_host_valid(Some("parity.io"), &AllowHosts::Only(vec!["parity.io".into()])); + assert_eq!(valid, true); + } + + #[test] + fn should_accept_if_on_the_list_with_port() { + let valid = is_host_valid(Some("parity.io:443"), &AllowHosts::Only(vec!["parity.io:443".into()])); + assert_eq!(valid, true); + } + + #[test] + fn should_support_wildcards() { + let valid = is_host_valid(Some("parity.web3.site:8180"), &AllowHosts::Only(vec!["*.web3.site:*".into()])); + assert_eq!(valid, true); + } } diff --git a/src/http/server_utils/matcher.rs b/src/http/server_utils/matcher.rs index 5bb594bd3d..85c30f6941 100644 --- a/src/http/server_utils/matcher.rs +++ b/src/http/server_utils/matcher.rs @@ -30,54 +30,54 @@ use std::{fmt, hash}; /// Pattern that can be matched to string. pub trait Pattern { - /// Returns true if given string matches the pattern. - fn matches>(&self, other: T) -> bool; + /// Returns true if given string matches the pattern. + fn matches>(&self, other: T) -> bool; } #[derive(Clone)] pub struct Matcher(Option, String); impl Matcher { - pub fn new(string: &str) -> Matcher { - Matcher( - GlobBuilder::new(string) - .case_insensitive(true) - .build() - .map(|g| g.compile_matcher()) - .map_err(|e| warn!("Invalid glob pattern for {}: {:?}", string, e)) - .ok(), - string.into(), - ) - } + pub fn new(string: &str) -> Matcher { + Matcher( + GlobBuilder::new(string) + .case_insensitive(true) + .build() + .map(|g| g.compile_matcher()) + .map_err(|e| warn!("Invalid glob pattern for {}: {:?}", string, e)) + .ok(), + string.into(), + ) + } } impl Pattern for Matcher { - fn matches>(&self, other: T) -> bool { - let s = other.as_ref(); - match self.0 { - Some(ref matcher) => matcher.is_match(s), - None => self.1.eq_ignore_ascii_case(s), - } - } + fn matches>(&self, other: T) -> bool { + let s = other.as_ref(); + match self.0 { + Some(ref matcher) => matcher.is_match(s), + None => self.1.eq_ignore_ascii_case(s), + } + } } impl fmt::Debug for Matcher { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - write!(fmt, "{:?} ({})", self.1, self.0.is_some()) - } + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{:?} ({})", self.1, self.0.is_some()) + } } impl hash::Hash for Matcher { - fn hash(&self, state: &mut H) - where - H: hash::Hasher, - { - self.1.hash(state) - } + fn hash(&self, state: &mut H) + where + H: hash::Hasher, + { + self.1.hash(state) + } } impl Eq for Matcher {} impl PartialEq for Matcher { - fn eq(&self, other: &Matcher) -> bool { - self.1.eq(&other.1) - } + fn eq(&self, other: &Matcher) -> bool { + self.1.eq(&other.1) + } } diff --git a/src/http/server_utils/utils.rs b/src/http/server_utils/utils.rs index 822b00b718..ee5dea513e 100644 --- a/src/http/server_utils/utils.rs +++ b/src/http/server_utils/utils.rs @@ -28,5 +28,5 @@ /// Extracts string value of a single header in request. pub fn read_header<'a>(req: &'a hyper::Request, header_name: &str) -> Option<&'a str> { - req.headers().get(header_name).and_then(|v| v.to_str().ok()) + req.headers().get(header_name).and_then(|v| v.to_str().ok()) } diff --git a/src/http/tests.rs b/src/http/tests.rs index 40acd6d218..80e536acf3 100644 --- a/src/http/tests.rs +++ b/src/http/tests.rs @@ -10,121 +10,103 @@ use jsonrpsee_test_utils::types::{Id, StatusCode}; use std::net::SocketAddr; async fn server(server_started_tx: Sender) { - let server = HttpServer::new("127.0.0.1:0").await.unwrap(); - let mut hello = server.register_method("say_hello".to_owned()).unwrap(); - let mut add = server.register_method("add".to_owned()).unwrap(); - let mut notif = server - .register_notification("notif".to_owned(), false) - .unwrap(); - server_started_tx.send(*server.local_addr()).unwrap(); - - loop { - let hello_fut = async { - let handle = hello.next().await; - log::debug!("server respond to hello"); - handle - .respond(Ok(JsonValue::String("hello".to_owned()))) - .await; - } - .fuse(); - - let add_fut = async { - let handle = add.next().await; - let params: Vec = handle.params().clone().parse().unwrap(); - let sum: u64 = params.iter().sum(); - handle.respond(Ok(JsonValue::Number(sum.into()))).await; - } - .fuse(); - - let notif_fut = async { - let params = notif.next().await; - println!("received notification: say_hello params[{:?}]", params); - } - .fuse(); - - pin_mut!(hello_fut, add_fut, notif_fut); - select! { - say_hello = hello_fut => (), - add = add_fut => (), - notif = notif_fut => (), - complete => (), - }; - } + let server = HttpServer::new("127.0.0.1:0").await.unwrap(); + let mut hello = server.register_method("say_hello".to_owned()).unwrap(); + let mut add = server.register_method("add".to_owned()).unwrap(); + let mut notif = server.register_notification("notif".to_owned(), false).unwrap(); + server_started_tx.send(*server.local_addr()).unwrap(); + + loop { + let hello_fut = async { + let handle = hello.next().await; + log::debug!("server respond to hello"); + handle.respond(Ok(JsonValue::String("hello".to_owned()))).await; + } + .fuse(); + + let add_fut = async { + let handle = add.next().await; + let params: Vec = handle.params().clone().parse().unwrap(); + let sum: u64 = params.iter().sum(); + handle.respond(Ok(JsonValue::Number(sum.into()))).await; + } + .fuse(); + + let notif_fut = async { + let params = notif.next().await; + println!("received notification: say_hello params[{:?}]", params); + } + .fuse(); + + pin_mut!(hello_fut, add_fut, notif_fut); + select! { + say_hello = hello_fut => (), + add = add_fut => (), + notif = notif_fut => (), + complete => (), + }; + } } #[tokio::test] async fn single_method_call_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let uri = to_http_uri(server_addr); - - for i in 0..10 { - let req = format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i); - let response = http_request(req.into(), uri.clone()).await.unwrap(); - assert_eq!(response.status, StatusCode::OK); - assert_eq!( - response.body, - ok_response(JsonValue::String("hello".to_owned()), Id::Num(i)) - ); - } + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + let uri = to_http_uri(server_addr); + + for i in 0..10 { + let req = format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i); + let response = http_request(req.into(), uri.clone()).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, ok_response(JsonValue::String("hello".to_owned()), Id::Num(i))); + } } #[tokio::test] async fn single_method_call_with_params() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; - let response = http_request(req.into(), to_http_uri(server_addr)) - .await - .unwrap(); - assert_eq!(response.status, StatusCode::OK); - assert_eq!( - response.body, - ok_response(JsonValue::Number(3.into()), Id::Num(1)) - ); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; + let response = http_request(req.into(), to_http_uri(server_addr)).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } #[tokio::test] async fn should_return_method_not_found() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"bar","id":"foo"}"#; - let response = http_request(req.into(), to_http_uri(server_addr)) - .await - .unwrap(); - assert_eq!(response.status, StatusCode::OK); - assert_eq!(response.body, method_not_found(Id::Str("foo".into()))); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"bar","id":"foo"}"#; + let response = http_request(req.into(), to_http_uri(server_addr)).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, method_not_found(Id::Str("foo".into()))); } #[tokio::test] async fn invalid_json_id_missing_value() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"say_hello","id"}"#; - let response = http_request(req.into(), to_http_uri(server_addr)) - .await - .unwrap(); - // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), it MUST be Null. - assert_eq!(response.body, parse_error(Id::Null)); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"say_hello","id"}"#; + let response = http_request(req.into(), to_http_uri(server_addr)).await.unwrap(); + // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), it MUST be Null. + assert_eq!(response.body, parse_error(Id::Null)); } #[tokio::test] async fn invalid_request_object() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; - let response = http_request(req.into(), to_http_uri(server_addr)) - .await - .unwrap(); - assert_eq!(response.status, StatusCode::OK); - assert_eq!(response.body, invalid_request(Id::Num(1))); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; + let response = http_request(req.into(), to_http_uri(server_addr)).await.unwrap(); + assert_eq!(response.status, StatusCode::OK); + assert_eq!(response.body, invalid_request(Id::Num(1))); } diff --git a/src/http/transport/background.rs b/src/http/transport/background.rs index b8c0a14927..d26bc0be3d 100644 --- a/src/http/transport/background.rs +++ b/src/http/transport/background.rs @@ -34,95 +34,84 @@ use std::{error, io, net::SocketAddr, thread}; /// Background thread that serves HTTP requests. pub(super) struct BackgroundHttp { - /// Receiver for requests coming from the background thread. - rx: stream::Fuse>, + /// Receiver for requests coming from the background thread. + rx: stream::Fuse>, } /// Request generated from the background thread. pub(super) struct Request { - /// Sender for the body of the response to send on the network. - pub send_back: oneshot::Sender>, - /// The JSON body that was sent by the client. - pub request: common::Request, + /// Sender for the body of the response to send on the network. + pub send_back: oneshot::Sender>, + /// The JSON body that was sent by the client. + pub request: common::Request, } impl BackgroundHttp { - /// Tries to create an HTTP server listening on the given address and start a background - /// thread. - /// - /// In addition to `Self`, also returns the local address the server ends up listening on, - /// which might be different than the one passed as parameter. - pub async fn bind( - addr: &SocketAddr, - ) -> Result<(BackgroundHttp, SocketAddr), Box> { - Self::bind_with_acl(addr, AccessControl::default()).await - } + /// Tries to create an HTTP server listening on the given address and start a background + /// thread. + /// + /// In addition to `Self`, also returns the local address the server ends up listening on, + /// which might be different than the one passed as parameter. + pub async fn bind(addr: &SocketAddr) -> Result<(BackgroundHttp, SocketAddr), Box> { + Self::bind_with_acl(addr, AccessControl::default()).await + } - pub async fn bind_with_acl( - addr: &SocketAddr, - access_control: AccessControl, - ) -> Result<(BackgroundHttp, SocketAddr), Box> { - let (tx, rx) = mpsc::channel(4); + pub async fn bind_with_acl( + addr: &SocketAddr, + access_control: AccessControl, + ) -> Result<(BackgroundHttp, SocketAddr), Box> { + let (tx, rx) = mpsc::channel(4); - let make_service = make_service_fn(move |_| { - let tx = tx.clone(); - let access_control = access_control.clone(); - async move { - Ok::<_, Error>(service_fn(move |req| { - let mut tx = tx.clone(); - let access_control = access_control.clone(); - async move { Ok::<_, Error>(process_request(req, &mut tx, &access_control).await) } - })) - } - }); + let make_service = make_service_fn(move |_| { + let tx = tx.clone(); + let access_control = access_control.clone(); + async move { + Ok::<_, Error>(service_fn(move |req| { + let mut tx = tx.clone(); + let access_control = access_control.clone(); + async move { Ok::<_, Error>(process_request(req, &mut tx, &access_control).await) } + })) + } + }); - let (addr_tx, addr_rx) = oneshot::channel(); - let addr = *addr; + let (addr_tx, addr_rx) = oneshot::channel(); + let addr = *addr; - // Because hyper can only be polled through tokio, we spawn it in a background thread. - thread::Builder::new() - .name("jsonrpsee-hyper-server".to_string()) - .spawn(move || { - let mut runtime = match tokio::runtime::Builder::new() - .basic_scheduler() - .enable_all() - .build() - { - Ok(r) => r, - Err(err) => { - log::error!( - "Failed to initialize tokio runtime in HTTP JSON-RPC server: {}", - err - ); - return; - } - }; + // Because hyper can only be polled through tokio, we spawn it in a background thread. + thread::Builder::new().name("jsonrpsee-hyper-server".to_string()).spawn(move || { + let mut runtime = match tokio::runtime::Builder::new().basic_scheduler().enable_all().build() { + Ok(r) => r, + Err(err) => { + log::error!("Failed to initialize tokio runtime in HTTP JSON-RPC server: {}", err); + return; + } + }; - runtime.block_on(async move { - match hyper::Server::try_bind(&addr) { - Ok(builder) => { - let server = builder.serve(make_service); - let _ = addr_tx.send(Ok(server.local_addr())); - if let Err(err) = server.await { - log::error!("HTTP JSON-RPC server closed with an error: {}", err); - } - } - Err(err) => { - log::error!("Failed to bind to address {}: {}", addr, err); - let _ = addr_tx.send(Err(err)); - } - }; - }); - })?; + runtime.block_on(async move { + match hyper::Server::try_bind(&addr) { + Ok(builder) => { + let server = builder.serve(make_service); + let _ = addr_tx.send(Ok(server.local_addr())); + if let Err(err) = server.await { + log::error!("HTTP JSON-RPC server closed with an error: {}", err); + } + } + Err(err) => { + log::error!("Failed to bind to address {}: {}", addr, err); + let _ = addr_tx.send(Err(err)); + } + }; + }); + })?; - let local_addr = addr_rx.await??; - Ok((BackgroundHttp { rx: rx.fuse() }, local_addr)) - } + let local_addr = addr_rx.await??; + Ok((BackgroundHttp { rx: rx.fuse() }, local_addr)) + } - /// Returns the next request, or an error if the background thread has unexpectedly closed. - pub async fn next(&mut self) -> Result { - self.rx.next().await.ok_or(()) - } + /// Returns the next request, or an error if the background thread has unexpectedly closed. + pub async fn next(&mut self) -> Result { + self.rx.next().await.ok_or(()) + } } /// Process an HTTP request and sends back a response. @@ -132,164 +121,153 @@ impl BackgroundHttp { /// In order to process JSON-RPC requests, it has access to `fg_process_tx`. Objects sent on this /// channel will be dispatched to the user. async fn process_request( - request: hyper::Request, - fg_process_tx: &mut mpsc::Sender, - access_control: &AccessControl, + request: hyper::Request, + fg_process_tx: &mut mpsc::Sender, + access_control: &AccessControl, ) -> hyper::Response { - log::debug!(target: "jsonrpc-http-transport-server", "Recevied request={:?}", request); + log::debug!(target: "jsonrpc-http-transport-server", "Recevied request={:?}", request); - // Process access control - if access_control.deny_host(&request) { - return response::host_not_allowed(); - } - if access_control.deny_cors_origin(&request) { - return response::invalid_allow_origin(); - } - if access_control.deny_cors_header(&request) { - return response::invalid_allow_headers(); - } + // Process access control + if access_control.deny_host(&request) { + return response::host_not_allowed(); + } + if access_control.deny_cors_origin(&request) { + return response::invalid_allow_origin(); + } + if access_control.deny_cors_header(&request) { + return response::invalid_allow_headers(); + } - /* - // Read metadata - let metadata = self.jsonrpc_handler.extractor.read_metadata(&request); - */ + /* + // Read metadata + let metadata = self.jsonrpc_handler.extractor.read_metadata(&request); + */ - // Proceed - match *request.method() { - // Validate the ContentType header - // to prevent Cross-Origin XHRs with text/plain - hyper::Method::POST if is_json(request.headers().get("content-type")) => { - let uri = //if self.rest_api != RestApi::Disabled { + // Proceed + match *request.method() { + // Validate the ContentType header + // to prevent Cross-Origin XHRs with text/plain + hyper::Method::POST if is_json(request.headers().get("content-type")) => { + let uri = //if self.rest_api != RestApi::Disabled { Some(request.uri().clone()) /*} else { None }*/; - let json_body = match body_to_request(request.into_body()).await { - Ok(b) => b, - Err(e) => match (e.kind(), e.into_inner()) { - (io::ErrorKind::InvalidData, _) => return response::parse_error(), - (io::ErrorKind::UnexpectedEof, _) => return response::parse_error(), - (_, Some(inner)) => return response::internal_error(inner.to_string()), - (kind, None) => return response::internal_error(format!("{:?}", kind)), - }, - }; + let json_body = match body_to_request(request.into_body()).await { + Ok(b) => b, + Err(e) => match (e.kind(), e.into_inner()) { + (io::ErrorKind::InvalidData, _) => return response::parse_error(), + (io::ErrorKind::UnexpectedEof, _) => return response::parse_error(), + (_, Some(inner)) => return response::internal_error(inner.to_string()), + (kind, None) => return response::internal_error(format!("{:?}", kind)), + }, + }; - log::debug!(target: "http-server transport", "received request={:?}", json_body); + log::debug!(target: "http-server transport", "received request={:?}", json_body); - let (tx, rx) = oneshot::channel(); - let user_facing_rq = Request { - send_back: tx, - request: json_body, - }; - if fg_process_tx.send(user_facing_rq).await.is_err() { - return response::internal_error("JSON requests processing channel has shut down"); - } - match rx.await { - Ok(response) => response, - Err(_) => { - return response::internal_error("JSON request send back channel has shut down") - } - } - } - /*Method::POST if /*self.rest_api == RestApi::Unsecure &&*/ request.uri().path().split('/').count() > 2 => { - RpcHandlerState::ProcessRest { - metadata, - uri: request.uri().clone(), - } - } - // Just return error for unsupported content type - Method::POST => response::unsupported_content_type(), - // Don't validate content type on options - Method::OPTIONS => response::empty(), - // Respond to health API request if there is one configured. - Method::GET if self.health_api.as_ref().map(|x| &*x.0) == Some(request.uri().path()) => { - RpcHandlerState::ProcessHealth { - metadata, - method: self - .health_api - .as_ref() - .map(|x| x.1.clone()) - .expect("Health api is defined since the URI matched."), - } - }*/ - // Disallow other methods. - _ => response::method_not_allowed(), - } + let (tx, rx) = oneshot::channel(); + let user_facing_rq = Request { send_back: tx, request: json_body }; + if fg_process_tx.send(user_facing_rq).await.is_err() { + return response::internal_error("JSON requests processing channel has shut down"); + } + match rx.await { + Ok(response) => response, + Err(_) => return response::internal_error("JSON request send back channel has shut down"), + } + } + /*Method::POST if /*self.rest_api == RestApi::Unsecure &&*/ request.uri().path().split('/').count() > 2 => { + RpcHandlerState::ProcessRest { + metadata, + uri: request.uri().clone(), + } + } + // Just return error for unsupported content type + Method::POST => response::unsupported_content_type(), + // Don't validate content type on options + Method::OPTIONS => response::empty(), + // Respond to health API request if there is one configured. + Method::GET if self.health_api.as_ref().map(|x| &*x.0) == Some(request.uri().path()) => { + RpcHandlerState::ProcessHealth { + metadata, + method: self + .health_api + .as_ref() + .map(|x| x.1.clone()) + .expect("Health api is defined since the URI matched."), + } + }*/ + // Disallow other methods. + _ => response::method_not_allowed(), + } } /// Returns true if the `content_type` header indicates a valid JSON message. fn is_json(content_type: Option<&hyper::header::HeaderValue>) -> bool { - match content_type.and_then(|val| val.to_str().ok()) { - Some(ref content) - if content.eq_ignore_ascii_case("application/json") - || content.eq_ignore_ascii_case("application/json; charset=utf-8") - || content.eq_ignore_ascii_case("application/json;charset=utf-8") => - { - true - } - _ => false, - } + match content_type.and_then(|val| val.to_str().ok()) { + Some(ref content) + if content.eq_ignore_ascii_case("application/json") + || content.eq_ignore_ascii_case("application/json; charset=utf-8") + || content.eq_ignore_ascii_case("application/json;charset=utf-8") => + { + true + } + _ => false, + } } /// Converts a `hyper` body into a structured JSON object. /// /// Enforces a size limit on the body. async fn body_to_request(mut body: hyper::Body) -> Result { - let mut json_body = Vec::new(); - while let Some(chunk) = body.next().await { - let chunk = match chunk { - Ok(c) => c, - Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string())), // TODO: - }; - json_body.extend_from_slice(&chunk); - if json_body.len() >= 16384 { - // TODO: some limit - return Err(io::Error::new(io::ErrorKind::Other, "request too large")); - } - } + let mut json_body = Vec::new(); + while let Some(chunk) = body.next().await { + let chunk = match chunk { + Ok(c) => c, + Err(err) => return Err(io::Error::new(io::ErrorKind::Other, err.to_string())), // TODO: + }; + json_body.extend_from_slice(&chunk); + if json_body.len() >= 16384 { + // TODO: some limit + return Err(io::Error::new(io::ErrorKind::Other, "request too large")); + } + } - Ok(serde_json::from_slice(&json_body)?) + Ok(serde_json::from_slice(&json_body)?) } #[cfg(test)] mod tests { - use super::body_to_request; + use super::body_to_request; - #[test] - fn body_to_request_works() { - let s = r#"[{"a":"hello"}]"#; - let expected: super::common::Request = serde_json::from_str(s).unwrap(); - let req = futures::executor::block_on(async move { - let body = hyper::Body::from(s); - body_to_request(body).await.unwrap() - }); - assert_eq!(req, expected); - } + #[test] + fn body_to_request_works() { + let s = r#"[{"a":"hello"}]"#; + let expected: super::common::Request = serde_json::from_str(s).unwrap(); + let req = futures::executor::block_on(async move { + let body = hyper::Body::from(s); + body_to_request(body).await.unwrap() + }); + assert_eq!(req, expected); + } - #[test] - fn body_to_request_size_limit_json() { - let huge_body = serde_json::to_vec( - &(0..32768) - .map(|_| serde_json::Value::from("test")) - .collect::>(), - ) - .unwrap(); + #[test] + fn body_to_request_size_limit_json() { + let huge_body = + serde_json::to_vec(&(0..32768).map(|_| serde_json::Value::from("test")).collect::>()).unwrap(); - futures::executor::block_on(async move { - let body = hyper::Body::from(huge_body); - assert!(body_to_request(body).await.is_err()); - }); - } + futures::executor::block_on(async move { + let body = hyper::Body::from(huge_body); + assert!(body_to_request(body).await.is_err()); + }); + } - #[test] - fn body_to_request_size_limit_garbage() { - let huge_body = (0..100_000) - .map(|_| rand::random::()) - .collect::>(); - futures::executor::block_on(async move { - let body = hyper::Body::from(huge_body); - assert!(body_to_request(body).await.is_err()); - }); - } + #[test] + fn body_to_request_size_limit_garbage() { + let huge_body = (0..100_000).map(|_| rand::random::()).collect::>(); + futures::executor::block_on(async move { + let body = hyper::Body::from(huge_body); + assert!(body_to_request(body).await.is_err()); + }); + } } diff --git a/src/http/transport/mod.rs b/src/http/transport/mod.rs index 9768005432..de1c906a33 100644 --- a/src/http/transport/mod.rs +++ b/src/http/transport/mod.rs @@ -39,233 +39,212 @@ pub type RequestId = u64; /// Event that the [`TransportServer`] can generate. #[derive(Debug, PartialEq)] pub enum TransportServerEvent { - /// A new request has arrived on the wire. - /// - /// This generates a new "request object" within the state of the [`TransportServer`] that is - /// identified through the returned `id`. You can then use the other methods of the - /// [`TransportServer`] trait in order to manipulate that request. - Request { - /// Identifier of the request within the state of the [`TransportServer`]. - id: T, - /// Body of the request. - request: common::Request, - }, - - /// A request has been cancelled, most likely because the client has closed the connection. - /// - /// The corresponding request is no longer valid to manipulate. - Closed(T), + /// A new request has arrived on the wire. + /// + /// This generates a new "request object" within the state of the [`TransportServer`] that is + /// identified through the returned `id`. You can then use the other methods of the + /// [`TransportServer`] trait in order to manipulate that request. + Request { + /// Identifier of the request within the state of the [`TransportServer`]. + id: T, + /// Body of the request. + request: common::Request, + }, + + /// A request has been cancelled, most likely because the client has closed the connection. + /// + /// The corresponding request is no longer valid to manipulate. + Closed(T), } /// Implementation of the [`TransportServer`](crate::transport::TransportServer) trait for HTTP. pub struct HttpTransportServer { - /// Background thread that processes HTTP requests. - background_thread: background::BackgroundHttp, + /// Background thread that processes HTTP requests. + background_thread: background::BackgroundHttp, - /// Local address of the server. - local_addr: SocketAddr, + /// Local address of the server. + local_addr: SocketAddr, - /// Next identifier to use when inserting an element in `requests`. - next_request_id: u64, + /// Next identifier to use when inserting an element in `requests`. + next_request_id: u64, - /// The identifier is lineraly increasing and is never leaked on the wire or outside of this - /// module. Therefore there is no risk of hash collision and using a `FnvHashMap` is safe. - requests: FnvHashMap>>, + /// The identifier is lineraly increasing and is never leaked on the wire or outside of this + /// module. Therefore there is no risk of hash collision and using a `FnvHashMap` is safe. + requests: FnvHashMap>>, } impl HttpTransportServer { - /// Tries to start an HTTP server that listens on the given address. - /// - /// Returns an error if we fail to start listening, which generally happens if the port is - /// already occupied. - // - // > Note: This function is `async` despite not performing any asynchronous operation. Normally - // > starting to listen on a port is an asynchronous operation, but the hyper library - // > hides this to us. In order to be future-proof, this function is async, so that we - // > might switch out to a different library later without breaking the API. - pub async fn new( - addr: &SocketAddr, - ) -> Result> { - let (background_thread, local_addr) = background::BackgroundHttp::bind(addr).await?; - - log::debug!(target: "jsonrpc-http-server", "Starting jsonrpc http server at address={:?}, local_addr={:?}", addr, local_addr); - - Ok(HttpTransportServer { - background_thread, - local_addr, - requests: Default::default(), - next_request_id: 0, - }) - } - - /// Tries to start an HTTP server that listens on the given address with an access control list. - pub async fn bind_with_acl( - addr: &SocketAddr, - access_control: AccessControl, - ) -> Result> { - let (background_thread, local_addr) = - background::BackgroundHttp::bind_with_acl(addr, access_control).await?; - - Ok(HttpTransportServer { - background_thread, - local_addr, - requests: Default::default(), - next_request_id: 0, - }) - } - - /// Returns the address we are actually listening on, which might be different from the one - /// passed as parameter. - pub fn local_addr(&self) -> &SocketAddr { - &self.local_addr - } + /// Tries to start an HTTP server that listens on the given address. + /// + /// Returns an error if we fail to start listening, which generally happens if the port is + /// already occupied. + // + // > Note: This function is `async` despite not performing any asynchronous operation. Normally + // > starting to listen on a port is an asynchronous operation, but the hyper library + // > hides this to us. In order to be future-proof, this function is async, so that we + // > might switch out to a different library later without breaking the API. + pub async fn new(addr: &SocketAddr) -> Result> { + let (background_thread, local_addr) = background::BackgroundHttp::bind(addr).await?; + + log::debug!(target: "jsonrpc-http-server", "Starting jsonrpc http server at address={:?}, local_addr={:?}", addr, local_addr); + + Ok(HttpTransportServer { background_thread, local_addr, requests: Default::default(), next_request_id: 0 }) + } + + /// Tries to start an HTTP server that listens on the given address with an access control list. + pub async fn bind_with_acl( + addr: &SocketAddr, + access_control: AccessControl, + ) -> Result> { + let (background_thread, local_addr) = background::BackgroundHttp::bind_with_acl(addr, access_control).await?; + + Ok(HttpTransportServer { background_thread, local_addr, requests: Default::default(), next_request_id: 0 }) + } + + /// Returns the address we are actually listening on, which might be different from the one + /// passed as parameter. + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } } // former `TransportServer trait impl` impl HttpTransportServer { - /// Returns the next event that the raw server wants to notify us. - pub fn next_request<'a>( - &'a mut self, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - let request = match self.background_thread.next().await { - Ok(r) => r, - Err(_) => loop { - log::debug!("http transport server inf loop?!"); - futures::pending!() - }, - }; - - let request_id = { - let id = self.next_request_id; - self.next_request_id = match self.next_request_id.checked_add(1) { - Some(i) => i, - None => { - log::error!("Overflow in HttpTransportServer request ID assignment"); - loop { - futures::pending!() - } - } - }; - id - }; - - self.requests.insert(request_id, request.send_back); - - // Every 128 requests, we call `shrink_to_fit` on the list for a general cleanup. - if request_id % 128 == 0 { - self.requests.shrink_to_fit(); - } - - let request = TransportServerEvent::Request { - id: request_id, - request: request.request, - }; - - log::debug!(target: "jsonrpc-http-transport-server", "received request: {:?}", request); - request - }) - } - - /// Sends back a response and destroys the request. - /// - /// You can pass `None` in order to destroy the request object without sending back anything. - /// - /// The implementation blindly sends back the response and doesn't check whether there is any - /// correspondance with the request in terms of logic. For example, `respond` will accept - /// sending back a batch of six responses even if the original request was a single - /// notification. - /// - /// > **Note**: While this method returns a `Future` that must be driven to completion, - /// > implementations must be aware that the entire requests processing logic is - /// > blocked for as long as this `Future` is pending. As an example, you shouldn't - /// > use this `Future` to send back a TCP message, because if the remote is - /// > unresponsive and the buffers full, the `Future` would then wait for a long time. - /// - pub fn finish<'a>( - &'a mut self, - request_id: &'a RequestId, - response: Option<&'a common::Response>, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - let send_back = match self.requests.remove(request_id) { - Some(rq) => rq, - None => return Err(()), - }; - - let response = match response.map(|r| serde_json::to_vec(r)) { - Some(Ok(bytes)) => hyper::Response::builder() - .status(hyper::StatusCode::OK) - .header( - "Content-Type", - hyper::header::HeaderValue::from_static("application/json; charset=utf-8"), - ) - .body(hyper::Body::from(bytes)) - .expect("Unable to parse response body for type conversion"), - Some(Err(_)) => panic!(), // TODO: no - None => { - // TODO: is that a good idea? should the param really be an Option? - hyper::Response::builder() - .status(hyper::StatusCode::NO_CONTENT) - .body(hyper::Body::empty()) - .expect("Unable to parse response body for type conversion") - } - }; - - if send_back.send(response).is_err() { - log::error!("Couldn't send back JSON-RPC response, as background task has crashed"); - } - - Ok(()) - }) - } - - /// Returns true if this implementation supports sending back data on this request without - /// closing it. - /// - /// Returns an error if the request id is invalid. - /// > **Note**: Not supported by HTTP - // - // TODO: this method is useless remove or create abstraction. - pub fn supports_resuming(&self, id: &u64) -> Result { - if self.requests.contains_key(id) { - Ok(false) - } else { - Err(()) - } - } - - /// Sends back some data on the request and keeps the request alive. - /// - /// You can continue sending data on that same request later. - /// - /// Returns an error if the request identifier is incorrect, or if the implementation doesn't - /// support that operation (see [`supports_resuming`](TransportServer::supports_resuming)). - /// - /// > **Note**: Not supported by HTTP. - // - // TODO: this method is useless remove or create abstraction. - pub fn send<'a>( - &'a mut self, - _: &'a RequestId, - _: &'a common::Response, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { Err(()) }) - } + /// Returns the next event that the raw server wants to notify us. + pub fn next_request<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let request = match self.background_thread.next().await { + Ok(r) => r, + Err(_) => loop { + log::debug!("http transport server inf loop?!"); + futures::pending!() + }, + }; + + let request_id = { + let id = self.next_request_id; + self.next_request_id = match self.next_request_id.checked_add(1) { + Some(i) => i, + None => { + log::error!("Overflow in HttpTransportServer request ID assignment"); + loop { + futures::pending!() + } + } + }; + id + }; + + self.requests.insert(request_id, request.send_back); + + // Every 128 requests, we call `shrink_to_fit` on the list for a general cleanup. + if request_id % 128 == 0 { + self.requests.shrink_to_fit(); + } + + let request = TransportServerEvent::Request { id: request_id, request: request.request }; + + log::debug!(target: "jsonrpc-http-transport-server", "received request: {:?}", request); + request + }) + } + + /// Sends back a response and destroys the request. + /// + /// You can pass `None` in order to destroy the request object without sending back anything. + /// + /// The implementation blindly sends back the response and doesn't check whether there is any + /// correspondance with the request in terms of logic. For example, `respond` will accept + /// sending back a batch of six responses even if the original request was a single + /// notification. + /// + /// > **Note**: While this method returns a `Future` that must be driven to completion, + /// > implementations must be aware that the entire requests processing logic is + /// > blocked for as long as this `Future` is pending. As an example, you shouldn't + /// > use this `Future` to send back a TCP message, because if the remote is + /// > unresponsive and the buffers full, the `Future` would then wait for a long time. + /// + pub fn finish<'a>( + &'a mut self, + request_id: &'a RequestId, + response: Option<&'a common::Response>, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + let send_back = match self.requests.remove(request_id) { + Some(rq) => rq, + None => return Err(()), + }; + + let response = match response.map(|r| serde_json::to_vec(r)) { + Some(Ok(bytes)) => hyper::Response::builder() + .status(hyper::StatusCode::OK) + .header("Content-Type", hyper::header::HeaderValue::from_static("application/json; charset=utf-8")) + .body(hyper::Body::from(bytes)) + .expect("Unable to parse response body for type conversion"), + Some(Err(_)) => panic!(), // TODO: no + None => { + // TODO: is that a good idea? should the param really be an Option? + hyper::Response::builder() + .status(hyper::StatusCode::NO_CONTENT) + .body(hyper::Body::empty()) + .expect("Unable to parse response body for type conversion") + } + }; + + if send_back.send(response).is_err() { + log::error!("Couldn't send back JSON-RPC response, as background task has crashed"); + } + + Ok(()) + }) + } + + /// Returns true if this implementation supports sending back data on this request without + /// closing it. + /// + /// Returns an error if the request id is invalid. + /// > **Note**: Not supported by HTTP + // + // TODO: this method is useless remove or create abstraction. + pub fn supports_resuming(&self, id: &u64) -> Result { + if self.requests.contains_key(id) { + Ok(false) + } else { + Err(()) + } + } + + /// Sends back some data on the request and keeps the request alive. + /// + /// You can continue sending data on that same request later. + /// + /// Returns an error if the request identifier is incorrect, or if the implementation doesn't + /// support that operation (see [`supports_resuming`](TransportServer::supports_resuming)). + /// + /// > **Note**: Not supported by HTTP. + // + // TODO: this method is useless remove or create abstraction. + pub fn send<'a>( + &'a mut self, + _: &'a RequestId, + _: &'a common::Response, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { Err(()) }) + } } #[cfg(test)] mod tests { - use super::HttpTransportServer; - - #[test] - fn error_if_port_occupied() { - futures::executor::block_on(async move { - let addr = "127.0.0.1:0".parse().unwrap(); - let server1 = HttpTransportServer::new(&addr).await.unwrap(); - assert!(HttpTransportServer::new(server1.local_addr()) - .await - .is_err()); - }); - } + use super::HttpTransportServer; + + #[test] + fn error_if_port_occupied() { + futures::executor::block_on(async move { + let addr = "127.0.0.1:0".parse().unwrap(); + let server1 = HttpTransportServer::new(&addr).await.unwrap(); + assert!(HttpTransportServer::new(server1.local_addr()).await.is_err()); + }); + } } diff --git a/src/http/transport/response.rs b/src/http/transport/response.rs index 5ca7c15b5f..112e83e62e 100644 --- a/src/http/transport/response.rs +++ b/src/http/transport/response.rs @@ -28,68 +28,58 @@ /// Create a response for plaintext internal error. pub fn internal_error>(msg: T) -> hyper::Response { - from_template( - hyper::StatusCode::INTERNAL_SERVER_ERROR, - format!("Internal Server Error: {}", msg.into()), - ) + from_template(hyper::StatusCode::INTERNAL_SERVER_ERROR, format!("Internal Server Error: {}", msg.into())) } /// Create a json response for service unavailable. pub fn service_unavailable>(msg: T) -> hyper::Response { - hyper::Response::builder() - .status(hyper::StatusCode::SERVICE_UNAVAILABLE) - .header( - "Content-Type", - hyper::header::HeaderValue::from_static("application/json; charset=utf-8"), - ) - .body(hyper::Body::from(msg.into())) - .expect("Unable to parse response body for type conversion") + hyper::Response::builder() + .status(hyper::StatusCode::SERVICE_UNAVAILABLE) + .header("Content-Type", hyper::header::HeaderValue::from_static("application/json; charset=utf-8")) + .body(hyper::Body::from(msg.into())) + .expect("Unable to parse response body for type conversion") } /// Create a response for not allowed hosts. pub fn host_not_allowed() -> hyper::Response { - from_template( - hyper::StatusCode::FORBIDDEN, - "Provided Host header is not whitelisted.\n".to_owned(), - ) + from_template(hyper::StatusCode::FORBIDDEN, "Provided Host header is not whitelisted.\n".to_owned()) } /// Create a response for unsupported content type. pub fn unsupported_content_type() -> hyper::Response { - from_template( - hyper::StatusCode::UNSUPPORTED_MEDIA_TYPE, - "Supplied content type is not allowed. Content-Type: application/json is required\n" - .to_owned(), - ) + from_template( + hyper::StatusCode::UNSUPPORTED_MEDIA_TYPE, + "Supplied content type is not allowed. Content-Type: application/json is required\n".to_owned(), + ) } /// Create a response for invalid JSON in request pub fn parse_error() -> hyper::Response { - hyper::Response::builder() - .status(hyper::StatusCode::OK) - .header("Content-type", "application/json") - .body(hyper::Body::from( - serde_json::to_string(&crate::common::Output::Failure(crate::common::Failure { - jsonrpc: crate::common::Version::V2, - error: crate::common::Error::parse_error(), - id: crate::common::Id::Null, - })) - .expect("Unable to serialize parse error"), - )) - .expect("Unable to parse response body for type conversion") + hyper::Response::builder() + .status(hyper::StatusCode::OK) + .header("Content-type", "application/json") + .body(hyper::Body::from( + serde_json::to_string(&crate::common::Output::Failure(crate::common::Failure { + jsonrpc: crate::common::Version::V2, + error: crate::common::Error::parse_error(), + id: crate::common::Id::Null, + })) + .expect("Unable to serialize parse error"), + )) + .expect("Unable to parse response body for type conversion") } /// Create a response for disallowed method used. pub fn method_not_allowed() -> hyper::Response { - from_template( - hyper::StatusCode::METHOD_NOT_ALLOWED, - "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned(), - ) + from_template( + hyper::StatusCode::METHOD_NOT_ALLOWED, + "Used HTTP Method is not allowed. POST or OPTIONS is required\n".to_owned(), + ) } /// CORS invalid pub fn invalid_allow_origin() -> hyper::Response { - from_template( + from_template( hyper::StatusCode::FORBIDDEN, "Origin of the request is not whitelisted. CORS headers would not be sent and any side-effects were cancelled as well.\n".to_owned(), ) @@ -97,7 +87,7 @@ pub fn invalid_allow_origin() -> hyper::Response { /// CORS header invalid pub fn invalid_allow_headers() -> hyper::Response { - from_template( + from_template( hyper::StatusCode::FORBIDDEN, "Requested headers are not allowed for CORS. CORS headers would not be sent and any side-effects were cancelled as well.\n".to_owned(), ) @@ -105,24 +95,21 @@ pub fn invalid_allow_headers() -> hyper::Response { /// Create a response for bad request pub fn bad_request>(msg: S) -> hyper::Response { - from_template(hyper::StatusCode::BAD_REQUEST, msg.into()) + from_template(hyper::StatusCode::BAD_REQUEST, msg.into()) } /// Create a response for too large (413) pub fn too_large>(msg: S) -> hyper::Response { - from_template(hyper::StatusCode::PAYLOAD_TOO_LARGE, msg.into()) + from_template(hyper::StatusCode::PAYLOAD_TOO_LARGE, msg.into()) } /// Create a response for a template. fn from_template(status: hyper::StatusCode, body: String) -> hyper::Response { - hyper::Response::builder() - .status(status) - .header( - "content-type", - hyper::header::HeaderValue::from_static("text/plain; charset=utf-8"), - ) - .body(hyper::Body::from(body)) - // Parsing `StatusCode` and `HeaderValue` is infalliable but - // parsing body content is not. - .expect("Unable to parse response body for type conversion") + hyper::Response::builder() + .status(status) + .header("content-type", hyper::header::HeaderValue::from_static("text/plain; charset=utf-8")) + .body(hyper::Body::from(body)) + // Parsing `StatusCode` and `HeaderValue` is infalliable but + // parsing body content is not. + .expect("Unable to parse response body for type conversion") } diff --git a/src/ws/mod.rs b/src/ws/mod.rs index 5e9d9def48..42ed9c1588 100644 --- a/src/ws/mod.rs +++ b/src/ws/mod.rs @@ -32,8 +32,8 @@ mod transport; mod tests; pub use raw::{ - Notification as WsRawNotification, RawServer as RawWsServer, - RawServerEvent as RawWsServerEvent, TypedResponder as WsTypedResponder, + Notification as WsRawNotification, RawServer as RawWsServer, RawServerEvent as RawWsServerEvent, + TypedResponder as WsTypedResponder, }; pub use server::{RegisteredMethod, RegisteredNotification, Server as WsServer}; pub use transport::WsTransportServer; diff --git a/src/ws/raw/batch.rs b/src/ws/raw/batch.rs index 062ecd58b5..b42f094aae 100644 --- a/src/ws/raw/batch.rs +++ b/src/ws/raw/batch.rs @@ -52,333 +52,303 @@ use smallvec::SmallVec; /// destroy the [`BatchState`]. /// pub struct BatchState { - /// List of elements to present to the user. - to_yield: SmallVec<[ToYield; 1]>, + /// List of elements to present to the user. + to_yield: SmallVec<[ToYield; 1]>, - /// List of requests to be answered. When a request is answered, we replace it with `None` so - /// that indices don't change. - requests: SmallVec<[Option; 1]>, + /// List of requests to be answered. When a request is answered, we replace it with `None` so + /// that indices don't change. + requests: SmallVec<[Option; 1]>, - /// List of pending responses. - responses: SmallVec<[common::Output; 1]>, + /// List of pending responses. + responses: SmallVec<[common::Output; 1]>, - /// True if the original request was a batch. We need to keep track of this because we need to - /// respond differently depending on whether we have a single request or a batch with one - /// request. - is_batch: bool, + /// True if the original request was a batch. We need to keep track of this because we need to + /// respond differently depending on whether we have a single request or a batch with one + /// request. + is_batch: bool, } /// Element remaining to be yielded to the user. #[derive(Debug)] enum ToYield { - Notification(common::Notification), - Request(common::MethodCall), + Notification(common::Notification), + Request(common::MethodCall), } /// Event generated by the [`next`](BatchState::next) function. #[derive(Debug)] pub enum BatchInc<'a> { - /// Request is a notification. - Notification(Notification), - /// Request is a method call. - Request(BatchElem<'a>), + /// Request is a notification. + Notification(Notification), + /// Request is a method call. + Request(BatchElem<'a>), } /// References to a request within the batch that must be answered. pub struct BatchElem<'a> { - /// Index within the `BatchState::requests` list. - index: usize, - /// Reference to the actual element. Must always be `Some` for the lifetime of this object. - /// We hold a `&mut Option` rather than a `&mut common::MethodCall` so - /// that we can put `None` in it. - elem: &'a mut Option, - /// Reference to the `BatchState::responses` list so that we can push a response. - responses: &'a mut SmallVec<[common::Output; 1]>, + /// Index within the `BatchState::requests` list. + index: usize, + /// Reference to the actual element. Must always be `Some` for the lifetime of this object. + /// We hold a `&mut Option` rather than a `&mut common::MethodCall` so + /// that we can put `None` in it. + elem: &'a mut Option, + /// Reference to the `BatchState::responses` list so that we can push a response. + responses: &'a mut SmallVec<[common::Output; 1]>, } impl BatchState { - /// Creates a `BatchState` that will manage the given request. - pub fn from_request(raw_request_body: common::Request) -> BatchState { - match raw_request_body { - common::Request::Single(rq) => BatchState::from_iter(iter::once(rq), false), - common::Request::Batch(requests) => BatchState::from_iter(requests.into_iter(), true), - } - } - - /// Internal implementation of [`from_request`](BatchState::from_request). Generic over the - /// iterator. - fn from_iter( - calls_list: impl ExactSizeIterator, - is_batch: bool, - ) -> BatchState { - debug_assert!(!(!is_batch && calls_list.len() >= 2)); - - let mut to_yield = SmallVec::with_capacity(calls_list.len()); - let mut responses = SmallVec::with_capacity(calls_list.len()); - let mut num_requests = 0; - - for call in calls_list { - match call { - common::Call::MethodCall(call) => { - to_yield.push(ToYield::Request(call)); - num_requests += 1; - } - common::Call::Notification(n) => { - to_yield.push(ToYield::Notification(n)); - } - common::Call::Invalid { id } => { - let err = Err(common::Error::invalid_request()); - let out = common::Output::from(err, id, common::Version::V2); - responses.push(out); - } - } - } - - BatchState { - to_yield, - requests: SmallVec::with_capacity(num_requests), - responses, - is_batch, - } - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id(&mut self, id: usize) -> Option { - if let Some(elem) = self.requests.get_mut(id) { - if elem.is_none() { - return None; - } - Some(BatchElem { - elem, - index: id, - responses: &mut self.responses, - }) - } else { - None - } - } - - /// Extracts the next request from the batch. Returns `None` if the batch is empty. - pub fn next(&mut self) -> Option { - if self.to_yield.is_empty() { - return None; - } - - match self.to_yield.remove(0) { - ToYield::Notification(n) => Some(BatchInc::Notification(From::from(n))), - ToYield::Request(n) => { - let request_id = self.requests.len(); - self.requests.push(Some(n)); - Some(BatchInc::Request(BatchElem { - index: request_id, - elem: &mut self.requests[request_id], - responses: &mut self.responses, - })) - } - } - } - - /// Returns true if this batch is ready to send out its response. - pub fn is_ready_to_respond(&self) -> bool { - self.to_yield.is_empty() && self.requests.iter().all(|r| r.is_none()) - } - - /// Turns this batch into a response to send out to the client. - /// - /// Returns `Ok(None)` if there is actually nothing to send to the client, such as when the - /// client has only sent notifications. - pub fn into_response(mut self) -> Result, Self> { - if !self.is_ready_to_respond() { - return Err(self); - } - - let raw_response = if self.is_batch { - let list: Vec<_> = self.responses.drain(..).collect(); - if list.is_empty() { - None - } else { - Some(common::Response::Batch(list)) - } - } else { - debug_assert!(self.responses.len() <= 1); - if self.responses.is_empty() { - None - } else { - Some(common::Response::Single(self.responses.remove(0))) - } - }; - - Ok(raw_response) - } + /// Creates a `BatchState` that will manage the given request. + pub fn from_request(raw_request_body: common::Request) -> BatchState { + match raw_request_body { + common::Request::Single(rq) => BatchState::from_iter(iter::once(rq), false), + common::Request::Batch(requests) => BatchState::from_iter(requests.into_iter(), true), + } + } + + /// Internal implementation of [`from_request`](BatchState::from_request). Generic over the + /// iterator. + fn from_iter(calls_list: impl ExactSizeIterator, is_batch: bool) -> BatchState { + debug_assert!(!(!is_batch && calls_list.len() >= 2)); + + let mut to_yield = SmallVec::with_capacity(calls_list.len()); + let mut responses = SmallVec::with_capacity(calls_list.len()); + let mut num_requests = 0; + + for call in calls_list { + match call { + common::Call::MethodCall(call) => { + to_yield.push(ToYield::Request(call)); + num_requests += 1; + } + common::Call::Notification(n) => { + to_yield.push(ToYield::Notification(n)); + } + common::Call::Invalid { id } => { + let err = Err(common::Error::invalid_request()); + let out = common::Output::from(err, id, common::Version::V2); + responses.push(out); + } + } + } + + BatchState { to_yield, requests: SmallVec::with_capacity(num_requests), responses, is_batch } + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id(&mut self, id: usize) -> Option { + if let Some(elem) = self.requests.get_mut(id) { + if elem.is_none() { + return None; + } + Some(BatchElem { elem, index: id, responses: &mut self.responses }) + } else { + None + } + } + + /// Extracts the next request from the batch. Returns `None` if the batch is empty. + pub fn next(&mut self) -> Option { + if self.to_yield.is_empty() { + return None; + } + + match self.to_yield.remove(0) { + ToYield::Notification(n) => Some(BatchInc::Notification(From::from(n))), + ToYield::Request(n) => { + let request_id = self.requests.len(); + self.requests.push(Some(n)); + Some(BatchInc::Request(BatchElem { + index: request_id, + elem: &mut self.requests[request_id], + responses: &mut self.responses, + })) + } + } + } + + /// Returns true if this batch is ready to send out its response. + pub fn is_ready_to_respond(&self) -> bool { + self.to_yield.is_empty() && self.requests.iter().all(|r| r.is_none()) + } + + /// Turns this batch into a response to send out to the client. + /// + /// Returns `Ok(None)` if there is actually nothing to send to the client, such as when the + /// client has only sent notifications. + pub fn into_response(mut self) -> Result, Self> { + if !self.is_ready_to_respond() { + return Err(self); + } + + let raw_response = if self.is_batch { + let list: Vec<_> = self.responses.drain(..).collect(); + if list.is_empty() { + None + } else { + Some(common::Response::Batch(list)) + } + } else { + debug_assert!(self.responses.len() <= 1); + if self.responses.is_empty() { + None + } else { + Some(common::Response::Single(self.responses.remove(0))) + } + }; + + Ok(raw_response) + } } impl fmt::Debug for BatchState { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_list() - .entries(self.to_yield.iter()) - .entries(self.requests.iter().filter(|r| r.is_some())) - .entries(self.responses.iter()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_list() + .entries(self.to_yield.iter()) + .entries(self.requests.iter().filter(|r| r.is_some())) + .entries(self.responses.iter()) + .finish() + } } impl<'a> BatchElem<'a> { - /// Returns the id of the request within the [`BatchState`]. - /// - /// > **Note**: This is NOT the request id that the client passed. - pub fn id(&self) -> usize { - self.index - } - - /// Returns the id that the client sent out. - pub fn request_id(&self) -> &common::Id { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - &request.id - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - &request.method - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - let request = self - .elem - .as_ref() - .expect("elem must be Some for the lifetime of the object; qed"); - Params::from(&request.params) - } - - /// Responds to the request. This destroys the request object, meaning you can no longer - /// retrieve it with [`request_by_id`](BatchState::request_by_id) later anymore. - pub fn set_response(self, response: Result) { - let request = self - .elem - .take() - .expect("elem must be Some for the lifetime of the object; qed"); - let response = common::Output::from(response, request.id, common::Version::V2); - self.responses.push(response); - } + /// Returns the id of the request within the [`BatchState`]. + /// + /// > **Note**: This is NOT the request id that the client passed. + pub fn id(&self) -> usize { + self.index + } + + /// Returns the id that the client sent out. + pub fn request_id(&self) -> &common::Id { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + &request.id + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + &request.method + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + let request = self.elem.as_ref().expect("elem must be Some for the lifetime of the object; qed"); + Params::from(&request.params) + } + + /// Responds to the request. This destroys the request object, meaning you can no longer + /// retrieve it with [`request_by_id`](BatchState::request_by_id) later anymore. + pub fn set_response(self, response: Result) { + let request = self.elem.take().expect("elem must be Some for the lifetime of the object; qed"); + let response = common::Output::from(response, request.id, common::Version::V2); + self.responses.push(response); + } } impl<'a> fmt::Debug for BatchElem<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("BatchElem") - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BatchElem").field("method", &self.method()).field("params", &self.params()).finish() + } } #[cfg(test)] mod tests { - use super::{BatchInc, BatchState}; - use crate::{common, ws::WsRawNotification}; - - #[test] - fn basic_notification() { - let notif = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let mut state = { - let rq = common::Request::Single(common::Call::Notification(notif.clone())); - BatchState::from_request(rq) - }; - - assert!(!state.is_ready_to_respond()); - match state.next() { - Some(BatchInc::Notification(ref n)) if n == &WsRawNotification::from(notif) => {} - _ => panic!(), - } - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - match state.into_response() { - Ok(None) => {} - _ => panic!(), - } - } - - #[test] - fn basic_request() { - let call = common::MethodCall { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), - id: common::Id::Num(123), - }; - - let mut state = { - let rq = common::Request::Single(common::Call::MethodCall(call.clone())); - BatchState::from_request(rq) - }; - - assert!(!state.is_ready_to_respond()); - let rq_id = match state.next() { - Some(BatchInc::Request(rq)) => { - assert_eq!(rq.method(), "foo"); - assert_eq!( - { - let v: String = rq.params().get("test").unwrap(); - v - }, - "foo" - ); - assert_eq!(rq.request_id(), &common::Id::Num(123)); - rq.id() - } - _ => panic!(), - }; - - assert!(state.next().is_none()); - assert!(!state.is_ready_to_respond()); - assert!(state.next().is_none()); - - assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); - state - .request_by_id(rq_id) - .unwrap() - .set_response(Err(common::Error::method_not_found())); - - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - - match state.into_response() { - Ok(Some(common::Response::Single(common::Output::Failure(f)))) => { - assert_eq!(f.id, common::Id::Num(123)); - } - _ => panic!(), - } - } - - #[test] - fn empty_batch() { - let mut state = { - let rq = common::Request::Batch(Vec::new()); - BatchState::from_request(rq) - }; - - assert!(state.is_ready_to_respond()); - assert!(state.next().is_none()); - match state.into_response() { - Ok(None) => {} - _ => panic!(), - } - } + use super::{BatchInc, BatchState}; + use crate::{common, ws::WsRawNotification}; + + #[test] + fn basic_notification() { + let notif = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let mut state = { + let rq = common::Request::Single(common::Call::Notification(notif.clone())); + BatchState::from_request(rq) + }; + + assert!(!state.is_ready_to_respond()); + match state.next() { + Some(BatchInc::Notification(ref n)) if n == &WsRawNotification::from(notif) => {} + _ => panic!(), + } + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + match state.into_response() { + Ok(None) => {} + _ => panic!(), + } + } + + #[test] + fn basic_request() { + let call = common::MethodCall { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), + id: common::Id::Num(123), + }; + + let mut state = { + let rq = common::Request::Single(common::Call::MethodCall(call.clone())); + BatchState::from_request(rq) + }; + + assert!(!state.is_ready_to_respond()); + let rq_id = match state.next() { + Some(BatchInc::Request(rq)) => { + assert_eq!(rq.method(), "foo"); + assert_eq!( + { + let v: String = rq.params().get("test").unwrap(); + v + }, + "foo" + ); + assert_eq!(rq.request_id(), &common::Id::Num(123)); + rq.id() + } + _ => panic!(), + }; + + assert!(state.next().is_none()); + assert!(!state.is_ready_to_respond()); + assert!(state.next().is_none()); + + assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); + state.request_by_id(rq_id).unwrap().set_response(Err(common::Error::method_not_found())); + + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + + match state.into_response() { + Ok(Some(common::Response::Single(common::Output::Failure(f)))) => { + assert_eq!(f.id, common::Id::Num(123)); + } + _ => panic!(), + } + } + + #[test] + fn empty_batch() { + let mut state = { + let rq = common::Request::Batch(Vec::new()); + BatchState::from_request(rq) + }; + + assert!(state.is_ready_to_respond()); + assert!(state.next().is_none()); + match state.into_response() { + Ok(None) => {} + _ => panic!(), + } + } } diff --git a/src/ws/raw/batches.rs b/src/ws/raw/batches.rs index 3132a17423..e346852f0e 100644 --- a/src/ws/raw/batches.rs +++ b/src/ws/raw/batches.rs @@ -48,398 +48,365 @@ use hashbrown::{hash_map::Entry, HashMap}; /// [`request_by_id`](BatchesState::request_by_id). /// pub struct BatchesState { - /// Identifier of the next batch to add to `batches`. - next_batch_id: u64, - - /// For each batch, the individual batch's state and the user parameter. - /// - /// The identifier is lineraly increasing and is never leaked on the wire or outside of this - /// module. Therefore there is no risk of hash collision. - batches: HashMap, + /// Identifier of the next batch to add to `batches`. + next_batch_id: u64, + + /// For each batch, the individual batch's state and the user parameter. + /// + /// The identifier is lineraly increasing and is never leaked on the wire or outside of this + /// module. Therefore there is no risk of hash collision. + batches: HashMap, } /// Event generated by [`next_event`](BatchesState::next_event). #[derive(Debug)] pub enum BatchesEvent<'a, T> { - /// A notification has been extracted from a batch. - Notification { - /// Notification in question. - notification: Notification, - /// User parameter passed when calling [`inject`](BatchesState::inject). - user_param: &'a mut T, - }, - - /// A request has been extracted from a batch. - Request(BatchesElem<'a, T>), - - /// A batch has gotten all its requests answered and a response is ready to be sent out. - ReadyToSend { - /// Response to send out to the JSON-RPC client. - response: common::Response, - /// User parameter passed when calling [`inject`](BatchesState::inject). - user_param: T, - }, + /// A notification has been extracted from a batch. + Notification { + /// Notification in question. + notification: Notification, + /// User parameter passed when calling [`inject`](BatchesState::inject). + user_param: &'a mut T, + }, + + /// A request has been extracted from a batch. + Request(BatchesElem<'a, T>), + + /// A batch has gotten all its requests answered and a response is ready to be sent out. + ReadyToSend { + /// Response to send out to the JSON-RPC client. + response: common::Response, + /// User parameter passed when calling [`inject`](BatchesState::inject). + user_param: T, + }, } /// Request within the batches. pub struct BatchesElem<'a, T> { - /// Id of the batch that contains this element. - batch_id: u64, - /// Inner reference to a request within a batch. - inner: batch::BatchElem<'a>, - /// User parameter passed when calling `inject`. - user_param: &'a mut T, + /// Id of the batch that contains this element. + batch_id: u64, + /// Inner reference to a request within a batch. + inner: batch::BatchElem<'a>, + /// User parameter passed when calling `inject`. + user_param: &'a mut T, } /// Identifier of a request within a [`BatchesState`]. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct BatchesElemId { - /// Id of the batch within `BatchesState::batches`. - outer: u64, - /// Id of the request within the batch. - inner: usize, + /// Id of the batch within `BatchesState::batches`. + outer: u64, + /// Id of the request within the batch. + inner: usize, } /// Minimal capacity for the `batches` container. const BATCHES_MIN_CAPACITY: usize = 256; impl BatchesState { - /// Creates a new empty `BatchesState`. - pub fn new() -> BatchesState { - BatchesState { - next_batch_id: 0, - batches: HashMap::with_capacity_and_hasher(BATCHES_MIN_CAPACITY, Default::default()), - } - } - - /// Processes one step from a batch and returns an event. Returns `None` if there is nothing - /// to do. After you call `inject`, then this will return `Some` at least once. - pub fn next_event(&mut self) -> Option> { - // Note that this function has a complexity of `O(n)`, as we iterate over every single - // batch every single time. This is however the most straight-forward way to implement it, - // and while better strategies might yield better complexities, it might not actually yield - // better performances in real-world situations. More brainstorming and benchmarking could - // get helpful here. - - // Because of long-standing Rust lifetime issues - // (https://github.com/rust-lang/rust/issues/51526), we can't do this in an elegant way. - // If you're reading this code, know that it took several iterations and that I hated my - // life while trying to figure out how to make the compiler happy. - for batch_id in self.batches.keys().cloned().collect::>() { - enum WhatCanWeDo { - Nothing, - ReadyToRespond, - Notification(Notification), - Request(usize), - } - - let what_can_we_do = { - let (batch, _) = self - .batches - .get_mut(&batch_id) - .expect("all keys are valid; qed"); - let is_ready_to_respond = batch.is_ready_to_respond(); - match batch.next() { - None if is_ready_to_respond => WhatCanWeDo::ReadyToRespond, - None => WhatCanWeDo::Nothing, - Some(batch::BatchInc::Notification(n)) => WhatCanWeDo::Notification(n), - Some(batch::BatchInc::Request(inner)) => WhatCanWeDo::Request(inner.id()), - } - }; - - match what_can_we_do { - WhatCanWeDo::Nothing => {} - WhatCanWeDo::ReadyToRespond => { - let (batch, user_param) = self - .batches - .remove(&batch_id) - .expect("key was grabbed from self.batches; qed"); - let response = batch - .into_response() - .unwrap_or_else(|_| panic!("is_ready_to_respond returned true; qed")); - if let Some(response) = response { - return Some(BatchesEvent::ReadyToSend { - response, - user_param, - }); - } - } - WhatCanWeDo::Notification(notification) => { - return Some(BatchesEvent::Notification { - notification, - user_param: &mut self.batches.get_mut(&batch_id).unwrap().1, - }); - } - WhatCanWeDo::Request(id) => { - let (batch, user_param) = self.batches.get_mut(&batch_id).unwrap(); - return Some(BatchesEvent::Request(BatchesElem { - batch_id, - inner: batch.request_by_id(id).unwrap(), - user_param, - })); - } - } - } - - None - } - - /// Injects a newly-received batch into the list. You must then call - /// [`next_event`](BatchesState::next_event) in order to process it. - pub fn inject(&mut self, request: common::Request, user_param: T) { - let batch = batch::BatchState::from_request(request); - - loop { - let id = self.next_batch_id; - self.next_batch_id = self.next_batch_id.wrapping_add(1); - - // We shrink `self.batches` from time to time so that it doesn't grow too much. - if id % 256 == 0 { - self.batches.shrink_to_fit(); - // TODO: self.batches.shrink_to(BATCHES_MIN_CAPACITY); - // ^ see https://github.com/rust-lang/rust/issues/56431 - } - - match self.batches.entry(id) { - Entry::Occupied(_) => continue, - Entry::Vacant(e) => { - e.insert((batch, user_param)); - break; - } - } - } - } - - /// Returns a list of all user data associated to active batches. - pub fn batches<'a>(&'a mut self) -> impl Iterator + 'a { - self.batches.values_mut().map(|(_, user_data)| user_data) - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id(&mut self, id: BatchesElemId) -> Option> { - if let Some((batch, user_param)) = self.batches.get_mut(&id.outer) { - Some(BatchesElem { - batch_id: id.outer, - inner: batch.request_by_id(id.inner)?, - user_param, - }) - } else { - None - } - } + /// Creates a new empty `BatchesState`. + pub fn new() -> BatchesState { + BatchesState { + next_batch_id: 0, + batches: HashMap::with_capacity_and_hasher(BATCHES_MIN_CAPACITY, Default::default()), + } + } + + /// Processes one step from a batch and returns an event. Returns `None` if there is nothing + /// to do. After you call `inject`, then this will return `Some` at least once. + pub fn next_event(&mut self) -> Option> { + // Note that this function has a complexity of `O(n)`, as we iterate over every single + // batch every single time. This is however the most straight-forward way to implement it, + // and while better strategies might yield better complexities, it might not actually yield + // better performances in real-world situations. More brainstorming and benchmarking could + // get helpful here. + + // Because of long-standing Rust lifetime issues + // (https://github.com/rust-lang/rust/issues/51526), we can't do this in an elegant way. + // If you're reading this code, know that it took several iterations and that I hated my + // life while trying to figure out how to make the compiler happy. + for batch_id in self.batches.keys().cloned().collect::>() { + enum WhatCanWeDo { + Nothing, + ReadyToRespond, + Notification(Notification), + Request(usize), + } + + let what_can_we_do = { + let (batch, _) = self.batches.get_mut(&batch_id).expect("all keys are valid; qed"); + let is_ready_to_respond = batch.is_ready_to_respond(); + match batch.next() { + None if is_ready_to_respond => WhatCanWeDo::ReadyToRespond, + None => WhatCanWeDo::Nothing, + Some(batch::BatchInc::Notification(n)) => WhatCanWeDo::Notification(n), + Some(batch::BatchInc::Request(inner)) => WhatCanWeDo::Request(inner.id()), + } + }; + + match what_can_we_do { + WhatCanWeDo::Nothing => {} + WhatCanWeDo::ReadyToRespond => { + let (batch, user_param) = + self.batches.remove(&batch_id).expect("key was grabbed from self.batches; qed"); + let response = + batch.into_response().unwrap_or_else(|_| panic!("is_ready_to_respond returned true; qed")); + if let Some(response) = response { + return Some(BatchesEvent::ReadyToSend { response, user_param }); + } + } + WhatCanWeDo::Notification(notification) => { + return Some(BatchesEvent::Notification { + notification, + user_param: &mut self.batches.get_mut(&batch_id).unwrap().1, + }); + } + WhatCanWeDo::Request(id) => { + let (batch, user_param) = self.batches.get_mut(&batch_id).unwrap(); + return Some(BatchesEvent::Request(BatchesElem { + batch_id, + inner: batch.request_by_id(id).unwrap(), + user_param, + })); + } + } + } + + None + } + + /// Injects a newly-received batch into the list. You must then call + /// [`next_event`](BatchesState::next_event) in order to process it. + pub fn inject(&mut self, request: common::Request, user_param: T) { + let batch = batch::BatchState::from_request(request); + + loop { + let id = self.next_batch_id; + self.next_batch_id = self.next_batch_id.wrapping_add(1); + + // We shrink `self.batches` from time to time so that it doesn't grow too much. + if id % 256 == 0 { + self.batches.shrink_to_fit(); + // TODO: self.batches.shrink_to(BATCHES_MIN_CAPACITY); + // ^ see https://github.com/rust-lang/rust/issues/56431 + } + + match self.batches.entry(id) { + Entry::Occupied(_) => continue, + Entry::Vacant(e) => { + e.insert((batch, user_param)); + break; + } + } + } + } + + /// Returns a list of all user data associated to active batches. + pub fn batches<'a>(&'a mut self) -> impl Iterator + 'a { + self.batches.values_mut().map(|(_, user_data)| user_data) + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id(&mut self, id: BatchesElemId) -> Option> { + if let Some((batch, user_param)) = self.batches.get_mut(&id.outer) { + Some(BatchesElem { batch_id: id.outer, inner: batch.request_by_id(id.inner)?, user_param }) + } else { + None + } + } } impl Default for BatchesState { - fn default() -> Self { - Self::new() - } + fn default() -> Self { + Self::new() + } } impl fmt::Debug for BatchesState where - T: fmt::Debug, + T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_list().entries(self.batches.values()).finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_list().entries(self.batches.values()).finish() + } } impl<'a, T> BatchesElem<'a, T> { - /// Returns the id of the request within the [`BatchesState`]. - /// - /// > **Note**: This is NOT the request id that the client passed. - pub fn id(&self) -> BatchesElemId { - BatchesElemId { - outer: self.batch_id, - inner: self.inner.id(), - } - } - - /// Returns the user parameter passed when calling [`inject`](BatchesState::inject). - pub fn user_param(&mut self) -> &mut T { - &mut self.user_param - } - - /// Returns the id that the client sent out. - pub fn request_id(&self) -> &common::Id { - self.inner.request_id() - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - self.inner.method() - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - self.inner.params() - } - - /// Responds to the request. This destroys the request object, meaning you can no longer - /// retrieve it with [`request_by_id`](BatchesState::request_by_id) later anymore. - /// - /// A [`ReadyToSend`](BatchesEvent::ReadyToSend) event containing this response might be - /// generated the next time you call [`next_event`](BatchesState::next_event). - pub fn set_response(self, response: Result) { - self.inner.set_response(response) - } + /// Returns the id of the request within the [`BatchesState`]. + /// + /// > **Note**: This is NOT the request id that the client passed. + pub fn id(&self) -> BatchesElemId { + BatchesElemId { outer: self.batch_id, inner: self.inner.id() } + } + + /// Returns the user parameter passed when calling [`inject`](BatchesState::inject). + pub fn user_param(&mut self) -> &mut T { + &mut self.user_param + } + + /// Returns the id that the client sent out. + pub fn request_id(&self) -> &common::Id { + self.inner.request_id() + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + self.inner.method() + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + self.inner.params() + } + + /// Responds to the request. This destroys the request object, meaning you can no longer + /// retrieve it with [`request_by_id`](BatchesState::request_by_id) later anymore. + /// + /// A [`ReadyToSend`](BatchesEvent::ReadyToSend) event containing this response might be + /// generated the next time you call [`next_event`](BatchesState::next_event). + pub fn set_response(self, response: Result) { + self.inner.set_response(response) + } } impl<'a, T> fmt::Debug for BatchesElem<'a, T> where - T: fmt::Debug, + T: fmt::Debug, { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("BatchesElem") - .field("id", &self.id()) - .field("user_param", &self.user_param) - .field("request_id", &self.request_id()) - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BatchesElem") + .field("id", &self.id()) + .field("user_param", &self.user_param) + .field("request_id", &self.request_id()) + .field("method", &self.method()) + .field("params", &self.params()) + .finish() + } } #[cfg(test)] mod tests { - use super::{BatchesEvent, BatchesState}; - use crate::{common, ws::WsRawNotification}; - - #[test] - fn basic_notification() { - let notif = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Single(common::Call::Notification(notif.clone())), - (), - ); - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, .. - }) if *notification == WsRawNotification::from(notif) => {} - _ => panic!(), - } - assert!(state.next_event().is_none()); - } - - #[test] - fn basic_request() { - let call = common::MethodCall { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), - id: common::Id::Num(123), - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Single(common::Call::MethodCall(call)), - 8889, - ); - - let rq_id = match state.next_event() { - Some(BatchesEvent::Request(rq)) => { - assert_eq!(rq.method(), "foo"); - assert_eq!( - { - let v: String = rq.params().get("test").unwrap(); - v - }, - "foo" - ); - assert_eq!(rq.request_id(), &common::Id::Num(123)); - rq.id() - } - _ => panic!(), - }; - - assert!(state.next_event().is_none()); - - assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); - state - .request_by_id(rq_id) - .unwrap() - .set_response(Err(common::Error::method_not_found())); - assert!(state.request_by_id(rq_id).is_none()); - - match state.next_event() { - Some(BatchesEvent::ReadyToSend { - response, - user_param, - }) => { - assert_eq!(user_param, 8889); - match response { - common::Response::Single(common::Output::Failure(f)) => { - assert_eq!(f.id, common::Id::Num(123)); - } - _ => panic!(), - } - } - _ => panic!(), - }; - } - - #[test] - fn empty_batch() { - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject(common::Request::Batch(Vec::new()), ()); - assert!(state.next_event().is_none()); - } - - #[test] - fn batch_of_notifs() { - let notif1 = common::Notification { - jsonrpc: common::Version::V2, - method: "foo".to_string(), - params: common::Params::None, - }; - - let notif2 = common::Notification { - jsonrpc: common::Version::V2, - method: "bar".to_string(), - params: common::Params::None, - }; - - let mut state = BatchesState::new(); - assert!(state.next_event().is_none()); - state.inject( - common::Request::Batch(vec![ - common::Call::Notification(notif1.clone()), - common::Call::Notification(notif2.clone()), - ]), - 2, - ); - - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, - ref user_param, - }) if *notification == WsRawNotification::from(notif1) && **user_param == 2 => {} - _ => panic!(), - } - - match state.next_event() { - Some(BatchesEvent::Notification { - ref notification, - ref user_param, - }) if *notification == WsRawNotification::from(notif2) && **user_param == 2 => {} - _ => panic!(), - } - - assert!(state.next_event().is_none()); - } + use super::{BatchesEvent, BatchesState}; + use crate::{common, ws::WsRawNotification}; + + #[test] + fn basic_notification() { + let notif = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Single(common::Call::Notification(notif.clone())), ()); + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, .. }) + if *notification == WsRawNotification::from(notif) => {} + _ => panic!(), + } + assert!(state.next_event().is_none()); + } + + #[test] + fn basic_request() { + let call = common::MethodCall { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::Map(serde_json::from_str("{\"test\":\"foo\"}").unwrap()), + id: common::Id::Num(123), + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Single(common::Call::MethodCall(call)), 8889); + + let rq_id = match state.next_event() { + Some(BatchesEvent::Request(rq)) => { + assert_eq!(rq.method(), "foo"); + assert_eq!( + { + let v: String = rq.params().get("test").unwrap(); + v + }, + "foo" + ); + assert_eq!(rq.request_id(), &common::Id::Num(123)); + rq.id() + } + _ => panic!(), + }; + + assert!(state.next_event().is_none()); + + assert_eq!(state.request_by_id(rq_id).unwrap().method(), "foo"); + state.request_by_id(rq_id).unwrap().set_response(Err(common::Error::method_not_found())); + assert!(state.request_by_id(rq_id).is_none()); + + match state.next_event() { + Some(BatchesEvent::ReadyToSend { response, user_param }) => { + assert_eq!(user_param, 8889); + match response { + common::Response::Single(common::Output::Failure(f)) => { + assert_eq!(f.id, common::Id::Num(123)); + } + _ => panic!(), + } + } + _ => panic!(), + }; + } + + #[test] + fn empty_batch() { + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject(common::Request::Batch(Vec::new()), ()); + assert!(state.next_event().is_none()); + } + + #[test] + fn batch_of_notifs() { + let notif1 = common::Notification { + jsonrpc: common::Version::V2, + method: "foo".to_string(), + params: common::Params::None, + }; + + let notif2 = common::Notification { + jsonrpc: common::Version::V2, + method: "bar".to_string(), + params: common::Params::None, + }; + + let mut state = BatchesState::new(); + assert!(state.next_event().is_none()); + state.inject( + common::Request::Batch(vec![ + common::Call::Notification(notif1.clone()), + common::Call::Notification(notif2.clone()), + ]), + 2, + ); + + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, ref user_param }) + if *notification == WsRawNotification::from(notif1) && **user_param == 2 => {} + _ => panic!(), + } + + match state.next_event() { + Some(BatchesEvent::Notification { ref notification, ref user_param }) + if *notification == WsRawNotification::from(notif2) && **user_param == 2 => {} + _ => panic!(), + } + + assert!(state.next_event().is_none()); + } } diff --git a/src/ws/raw/core.rs b/src/ws/raw/core.rs index 60b263b763..701471763a 100644 --- a/src/ws/raw/core.rs +++ b/src/ws/raw/core.rs @@ -36,35 +36,35 @@ use hashbrown::{hash_map::Entry, HashMap}; /// /// See the module-level documentation for more information. pub struct RawServer { - /// Internal "raw" server. - raw: WsTransportServer, - - /// List of requests that are in the progress of being answered. Each batch is associated with - /// the raw request ID, or with `None` if this raw request has been closed. - /// - /// See the documentation of [`BatchesState`][batches::BatchesState] for more information. - batches: batches::BatchesState>, - - /// List of active subscriptions. - /// The identifier is chosen randomly and uniformy distributed. It is never decided by the - /// client. There is therefore no risk of hash collision attack. - subscriptions: HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, - - /// For each raw request ID (i.e. client connection), the number of active subscriptions - /// that are using it. - /// - /// If this reaches 0, we can tell the raw server to close the request. - /// - /// Because we don't have any information about `I`, we have to use a collision-resistant - /// hashing algorithm. This incurs a performance cost that is theoretically avoidable (if `I` - /// is always local), but that should be negligible in practice. - num_subscriptions: HashMap, + /// Internal "raw" server. + raw: WsTransportServer, + + /// List of requests that are in the progress of being answered. Each batch is associated with + /// the raw request ID, or with `None` if this raw request has been closed. + /// + /// See the documentation of [`BatchesState`][batches::BatchesState] for more information. + batches: batches::BatchesState>, + + /// List of active subscriptions. + /// The identifier is chosen randomly and uniformy distributed. It is never decided by the + /// client. There is therefore no risk of hash collision attack. + subscriptions: HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, + + /// For each raw request ID (i.e. client connection), the number of active subscriptions + /// that are using it. + /// + /// If this reaches 0, we can tell the raw server to close the request. + /// + /// Because we don't have any information about `I`, we have to use a collision-resistant + /// hashing algorithm. This incurs a performance cost that is theoretically avoidable (if `I` + /// is always local), but that should be negligible in practice. + num_subscriptions: HashMap, } /// Identifier of a request within a `RawServer`. #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub struct RawServerRequestId { - inner: batches::BatchesElemId, + inner: batches::BatchesElemId, } /// Identifier of a subscription within a [`RawServer`](crate::server::RawServer). @@ -77,32 +77,32 @@ pub struct RawServerSubscriptionId([u8; 32]); /// > be dropped. #[derive(Debug)] pub enum RawServerEvent<'a> { - /// Request is a notification. - Notification(Notification), + /// Request is a notification. + Notification(Notification), - /// Request is a method call. - Request(RawServerRequest<'a>), + /// Request is a method call. + Request(RawServerRequest<'a>), - /// Subscriptions are now ready. - SubscriptionsReady(SubscriptionsReadyIter), + /// Subscriptions are now ready. + SubscriptionsReady(SubscriptionsReadyIter), - /// Subscriptions have been closed because the client closed the connection. - SubscriptionsClosed(SubscriptionsClosedIter), + /// Subscriptions have been closed because the client closed the connection. + SubscriptionsClosed(SubscriptionsClosedIter), } /// Request received by a [`RawServer`](crate::raw::RawServer). pub struct RawServerRequest<'a> { - /// Reference to the request within `self.batches`. - inner: batches::BatchesElem<'a, Option>, + /// Reference to the request within `self.batches`. + inner: batches::BatchesElem<'a, Option>, - /// Reference to the corresponding field in `RawServer`. - raw: &'a mut WsTransportServer, + /// Reference to the corresponding field in `RawServer`. + raw: &'a mut WsTransportServer, - /// Pending subscriptions. - subscriptions: &'a mut HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, + /// Pending subscriptions. + subscriptions: &'a mut HashMap<[u8; 32], SubscriptionState, fnv::FnvBuildHasher>, - /// Reference to the corresponding field in `RawServer`. - num_subscriptions: &'a mut HashMap, + /// Reference to the corresponding field in `RawServer`. + num_subscriptions: &'a mut HashMap, } /// Active subscription of a client towards a server. @@ -110,17 +110,17 @@ pub struct RawServerRequest<'a> { /// > **Note**: Holds a borrow of the `RawServer`. Therefore, must be dropped before the `RawServer` can /// > be dropped. pub struct ServerSubscription<'a> { - server: &'a mut RawServer, - id: [u8; 32], + server: &'a mut RawServer, + id: [u8; 32], } /// Error that can happen when calling `into_subscription`. #[derive(Debug)] pub enum IntoSubscriptionErr { - /// Underlying server doesn't support subscriptions. - NotSupported, - /// Request has already been closed by the client. - Closed, + /// Underlying server doesn't support subscriptions. + NotSupported, + /// Request has already been closed by the client. + Closed, } /// Iterator for the list of subscriptions that are now ready. @@ -134,410 +134,372 @@ pub struct SubscriptionsClosedIter(vec::IntoIter); /// Internal structure. Information about a subscription. #[derive(Debug)] struct SubscriptionState { - /// Identifier of the connection in the raw server. - raw_id: I, - /// Method that triggered the subscription. Must be sent to the client at each notification. - method: String, - /// If true, the subscription shouldn't accept any notification push because the confirmation - /// hasn't been sent to the client yet. Once this has switched to `false`, it can never be - /// switched to `true` ever again. - pending: bool, + /// Identifier of the connection in the raw server. + raw_id: I, + /// Method that triggered the subscription. Must be sent to the client at each notification. + method: String, + /// If true, the subscription shouldn't accept any notification push because the confirmation + /// hasn't been sent to the client yet. Once this has switched to `false`, it can never be + /// switched to `true` ever again. + pending: bool, } impl RawServer { - /// Starts a [`RawServer`](crate::raw::RawServer) using the given raw server internally. - pub fn new(raw: WsTransportServer) -> RawServer { - RawServer { - raw, - batches: batches::BatchesState::new(), - subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), - num_subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), - } - } + /// Starts a [`RawServer`](crate::raw::RawServer) using the given raw server internally. + pub fn new(raw: WsTransportServer) -> RawServer { + RawServer { + raw, + batches: batches::BatchesState::new(), + subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), + num_subscriptions: HashMap::with_capacity_and_hasher(8, Default::default()), + } + } } impl RawServer { - /// Returns a `Future` resolving to the next event that this server generates. - pub async fn next_event<'a>(&'a mut self) -> RawServerEvent<'a> { - let request_id = loop { - match self.batches.next_event() { - None => {} - Some(batches::BatchesEvent::Notification { notification, .. }) => { - return RawServerEvent::Notification(notification) - } - Some(batches::BatchesEvent::Request(inner)) => { - break RawServerRequestId { inner: inner.id() }; - } - Some(batches::BatchesEvent::ReadyToSend { - response, - user_param: Some(raw_request_id), - }) => { - // If we have any active subscription, we only use `send` to not close the - // client request. - if self.num_subscriptions.contains_key(&raw_request_id) { - debug_assert!(self.raw.supports_resuming(&raw_request_id).unwrap_or(false)); - let _ = self.raw.send(&raw_request_id, &response).await; - // TODO: that's O(n) - let mut ready = Vec::new(); // TODO: with_capacity - for (sub_id, sub) in self.subscriptions.iter_mut() { - if sub.raw_id == raw_request_id { - ready.push(RawServerSubscriptionId(sub_id.clone())); - sub.pending = false; - } - } - debug_assert!(!ready.is_empty()); // TODO: assert that capacity == len - return RawServerEvent::SubscriptionsReady(SubscriptionsReadyIter( - ready.into_iter(), - )); - } else { - let _ = self.raw.finish(&raw_request_id, Some(&response)).await; - } - continue; - } - Some(batches::BatchesEvent::ReadyToSend { - response: _, - user_param: None, - }) => { - // This situation happens if the connection has been closed by the client. - // Clients who close their connection. - continue; - } - }; - - match self.raw.next_request().await { - TransportServerEvent::Request { id, request } => { - self.batches.inject(request, Some(id)) - } - TransportServerEvent::Closed(raw_id) => { - // The client has a closed their connection. We eliminate all traces of the - // raw request ID from our state. - // TODO: this has an O(n) complexity; make sure that this is not attackable - for ud in self.batches.batches() { - if ud.as_ref() == Some(&raw_id) { - *ud = None; - } - } - - // Additionally, active subscriptions that were using this connection are - // closed. - if let Some(_) = self.num_subscriptions.remove(&raw_id) { - let ids = self - .subscriptions - .iter() - .filter(|(_, v)| v.raw_id == raw_id) - .map(|(k, _)| RawServerSubscriptionId(*k)) - .collect::>(); - for id in &ids { - let _ = self.subscriptions.remove(&id.0); - } - return RawServerEvent::SubscriptionsClosed(SubscriptionsClosedIter( - ids.into_iter(), - )); - } - } - }; - }; - - RawServerEvent::Request(self.request_by_id(&request_id).unwrap()) - } - - /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) - /// by its id. - /// - /// Note that previous notifications don't have an ID and can't be accessed with this method. - /// - /// Returns `None` if the request ID is invalid or if the request has already been answered in - /// the past. - pub fn request_by_id<'a>( - &'a mut self, - id: &RawServerRequestId, - ) -> Option> { - Some(RawServerRequest { - inner: self.batches.request_by_id(id.inner)?, - raw: &mut self.raw, - subscriptions: &mut self.subscriptions, - num_subscriptions: &mut self.num_subscriptions, - }) - } - - /// Returns a subscription previously returned by - /// [`into_subscription`](crate::raw::server::RawServerRequest::into_subscription). - pub fn subscription_by_id( - &mut self, - id: RawServerSubscriptionId, - ) -> Option { - if self.subscriptions.contains_key(&id.0) { - Some(ServerSubscription { - server: self, - id: id.0, - }) - } else { - None - } - } + /// Returns a `Future` resolving to the next event that this server generates. + pub async fn next_event<'a>(&'a mut self) -> RawServerEvent<'a> { + let request_id = loop { + match self.batches.next_event() { + None => {} + Some(batches::BatchesEvent::Notification { notification, .. }) => { + return RawServerEvent::Notification(notification) + } + Some(batches::BatchesEvent::Request(inner)) => { + break RawServerRequestId { inner: inner.id() }; + } + Some(batches::BatchesEvent::ReadyToSend { response, user_param: Some(raw_request_id) }) => { + // If we have any active subscription, we only use `send` to not close the + // client request. + if self.num_subscriptions.contains_key(&raw_request_id) { + debug_assert!(self.raw.supports_resuming(&raw_request_id).unwrap_or(false)); + let _ = self.raw.send(&raw_request_id, &response).await; + // TODO: that's O(n) + let mut ready = Vec::new(); // TODO: with_capacity + for (sub_id, sub) in self.subscriptions.iter_mut() { + if sub.raw_id == raw_request_id { + ready.push(RawServerSubscriptionId(sub_id.clone())); + sub.pending = false; + } + } + debug_assert!(!ready.is_empty()); // TODO: assert that capacity == len + return RawServerEvent::SubscriptionsReady(SubscriptionsReadyIter(ready.into_iter())); + } else { + let _ = self.raw.finish(&raw_request_id, Some(&response)).await; + } + continue; + } + Some(batches::BatchesEvent::ReadyToSend { response: _, user_param: None }) => { + // This situation happens if the connection has been closed by the client. + // Clients who close their connection. + continue; + } + }; + + match self.raw.next_request().await { + TransportServerEvent::Request { id, request } => self.batches.inject(request, Some(id)), + TransportServerEvent::Closed(raw_id) => { + // The client has a closed their connection. We eliminate all traces of the + // raw request ID from our state. + // TODO: this has an O(n) complexity; make sure that this is not attackable + for ud in self.batches.batches() { + if ud.as_ref() == Some(&raw_id) { + *ud = None; + } + } + + // Additionally, active subscriptions that were using this connection are + // closed. + if let Some(_) = self.num_subscriptions.remove(&raw_id) { + let ids = self + .subscriptions + .iter() + .filter(|(_, v)| v.raw_id == raw_id) + .map(|(k, _)| RawServerSubscriptionId(*k)) + .collect::>(); + for id in &ids { + let _ = self.subscriptions.remove(&id.0); + } + return RawServerEvent::SubscriptionsClosed(SubscriptionsClosedIter(ids.into_iter())); + } + } + }; + }; + + RawServerEvent::Request(self.request_by_id(&request_id).unwrap()) + } + + /// Returns a request previously returned by [`next_event`](crate::raw::RawServer::next_event) + /// by its id. + /// + /// Note that previous notifications don't have an ID and can't be accessed with this method. + /// + /// Returns `None` if the request ID is invalid or if the request has already been answered in + /// the past. + pub fn request_by_id<'a>(&'a mut self, id: &RawServerRequestId) -> Option> { + Some(RawServerRequest { + inner: self.batches.request_by_id(id.inner)?, + raw: &mut self.raw, + subscriptions: &mut self.subscriptions, + num_subscriptions: &mut self.num_subscriptions, + }) + } + + /// Returns a subscription previously returned by + /// [`into_subscription`](crate::raw::server::RawServerRequest::into_subscription). + pub fn subscription_by_id(&mut self, id: RawServerSubscriptionId) -> Option { + if self.subscriptions.contains_key(&id.0) { + Some(ServerSubscription { server: self, id: id.0 }) + } else { + None + } + } } impl From for RawServer { - fn from(inner: WsTransportServer) -> Self { - RawServer::new(inner) - } + fn from(inner: WsTransportServer) -> Self { + RawServer::new(inner) + } } impl<'a> RawServerRequest<'a> { - /// Returns the id of the request. - /// - /// If this request object is dropped, you can retreive it again later by calling - /// [`request_by_id`](crate::raw::RawServer::request_by_id). - pub fn id(&self) -> RawServerRequestId { - RawServerRequestId { - inner: self.inner.id(), - } - } - - /// Returns the id that the client sent out. - // TODO: can return None, which is wrong - pub fn request_id(&self) -> &common::Id { - self.inner.request_id() - } - - /// Returns the method of this request. - pub fn method(&self) -> &str { - self.inner.method() - } - - /// Returns the parameters of the request, as a `common::Params`. - pub fn params(&self) -> Params { - self.inner.params() - } + /// Returns the id of the request. + /// + /// If this request object is dropped, you can retreive it again later by calling + /// [`request_by_id`](crate::raw::RawServer::request_by_id). + pub fn id(&self) -> RawServerRequestId { + RawServerRequestId { inner: self.inner.id() } + } + + /// Returns the id that the client sent out. + // TODO: can return None, which is wrong + pub fn request_id(&self) -> &common::Id { + self.inner.request_id() + } + + /// Returns the method of this request. + pub fn method(&self) -> &str { + self.inner.method() + } + + /// Returns the parameters of the request, as a `common::Params`. + pub fn params(&self) -> Params { + self.inner.params() + } } impl<'a> RawServerRequest<'a> { - /// Send back a response. - /// - /// If this request is part of a batch: - /// - /// - If all requests of the batch have been responded to, then the response is actively - /// sent out. - /// - Otherwise, this response is buffered. - /// - /// > **Note**: This method is implemented in a way that doesn't wait for long to send the - /// > response. While calling this method will block your entire server, it - /// > should only block it for a short amount of time. See also [the equivalent - /// > method](crate::transport::TransportServer::finish) on the - /// > [`TransportServer`](crate::transport::TransportServer) trait. - /// - pub fn respond(self, response: Result) { - self.inner.set_response(response); - //unimplemented!(); - // TODO: actually send out response? - } - - /// Sends back a response similar to `respond`, then returns a [`RawServerSubscriptionId`] object - /// that allows you to push more data on the corresponding connection. - /// - /// The [`RawServerSubscriptionId`] corresponds to the identifier that has been sent back to the - /// client. If the client refers to this subscription id, you can turn it into a - /// [`RawServerSubscriptionId`] using - /// [`from_wire_message`](RawServerSubscriptionId::from_wire_message). - /// - /// After the request has been turned into a subscription, the subscription might be in - /// "pending mode". Pushing notifications on that subscription will return an error. This - /// mechanism is necessary because the subscription request might be part of a batch, and all - /// the requests of that batch have to be processed before informing the client of the start - /// of the subscription. - /// - /// Returns an error and doesn't do anything if the underlying server doesn't support - /// subscriptions, or if the connection has already been closed by the client. - /// - /// > **Note**: Because of borrowing issues, we return a [`RawServerSubscriptionId`] rather than - /// > a [`ServerSubscription`]. You will have to call - /// > [`subscription_by_id`](RawServer::subscription_by_id) in order to manipulate the - /// > subscription. - // TODO: solve the note - pub fn into_subscription(mut self) -> Result { - let raw_request_id = match self.inner.user_param().clone() { - Some(id) => id, - None => return Err(IntoSubscriptionErr::Closed), - }; - - if !self.raw.supports_resuming(&raw_request_id).unwrap_or(false) { - return Err(IntoSubscriptionErr::NotSupported); - } - - loop { - let new_subscr_id: [u8; 32] = rand::random(); - - match self.subscriptions.entry(new_subscr_id) { - Entry::Vacant(e) => e.insert(SubscriptionState { - raw_id: raw_request_id.clone(), - method: self.inner.method().to_owned(), - pending: true, - }), - // Continue looping if we accidentally chose an existing ID. - Entry::Occupied(_) => continue, - }; - - self.num_subscriptions - .entry(raw_request_id) - .and_modify(|e| { - *e = NonZeroUsize::new(e.get() + 1) - .expect("we add 1 to an existing non-zero value; qed"); - }) - .or_insert_with(|| NonZeroUsize::new(1).expect("1 != 0")); - - let subscr_id_string = bs58::encode(&new_subscr_id).into_string(); - self.inner.set_response(Ok(subscr_id_string.into())); - break Ok(RawServerSubscriptionId(new_subscr_id)); - } - } + /// Send back a response. + /// + /// If this request is part of a batch: + /// + /// - If all requests of the batch have been responded to, then the response is actively + /// sent out. + /// - Otherwise, this response is buffered. + /// + /// > **Note**: This method is implemented in a way that doesn't wait for long to send the + /// > response. While calling this method will block your entire server, it + /// > should only block it for a short amount of time. See also [the equivalent + /// > method](crate::transport::TransportServer::finish) on the + /// > [`TransportServer`](crate::transport::TransportServer) trait. + /// + pub fn respond(self, response: Result) { + self.inner.set_response(response); + //unimplemented!(); + // TODO: actually send out response? + } + + /// Sends back a response similar to `respond`, then returns a [`RawServerSubscriptionId`] object + /// that allows you to push more data on the corresponding connection. + /// + /// The [`RawServerSubscriptionId`] corresponds to the identifier that has been sent back to the + /// client. If the client refers to this subscription id, you can turn it into a + /// [`RawServerSubscriptionId`] using + /// [`from_wire_message`](RawServerSubscriptionId::from_wire_message). + /// + /// After the request has been turned into a subscription, the subscription might be in + /// "pending mode". Pushing notifications on that subscription will return an error. This + /// mechanism is necessary because the subscription request might be part of a batch, and all + /// the requests of that batch have to be processed before informing the client of the start + /// of the subscription. + /// + /// Returns an error and doesn't do anything if the underlying server doesn't support + /// subscriptions, or if the connection has already been closed by the client. + /// + /// > **Note**: Because of borrowing issues, we return a [`RawServerSubscriptionId`] rather than + /// > a [`ServerSubscription`]. You will have to call + /// > [`subscription_by_id`](RawServer::subscription_by_id) in order to manipulate the + /// > subscription. + // TODO: solve the note + pub fn into_subscription(mut self) -> Result { + let raw_request_id = match self.inner.user_param().clone() { + Some(id) => id, + None => return Err(IntoSubscriptionErr::Closed), + }; + + if !self.raw.supports_resuming(&raw_request_id).unwrap_or(false) { + return Err(IntoSubscriptionErr::NotSupported); + } + + loop { + let new_subscr_id: [u8; 32] = rand::random(); + + match self.subscriptions.entry(new_subscr_id) { + Entry::Vacant(e) => e.insert(SubscriptionState { + raw_id: raw_request_id.clone(), + method: self.inner.method().to_owned(), + pending: true, + }), + // Continue looping if we accidentally chose an existing ID. + Entry::Occupied(_) => continue, + }; + + self.num_subscriptions + .entry(raw_request_id) + .and_modify(|e| { + *e = NonZeroUsize::new(e.get() + 1).expect("we add 1 to an existing non-zero value; qed"); + }) + .or_insert_with(|| NonZeroUsize::new(1).expect("1 != 0")); + + let subscr_id_string = bs58::encode(&new_subscr_id).into_string(); + self.inner.set_response(Ok(subscr_id_string.into())); + break Ok(RawServerSubscriptionId(new_subscr_id)); + } + } } impl<'a> fmt::Debug for RawServerRequest<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("RawServerRequest") - .field("request_id", &self.request_id()) - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("RawServerRequest") + .field("request_id", &self.request_id()) + .field("method", &self.method()) + .field("params", &self.params()) + .finish() + } } impl RawServerSubscriptionId { - /// When the client sends a unsubscribe message containing a subscription ID, this function can - /// be used to parse it into a [`RawServerSubscriptionId`]. - pub fn from_wire_message(params: &JsonValue) -> Result { - let string = match params { - JsonValue::String(s) => s, - _ => return Err(()), - }; - - let decoded = bs58::decode(&string).into_vec().map_err(|_| ())?; - if decoded.len() > 32 { - return Err(()); - } - - let mut out = [0; 32]; - out[(32 - decoded.len())..].copy_from_slice(&decoded); - // TODO: write a test to check that encoding/decoding match - Ok(RawServerSubscriptionId(out)) - } + /// When the client sends a unsubscribe message containing a subscription ID, this function can + /// be used to parse it into a [`RawServerSubscriptionId`]. + pub fn from_wire_message(params: &JsonValue) -> Result { + let string = match params { + JsonValue::String(s) => s, + _ => return Err(()), + }; + + let decoded = bs58::decode(&string).into_vec().map_err(|_| ())?; + if decoded.len() > 32 { + return Err(()); + } + + let mut out = [0; 32]; + out[(32 - decoded.len())..].copy_from_slice(&decoded); + // TODO: write a test to check that encoding/decoding match + Ok(RawServerSubscriptionId(out)) + } } impl<'a> ServerSubscription<'a> { - /// Returns the id of the subscription. - /// - /// If this subscription object is dropped, you can retreive it again later by calling - /// [`subscription_by_id`](crate::raw::RawServer::subscription_by_id). - pub fn id(&self) -> RawServerSubscriptionId { - RawServerSubscriptionId(self.id) - } - - /// Pushes a notification. - /// - // TODO: refactor to progate the error. - pub async fn push(self, message: impl Into) { - let subscription_state = self.server.subscriptions.get(&self.id).unwrap(); - if subscription_state.pending { - return; // TODO: notify user with error - } - - let output = common::SubscriptionNotif { - jsonrpc: common::Version::V2, - method: subscription_state.method.clone(), - params: common::SubscriptionNotifParams { - subscription: common::SubscriptionId::Str(bs58::encode(&self.id).into_string()), - result: message.into(), - }, - }; - let response = common::Response::Notif(output); - - let _ = self - .server - .raw - .send(&subscription_state.raw_id, &response) - .await; // TODO: error handling? - } - - /// Destroys the subscription object. - /// - /// This does not send any message back to the client. Instead, this function is supposed to - /// be used in reaction to the client requesting to be unsubscribed. - /// - /// If this was the last active subscription, also closes the connection ("raw request") with - /// the client. - pub async fn close(self) { - let subscription_state = self.server.subscriptions.remove(&self.id).unwrap(); - - // Check if we're the last subscription on this connection. - // Remove entry from `num_subscriptions` if so. - let is_last_sub = match self - .server - .num_subscriptions - .entry(subscription_state.raw_id.clone()) - { - Entry::Vacant(_) => unreachable!(), - Entry::Occupied(ref mut e) if e.get().get() >= 2 => { - let e = e.get_mut(); - *e = NonZeroUsize::new(e.get() - 1).expect("e is >= 2; qed"); - false - } - Entry::Occupied(e) => { - e.remove(); - true - } - }; - - // If the subscription is pending, we have yet to send something back on that connection - // and thus shouldn't close it. - // When the response is sent back later, the code will realize that `num_subscriptions` - // is zero/empty and call `finish`. - if is_last_sub && !subscription_state.pending { - let _ = self - .server - .raw - .finish(&subscription_state.raw_id, None) - .await; - } - } + /// Returns the id of the subscription. + /// + /// If this subscription object is dropped, you can retreive it again later by calling + /// [`subscription_by_id`](crate::raw::RawServer::subscription_by_id). + pub fn id(&self) -> RawServerSubscriptionId { + RawServerSubscriptionId(self.id) + } + + /// Pushes a notification. + /// + // TODO: refactor to progate the error. + pub async fn push(self, message: impl Into) { + let subscription_state = self.server.subscriptions.get(&self.id).unwrap(); + if subscription_state.pending { + return; // TODO: notify user with error + } + + let output = common::SubscriptionNotif { + jsonrpc: common::Version::V2, + method: subscription_state.method.clone(), + params: common::SubscriptionNotifParams { + subscription: common::SubscriptionId::Str(bs58::encode(&self.id).into_string()), + result: message.into(), + }, + }; + let response = common::Response::Notif(output); + + let _ = self.server.raw.send(&subscription_state.raw_id, &response).await; // TODO: error handling? + } + + /// Destroys the subscription object. + /// + /// This does not send any message back to the client. Instead, this function is supposed to + /// be used in reaction to the client requesting to be unsubscribed. + /// + /// If this was the last active subscription, also closes the connection ("raw request") with + /// the client. + pub async fn close(self) { + let subscription_state = self.server.subscriptions.remove(&self.id).unwrap(); + + // Check if we're the last subscription on this connection. + // Remove entry from `num_subscriptions` if so. + let is_last_sub = match self.server.num_subscriptions.entry(subscription_state.raw_id.clone()) { + Entry::Vacant(_) => unreachable!(), + Entry::Occupied(ref mut e) if e.get().get() >= 2 => { + let e = e.get_mut(); + *e = NonZeroUsize::new(e.get() - 1).expect("e is >= 2; qed"); + false + } + Entry::Occupied(e) => { + e.remove(); + true + } + }; + + // If the subscription is pending, we have yet to send something back on that connection + // and thus shouldn't close it. + // When the response is sent back later, the code will realize that `num_subscriptions` + // is zero/empty and call `finish`. + if is_last_sub && !subscription_state.pending { + let _ = self.server.raw.finish(&subscription_state.raw_id, None).await; + } + } } impl fmt::Display for IntoSubscriptionErr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - IntoSubscriptionErr::NotSupported => { - write!(f, "Underlying server doesn't support subscriptions") - } - IntoSubscriptionErr::Closed => write!(f, "Request is already closed"), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + IntoSubscriptionErr::NotSupported => write!(f, "Underlying server doesn't support subscriptions"), + IntoSubscriptionErr::Closed => write!(f, "Request is already closed"), + } + } } impl std::error::Error for IntoSubscriptionErr {} impl Iterator for SubscriptionsReadyIter { - type Item = RawServerSubscriptionId; + type Item = RawServerSubscriptionId; - fn next(&mut self) -> Option { - self.0.next() - } + fn next(&mut self) -> Option { + self.0.next() + } - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } impl ExactSizeIterator for SubscriptionsReadyIter {} impl Iterator for SubscriptionsClosedIter { - type Item = RawServerSubscriptionId; + type Item = RawServerSubscriptionId; - fn next(&mut self) -> Option { - self.0.next() - } + fn next(&mut self) -> Option { + self.0.next() + } - fn size_hint(&self) -> (usize, Option) { - self.0.size_hint() - } + fn size_hint(&self) -> (usize, Option) { + self.0.size_hint() + } } impl ExactSizeIterator for SubscriptionsClosedIter {} diff --git a/src/ws/raw/mod.rs b/src/ws/raw/mod.rs index 52239bd03b..03109f8c86 100644 --- a/src/ws/raw/mod.rs +++ b/src/ws/raw/mod.rs @@ -8,9 +8,7 @@ mod typed_rp; #[cfg(test)] mod tests; -pub use self::core::{ - RawServer, RawServerEvent, RawServerRequest, RawServerRequestId, RawServerSubscriptionId, -}; +pub use self::core::{RawServer, RawServerEvent, RawServerRequest, RawServerRequestId, RawServerSubscriptionId}; pub use self::notification::Notification; pub use self::params::{Iter as ParamsIter, ParamKey as ParamsKey, Params}; pub use self::typed_rp::TypedResponder; diff --git a/src/ws/raw/notification.rs b/src/ws/raw/notification.rs index f92d83b791..5ed5474f6a 100644 --- a/src/ws/raw/notification.rs +++ b/src/ws/raw/notification.rs @@ -35,34 +35,31 @@ use core::fmt; pub struct Notification(common::Notification); impl From for Notification { - fn from(notif: common::Notification) -> Notification { - Notification(notif) - } + fn from(notif: common::Notification) -> Notification { + Notification(notif) + } } impl From for common::Notification { - fn from(notif: Notification) -> common::Notification { - notif.0 - } + fn from(notif: Notification) -> common::Notification { + notif.0 + } } impl Notification { - /// Returns the method of this notification. - pub fn method(&self) -> &str { - &self.0.method - } + /// Returns the method of this notification. + pub fn method(&self) -> &str { + &self.0.method + } - /// Returns the parameters of the notification. - pub fn params(&self) -> Params { - Params::from(&self.0.params) - } + /// Returns the parameters of the notification. + pub fn params(&self) -> Params { + Params::from(&self.0.params) + } } impl fmt::Debug for Notification { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Notification") - .field("method", &self.method()) - .field("params", &self.params()) - .finish() - } + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Notification").field("method", &self.method()).field("params", &self.params()).finish() + } } diff --git a/src/ws/raw/params.rs b/src/ws/raw/params.rs index fa91b460c4..90a3a7e782 100644 --- a/src/ws/raw/params.rs +++ b/src/ws/raw/params.rs @@ -32,144 +32,144 @@ use core::fmt; /// Access to the parameters of a request. #[derive(Copy, Clone)] pub struct Params<'a> { - /// Raw parameters of the request. - params: &'a common::Params, + /// Raw parameters of the request. + params: &'a common::Params, } /// Key referring to a potential parameter of a request. pub enum ParamKey<'a> { - /// String key. Only valid when the parameters list is a map. - String(&'a str), - /// Integer key. Only valid when the parameters list is an array. - Index(usize), + /// String key. Only valid when the parameters list is a map. + String(&'a str), + /// Integer key. Only valid when the parameters list is an array. + Index(usize), } impl<'a> Params<'a> { - /// Wraps around a `&common::Params` and provides utility functions for the user. - pub(crate) fn from(params: &'a common::Params) -> Params<'a> { - Params { params } - } - - /// Returns a parameter of the request by name and decodes it. - /// - /// Returns an error if the parameter doesn't exist or is of the wrong type. - pub fn get<'k, T>(self, param: impl Into>) -> Result - where - T: serde::de::DeserializeOwned, - { - let val = self.get_raw(param).ok_or(())?; - serde_json::from_value(val.clone()).map_err(|_| ()) - } - - /// Returns a parameter of the request by name. - pub fn get_raw<'k>(self, param: impl Into>) -> Option<&'a common::JsonValue> { - match (self.params, param.into()) { - (common::Params::None, _) => None, - (common::Params::Map(map), ParamKey::String(key)) => map.get(key), - (common::Params::Map(_), ParamKey::Index(_)) => None, - (common::Params::Array(_), ParamKey::String(_)) => None, - (common::Params::Array(array), ParamKey::Index(index)) => { - if index < array.len() { - Some(&array[index]) - } else { - None - } - } - } - } + /// Wraps around a `&common::Params` and provides utility functions for the user. + pub(crate) fn from(params: &'a common::Params) -> Params<'a> { + Params { params } + } + + /// Returns a parameter of the request by name and decodes it. + /// + /// Returns an error if the parameter doesn't exist or is of the wrong type. + pub fn get<'k, T>(self, param: impl Into>) -> Result + where + T: serde::de::DeserializeOwned, + { + let val = self.get_raw(param).ok_or(())?; + serde_json::from_value(val.clone()).map_err(|_| ()) + } + + /// Returns a parameter of the request by name. + pub fn get_raw<'k>(self, param: impl Into>) -> Option<&'a common::JsonValue> { + match (self.params, param.into()) { + (common::Params::None, _) => None, + (common::Params::Map(map), ParamKey::String(key)) => map.get(key), + (common::Params::Map(_), ParamKey::Index(_)) => None, + (common::Params::Array(_), ParamKey::String(_)) => None, + (common::Params::Array(array), ParamKey::Index(index)) => { + if index < array.len() { + Some(&array[index]) + } else { + None + } + } + } + } } impl<'a> IntoIterator for Params<'a> { - type Item = (ParamKey<'a>, &'a common::JsonValue); - type IntoIter = Iter<'a>; - - fn into_iter(self) -> Self::IntoIter { - Iter(match self.params { - common::Params::None => IterInner::Empty, - common::Params::Array(arr) => IterInner::Array(arr.iter()), - common::Params::Map(map) => IterInner::Map(map.iter()), - }) - } + type Item = (ParamKey<'a>, &'a common::JsonValue); + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Self::IntoIter { + Iter(match self.params { + common::Params::None => IterInner::Empty, + common::Params::Array(arr) => IterInner::Array(arr.iter()), + common::Params::Map(map) => IterInner::Map(map.iter()), + }) + } } impl<'a> fmt::Debug for Params<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_map().entries(self.into_iter()).finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_map().entries(self.into_iter()).finish() + } } impl<'a> AsRef for Params<'a> { - fn as_ref(&self) -> &common::Params { - self.params - } + fn as_ref(&self) -> &common::Params { + self.params + } } impl<'a> From> for &'a common::Params { - fn from(params: Params<'a>) -> &'a common::Params { - params.params - } + fn from(params: Params<'a>) -> &'a common::Params { + params.params + } } impl<'a> From<&'a str> for ParamKey<'a> { - fn from(s: &'a str) -> Self { - ParamKey::String(s) - } + fn from(s: &'a str) -> Self { + ParamKey::String(s) + } } impl<'a> From<&'a String> for ParamKey<'a> { - fn from(s: &'a String) -> Self { - ParamKey::String(&s[..]) - } + fn from(s: &'a String) -> Self { + ParamKey::String(&s[..]) + } } impl<'a> From for ParamKey<'a> { - fn from(i: usize) -> Self { - ParamKey::Index(i) - } + fn from(i: usize) -> Self { + ParamKey::Index(i) + } } impl<'a> fmt::Debug for ParamKey<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - ParamKey::String(s) => fmt::Debug::fmt(s, f), - ParamKey::Index(s) => fmt::Debug::fmt(s, f), - } - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + ParamKey::String(s) => fmt::Debug::fmt(s, f), + ParamKey::Index(s) => fmt::Debug::fmt(s, f), + } + } } /// Iterator to all the parameters of a request. pub struct Iter<'a>(IterInner<'a>); enum IterInner<'a> { - Empty, - Map(serde_json::map::Iter<'a>), - Array(std::slice::Iter<'a, serde_json::Value>), + Empty, + Map(serde_json::map::Iter<'a>), + Array(std::slice::Iter<'a, serde_json::Value>), } impl<'a> Iterator for Iter<'a> { - type Item = (ParamKey<'a>, &'a common::JsonValue); - - fn next(&mut self) -> Option { - match &mut self.0 { - IterInner::Empty => None, - IterInner::Map(iter) => iter.next().map(|(k, v)| (ParamKey::String(&k[..]), v)), - IterInner::Array(iter) => iter.next().map(|v| (ParamKey::String(""), v)), - } - } - - fn size_hint(&self) -> (usize, Option) { - match &self.0 { - IterInner::Empty => (0, Some(0)), - IterInner::Map(iter) => iter.size_hint(), - IterInner::Array(iter) => iter.size_hint(), - } - } + type Item = (ParamKey<'a>, &'a common::JsonValue); + + fn next(&mut self) -> Option { + match &mut self.0 { + IterInner::Empty => None, + IterInner::Map(iter) => iter.next().map(|(k, v)| (ParamKey::String(&k[..]), v)), + IterInner::Array(iter) => iter.next().map(|v| (ParamKey::String(""), v)), + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.0 { + IterInner::Empty => (0, Some(0)), + IterInner::Map(iter) => iter.size_hint(), + IterInner::Array(iter) => iter.size_hint(), + } + } } impl<'a> ExactSizeIterator for Iter<'a> {} impl<'a> fmt::Debug for Iter<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("ParamsIter").finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("ParamsIter").finish() + } } diff --git a/src/ws/raw/tests.rs b/src/ws/raw/tests.rs index f24f04767e..bb1de0410d 100644 --- a/src/ws/raw/tests.rs +++ b/src/ws/raw/tests.rs @@ -34,71 +34,61 @@ use serde_json::Value; use std::net::SocketAddr; async fn raw_server() -> (RawWsServer, SocketAddr) { - let server = WsTransportServer::builder("127.0.0.1:0".parse().unwrap()) - .build() - .await - .unwrap(); - let addr = *server.local_addr(); - (server.into(), addr) + let server = WsTransportServer::builder("127.0.0.1:0".parse().unwrap()).build().await.unwrap(); + let addr = *server.local_addr(); + (server.into(), addr) } #[tokio::test] async fn request_work() { - let (mut server, server_addr) = raw_server().await; + let (mut server, server_addr) = raw_server().await; - tokio::spawn(async move { - let mut client = WsTransportClient::new(&to_ws_uri_string(server_addr)) - .await - .unwrap(); - let call = Call::MethodCall(MethodCall { - jsonrpc: Version::V2, - method: "hello_world".to_owned(), - params: Params::Array(vec![Value::from(1), Value::from(2)]), - id: common::Id::Num(3), - }); - client.send_request(Request::Single(call)).await.unwrap(); - }); + tokio::spawn(async move { + let mut client = WsTransportClient::new(&to_ws_uri_string(server_addr)).await.unwrap(); + let call = Call::MethodCall(MethodCall { + jsonrpc: Version::V2, + method: "hello_world".to_owned(), + params: Params::Array(vec![Value::from(1), Value::from(2)]), + id: common::Id::Num(3), + }); + client.send_request(Request::Single(call)).await.unwrap(); + }); - match server.next_event().await { - RawWsServerEvent::Request(r) => { - assert_eq!(r.method(), "hello_world"); - let p1: i32 = r.params().get(0).unwrap(); - let p2: i32 = r.params().get(1).unwrap(); - assert_eq!(p1, 1); - assert_eq!(p2, 2); - assert_eq!(r.request_id(), &common::Id::Num(3)); - } - e @ _ => panic!("Invalid server event: {:?} expected Request", e), - } + match server.next_event().await { + RawWsServerEvent::Request(r) => { + assert_eq!(r.method(), "hello_world"); + let p1: i32 = r.params().get(0).unwrap(); + let p2: i32 = r.params().get(1).unwrap(); + assert_eq!(p1, 1); + assert_eq!(p2, 2); + assert_eq!(r.request_id(), &common::Id::Num(3)); + } + e @ _ => panic!("Invalid server event: {:?} expected Request", e), + } } #[tokio::test] async fn notification_work() { - let (mut server, server_addr) = raw_server().await; + let (mut server, server_addr) = raw_server().await; - tokio::spawn(async move { - let mut client = WsTransportClient::new(&to_ws_uri_string(server_addr)) - .await - .unwrap(); - let n = Notification { - jsonrpc: Version::V2, - method: "hello_world".to_owned(), - params: Params::Array(vec![Value::from("lo"), Value::from(2)]), - }; - client - .send_request(Request::Single(Call::Notification(n))) - .await - .unwrap(); - }); + tokio::spawn(async move { + let mut client = WsTransportClient::new(&to_ws_uri_string(server_addr)).await.unwrap(); + let n = Notification { + jsonrpc: Version::V2, + method: "hello_world".to_owned(), + params: Params::Array(vec![Value::from("lo"), Value::from(2)]), + }; + client.send_request(Request::Single(Call::Notification(n))).await.unwrap(); + }); - match server.next_event().await { - RawWsServerEvent::Notification(r) => { - assert_eq!(r.method(), "hello_world"); - let p1: String = r.params().get(0).unwrap(); - let p2: i32 = r.params().get(1).unwrap(); - assert_eq!(p1, "lo"); - assert_eq!(p2, 2); - } - e @ _ => panic!("Invalid server event: {:?} expected Notification", e), - } + match server.next_event().await { + RawWsServerEvent::Notification(r) => { + assert_eq!(r.method(), "hello_world"); + let p1: String = r.params().get(0).unwrap(); + let p2: i32 = r.params().get(1).unwrap(); + assert_eq!(p1, "lo"); + assert_eq!(p2, 2); + } + e @ _ => panic!("Invalid server event: {:?} expected Notification", e), + } } diff --git a/src/ws/raw/typed_rp.rs b/src/ws/raw/typed_rp.rs index 3f355161b3..aaa1f3da73 100644 --- a/src/ws/raw/typed_rp.rs +++ b/src/ws/raw/typed_rp.rs @@ -29,43 +29,39 @@ use core::marker::PhantomData; /// Allows responding to a server request in a more elegant and strongly-typed fashion. pub struct TypedResponder<'a, T> { - /// The request to answer. - rq: RawServerRequest<'a>, - /// Marker that pins the type of the response. - response_ty: PhantomData, + /// The request to answer. + rq: RawServerRequest<'a>, + /// Marker that pins the type of the response. + response_ty: PhantomData, } impl<'a, T> From> for TypedResponder<'a, T> { - fn from(rq: RawServerRequest<'a>) -> TypedResponder<'a, T> { - TypedResponder { - rq, - response_ty: PhantomData, - } - } + fn from(rq: RawServerRequest<'a>) -> TypedResponder<'a, T> { + TypedResponder { rq, response_ty: PhantomData } + } } impl<'a, T> TypedResponder<'a, T> where - T: serde::Serialize, + T: serde::Serialize, { - /// Returns a successful response. - pub fn ok(self, response: impl Into) { - self.respond(Ok(response)) - } + /// Returns a successful response. + pub fn ok(self, response: impl Into) { + self.respond(Ok(response)) + } - /// Returns an erroneous response. - pub fn err(self, err: crate::common::Error) { - self.respond(Err::(err)) - } + /// Returns an erroneous response. + pub fn err(self, err: crate::common::Error) { + self.respond(Err::(err)) + } - /// Returns a response. - pub fn respond(self, response: Result, crate::common::Error>) { - let response = match response { - Ok(v) => crate::common::to_value(v.into()) - .map_err(|_| crate::common::Error::internal_error()), - Err(err) => Err(err), - }; + /// Returns a response. + pub fn respond(self, response: Result, crate::common::Error>) { + let response = match response { + Ok(v) => crate::common::to_value(v.into()).map_err(|_| crate::common::Error::internal_error()), + Err(err) => Err(err), + }; - self.rq.respond(response) - } + self.rq.respond(response) + } } diff --git a/src/ws/server.rs b/src/ws/server.rs index c2b995fd5b..a049c977bc 100644 --- a/src/ws/server.rs +++ b/src/ws/server.rs @@ -31,10 +31,10 @@ use crate::ws::transport::WsTransportServer; use futures::{channel::mpsc, future::Either, pin_mut, prelude::*}; use parking_lot::Mutex; use std::{ - collections::{HashMap, HashSet}, - error, - net::SocketAddr, - sync::{atomic, Arc}, + collections::{HashMap, HashSet}, + error, + net::SocketAddr, + sync::{atomic, Arc}, }; /// Server that can be cloned. @@ -44,481 +44,434 @@ use std::{ /// > [`RawServer`] struct instead. #[derive(Clone)] pub struct Server { - /// Channel to send requests to the background task. - to_back: mpsc::UnboundedSender, - /// List of methods (for RPC queries, subscriptions, and unsubscriptions) that have been - /// registered. Serves no purpose except to check for duplicates. - registered_methods: Arc>>, - /// Next unique ID used when registering a subscription. - next_subscription_unique_id: Arc, - /// Local socket address of the transport server. - local_addr: SocketAddr, + /// Channel to send requests to the background task. + to_back: mpsc::UnboundedSender, + /// List of methods (for RPC queries, subscriptions, and unsubscriptions) that have been + /// registered. Serves no purpose except to check for duplicates. + registered_methods: Arc>>, + /// Next unique ID used when registering a subscription. + next_subscription_unique_id: Arc, + /// Local socket address of the transport server. + local_addr: SocketAddr, } /// Notification method that's been registered. pub struct RegisteredNotification { - /// Receives notifications that the client sent to us. - queries_rx: mpsc::Receiver, + /// Receives notifications that the client sent to us. + queries_rx: mpsc::Receiver, } /// Method that's been registered. pub struct RegisteredMethod { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Receives requests that the client sent to us. - queries_rx: mpsc::Receiver<(RawServerRequestId, common::Params)>, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Receives requests that the client sent to us. + queries_rx: mpsc::Receiver<(RawServerRequestId, common::Params)>, } /// Pub-sub subscription that's been registered. // TODO: unregister on drop pub struct RegisteredSubscription { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Value passed to [`FrontToBack::RegisterSubscription::unique_id`]. - unique_id: usize, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Value passed to [`FrontToBack::RegisterSubscription::unique_id`]. + unique_id: usize, } /// Active request that needs to be answered. pub struct IncomingRequest { - /// Clone of [`Server::to_back`]. - to_back: mpsc::UnboundedSender, - /// Identifier of the request towards the server. - request_id: RawServerRequestId, - /// Parameters of the request. - params: common::Params, + /// Clone of [`Server::to_back`]. + to_back: mpsc::UnboundedSender, + /// Identifier of the request towards the server. + request_id: RawServerRequestId, + /// Parameters of the request. + params: common::Params, } /// Message that the [`Server`] can send to the background task. enum FrontToBack { - /// Registers a notifications endpoint. - RegisterNotifications { - /// Name of the method. - name: String, - /// Where to send incoming notifications. - handler: mpsc::Sender, - /// See the documentation of [`Server::register_notifications`]. - allow_losses: bool, - }, - - /// Registers a method. The server will then handle requests using this method. - RegisterMethod { - /// Name of the method. - name: String, - /// Where to send requests. - handler: mpsc::Sender<(RawServerRequestId, common::Params)>, - }, - - /// Send a response to a request that a client made. - AnswerRequest { - /// Request to answer. - request_id: RawServerRequestId, - /// Response to send back. - answer: Result, - }, - - /// Registers a subscription. The server will then handle subscription requests of that - /// method. - RegisterSubscription { - /// Unique identifier decided by the front-end in order to identify this registered - /// subscription. - unique_id: usize, - /// Name of the method that registers the subscription. - subscribe_method: String, - /// Name of the method that unregisters the subscription. - unsubscribe_method: String, - }, - - /// Send out a notification to all the clients registered to a subscription. - SendOutNotif { - /// The value that was passed in [`FrontToBack::RegisterSubscription::unique_id`] earlier. - unique_id: usize, - /// Notification to send to the subscribed clients. - notification: JsonValue, - }, + /// Registers a notifications endpoint. + RegisterNotifications { + /// Name of the method. + name: String, + /// Where to send incoming notifications. + handler: mpsc::Sender, + /// See the documentation of [`Server::register_notifications`]. + allow_losses: bool, + }, + + /// Registers a method. The server will then handle requests using this method. + RegisterMethod { + /// Name of the method. + name: String, + /// Where to send requests. + handler: mpsc::Sender<(RawServerRequestId, common::Params)>, + }, + + /// Send a response to a request that a client made. + AnswerRequest { + /// Request to answer. + request_id: RawServerRequestId, + /// Response to send back. + answer: Result, + }, + + /// Registers a subscription. The server will then handle subscription requests of that + /// method. + RegisterSubscription { + /// Unique identifier decided by the front-end in order to identify this registered + /// subscription. + unique_id: usize, + /// Name of the method that registers the subscription. + subscribe_method: String, + /// Name of the method that unregisters the subscription. + unsubscribe_method: String, + }, + + /// Send out a notification to all the clients registered to a subscription. + SendOutNotif { + /// The value that was passed in [`FrontToBack::RegisterSubscription::unique_id`] earlier. + unique_id: usize, + /// Notification to send to the subscribed clients. + notification: JsonValue, + }, } impl Server { - /// Initializes a new server. - pub async fn new(url: &str) -> Result> { - let sockaddr = url.parse()?; - let transport_server = WsTransportServer::builder(sockaddr).build().await?; - let local_addr = *transport_server.local_addr(); - - // We use an unbounded channel because the only exchanged messages concern registering - // methods. The volume of messages is therefore very low and it doesn't make sense to have - // a backpressure mechanism. - // TODO: that's not true anymore ^ - let (to_back, from_front) = mpsc::unbounded(); - - async_std::task::spawn(async move { - background_task(transport_server.into(), from_front).await; - }); - - Ok(Server { - to_back, - registered_methods: Arc::new(Mutex::new(Default::default())), - next_subscription_unique_id: Arc::new(atomic::AtomicUsize::new(0)), - local_addr, - }) - } - - /// Local socket address of the underlying transport server. - pub fn local_addr(&self) -> &SocketAddr { - &self.local_addr - } - - /// Registers a notification method name towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to process incoming notifications. - /// - /// If `allow_losses` is true, then the server is allowed to drop notifications if the - /// notifications handler (i.e. the code that uses [`RegisteredNotifications`]) is too slow - /// to process notifications. - /// - /// Returns an error if the method name was already registered. - pub fn register_notification( - &self, - method_name: String, - allow_losses: bool, - ) -> Result { - log::debug!("[frontend]: register_notification={}", method_name); - if !self.registered_methods.lock().insert(method_name.clone()) { - return Err(()); - } - - let (tx, rx) = mpsc::channel(32); - - let _ = self - .to_back - .unbounded_send(FrontToBack::RegisterNotifications { - name: method_name, - handler: tx, - allow_losses, - }); - - Ok(RegisteredNotification { queries_rx: rx }) - } - - /// Registers a method towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to handle incoming requests. - /// - /// Contrary to [`register_notifications`](Server::register_notifications), there is no - /// `allow_losses` parameter here. If the handler is too slow to process requests, then the - /// server automatically returns an "internal error" to the client. - /// - /// Returns an error if the method name was already registered. - pub fn register_method(&self, method_name: String) -> Result { - log::debug!("[frontend]: register_method={}", method_name); - if !self.registered_methods.lock().insert(method_name.clone()) { - return Err(()); - } - - let (tx, rx) = mpsc::channel(32); - - let _ = self.to_back.unbounded_send(FrontToBack::RegisterMethod { - name: method_name, - handler: tx, - }); - - Ok(RegisteredMethod { - to_back: self.to_back.clone(), - queries_rx: rx, - }) - } - - /// Registers a subscription towards the server. - /// - /// Clients will then be able to call this method. - /// The returned object allows you to send out notifications. - /// - /// Returns an error if one of the method names was already registered. - pub fn register_subscription( - &self, - subscribe_method_name: String, - unsubscribe_method_name: String, - ) -> Result { - log::debug!( - "[frontend]: server register subscription: subscribe_method={}, unsubscribe_method={}", - subscribe_method_name, - unsubscribe_method_name - ); - { - let mut registered_methods = self.registered_methods.lock(); - if registered_methods.contains(&subscribe_method_name) - || registered_methods.contains(&unsubscribe_method_name) - { - return Err(()); - } - registered_methods.insert(subscribe_method_name.clone()); - registered_methods.insert(unsubscribe_method_name.clone()); - } - - let unique_id = self - .next_subscription_unique_id - .fetch_add(1, atomic::Ordering::Relaxed); - - self.to_back - .unbounded_send(FrontToBack::RegisterSubscription { - unique_id, - subscribe_method: subscribe_method_name, - unsubscribe_method: unsubscribe_method_name, - }) - .map_err(|_| ())?; - - Ok(RegisteredSubscription { - to_back: self.to_back.clone(), - unique_id, - }) - } + /// Initializes a new server. + pub async fn new(url: &str) -> Result> { + let sockaddr = url.parse()?; + let transport_server = WsTransportServer::builder(sockaddr).build().await?; + let local_addr = *transport_server.local_addr(); + + // We use an unbounded channel because the only exchanged messages concern registering + // methods. The volume of messages is therefore very low and it doesn't make sense to have + // a backpressure mechanism. + // TODO: that's not true anymore ^ + let (to_back, from_front) = mpsc::unbounded(); + + async_std::task::spawn(async move { + background_task(transport_server.into(), from_front).await; + }); + + Ok(Server { + to_back, + registered_methods: Arc::new(Mutex::new(Default::default())), + next_subscription_unique_id: Arc::new(atomic::AtomicUsize::new(0)), + local_addr, + }) + } + + /// Local socket address of the underlying transport server. + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } + + /// Registers a notification method name towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to process incoming notifications. + /// + /// If `allow_losses` is true, then the server is allowed to drop notifications if the + /// notifications handler (i.e. the code that uses [`RegisteredNotifications`]) is too slow + /// to process notifications. + /// + /// Returns an error if the method name was already registered. + pub fn register_notification(&self, method_name: String, allow_losses: bool) -> Result { + log::debug!("[frontend]: register_notification={}", method_name); + if !self.registered_methods.lock().insert(method_name.clone()) { + return Err(()); + } + + let (tx, rx) = mpsc::channel(32); + + let _ = self.to_back.unbounded_send(FrontToBack::RegisterNotifications { + name: method_name, + handler: tx, + allow_losses, + }); + + Ok(RegisteredNotification { queries_rx: rx }) + } + + /// Registers a method towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to handle incoming requests. + /// + /// Contrary to [`register_notifications`](Server::register_notifications), there is no + /// `allow_losses` parameter here. If the handler is too slow to process requests, then the + /// server automatically returns an "internal error" to the client. + /// + /// Returns an error if the method name was already registered. + pub fn register_method(&self, method_name: String) -> Result { + log::debug!("[frontend]: register_method={}", method_name); + if !self.registered_methods.lock().insert(method_name.clone()) { + return Err(()); + } + + let (tx, rx) = mpsc::channel(32); + + let _ = self.to_back.unbounded_send(FrontToBack::RegisterMethod { name: method_name, handler: tx }); + + Ok(RegisteredMethod { to_back: self.to_back.clone(), queries_rx: rx }) + } + + /// Registers a subscription towards the server. + /// + /// Clients will then be able to call this method. + /// The returned object allows you to send out notifications. + /// + /// Returns an error if one of the method names was already registered. + pub fn register_subscription( + &self, + subscribe_method_name: String, + unsubscribe_method_name: String, + ) -> Result { + log::debug!( + "[frontend]: server register subscription: subscribe_method={}, unsubscribe_method={}", + subscribe_method_name, + unsubscribe_method_name + ); + { + let mut registered_methods = self.registered_methods.lock(); + if registered_methods.contains(&subscribe_method_name) + || registered_methods.contains(&unsubscribe_method_name) + { + return Err(()); + } + registered_methods.insert(subscribe_method_name.clone()); + registered_methods.insert(unsubscribe_method_name.clone()); + } + + let unique_id = self.next_subscription_unique_id.fetch_add(1, atomic::Ordering::Relaxed); + + self.to_back + .unbounded_send(FrontToBack::RegisterSubscription { + unique_id, + subscribe_method: subscribe_method_name, + unsubscribe_method: unsubscribe_method_name, + }) + .map_err(|_| ())?; + + Ok(RegisteredSubscription { to_back: self.to_back.clone(), unique_id }) + } } impl RegisteredNotification { - /// Returns the next notification. - pub async fn next(&mut self) -> common::Params { - loop { - match self.queries_rx.next().await { - Some(v) => break v, - None => futures::pending!(), - } - } - } + /// Returns the next notification. + pub async fn next(&mut self) -> common::Params { + loop { + match self.queries_rx.next().await { + Some(v) => break v, + None => futures::pending!(), + } + } + } } impl RegisteredMethod { - /// Returns the next request. - pub async fn next(&mut self) -> IncomingRequest { - let (request_id, params) = loop { - match self.queries_rx.next().await { - Some(v) => break v, - None => futures::pending!(), - } - }; - IncomingRequest { - to_back: self.to_back.clone(), - request_id, - params, - } - } + /// Returns the next request. + pub async fn next(&mut self) -> IncomingRequest { + let (request_id, params) = loop { + match self.queries_rx.next().await { + Some(v) => break v, + None => futures::pending!(), + } + }; + IncomingRequest { to_back: self.to_back.clone(), request_id, params } + } } impl RegisteredSubscription { - /// Sends out a value to all the registered clients. - // TODO: return `Result<(), ()>` - pub async fn send(&mut self, value: JsonValue) { - let _ = self - .to_back - .send(FrontToBack::SendOutNotif { - unique_id: self.unique_id, - notification: value, - }) - .await; - } + /// Sends out a value to all the registered clients. + // TODO: return `Result<(), ()>` + pub async fn send(&mut self, value: JsonValue) { + let _ = self.to_back.send(FrontToBack::SendOutNotif { unique_id: self.unique_id, notification: value }).await; + } } impl IncomingRequest { - /// Returns the parameters of the request. - pub fn params(&self) -> &common::Params { - &self.params - } - - /// Respond to the request. - // TODO: return `Result<(), ()>` - pub async fn respond(mut self, response: impl Into>) { - let _ = self - .to_back - .send(FrontToBack::AnswerRequest { - request_id: self.request_id, - answer: response.into(), - }) - .await; - } + /// Returns the parameters of the request. + pub fn params(&self) -> &common::Params { + &self.params + } + + /// Respond to the request. + // TODO: return `Result<(), ()>` + pub async fn respond(mut self, response: impl Into>) { + let _ = self + .to_back + .send(FrontToBack::AnswerRequest { request_id: self.request_id, answer: response.into() }) + .await; + } } /// Function being run in the background that processes messages from the frontend. -async fn background_task( - mut server: RawServer, - mut from_front: mpsc::UnboundedReceiver, -) { - // List of notifications methods that the user has registered, and the channels to dispatch - // incoming notifications. - let mut registered_notifications: HashMap, bool)> = HashMap::new(); - // List of methods that the user has registered, and the channels to dispatch incoming - // requests. - let mut registered_methods: HashMap> = HashMap::new(); - // For each registered subscription, a subscribe method linked to a unique identifier for - // that subscription. - let mut subscribe_methods: HashMap = HashMap::new(); - // For each registered subscription, an unsubscribe method linked to a unique identifier for - // that subscription. - let mut unsubscribe_methods: HashMap = HashMap::new(); - // For each registered subscription, a list of clients that are registered towards us. - let mut subscribed_clients: HashMap> = HashMap::new(); - // Reversed mapping of `subscribed_clients`. Must always be in sync. - let mut active_subscriptions: HashMap = HashMap::new(); - - loop { - // We need to do a little transformation in order to destroy the borrow to `client` - // and `from_front`. - let outcome = { - let next_message = from_front.next(); - let next_event = server.next_event(); - pin_mut!(next_message); - pin_mut!(next_event); - match future::select(next_message, next_event).await { - Either::Left((v, _)) => Either::Left(v), - Either::Right((v, _)) => Either::Right(v), - } - }; - - match outcome { - Either::Left(None) => return, - Either::Left(Some(FrontToBack::AnswerRequest { request_id, answer })) => { - server.request_by_id(&request_id).unwrap().respond(answer); - } - Either::Left(Some(FrontToBack::RegisterNotifications { - name, - handler, - allow_losses, - })) => { - registered_notifications.insert(name, (handler, allow_losses)); - } - Either::Left(Some(FrontToBack::RegisterMethod { name, handler })) => { - registered_methods.insert(name, handler); - } - Either::Left(Some(FrontToBack::RegisterSubscription { - unique_id, - subscribe_method, - unsubscribe_method, - })) => { - log::debug!("[backend]: server register subscription=id={:?}, subscribe_method:{}, unsubscribe_method={}", unique_id, subscribe_method, unsubscribe_method); - debug_assert_ne!(subscribe_method, unsubscribe_method); - debug_assert!(!subscribe_methods.contains_key(&subscribe_method)); - debug_assert!(!subscribe_methods.contains_key(&unsubscribe_method)); - debug_assert!(!unsubscribe_methods.contains_key(&subscribe_method)); - debug_assert!(!unsubscribe_methods.contains_key(&unsubscribe_method)); - debug_assert!(!registered_methods.contains_key(&subscribe_method)); - debug_assert!(!registered_methods.contains_key(&unsubscribe_method)); - debug_assert!(!registered_notifications.contains_key(&subscribe_method)); - debug_assert!(!registered_notifications.contains_key(&unsubscribe_method)); - debug_assert!(!subscribed_clients.contains_key(&unique_id)); - subscribe_methods.insert(subscribe_method, unique_id); - unsubscribe_methods.insert(unsubscribe_method, unique_id); - subscribed_clients.insert(unique_id, Vec::new()); - } - Either::Left(Some(FrontToBack::SendOutNotif { - unique_id, - notification, - })) => { - log::debug!( - "[backend]: server preparing response to subscription={:?}", - unique_id - ); - debug_assert!(subscribed_clients.contains_key(&unique_id)); - if let Some(clients) = subscribed_clients.get(&unique_id) { - log::trace!( - "[backend]: {} client(s) is subscribing to subscription={:?}", - clients.len(), - unique_id - ); - for client in clients { - debug_assert_eq!(active_subscriptions.get(client), Some(&unique_id)); - debug_assert!(server.subscription_by_id(*client).is_some()); - if let Some(sub) = server.subscription_by_id(*client) { - sub.push(notification.clone()).await; - } - } - } else { - log::warn!( - "[backend]: server received invalid subscription={:?}", - unique_id - ); - } - } - Either::Right(RawServerEvent::Notification(notification)) => { - log::debug!( - "[backend]: server received notification: {:?}", - notification - ); - if let Some((handler, allow_losses)) = - registered_notifications.get_mut(notification.method()) - { - let params: &common::Params = notification.params().into(); - // Note: we just ignore errors. It doesn't make sense logically speaking to - // unregister the notification here. - if *allow_losses { - let _ = handler.send(params.clone()).now_or_never(); - } else { - let _ = handler.send(params.clone()).await; - } - } - } - Either::Right(RawServerEvent::Request(request)) => { - log::debug!("[backend]: server received request: {:?}", request); - if let Some(handler) = registered_methods.get_mut(request.method()) { - let params: &common::Params = request.params().into(); - log::debug!("server called handler"); - match handler.send((request.id(), params.clone())).now_or_never() { - Some(Ok(())) => {} - Some(Err(_)) | None => { - request.respond(Err(From::from(common::ErrorCode::ServerError(0)))); - } - } - } else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) { - if let Ok(sub_id) = request.into_subscription() { - debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { - debug_assert!(clients.iter().all(|c| *c != sub_id)); - clients.push(sub_id); - } - - debug_assert!(!active_subscriptions.contains_key(&sub_id)); - active_subscriptions.insert(sub_id, *sub_unique_id); - } - } else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) { - if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) - { - // FIXME: from request params - debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { - // TODO: we don't actually check whether the unsubscribe comes from the right - // clients, but since this the ID is randomly-generated, it should be - // fine - if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { - clients.remove(client_pos); - } - - if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { - debug_assert_eq!(s_u_id, *sub_unique_id); - } - } - } - } else { - // TODO: we assert that the request is valid because the parsing succeeded but - // not registered. - request.respond(Err(From::from(common::ErrorCode::MethodNotFound))); - } - } - Either::Right(RawServerEvent::SubscriptionsReady(_)) => { - // We don't really care whether subscriptions are now ready. - } - Either::Right(RawServerEvent::SubscriptionsClosed(subscriptions)) => { - log::debug!("[backend]: server close subscriptions: {:?}", subscriptions); - // Remove all the subscriptions from `active_subscriptions` and - // `subscribed_clients`. - for sub_id in subscriptions { - debug_assert!(active_subscriptions.contains_key(&sub_id)); - if let Some(unique_id) = active_subscriptions.remove(&sub_id) { - debug_assert!(subscribed_clients.contains_key(&unique_id)); - if let Some(clients) = subscribed_clients.get_mut(&unique_id) { - assert_eq!(clients.iter().filter(|c| **c == sub_id).count(), 1); - clients.retain(|c| *c != sub_id); - } - } - } - } - } - } +async fn background_task(mut server: RawServer, mut from_front: mpsc::UnboundedReceiver) { + // List of notifications methods that the user has registered, and the channels to dispatch + // incoming notifications. + let mut registered_notifications: HashMap, bool)> = HashMap::new(); + // List of methods that the user has registered, and the channels to dispatch incoming + // requests. + let mut registered_methods: HashMap> = HashMap::new(); + // For each registered subscription, a subscribe method linked to a unique identifier for + // that subscription. + let mut subscribe_methods: HashMap = HashMap::new(); + // For each registered subscription, an unsubscribe method linked to a unique identifier for + // that subscription. + let mut unsubscribe_methods: HashMap = HashMap::new(); + // For each registered subscription, a list of clients that are registered towards us. + let mut subscribed_clients: HashMap> = HashMap::new(); + // Reversed mapping of `subscribed_clients`. Must always be in sync. + let mut active_subscriptions: HashMap = HashMap::new(); + + loop { + // We need to do a little transformation in order to destroy the borrow to `client` + // and `from_front`. + let outcome = { + let next_message = from_front.next(); + let next_event = server.next_event(); + pin_mut!(next_message); + pin_mut!(next_event); + match future::select(next_message, next_event).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + }; + + match outcome { + Either::Left(None) => return, + Either::Left(Some(FrontToBack::AnswerRequest { request_id, answer })) => { + server.request_by_id(&request_id).unwrap().respond(answer); + } + Either::Left(Some(FrontToBack::RegisterNotifications { name, handler, allow_losses })) => { + registered_notifications.insert(name, (handler, allow_losses)); + } + Either::Left(Some(FrontToBack::RegisterMethod { name, handler })) => { + registered_methods.insert(name, handler); + } + Either::Left(Some(FrontToBack::RegisterSubscription { + unique_id, + subscribe_method, + unsubscribe_method, + })) => { + log::debug!( + "[backend]: server register subscription=id={:?}, subscribe_method:{}, unsubscribe_method={}", + unique_id, + subscribe_method, + unsubscribe_method + ); + debug_assert_ne!(subscribe_method, unsubscribe_method); + debug_assert!(!subscribe_methods.contains_key(&subscribe_method)); + debug_assert!(!subscribe_methods.contains_key(&unsubscribe_method)); + debug_assert!(!unsubscribe_methods.contains_key(&subscribe_method)); + debug_assert!(!unsubscribe_methods.contains_key(&unsubscribe_method)); + debug_assert!(!registered_methods.contains_key(&subscribe_method)); + debug_assert!(!registered_methods.contains_key(&unsubscribe_method)); + debug_assert!(!registered_notifications.contains_key(&subscribe_method)); + debug_assert!(!registered_notifications.contains_key(&unsubscribe_method)); + debug_assert!(!subscribed_clients.contains_key(&unique_id)); + subscribe_methods.insert(subscribe_method, unique_id); + unsubscribe_methods.insert(unsubscribe_method, unique_id); + subscribed_clients.insert(unique_id, Vec::new()); + } + Either::Left(Some(FrontToBack::SendOutNotif { unique_id, notification })) => { + log::debug!("[backend]: server preparing response to subscription={:?}", unique_id); + debug_assert!(subscribed_clients.contains_key(&unique_id)); + if let Some(clients) = subscribed_clients.get(&unique_id) { + log::trace!( + "[backend]: {} client(s) is subscribing to subscription={:?}", + clients.len(), + unique_id + ); + for client in clients { + debug_assert_eq!(active_subscriptions.get(client), Some(&unique_id)); + debug_assert!(server.subscription_by_id(*client).is_some()); + if let Some(sub) = server.subscription_by_id(*client) { + sub.push(notification.clone()).await; + } + } + } else { + log::warn!("[backend]: server received invalid subscription={:?}", unique_id); + } + } + Either::Right(RawServerEvent::Notification(notification)) => { + log::debug!("[backend]: server received notification: {:?}", notification); + if let Some((handler, allow_losses)) = registered_notifications.get_mut(notification.method()) { + let params: &common::Params = notification.params().into(); + // Note: we just ignore errors. It doesn't make sense logically speaking to + // unregister the notification here. + if *allow_losses { + let _ = handler.send(params.clone()).now_or_never(); + } else { + let _ = handler.send(params.clone()).await; + } + } + } + Either::Right(RawServerEvent::Request(request)) => { + log::debug!("[backend]: server received request: {:?}", request); + if let Some(handler) = registered_methods.get_mut(request.method()) { + let params: &common::Params = request.params().into(); + log::debug!("server called handler"); + match handler.send((request.id(), params.clone())).now_or_never() { + Some(Ok(())) => {} + Some(Err(_)) | None => { + request.respond(Err(From::from(common::ErrorCode::ServerError(0)))); + } + } + } else if let Some(sub_unique_id) = subscribe_methods.get(request.method()) { + if let Ok(sub_id) = request.into_subscription() { + debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { + debug_assert!(clients.iter().all(|c| *c != sub_id)); + clients.push(sub_id); + } + + debug_assert!(!active_subscriptions.contains_key(&sub_id)); + active_subscriptions.insert(sub_id, *sub_unique_id); + } + } else if let Some(sub_unique_id) = unsubscribe_methods.get(request.method()) { + if let Ok(sub_id) = RawServerSubscriptionId::from_wire_message(&JsonValue::Null) { + // FIXME: from request params + debug_assert!(subscribed_clients.contains_key(&sub_unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&sub_unique_id) { + // TODO: we don't actually check whether the unsubscribe comes from the right + // clients, but since this the ID is randomly-generated, it should be + // fine + if let Some(client_pos) = clients.iter().position(|c| *c == sub_id) { + clients.remove(client_pos); + } + + if let Some(s_u_id) = active_subscriptions.remove(&sub_id) { + debug_assert_eq!(s_u_id, *sub_unique_id); + } + } + } + } else { + // TODO: we assert that the request is valid because the parsing succeeded but + // not registered. + request.respond(Err(From::from(common::ErrorCode::MethodNotFound))); + } + } + Either::Right(RawServerEvent::SubscriptionsReady(_)) => { + // We don't really care whether subscriptions are now ready. + } + Either::Right(RawServerEvent::SubscriptionsClosed(subscriptions)) => { + log::debug!("[backend]: server close subscriptions: {:?}", subscriptions); + // Remove all the subscriptions from `active_subscriptions` and + // `subscribed_clients`. + for sub_id in subscriptions { + debug_assert!(active_subscriptions.contains_key(&sub_id)); + if let Some(unique_id) = active_subscriptions.remove(&sub_id) { + debug_assert!(subscribed_clients.contains_key(&unique_id)); + if let Some(clients) = subscribed_clients.get_mut(&unique_id) { + assert_eq!(clients.iter().filter(|c| **c == sub_id).count(), 1); + clients.retain(|c| *c != sub_id); + } + } + } + } + } + } } diff --git a/src/ws/tests.rs b/src/ws/tests.rs index f01f03c575..5d2e0f0ec5 100644 --- a/src/ws/tests.rs +++ b/src/ws/tests.rs @@ -16,294 +16,245 @@ use std::net::SocketAddr; // // TODO: not sure why `tokio::spawn` doesn't works for this. pub fn server_subscribe_only(server_started: Sender) { - std::thread::spawn(move || { - use async_std::task::block_on; - let server = block_on(WsServer::new("127.0.0.1:0")).unwrap(); - let mut hello = server - .register_subscription("subscribe_hello".to_owned(), "unsubscribe_hello".to_owned()) - .unwrap(); - let mut foo = server - .register_subscription("subscribe_foo".to_owned(), "unsubscribe_foo".to_owned()) - .unwrap(); - server_started.send(*server.local_addr()).unwrap(); - - loop { - block_on(hello.send(JsonValue::String("hello from subscription".to_owned()))); - std::thread::sleep(std::time::Duration::from_millis(100)); - block_on(foo.send(JsonValue::Number(1337_u64.into()))); - std::thread::sleep(std::time::Duration::from_millis(100)); - } - }); + std::thread::spawn(move || { + use async_std::task::block_on; + let server = block_on(WsServer::new("127.0.0.1:0")).unwrap(); + let mut hello = + server.register_subscription("subscribe_hello".to_owned(), "unsubscribe_hello".to_owned()).unwrap(); + let mut foo = server.register_subscription("subscribe_foo".to_owned(), "unsubscribe_foo".to_owned()).unwrap(); + server_started.send(*server.local_addr()).unwrap(); + + loop { + block_on(hello.send(JsonValue::String("hello from subscription".to_owned()))); + std::thread::sleep(std::time::Duration::from_millis(100)); + block_on(foo.send(JsonValue::Number(1337_u64.into()))); + std::thread::sleep(std::time::Duration::from_millis(100)); + } + }); } /// Spawns a dummy `JSONRPC v2 WebSocket` /// It has two hardcoded methods "say_hello" and "add", one hardcoded notification "notif" pub async fn server(server_started: Sender) { - let server = WsServer::new("127.0.0.1:0").await.unwrap(); - let mut hello = server.register_method("say_hello".to_owned()).unwrap(); - let mut add = server.register_method("add".to_owned()).unwrap(); - let mut notif = server - .register_notification("notif".to_owned(), false) - .unwrap(); - server_started.send(*server.local_addr()).unwrap(); - - loop { - let hello_fut = async { - let handle = hello.next().await; - log::debug!("server respond to hello"); - handle - .respond(Ok(JsonValue::String("hello".to_owned()))) - .await; - } - .fuse(); - - let add_fut = async { - let handle = add.next().await; - let params: Vec = handle.params().clone().parse().unwrap(); - let sum: u64 = params.iter().sum(); - handle.respond(Ok(JsonValue::Number(sum.into()))).await; - } - .fuse(); - - let notif_fut = async { - let params = notif.next().await; - println!("received notification: say_hello params[{:?}]", params); - } - .fuse(); - - pin_mut!(hello_fut, add_fut, notif_fut); - select! { - say_hello = hello_fut => (), - add = add_fut => (), - notif = notif_fut => (), - complete => (), - }; - } + let server = WsServer::new("127.0.0.1:0").await.unwrap(); + let mut hello = server.register_method("say_hello".to_owned()).unwrap(); + let mut add = server.register_method("add".to_owned()).unwrap(); + let mut notif = server.register_notification("notif".to_owned(), false).unwrap(); + server_started.send(*server.local_addr()).unwrap(); + + loop { + let hello_fut = async { + let handle = hello.next().await; + log::debug!("server respond to hello"); + handle.respond(Ok(JsonValue::String("hello".to_owned()))).await; + } + .fuse(); + + let add_fut = async { + let handle = add.next().await; + let params: Vec = handle.params().clone().parse().unwrap(); + let sum: u64 = params.iter().sum(); + handle.respond(Ok(JsonValue::Number(sum.into()))).await; + } + .fuse(); + + let notif_fut = async { + let params = notif.next().await; + println!("received notification: say_hello params[{:?}]", params); + } + .fuse(); + + pin_mut!(hello_fut, add_fut, notif_fut); + select! { + say_hello = hello_fut => (), + add = add_fut => (), + notif = notif_fut => (), + complete => (), + }; + } } #[tokio::test] async fn single_method_call_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - - for i in 0..10 { - let req = format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i); - let response = client.send_request_text(req).await.unwrap(); - assert_eq!( - response, - ok_response(JsonValue::String("hello".to_owned()), Id::Num(i)) - ); - } + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + + for i in 0..10 { + let req = format!(r#"{{"jsonrpc":"2.0","method":"say_hello","id":{}}}"#, i); + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::String("hello".to_owned()), Id::Num(i))); + } } // TODO: technically more of a integration test because the "real" client is used. #[tokio::test] async fn subscription_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - server_subscribe_only(server_started_tx); - let server_addr = server_started_rx.await.unwrap(); - let uri = format!("ws://{}", server_addr); - let client = WsClient::new(&uri).await.unwrap(); - let mut hello_sub: WsSubscription = client - .subscribe("subscribe_hello", Params::None, "unsubscribe_hello") - .await - .unwrap(); - let mut foo_sub: WsSubscription = client - .subscribe("subscribe_foo", Params::None, "unsubscribe_foo") - .await - .unwrap(); - - for _ in 0..10 { - let hello = hello_sub.next().await; - let foo = foo_sub.next().await; - assert_eq!( - hello, - JsonValue::String("hello from subscription".to_owned()) - ); - assert_eq!(foo, JsonValue::Number(1337_u64.into())); - } + let (server_started_tx, server_started_rx) = oneshot::channel::(); + server_subscribe_only(server_started_tx); + let server_addr = server_started_rx.await.unwrap(); + let uri = format!("ws://{}", server_addr); + let client = WsClient::new(&uri).await.unwrap(); + let mut hello_sub: WsSubscription = + client.subscribe("subscribe_hello", Params::None, "unsubscribe_hello").await.unwrap(); + let mut foo_sub: WsSubscription = + client.subscribe("subscribe_foo", Params::None, "unsubscribe_foo").await.unwrap(); + + for _ in 0..10 { + let hello = hello_sub.next().await; + let foo = foo_sub.next().await; + assert_eq!(hello, JsonValue::String("hello from subscription".to_owned())); + assert_eq!(foo, JsonValue::Number(1337_u64.into())); + } } #[tokio::test] async fn subscription_several_clients() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - server_subscribe_only(server_started_tx); - let server_addr = server_started_rx.await.unwrap(); - - let mut clients = Vec::with_capacity(10); - for _ in 0..10 { - let uri = format!("ws://{}", server_addr); - let client = WsClient::new(&uri).await.unwrap(); - let hello_sub: WsSubscription = client - .subscribe("subscribe_hello", Params::None, "unsubscribe_hello") - .await - .unwrap(); - let foo_sub: WsSubscription = client - .subscribe("subscribe_foo", Params::None, "unsubscribe_foo") - .await - .unwrap(); - clients.push((client, hello_sub, foo_sub)) - } - - for _ in 0..10 { - for (_client, hello_sub, foo_sub) in &mut clients { - let hello = hello_sub.next().await; - let foo = foo_sub.next().await; - assert_eq!( - hello, - JsonValue::String("hello from subscription".to_owned()) - ); - assert_eq!(foo, JsonValue::Number(1337_u64.into())); - } - } - - for i in 0..5 { - let (client, _, _) = clients.remove(i); - drop(client); - } - - // make sure nothing weird happend after dropping half the clients (should be `unsubscribed` in the server) - // would be good to know that subscriptions actually were removed but not possible to verify at - // this layer. - for _ in 0..10 { - for (_client, hello_sub, foo_sub) in &mut clients { - let hello = hello_sub.next().await; - let foo = foo_sub.next().await; - assert_eq!( - hello, - JsonValue::String("hello from subscription".to_owned()) - ); - assert_eq!(foo, JsonValue::Number(1337_u64.into())); - } - } + let (server_started_tx, server_started_rx) = oneshot::channel::(); + server_subscribe_only(server_started_tx); + let server_addr = server_started_rx.await.unwrap(); + + let mut clients = Vec::with_capacity(10); + for _ in 0..10 { + let uri = format!("ws://{}", server_addr); + let client = WsClient::new(&uri).await.unwrap(); + let hello_sub: WsSubscription = + client.subscribe("subscribe_hello", Params::None, "unsubscribe_hello").await.unwrap(); + let foo_sub: WsSubscription = + client.subscribe("subscribe_foo", Params::None, "unsubscribe_foo").await.unwrap(); + clients.push((client, hello_sub, foo_sub)) + } + + for _ in 0..10 { + for (_client, hello_sub, foo_sub) in &mut clients { + let hello = hello_sub.next().await; + let foo = foo_sub.next().await; + assert_eq!(hello, JsonValue::String("hello from subscription".to_owned())); + assert_eq!(foo, JsonValue::Number(1337_u64.into())); + } + } + + for i in 0..5 { + let (client, _, _) = clients.remove(i); + drop(client); + } + + // make sure nothing weird happend after dropping half the clients (should be `unsubscribed` in the server) + // would be good to know that subscriptions actually were removed but not possible to verify at + // this layer. + for _ in 0..10 { + for (_client, hello_sub, foo_sub) in &mut clients { + let hello = hello_sub.next().await; + let foo = foo_sub.next().await; + assert_eq!(hello, JsonValue::String("hello from subscription".to_owned())); + assert_eq!(foo, JsonValue::Number(1337_u64.into())); + } + } } #[tokio::test] async fn single_method_call_with_params_works() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; - let response = client.send_request_text(req).await.unwrap(); - assert_eq!( - response, - ok_response(JsonValue::Number(3.into()), Id::Num(1)) - ); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } #[tokio::test] async fn single_method_send_binary() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; - let response = client.send_request_binary(req.as_bytes()).await.unwrap(); - assert_eq!( - response, - ok_response(JsonValue::Number(3.into()), Id::Num(1)) - ); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"add", "params":[1, 2],"id":1}"#; + let response = client.send_request_binary(req.as_bytes()).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::Number(3.into()), Id::Num(1))); } #[tokio::test] async fn should_return_method_not_found() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - - let req = r#"{"jsonrpc":"2.0","method":"bar","id":"foo"}"#; - let response = client.send_request_text(req).await.unwrap(); - assert_eq!(response, method_not_found(Id::Str("foo".into()))); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + + let req = r#"{"jsonrpc":"2.0","method":"bar","id":"foo"}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, method_not_found(Id::Str("foo".into()))); } #[tokio::test] async fn invalid_json_id_missing_value() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - let req = r#"{"jsonrpc":"2.0","method":"say_hello","id"}"#; - let response = client.send_request_text(req).await.unwrap(); - // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), it MUST be Null. - assert_eq!(response, parse_error(Id::Null)); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let req = r#"{"jsonrpc":"2.0","method":"say_hello","id"}"#; + let response = client.send_request_text(req).await.unwrap(); + // If there was an error in detecting the id in the Request object (e.g. Parse error/Invalid Request), it MUST be Null. + assert_eq!(response, parse_error(Id::Null)); } #[tokio::test] async fn invalid_request_object() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; - let response = client.send_request_text(req).await.unwrap(); - assert_eq!(response, invalid_request(Id::Num(1))); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, invalid_request(Id::Num(1))); } #[tokio::test] async fn register_methods_works() { - let server = WsServer::new("127.0.0.1:0").await.unwrap(); - assert!(server.register_method("say_hello".to_owned()).is_ok()); - assert!(server.register_method("say_hello".to_owned()).is_err()); - assert!(server - .register_notification("notif".to_owned(), false) - .is_ok()); - assert!(server - .register_notification("notif".to_owned(), false) - .is_err()); - assert!(server - .register_subscription("subscribe_hello".to_owned(), "unsubscribe_hello".to_owned()) - .is_ok()); - assert!(server - .register_subscription("subscribe_hello_again".to_owned(), "notif".to_owned()) - .is_err()); - assert!( - server - .register_method("subscribe_hello_again".to_owned()) - .is_ok(), - "Failed register_subscription should not have side-effects" - ); + let server = WsServer::new("127.0.0.1:0").await.unwrap(); + assert!(server.register_method("say_hello".to_owned()).is_ok()); + assert!(server.register_method("say_hello".to_owned()).is_err()); + assert!(server.register_notification("notif".to_owned(), false).is_ok()); + assert!(server.register_notification("notif".to_owned(), false).is_err()); + assert!(server.register_subscription("subscribe_hello".to_owned(), "unsubscribe_hello".to_owned()).is_ok()); + assert!(server.register_subscription("subscribe_hello_again".to_owned(), "notif".to_owned()).is_err()); + assert!( + server.register_method("subscribe_hello_again".to_owned()).is_ok(), + "Failed register_subscription should not have side-effects" + ); } #[tokio::test] async fn parse_error_request_should_not_close_connection() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - let invalid_request = r#"{"jsonrpc":"2.0","method":"bar","params":[1,"id":99}"#; - let response1 = client.send_request_text(invalid_request).await.unwrap(); - assert_eq!(response1, parse_error(Id::Null)); - let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":33}"#; - let response2 = client.send_request_text(request).await.unwrap(); - assert_eq!( - response2, - ok_response(JsonValue::String("hello".to_owned()), Id::Num(33)) - ); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let invalid_request = r#"{"jsonrpc":"2.0","method":"bar","params":[1,"id":99}"#; + let response1 = client.send_request_text(invalid_request).await.unwrap(); + assert_eq!(response1, parse_error(Id::Null)); + let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":33}"#; + let response2 = client.send_request_text(request).await.unwrap(); + assert_eq!(response2, ok_response(JsonValue::String("hello".to_owned()), Id::Num(33))); } #[tokio::test] async fn invalid_request_should_not_close_connection() { - let (server_started_tx, server_started_rx) = oneshot::channel::(); - tokio::spawn(server(server_started_tx)); - let server_addr = server_started_rx.await.unwrap(); - - let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); - let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; - let response = client.send_request_text(req).await.unwrap(); - assert_eq!(response, invalid_request(Id::Num(1))); - let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":33}"#; - let response = client.send_request_text(request).await.unwrap(); - assert_eq!( - response, - ok_response(JsonValue::String("hello".to_owned()), Id::Num(33)) - ); + let (server_started_tx, server_started_rx) = oneshot::channel::(); + tokio::spawn(server(server_started_tx)); + let server_addr = server_started_rx.await.unwrap(); + + let mut client = WebSocketTestClient::new(server_addr).await.unwrap(); + let req = r#"{"jsonrpc":"2.0","method":"bar","id":1,"is_not_request_object":1}"#; + let response = client.send_request_text(req).await.unwrap(); + assert_eq!(response, invalid_request(Id::Num(1))); + let request = r#"{"jsonrpc":"2.0","method":"say_hello","id":33}"#; + let response = client.send_request_text(request).await.unwrap(); + assert_eq!(response, ok_response(JsonValue::String("hello".to_owned()), Id::Num(33))); } diff --git a/src/ws/transport.rs b/src/ws/transport.rs index 67ab2631de..ea549f9bf2 100644 --- a/src/ws/transport.rs +++ b/src/ws/transport.rs @@ -30,32 +30,32 @@ use async_std::net::{TcpListener, TcpStream}; use futures::{channel::mpsc, prelude::*}; use soketto::handshake::{server::Response, Server}; use std::{ - collections::HashMap, - fmt, io, - net::SocketAddr, - pin::Pin, - sync::{atomic, Arc}, + collections::HashMap, + fmt, io, + net::SocketAddr, + pin::Pin, + sync::{atomic, Arc}, }; /// Event that the [`TransportServer`] can generate. #[derive(Debug, PartialEq)] pub enum TransportServerEvent { - /// A new request has arrived on the wire. - /// - /// This generates a new "request object" within the state of the [`TransportServer`] that is - /// identified through the returned `id`. You can then use the other methods of the - /// [`TransportServer`] trait in order to manipulate that request. - Request { - /// Identifier of the request within the state of the [`TransportServer`]. - id: T, - /// Body of the request. - request: common::Request, - }, - - /// A request has been cancelled, most likely because the client has closed the connection. - /// - /// The corresponding request is no longer valid to manipulate. - Closed(T), + /// A new request has arrived on the wire. + /// + /// This generates a new "request object" within the state of the [`TransportServer`] that is + /// identified through the returned `id`. You can then use the other methods of the + /// [`TransportServer`] trait in order to manipulate that request. + Request { + /// Identifier of the request within the state of the [`TransportServer`]. + id: T, + /// Body of the request. + request: common::Request, + }, + + /// A request has been cancelled, most likely because the client has closed the connection. + /// + /// The corresponding request is no longer valid to manipulate. + Closed(T), } /// Implementation of a raw server for WebSockets requests. @@ -74,43 +74,38 @@ pub enum TransportServerEvent { // If a task finishes, it must return the list of requests that were assigned to it so that they // get removed from [`WsTransportServer::to_connections`]. pub struct WsTransportServer { - /// Local socket address. - local_addr: SocketAddr, - /// List of events to for `next_request` to immediately produce. - pending_events: Vec>, - /// Endpoint for incoming TCP sockets. - listener: TcpListener, - /// Next identifier to assign to a request. Shared amongst all the tasks in the server so that - /// they all assign from the same pool. - next_request_id: Arc, - /// Events received from connections. - from_connections: mpsc::Receiver, - /// Sending side of [`WsTransportServer::from_connections`]. Cloned in each member of - /// [`WsTransportServer::connections_tasks`]. - to_front: mpsc::Sender, - /// List of connections, and senders to send them messages. - to_connections: HashMap>, - /// List of connections. Must be processed for the system to work. When a task finishes, it - /// returns the list of pending requests that should now be closed. - connections_tasks: - stream::FuturesUnordered> + Send>>>, + /// Local socket address. + local_addr: SocketAddr, + /// List of events to for `next_request` to immediately produce. + pending_events: Vec>, + /// Endpoint for incoming TCP sockets. + listener: TcpListener, + /// Next identifier to assign to a request. Shared amongst all the tasks in the server so that + /// they all assign from the same pool. + next_request_id: Arc, + /// Events received from connections. + from_connections: mpsc::Receiver, + /// Sending side of [`WsTransportServer::from_connections`]. Cloned in each member of + /// [`WsTransportServer::connections_tasks`]. + to_front: mpsc::Sender, + /// List of connections, and senders to send them messages. + to_connections: HashMap>, + /// List of connections. Must be processed for the system to work. When a task finishes, it + /// returns the list of pending requests that should now be closed. + connections_tasks: stream::FuturesUnordered> + Send>>>, } /// Message sent from a per-connection task to the main frontend. enum BackToFront { - NewRequest { - id: WsRequestId, - body: common::Request, - sender: mpsc::Sender, - }, + NewRequest { id: WsRequestId, body: common::Request, sender: mpsc::Sender }, } /// Message sent from the main frontend to a per-connection task. enum FrontToBack { - /// Send a payload to the client. - Send(String), - /// No more data concerning that request will be sent. - Finished(WsRequestId), + /// Send a payload to the client. + Send(String), + /// No more data concerning that request will be sent. + Finished(WsRequestId), } /// Identifier for a request made to a WebSocket server. @@ -119,232 +114,212 @@ pub struct WsRequestId(u64); /// Builder for a [`WsTransportServer`]. pub struct WsTransportServerBuilder { - /// IP address to try to bind to. - bind: SocketAddr, + /// IP address to try to bind to. + bind: SocketAddr, } impl WsTransportServer { - /// Creates a new [`WsTransportServerBuilder`] containing the given address and hostname. - pub fn builder(bind: SocketAddr) -> WsTransportServerBuilder { - WsTransportServerBuilder { bind } - } - - /// Local socket address. - pub fn local_addr(&self) -> &SocketAddr { - &self.local_addr - } + /// Creates a new [`WsTransportServerBuilder`] containing the given address and hostname. + pub fn builder(bind: SocketAddr) -> WsTransportServerBuilder { + WsTransportServerBuilder { bind } + } + + /// Local socket address. + pub fn local_addr(&self) -> &SocketAddr { + &self.local_addr + } } // former `trait TransportServer` impl. impl WsTransportServer { - /// Returns the next event that the raw server wants to notify us. - pub fn next_request<'a>( - &'a mut self, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - loop { - if !self.pending_events.is_empty() { - return self.pending_events.remove(0); - } else { - self.pending_events.shrink_to_fit(); - } - - enum Event { - TaskFinished(Vec), - NewConnection(TcpStream), - Event(BackToFront), - } - - let next = { - let next_connection = { - let listener = &self.listener; - async move { - loop { - if let Ok((connec, _)) = listener.accept().await { - break Event::NewConnection(connec); - } - } - } - }; - - let next_event = { - let from_connections = &mut self.from_connections; - async move { Event::Event(from_connections.next().await.unwrap()) } - }; - - let next_finished_task = { - let connections_tasks = &mut self.connections_tasks; - async move { Event::TaskFinished(connections_tasks.next().await.unwrap()) } - }; - - futures::pin_mut!(next_connection, next_event, next_finished_task); - match future::select( - future::select(next_connection, next_event), - next_finished_task, - ) - .await - { - future::Either::Left((future::Either::Left((ev, _)), _)) => ev, - future::Either::Left((future::Either::Right((ev, _)), _)) => ev, - future::Either::Right((ev, _)) => ev, - } - }; - - match next { - Event::NewConnection(connec) => { - log::debug!("new connection with id: {:?}", self.next_request_id); - self.connections_tasks.push( - per_connection_task( - connec, - self.next_request_id.clone(), - self.to_front.clone(), - ) - .boxed(), - ); - } - Event::Event(BackToFront::NewRequest { id, body, sender }) => { - log::debug!("new request with id: {:?}", id); - let _was_in = self.to_connections.insert(id.clone(), sender); - debug_assert!(_was_in.is_none()); - return TransportServerEvent::Request { id, request: body }; - } - Event::TaskFinished(list) => { - for rq_id in list { - let was_in = self.to_connections.remove(&rq_id); - - // It is possible that `per_connection_task` returns a list - // with non-existing connections if `task stream` - // gets read twice before the actual `event stream` gets read. - // - // -> Poll #1 Ok(read x bytes) - // -> Poll #2 Err(receiver closed) - // - // The `ws::raw::tests::request_work` is emperic proof of it. - if was_in.is_some() { - log::debug!("closed connection with id: {:?}", rq_id); - self.pending_events - .push(TransportServerEvent::Closed(rq_id)); - } - } - } - } - } - }) - } - - /// Sends back a response and destroys the request. - /// - /// You can pass `None` in order to destroy the request object without sending back anything. - /// - /// The implementation blindly sends back the response and doesn't check whether there is any - /// correspondance with the request in terms of logic. For example, `respond` will accept - /// sending back a batch of six responses even if the original request was a single - /// notification. - /// - /// > **Note**: While this method returns a `Future` that must be driven to completion, - /// > implementations must be aware that the entire requests processing logic is - /// > blocked for as long as this `Future` is pending. As an example, you shouldn't - /// > use this `Future` to send back a TCP message, because if the remote is - /// > unresponsive and the buffers full, the `Future` would then wait for a long time. - /// - pub fn finish<'a>( - &'a mut self, - request_id: &'a WsRequestId, - response: Option<&'a common::Response>, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - if let Some(mut sender) = self.to_connections.remove(request_id) { - let serialized = serde_json::to_string(&response).map_err(|_| ())?; - sender - .send(FrontToBack::Send(serialized)) - .await - .map_err(|_| ())?; - sender - .send(FrontToBack::Finished(*request_id)) - .await - .map_err(|_| ())?; - Ok(()) - } else { - Err(()) - } - }) - } - - /// Returns true if this implementation supports sending back data on this request without - /// closing it. - /// - /// Returns an error if the request id is invalid. - pub fn supports_resuming(&self, request_id: &WsRequestId) -> Result { - if self.to_connections.contains_key(request_id) { - Ok(true) - } else { - Err(()) - } - } - - /// Sends back some data on the request and keeps the request alive. - /// - /// You can continue sending data on that same request later. - /// - /// Returns an error if the request identifier is incorrect, or if the implementation doesn't - /// support that operation (see [`supports_resuming`](TransportServer::supports_resuming)). - pub fn send<'a>( - &'a mut self, - request_id: &'a WsRequestId, - response: &'a common::Response, - ) -> Pin> + Send + 'a>> { - Box::pin(async move { - if let Some(sender) = self.to_connections.get_mut(request_id) { - let serialized = serde_json::to_string(&response).map_err(|_| ())?; - sender - .send(FrontToBack::Send(serialized)) - .await - .map_err(|_| ())?; - } - Ok(()) - }) - } + /// Returns the next event that the raw server wants to notify us. + pub fn next_request<'a>( + &'a mut self, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + loop { + if !self.pending_events.is_empty() { + return self.pending_events.remove(0); + } else { + self.pending_events.shrink_to_fit(); + } + + enum Event { + TaskFinished(Vec), + NewConnection(TcpStream), + Event(BackToFront), + } + + let next = { + let next_connection = { + let listener = &self.listener; + async move { + loop { + if let Ok((connec, _)) = listener.accept().await { + break Event::NewConnection(connec); + } + } + } + }; + + let next_event = { + let from_connections = &mut self.from_connections; + async move { Event::Event(from_connections.next().await.unwrap()) } + }; + + let next_finished_task = { + let connections_tasks = &mut self.connections_tasks; + async move { Event::TaskFinished(connections_tasks.next().await.unwrap()) } + }; + + futures::pin_mut!(next_connection, next_event, next_finished_task); + match future::select(future::select(next_connection, next_event), next_finished_task).await { + future::Either::Left((future::Either::Left((ev, _)), _)) => ev, + future::Either::Left((future::Either::Right((ev, _)), _)) => ev, + future::Either::Right((ev, _)) => ev, + } + }; + + match next { + Event::NewConnection(connec) => { + log::debug!("new connection with id: {:?}", self.next_request_id); + self.connections_tasks.push( + per_connection_task(connec, self.next_request_id.clone(), self.to_front.clone()).boxed(), + ); + } + Event::Event(BackToFront::NewRequest { id, body, sender }) => { + log::debug!("new request with id: {:?}", id); + let _was_in = self.to_connections.insert(id.clone(), sender); + debug_assert!(_was_in.is_none()); + return TransportServerEvent::Request { id, request: body }; + } + Event::TaskFinished(list) => { + for rq_id in list { + let was_in = self.to_connections.remove(&rq_id); + + // It is possible that `per_connection_task` returns a list + // with non-existing connections if `task stream` + // gets read twice before the actual `event stream` gets read. + // + // -> Poll #1 Ok(read x bytes) + // -> Poll #2 Err(receiver closed) + // + // The `ws::raw::tests::request_work` is emperic proof of it. + if was_in.is_some() { + log::debug!("closed connection with id: {:?}", rq_id); + self.pending_events.push(TransportServerEvent::Closed(rq_id)); + } + } + } + } + } + }) + } + + /// Sends back a response and destroys the request. + /// + /// You can pass `None` in order to destroy the request object without sending back anything. + /// + /// The implementation blindly sends back the response and doesn't check whether there is any + /// correspondance with the request in terms of logic. For example, `respond` will accept + /// sending back a batch of six responses even if the original request was a single + /// notification. + /// + /// > **Note**: While this method returns a `Future` that must be driven to completion, + /// > implementations must be aware that the entire requests processing logic is + /// > blocked for as long as this `Future` is pending. As an example, you shouldn't + /// > use this `Future` to send back a TCP message, because if the remote is + /// > unresponsive and the buffers full, the `Future` would then wait for a long time. + /// + pub fn finish<'a>( + &'a mut self, + request_id: &'a WsRequestId, + response: Option<&'a common::Response>, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + if let Some(mut sender) = self.to_connections.remove(request_id) { + let serialized = serde_json::to_string(&response).map_err(|_| ())?; + sender.send(FrontToBack::Send(serialized)).await.map_err(|_| ())?; + sender.send(FrontToBack::Finished(*request_id)).await.map_err(|_| ())?; + Ok(()) + } else { + Err(()) + } + }) + } + + /// Returns true if this implementation supports sending back data on this request without + /// closing it. + /// + /// Returns an error if the request id is invalid. + pub fn supports_resuming(&self, request_id: &WsRequestId) -> Result { + if self.to_connections.contains_key(request_id) { + Ok(true) + } else { + Err(()) + } + } + + /// Sends back some data on the request and keeps the request alive. + /// + /// You can continue sending data on that same request later. + /// + /// Returns an error if the request identifier is incorrect, or if the implementation doesn't + /// support that operation (see [`supports_resuming`](TransportServer::supports_resuming)). + pub fn send<'a>( + &'a mut self, + request_id: &'a WsRequestId, + response: &'a common::Response, + ) -> Pin> + Send + 'a>> { + Box::pin(async move { + if let Some(sender) = self.to_connections.get_mut(request_id) { + let serialized = serde_json::to_string(&response).map_err(|_| ())?; + sender.send(FrontToBack::Send(serialized)).await.map_err(|_| ())?; + } + Ok(()) + }) + } } impl fmt::Debug for WsTransportServer { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_tuple("WsTransportServer").finish() - } + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_tuple("WsTransportServer").finish() + } } impl WsTransportServerBuilder { - /// Try establish the connection. - pub async fn build(self) -> Result { - let listener = TcpListener::bind(self.bind).await?; - let local_addr = listener.local_addr()?; - - let connections_tasks = { - let futures = stream::FuturesUnordered::new(); - // We push a dummy future in order for the `FuturesUnordered` to never produce `None`. - futures.push( - async move { - loop { - futures::pending!() - } - } - .boxed(), - ); - futures - }; - - let (to_front, from_connections) = mpsc::channel(256); - - Ok(WsTransportServer { - local_addr, - pending_events: Vec::new(), - listener, - next_request_id: Arc::new(atomic::AtomicU64::new(1)), - connections_tasks, - to_front, - from_connections, - to_connections: HashMap::new(), - }) - } + /// Try establish the connection. + pub async fn build(self) -> Result { + let listener = TcpListener::bind(self.bind).await?; + let local_addr = listener.local_addr()?; + + let connections_tasks = { + let futures = stream::FuturesUnordered::new(); + // We push a dummy future in order for the `FuturesUnordered` to never produce `None`. + futures.push( + async move { + loop { + futures::pending!() + } + } + .boxed(), + ); + futures + }; + + let (to_front, from_connections) = mpsc::channel(256); + + Ok(WsTransportServer { + local_addr, + pending_events: Vec::new(), + listener, + next_request_id: Arc::new(atomic::AtomicU64::new(1)), + connections_tasks, + to_front, + from_connections, + to_connections: HashMap::new(), + }) + } } /// Processes a single connection. @@ -352,163 +327,147 @@ impl WsTransportServerBuilder { // TODO: document this function it is quite hard to understand the outcome when it returns `Vec` // both when an error or if the actual connection was terminated. async fn per_connection_task( - socket: TcpStream, - next_request_id: Arc, - mut to_front: mpsc::Sender, + socket: TcpStream, + next_request_id: Arc, + mut to_front: mpsc::Sender, ) -> Vec { - let mut server = Server::new(socket); - - // Process the handshake from the client. - let websocket_key = match server.receive_request().await { - Ok(req) => req.into_key(), - Err(_) => return Vec::new(), - }; - - // Accept the client unconditionally. - { - let res = server - .send_response(&{ - Response::Accept { - key: &websocket_key, - protocol: None, - } - }) - .await; - if res.is_err() { - return Vec::new(); - } - } - - let (mut sender, receiver) = server.into_builder().finish(); - let mut pending_requests = Vec::new(); - let (to_connec, mut from_front) = mpsc::channel(16); - - let socket_packets = stream::unfold(receiver, move |mut receiver| async { - let mut buf = Vec::new(); - let ret = match receiver.receive_data(&mut buf).await { - // data is text or binary. - Ok(_) => Ok(buf), - Err(err) => Err(err), - }; - Some((ret, receiver)) - }); - futures::pin_mut!(socket_packets); - - loop { - let next_from_front = from_front.next(); - let next_socket_packet = socket_packets.next(); - futures::pin_mut!(next_socket_packet, next_from_front); - match future::select(next_socket_packet, next_from_front).await { - future::Either::Left((socket_packet, _)) => { - let socket_packet = match socket_packet { - Some(Ok(pq)) => { - log::debug!("received text data from WebSocket: {:?}", pq); - pq - } - Some(Err(err)) => { - log::error!("failed to receive data from WebSocket: {:?}", err); - return pending_requests; - } - None => { - log::error!("failed to receive data from Websocket channel closed"); - return pending_requests; - } - }; - - let body = match serde_json::from_slice(socket_packet.as_ref()) { - Ok(b) => b, - Err(err) => { - log::debug!("Deserialization of incoming request failed: {:?}", err); - let response = serde_json::to_string(&crate::common::Response::from( - crate::common::Error::parse_error(), - crate::common::Version::V2, - )) - .expect("valid JSON; qed"); - - match sender.send_text(&response).await { - // deserialization failed but the client is still alive - Ok(_) => continue, - // deserialization failed and the client is not alive - Err(e) => { - log::warn!( - "Failed to send: {:?} over WebSocket transport with error: {:?}", - response, - e - ); - return pending_requests; - } - } - } - }; - - let request_id = - WsRequestId(next_request_id.fetch_add(1, atomic::Ordering::Relaxed)); - debug_assert_ne!(request_id.0, u64::max_value()); - - // Important note: since the background task sends messages to the front task via - // a channel, and the front task sends messages to the background task via a - // channel as well, and considering that these channels are bounded, a deadlock - // situation would arise if both the front and background task waited while trying - // to send something while both channels are full. - // In order to prevent this from happening, the background -> front sending never - // blocks. If the back -> front channel is full, we simply kill the task, which - // has the same effect as a disconnect. - // The channel is normally large enough for this to never happen unless the server - // is considerably slowed down or subject to a DoS attack. - let result = to_front - .send(BackToFront::NewRequest { - id: request_id, - body, - sender: to_connec.clone(), - }) - .now_or_never(); - - match result { - // Request was succesfully transmitted to the frontend. - Some(Ok(_)) => pending_requests.push(request_id), - // The channel is down or full. - Some(Err(e)) => { - log::error!( - "send request={:?} to frontend failed because of {:?}, terminating the connection", - request_id, - e, - ); - return pending_requests; - } - // The future wasn't ready. - // TODO(niklasad1): verify if this is possible to happen "in practice". - None => { - log::error!( - "send request={:?} to frontend failed future not ready, terminating the connection", - request_id, - ); - return pending_requests; - } - } - } - - // Received data to send on the connection. - future::Either::Right((Some(FrontToBack::Send(to_send)), _)) => { - log::trace!("transmit: {:?}", to_send); - if let Err(err) = sender.send_text(&to_send).await { - log::warn!( - "failed to send: {:?} over WebSocket transport with error: {:?}", - to_send, - err - ); - return pending_requests; - } - } - - // Received data to send on the connection. - future::Either::Right((Some(FrontToBack::Finished(rq_id)), _)) => { - log::trace!("finished request_id={:?}", rq_id); - let pos = pending_requests.iter().position(|r| *r == rq_id).unwrap(); - pending_requests.remove(pos); - } - - // Channel to main WS server struct has closed. Let's close the task. - future::Either::Right((None, _)) => return pending_requests, - } - } + let mut server = Server::new(socket); + + // Process the handshake from the client. + let websocket_key = match server.receive_request().await { + Ok(req) => req.into_key(), + Err(_) => return Vec::new(), + }; + + // Accept the client unconditionally. + { + let res = server.send_response(&{ Response::Accept { key: &websocket_key, protocol: None } }).await; + if res.is_err() { + return Vec::new(); + } + } + + let (mut sender, receiver) = server.into_builder().finish(); + let mut pending_requests = Vec::new(); + let (to_connec, mut from_front) = mpsc::channel(16); + + let socket_packets = stream::unfold(receiver, move |mut receiver| async { + let mut buf = Vec::new(); + let ret = match receiver.receive_data(&mut buf).await { + // data is text or binary. + Ok(_) => Ok(buf), + Err(err) => Err(err), + }; + Some((ret, receiver)) + }); + futures::pin_mut!(socket_packets); + + loop { + let next_from_front = from_front.next(); + let next_socket_packet = socket_packets.next(); + futures::pin_mut!(next_socket_packet, next_from_front); + match future::select(next_socket_packet, next_from_front).await { + future::Either::Left((socket_packet, _)) => { + let socket_packet = match socket_packet { + Some(Ok(pq)) => { + log::debug!("received text data from WebSocket: {:?}", pq); + pq + } + Some(Err(err)) => { + log::error!("failed to receive data from WebSocket: {:?}", err); + return pending_requests; + } + None => { + log::error!("failed to receive data from Websocket channel closed"); + return pending_requests; + } + }; + + let body = match serde_json::from_slice(socket_packet.as_ref()) { + Ok(b) => b, + Err(err) => { + log::debug!("Deserialization of incoming request failed: {:?}", err); + let response = serde_json::to_string(&crate::common::Response::from( + crate::common::Error::parse_error(), + crate::common::Version::V2, + )) + .expect("valid JSON; qed"); + + match sender.send_text(&response).await { + // deserialization failed but the client is still alive + Ok(_) => continue, + // deserialization failed and the client is not alive + Err(e) => { + log::warn!( + "Failed to send: {:?} over WebSocket transport with error: {:?}", + response, + e + ); + return pending_requests; + } + } + } + }; + + let request_id = WsRequestId(next_request_id.fetch_add(1, atomic::Ordering::Relaxed)); + debug_assert_ne!(request_id.0, u64::max_value()); + + // Important note: since the background task sends messages to the front task via + // a channel, and the front task sends messages to the background task via a + // channel as well, and considering that these channels are bounded, a deadlock + // situation would arise if both the front and background task waited while trying + // to send something while both channels are full. + // In order to prevent this from happening, the background -> front sending never + // blocks. If the back -> front channel is full, we simply kill the task, which + // has the same effect as a disconnect. + // The channel is normally large enough for this to never happen unless the server + // is considerably slowed down or subject to a DoS attack. + let result = to_front + .send(BackToFront::NewRequest { id: request_id, body, sender: to_connec.clone() }) + .now_or_never(); + + match result { + // Request was succesfully transmitted to the frontend. + Some(Ok(_)) => pending_requests.push(request_id), + // The channel is down or full. + Some(Err(e)) => { + log::error!( + "send request={:?} to frontend failed because of {:?}, terminating the connection", + request_id, + e, + ); + return pending_requests; + } + // The future wasn't ready. + // TODO(niklasad1): verify if this is possible to happen "in practice". + None => { + log::error!( + "send request={:?} to frontend failed future not ready, terminating the connection", + request_id, + ); + return pending_requests; + } + } + } + + // Received data to send on the connection. + future::Either::Right((Some(FrontToBack::Send(to_send)), _)) => { + log::trace!("transmit: {:?}", to_send); + if let Err(err) = sender.send_text(&to_send).await { + log::warn!("failed to send: {:?} over WebSocket transport with error: {:?}", to_send, err); + return pending_requests; + } + } + + // Received data to send on the connection. + future::Either::Right((Some(FrontToBack::Finished(rq_id)), _)) => { + log::trace!("finished request_id={:?}", rq_id); + let pos = pending_requests.iter().position(|r| *r == rq_id).unwrap(); + pending_requests.remove(pos); + } + + // Channel to main WS server struct has closed. Let's close the task. + future::Either::Right((None, _)) => return pending_requests, + } + } } diff --git a/test-utils/src/helpers.rs b/test-utils/src/helpers.rs index bb999f98e3..42dd45e69c 100644 --- a/test-utils/src/helpers.rs +++ b/test-utils/src/helpers.rs @@ -4,76 +4,60 @@ use std::net::SocketAddr; /// Converts a sockaddress to a WebSocket URI. pub fn to_ws_uri_string(addr: SocketAddr) -> String { - let mut s = String::new(); - s.push_str("ws://"); - s.push_str(&addr.to_string()); - s + let mut s = String::new(); + s.push_str("ws://"); + s.push_str(&addr.to_string()); + s } /// Converts a sockaddress to a HTTP URI. pub fn to_http_uri(sockaddr: SocketAddr) -> Uri { - let s = sockaddr.to_string(); - Uri::builder() - .scheme("http") - .authority(s.as_str()) - .path_and_query("/") - .build() - .unwrap() + let s = sockaddr.to_string(); + Uri::builder().scheme("http").authority(s.as_str()).path_and_query("/").build().unwrap() } pub fn ok_response(result: Value, id: Id) -> String { - format!( - r#"{{"jsonrpc":"2.0","result":{},"id":{}}}"#, - result, - serde_json::to_string(&id).unwrap() - ) + format!(r#"{{"jsonrpc":"2.0","result":{},"id":{}}}"#, result, serde_json::to_string(&id).unwrap()) } pub fn method_not_found(id: Id) -> String { - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32601,"message":"Method not found"}},"id":{}}}"#, - serde_json::to_string(&id).unwrap() - ) + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32601,"message":"Method not found"}},"id":{}}}"#, + serde_json::to_string(&id).unwrap() + ) } pub fn parse_error(id: Id) -> String { - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32700,"message":"Parse error"}},"id":{}}}"#, - serde_json::to_string(&id).unwrap() - ) + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32700,"message":"Parse error"}},"id":{}}}"#, + serde_json::to_string(&id).unwrap() + ) } pub fn invalid_request(id: Id) -> String { - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32600,"message":"Invalid request"}},"id":{}}}"#, - serde_json::to_string(&id).unwrap() - ) + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32600,"message":"Invalid request"}},"id":{}}}"#, + serde_json::to_string(&id).unwrap() + ) } pub fn invalid_params(id: Id) -> String { - format!( - r#"{{"jsonrpc":"2.0","error":{{"code":-32602,"message":"Invalid params"}},"id":{}}}"#, - serde_json::to_string(&id).unwrap() - ) + format!( + r#"{{"jsonrpc":"2.0","error":{{"code":-32602,"message":"Invalid params"}},"id":{}}}"#, + serde_json::to_string(&id).unwrap() + ) } pub async fn http_request(body: Body, uri: Uri) -> Result { - let client = hyper::Client::new(); - let r = hyper::Request::post(uri) - .header( - hyper::header::CONTENT_TYPE, - hyper::header::HeaderValue::from_static("application/json"), - ) - .body(body) - .expect("uri and request headers are valid; qed"); - let res = client.request(r).await.map_err(|e| format!("{:?}", e))?; + let client = hyper::Client::new(); + let r = hyper::Request::post(uri) + .header(hyper::header::CONTENT_TYPE, hyper::header::HeaderValue::from_static("application/json")) + .body(body) + .expect("uri and request headers are valid; qed"); + let res = client.request(r).await.map_err(|e| format!("{:?}", e))?; - let (parts, body) = res.into_parts(); - let bytes = hyper::body::to_bytes(body).await.unwrap(); + let (parts, body) = res.into_parts(); + let bytes = hyper::body::to_bytes(body).await.unwrap(); - Ok(HttpResponse { - status: parts.status, - header: parts.headers, - body: String::from_utf8(bytes.to_vec()).unwrap(), - }) + Ok(HttpResponse { status: parts.status, header: parts.headers, body: String::from_utf8(bytes.to_vec()).unwrap() }) } diff --git a/test-utils/src/types.rs b/test-utils/src/types.rs index e4aabd14a2..dcc182f5bf 100644 --- a/test-utils/src/types.rs +++ b/test-utils/src/types.rs @@ -14,61 +14,57 @@ type Error = Box; #[serde(deny_unknown_fields)] #[serde(untagged)] pub enum Id { - /// No id (notification) - Null, - /// Numeric id - Num(u64), - /// String id - Str(String), + /// No id (notification) + Null, + /// Numeric id + Num(u64), + /// String id + Str(String), } #[derive(Debug)] pub struct HttpResponse { - pub status: StatusCode, - pub header: HeaderMap, - pub body: String, + pub status: StatusCode, + pub header: HeaderMap, + pub body: String, } /// WebSocket client to construct with arbitrary payload to construct bad payloads. pub struct WebSocketTestClient { - tx: soketto::Sender>>>, - rx: soketto::Receiver>>>, + tx: soketto::Sender>>>, + rx: soketto::Receiver>>>, } impl WebSocketTestClient { - pub async fn new(url: SocketAddr) -> Result { - let socket = TcpStream::connect(url).await?; - let mut client = handshake::Client::new( - BufReader::new(BufWriter::new(socket.compat())), - "test-client", - "/", - ); - match client.handshake().await { - Ok(handshake::ServerResponse::Accepted { .. }) => { - let (tx, rx) = client.into_builder().finish(); - Ok(Self { tx, rx }) - } - r @ _ => Err(format!("WebSocketHandshake failed: {:?}", r).into()), - } - } + pub async fn new(url: SocketAddr) -> Result { + let socket = TcpStream::connect(url).await?; + let mut client = handshake::Client::new(BufReader::new(BufWriter::new(socket.compat())), "test-client", "/"); + match client.handshake().await { + Ok(handshake::ServerResponse::Accepted { .. }) => { + let (tx, rx) = client.into_builder().finish(); + Ok(Self { tx, rx }) + } + r @ _ => Err(format!("WebSocketHandshake failed: {:?}", r).into()), + } + } - pub async fn send_request_text(&mut self, msg: impl AsRef) -> Result { - self.tx.send_text(msg).await?; - self.tx.flush().await?; - let mut data = Vec::new(); - self.rx.receive_data(&mut data).await?; - String::from_utf8(data).map_err(Into::into) - } + pub async fn send_request_text(&mut self, msg: impl AsRef) -> Result { + self.tx.send_text(msg).await?; + self.tx.flush().await?; + let mut data = Vec::new(); + self.rx.receive_data(&mut data).await?; + String::from_utf8(data).map_err(Into::into) + } - pub async fn send_request_binary(&mut self, msg: &[u8]) -> Result { - self.tx.send_binary(msg).await?; - self.tx.flush().await?; - let mut data = Vec::new(); - self.rx.receive_data(&mut data).await?; - String::from_utf8(data).map_err(Into::into) - } + pub async fn send_request_binary(&mut self, msg: &[u8]) -> Result { + self.tx.send_binary(msg).await?; + self.tx.flush().await?; + let mut data = Vec::new(); + self.rx.receive_data(&mut data).await?; + String::from_utf8(data).map_err(Into::into) + } - pub async fn close(&mut self) -> Result<(), Error> { - self.tx.close().await.map_err(Into::into) - } + pub async fn close(&mut self) -> Result<(), Error> { + self.tx.close().await.map_err(Into::into) + } }