diff --git a/Cargo.lock b/Cargo.lock index c9549f2..30ac38b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -48,6 +48,17 @@ dependencies = [ "syn 2.0.18", ] +[[package]] +name = "async-trait" +version = "0.1.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2d0f03b3640e3a630367e40c468cb7f309529c708ed1d88597047b0e7c6ef7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + [[package]] name = "atomic" version = "0.5.3" @@ -125,6 +136,17 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "derated" +version = "0.1.0" +dependencies = [ + "async-trait", + "once_cell", + "ordered-float", + "tokio", + "tracing", +] + [[package]] name = "diff" version = "0.1.13" @@ -171,8 +193,10 @@ version = "0.1.0" dependencies = [ "async-stream", "cc", + "derated", "futures", "membrane", + "once_cell", "serde", "serde_bytes", "tokio", @@ -539,6 +563,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.15.0" @@ -555,6 +588,15 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + [[package]] name = "output_vt100" version = "0.1.3" diff --git a/dart_example/test/main_test.dart b/dart_example/test/main_test.dart index e8a0851..59bef4e 100644 --- a/dart_example/test/main_test.dart +++ b/dart_example/test/main_test.dart @@ -506,4 +506,17 @@ void main() { id: Filter(value: [Match(field: "id", value: "1")]), withinGdpr: GDPR(value: true)); }); + + test('test that functions can be rate limited', () async { + final contact = + Contact(id: 1, fullName: "Alice Smith", status: Status.pending); + final accounts = AccountsApi(); + + assert(await accounts.rateLimitedFunction(contact: contact) == + contact.fullName); + expect(() async => await accounts.rateLimitedFunction(contact: contact), + throwsA(isA())); + expect(() async => await accounts.rateLimitedFunction(contact: contact), + throwsA(isA())); + }); } diff --git a/example/Cargo.toml b/example/Cargo.toml index f99b1af..e3a8902 100644 --- a/example/Cargo.toml +++ b/example/Cargo.toml @@ -22,8 +22,10 @@ skip-codegen = ["membrane/skip-generate"] [dependencies] async-stream = "0.3" +derated = {path = "../../derated"} futures = "0.3" membrane = {path = "../membrane"} +once_cell = "*" serde = {version = "1.0", features = ["derive"]} serde_bytes = "0.11" tokio = {version = "1", features = ["full"]} diff --git a/example/src/application/advanced.rs b/example/src/application/advanced.rs index 09a68d8..877ec27 100644 --- a/example/src/application/advanced.rs +++ b/example/src/application/advanced.rs @@ -1,8 +1,10 @@ use data::OptionsDemo; use membrane::emitter::{emitter, Emitter, StreamEmitter}; use membrane::{async_dart, sync_dart}; +use once_cell::sync::Lazy; use tokio_stream::Stream; +use std::collections::hash_map::DefaultHasher; // used for background threading examples use std::{thread, time::Duration}; @@ -584,3 +586,30 @@ pub async fn get_org_with_borrowed_type( pub async fn unused_duplicate_borrows(_id: i64) -> Result { todo!() } + +struct MyLimit(RateLimit); + +impl MyLimit { + fn per_milliseconds(milliseconds: u64, max_queued: Option) -> Self { + Self(RateLimit::per_milliseconds(milliseconds, max_queued)) + } + + fn hash_rate_limited_function(&self, fn_name: &str, contact: &data::Contact) -> u64 { + use std::hash::{Hash, Hasher}; + let mut s = DefaultHasher::new(); + (fn_name, contact.id).hash(&mut s); + s.finish() + } + + async fn check(&self, key: &'static str, hash: u64) -> Result<(), derated::Dropped> { + self.0.check(key, hash).await + } +} + +use derated::RateLimit; +static RATE_LIMIT: Lazy = Lazy::new(|| MyLimit::per_milliseconds(100, None)); + +#[async_dart(namespace = "accounts", rate_limit = RATE_LIMIT)] +pub async fn rate_limited_function(contact: data::Contact) -> Result { + Ok(contact.full_name) +} diff --git a/example/src/data.rs b/example/src/data.rs index db83ae9..1c2b0a4 100644 --- a/example/src/data.rs +++ b/example/src/data.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; #[dart_enum(namespace = "accounts")] #[dart_enum(namespace = "orgs")] -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, Hash)] pub enum Status { Pending, Active, @@ -42,7 +42,7 @@ pub struct Mixed { three: Option, } -#[derive(Debug, Clone, Deserialize, Serialize)] +#[derive(Debug, Clone, Deserialize, Serialize, Hash)] pub struct Contact { pub id: i64, pub full_name: String, diff --git a/membrane/src/generators/exceptions.rs b/membrane/src/generators/exceptions.rs index ee782b9..d4632b7 100644 --- a/membrane/src/generators/exceptions.rs +++ b/membrane/src/generators/exceptions.rs @@ -30,6 +30,10 @@ class MembraneRustPanicException extends MembraneException { class MembraneUnknownResponseVariantException extends MembraneException { const MembraneUnknownResponseVariantException([String? message]) : super(message); } + +class MembraneRateLimited extends MembraneException { + const MembraneRateLimited([String? message]) : super(message); +} "# .to_string() } diff --git a/membrane/src/generators/functions.rs b/membrane/src/generators/functions.rs index fe99ef4..17df36a 100644 --- a/membrane/src/generators/functions.rs +++ b/membrane/src/generators/functions.rs @@ -338,8 +338,11 @@ impl Callable for Ffi { _log.{fine_logger}('Deserializing data from {fn_name}'); }} final deserializer = BincodeDeserializer(data.asTypedList(length + 8).sublist(8)); - if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{ + final msgCode = deserializer.deserializeUint8(); + if (msgCode == MembraneMsgKind.ok) {{ return {return_de}; + }} else if (msgCode == MembraneMsgKind.rateLimited) {{ + throw MembraneRateLimited(); }} throw {class_name}ApiError({error_de}); }} finally {{ @@ -362,8 +365,11 @@ impl Callable for Ffi { _log.{fine_logger}('Deserializing data from {fn_name}'); }} final deserializer = BincodeDeserializer(input as Uint8List); - if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{ + final msgCode = deserializer.deserializeUint8(); + if (msgCode == MembraneMsgKind.ok) {{ return {return_de}; + }} else if (msgCode == MembraneMsgKind.rateLimited) {{ + throw MembraneRateLimited(); }} throw {class_name}ApiError({error_de}); }}); @@ -394,8 +400,11 @@ impl Callable for Ffi { _log.{fine_logger}('Deserializing data from {fn_name}'); }} final deserializer = BincodeDeserializer(await _port.first{timeout} as Uint8List); - if (deserializer.deserializeUint8() == MembraneMsgKind.ok) {{ + final msgCode = deserializer.deserializeUint8(); + if (msgCode == MembraneMsgKind.ok) {{ return {return_de}; + }} else if (msgCode == MembraneMsgKind.rateLimited) {{ + throw MembraneRateLimited(); }} throw {class_name}ApiError({error_de}); }} finally {{ diff --git a/membrane/src/lib.rs b/membrane/src/lib.rs index 1b879e8..e9648e1 100644 --- a/membrane/src/lib.rs +++ b/membrane/src/lib.rs @@ -656,6 +656,7 @@ impl<'a> Membrane { typedef enum MembraneMsgKind { Ok, Error, + RateLimited, } MembraneMsgKind; typedef enum MembraneResponseKind { @@ -826,6 +827,7 @@ enums: 'Error': 'error' 'Ok': 'ok' 'Panic': 'panic' + 'RateLimited': 'rateLimited' macros: include: - __none__ @@ -1354,10 +1356,11 @@ pub struct MembraneResponse { #[doc(hidden)] #[repr(u8)] -#[derive(serde::Serialize)] +#[derive(serde::Serialize, PartialEq)] pub enum MembraneMsgKind { Ok, Error, + RateLimited, } #[doc(hidden)] diff --git a/membrane/src/utils.rs b/membrane/src/utils.rs index a6b49a7..194e5fe 100644 --- a/membrane/src/utils.rs +++ b/membrane/src/utils.rs @@ -2,6 +2,14 @@ use crate::SourceCodeLocation; use allo_isolate::Isolate; use serde::ser::Serialize; +pub fn send_rate_limited(isolate: Isolate) -> bool { + if let Ok(buffer) = crate::bincode::serialize(&(crate::MembraneMsgKind::RateLimited as u8)) { + isolate.post(crate::allo_isolate::ZeroCopyBuffer(buffer)) + } else { + false + } +} + pub fn send(isolate: Isolate, result: Result) -> bool { match result { Ok(value) => { diff --git a/membrane_macro/src/lib.rs b/membrane_macro/src/lib.rs index 1ee7b50..9488507 100644 --- a/membrane_macro/src/lib.rs +++ b/membrane_macro/src/lib.rs @@ -172,6 +172,7 @@ fn to_token_stream( timeout, os_thread, borrow, + rate_limit, } = options; let mut functions = TokenStream::new(); @@ -204,6 +205,18 @@ fn to_token_stream( let dart_transforms: Vec = DartTransforms::try_from(&inputs)?.into(); let dart_inner_args: Vec = DartArgs::from(&inputs).into(); + let rate_limit_condition = if let Some(limiter) = rate_limit { + let hasher_function = Ident::new(&format!("hash_{}", &rust_fn_name), Span::call_site()); + quote! { + let ::std::result::Result::Err(err) = { + let hash = #limiter.#hasher_function(#rust_fn_name, #(&#rust_inner_args),*); + #limiter.check(#rust_fn_name, hash).await + } + } + } else { + quote!(false) + }; + let return_statement = match output_style { OutputStyle::EmitterSerialized | OutputStyle::StreamEmitterSerialized if sync => { syn::Error::new( @@ -281,9 +294,13 @@ fn to_token_stream( OutputStyle::Serialized => quote! { let membrane_join_handle = crate::RUNTIME.get().info_spawn( async move { - let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await; let isolate = ::membrane::allo_isolate::Isolate::new(membrane_port); - ::membrane::utils::send::<#output, #error>(isolate, result); + if #rate_limit_condition { + ::membrane::utils::send_rate_limited(isolate); + } else { + let result: ::std::result::Result<#output, #error> = #fn_name(#(#rust_inner_args),*).await; + ::membrane::utils::send::<#output, #error>(isolate, result); + } }, ::membrane::runtime::Info { name: #rust_fn_name } ); diff --git a/membrane_macro/src/options.rs b/membrane_macro/src/options.rs index 010ba5b..bf4e611 100644 --- a/membrane_macro/src/options.rs +++ b/membrane_macro/src/options.rs @@ -8,6 +8,7 @@ pub(crate) struct Options { pub timeout: Option, pub os_thread: bool, pub borrow: Vec, + pub rate_limit: Option, } pub(crate) fn extract_options( @@ -64,6 +65,13 @@ pub(crate) fn extract_options( options.disable_logging = val.value(); options } + Some((ident, syn::Expr::Path(syn::ExprPath { path, .. }))) + if ident == "rate_limit" && !sync => + { + options.rate_limit = Some(path); + options + } + // TODO handle the invalid rate_limit case Some(( ident, Lit(ExprLit {