diff --git a/Cargo.toml b/Cargo.toml index a8fb9782..93a10a1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -108,12 +108,13 @@ strum_macros = "0.24" backtrace = "0.3" linkme = "0.3" serde = { version = "1", features = ["derive"] } +cfg-if = "1" +redis-module-derive = { path = "./redismodule-rs-derive"} [dev-dependencies] anyhow = "1.0.38" redis = "0.22.1" lazy_static = "1.4.0" -redis-module-derive = { path = "./redismodule-rs-derive"} redis-module = { path = "./", features = ["test", "experimental-api"]} [build-dependencies] @@ -121,8 +122,12 @@ bindgen = "0.64" cc = "1.0" [features] -default = ["experimental-api"] +default = ["experimental-api", "min-redis-compatibility-version-6-0"] experimental-api = [] +min-redis-compatibility-version-7-2 = [] +min-redis-compatibility-version-7-0 = [] +min-redis-compatibility-version-6-2 = [] +min-redis-compatibility-version-6-0 = [] # Workaround to allow cfg(feature = "test") in dependencies: # https://github.com/rust-lang/rust/issues/59168#issuecomment-472653680 diff --git a/redismodule-rs-derive/Cargo.toml b/redismodule-rs-derive/Cargo.toml index be380b9e..daab1c71 100644 --- a/redismodule-rs-derive/Cargo.toml +++ b/redismodule-rs-derive/Cargo.toml @@ -14,6 +14,8 @@ categories = ["database", "api-bindings"] [dependencies] syn = { version="1.0", features = ["full"]} quote = "1.0" +lazy_static = "1.4.0" +proc-macro2 = "1.0.56" [lib] name = "redis_module_derive" diff --git a/redismodule-rs-derive/src/api_versions.rs b/redismodule-rs-derive/src/api_versions.rs new file mode 100644 index 00000000..d11c2975 --- /dev/null +++ b/redismodule-rs-derive/src/api_versions.rs @@ -0,0 +1,40 @@ +use std::collections::HashMap; + +use proc_macro2::TokenStream; +use quote::quote; + +lazy_static::lazy_static! { + pub(crate) static ref API_VERSION_MAPPING: HashMap = HashMap::from([ + ("RedisModule_AddPostNotificationJob".to_string(), 70200), + ("RedisModule_SetCommandACLCategories".to_string(), 70200), + ("RedisModule_GetOpenKeyModesAll".to_string(), 70200), + ("RedisModule_CallReplyPromiseSetUnblockHandler".to_string(), 70200), + ("RedisModule_CallReplyPromiseAbort".to_string(), 70200), + ("RedisModule_Microseconds".to_string(), 70200), + ("RedisModule_CachedMicroseconds".to_string(), 70200), + ("RedisModule_RegisterAuthCallback".to_string(), 70200), + ("RedisModule_BlockClientOnKeysWithFlags".to_string(), 70200), + ("RedisModule_GetModuleOptionsAll".to_string(), 70200), + ("RedisModule_BlockClientGetPrivateData".to_string(), 70200), + ("RedisModule_BlockClientSetPrivateData".to_string(), 70200), + ("RedisModule_BlockClientOnAuth".to_string(), 70200), + ("RedisModule_ACLAddLogEntryByUserName".to_string(), 70200), + ]); + + pub(crate) static ref API_OLDEST_VERSION: usize = 60000; + pub(crate) static ref ALL_VERSIONS: Vec<(usize, String)> = vec![ + (60000, "min-redis-compatibility-version-6-0".to_string()), + (60200, "min-redis-compatibility-version-6-2".to_string()), + (70000, "min-redis-compatibility-version-7-0".to_string()), + (70200, "min-redis-compatibility-version-7-2".to_string()), + ]; +} + +pub(crate) fn get_feature_flags(min_required_version: usize) -> (Vec, Vec) { + let all_lower_versions: Vec<&str> = ALL_VERSIONS.iter().filter_map(|(v, s)| if *v < min_required_version {Some(s.as_str())} else {None}).collect(); + let all_upper_versions: Vec<&str> = ALL_VERSIONS.iter().filter_map(|(v, s)| if *v >= min_required_version {Some(s.as_str())} else {None}).collect(); + ( + all_lower_versions.into_iter().map(|s| quote!(feature = #s).into()).collect(), + all_upper_versions.into_iter().map(|s| quote!(feature = #s).into()).collect(), + ) +} \ No newline at end of file diff --git a/redismodule-rs-derive/src/lib.rs b/redismodule-rs-derive/src/lib.rs index 3c0153db..07231935 100644 --- a/redismodule-rs-derive/src/lib.rs +++ b/redismodule-rs-derive/src/lib.rs @@ -1,8 +1,16 @@ extern crate proc_macro; -use proc_macro::TokenStream; + +mod api_versions; + +use proc_macro::{TokenStream}; use quote::quote; -use syn; +use syn::parse::{Parse, ParseStream, Result}; +use syn::punctuated::Punctuated; +use syn::token::{RArrow, Paren}; +use syn::{self, parse_macro_input, Token, Type, ReturnType, TypeTuple}; use syn::ItemFn; +use syn::Ident; +use api_versions::{API_VERSION_MAPPING, API_OLDEST_VERSION, get_feature_flags}; #[proc_macro_attribute] pub fn role_changed_event_handler(_attr: TokenStream, item: TokenStream) -> TokenStream { @@ -55,3 +63,99 @@ pub fn module_changed_event_handler(_attr: TokenStream, item: TokenStream) -> To }; gen.into() } + +#[derive(Debug)] +struct Args{ + requested_apis: Vec +} + +impl Parse for Args{ + fn parse(input: ParseStream) -> Result { + // parses a,b,c, or a,b,c where a,b and c are Indent + let vars = Punctuated::::parse_terminated(input)?; + Ok(Args { + requested_apis: vars.into_iter().collect(), + }) + } +} + +#[proc_macro_attribute] +pub fn redismodule_api(attr: TokenStream, item: TokenStream) -> TokenStream { + let args = parse_macro_input!(attr as Args); + let original_func = item.clone(); + let mut original_func = parse_macro_input!(original_func as ItemFn); + let original_func_name = original_func.sig.ident.clone(); + let original_func_name: Ident = Ident::new(&format!("{}_inner", original_func_name.to_string()), original_func_name.span()); + original_func.sig.ident = original_func_name.clone(); + + let mut use_self = false; + let input_names:Vec = original_func.sig.inputs.clone().into_iter().filter_map(|v| { + match v { + syn::FnArg::Receiver(_) => use_self = true, + syn::FnArg::Typed(pat_type) => { + if let syn::Pat::Ident(pat_ident) = *pat_type.pat.clone() { + return Some(pat_ident.ident) + } + } + } + None + }).collect(); + let func = parse_macro_input!(item as ItemFn); + + let minimum_require_version = args.requested_apis.iter().fold(*API_OLDEST_VERSION, |min_api_version, item|{ + // if we do not have a version mapping, we assume the API exists and return the minimum version. + let api_version = API_VERSION_MAPPING.get(&item.to_string()).map(|v| *v).unwrap_or(*API_OLDEST_VERSION); + api_version.max(min_api_version) + }); + + if *API_OLDEST_VERSION == minimum_require_version { + // all API exists on the older version supported so we can just return the function as is. + return quote!(#original_func).into(); + } + + let requested_apis = args.requested_apis; + let requested_apis_str: Vec = requested_apis.iter().map(|e| e.to_string()).collect(); + let vis = func.vis; + let inner_return_return_type = match func.sig.output.clone() { + ReturnType::Default => Box::new(Type::Tuple(TypeTuple{paren_token: Paren::default(), elems: Punctuated::new()})), + ReturnType::Type(_, t) => t, + }; + let new_return_return_type = Type::Path(syn::parse(quote!( + crate::apierror::APIResult<#inner_return_return_type> + ).into()).unwrap()); + let mut sig = func.sig; + sig.output = ReturnType::Type(RArrow::default(), Box::new(new_return_return_type)); + + let original_function_call = if use_self { + quote!(self.#original_func_name(#(#input_names, )*)) + } else { + quote!(#original_func_name(#(#input_names, )*)) + }; + + let new_func = quote!( + #original_func + + #vis #sig { + #( + unsafe{crate::raw::#requested_apis.ok_or(concat!(#requested_apis_str, " does not exists"))?}; + )* + + Ok(#original_function_call) + } + ); + + let (all_lower_features, all_upper_features) = get_feature_flags(minimum_require_version); + + let gen = quote! { + cfg_if::cfg_if! { + if #[cfg(any(#(#all_lower_features, )*))] { + #new_func + } else if #[cfg(any(#(#all_upper_features, )*))] { + #original_func + } else { + compile_error!("min-redis-compatibility-version is not set correctly") + } + } + }; + gen.into() +} diff --git a/src/apierror.rs b/src/apierror.rs new file mode 100644 index 00000000..05076681 --- /dev/null +++ b/src/apierror.rs @@ -0,0 +1,2 @@ +pub type APIError = String; +pub type APIResult = Result; diff --git a/src/context/mod.rs b/src/context/mod.rs index 22206afa..fc542f78 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -1,5 +1,7 @@ use bitflags::bitflags; +use redis_module_derive::redismodule_api; use std::ffi::CString; +use std::os::raw::c_void; use std::os::raw::{c_char, c_int, c_long, c_longlong}; use std::ptr::{self, NonNull}; use std::sync::atomic::{AtomicPtr, Ordering}; @@ -671,6 +673,32 @@ impl Context { let acl_permission_result: Result<(), &str> = acl_permission_result.into(); acl_permission_result.map_err(|_e| RedisError::Str("User does not have permissions on key")) } + + #[redismodule_api(RedisModule_AddPostNotificationJob)] + pub fn add_post_notification_job(self, callback: F) { + let callback = Box::into_raw(Box::new(callback)); + unsafe { + raw::RedisModule_AddPostNotificationJob.unwrap()( + self.ctx, + Some(post_notification_job::), + callback as *mut c_void, + Some(post_notification_job_free_callback::), + ); + } + } +} + +extern "C" fn post_notification_job_free_callback(pd: *mut c_void) { + unsafe { Box::from_raw(pd as *mut F) }; +} + +extern "C" fn post_notification_job( + ctx: *mut raw::RedisModuleCtx, + pd: *mut c_void, +) { + let callback = unsafe { &*(pd as *mut F) }; + let ctx = Context::new(ctx); + callback(&ctx); } unsafe impl RedisLockIndicator for Context {} diff --git a/src/include/redismodule.h b/src/include/redismodule.h index efd2fb35..558c3e08 100644 --- a/src/include/redismodule.h +++ b/src/include/redismodule.h @@ -803,6 +803,7 @@ typedef struct RedisModuleKeyOptCtx RedisModuleKeyOptCtx; typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc); typedef void (*RedisModuleDisconnectFunc)(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc); typedef int (*RedisModuleNotificationFunc)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key); +typedef void (*RedisModulePostNotificationJobFunc) (RedisModuleCtx *ctx, void *pd); typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver); typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value); typedef int (*RedisModuleTypeAuxLoadFunc)(RedisModuleIO *rdb, int encver, int when); @@ -1135,6 +1136,7 @@ REDISMODULE_API void (*RedisModule_ThreadSafeContextLock)(RedisModuleCtx *ctx) R REDISMODULE_API int (*RedisModule_ThreadSafeContextTryLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API void (*RedisModule_ThreadSafeContextUnlock)(RedisModuleCtx *ctx) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_SubscribeToKeyspaceEvents)(RedisModuleCtx *ctx, int types, RedisModuleNotificationFunc cb) REDISMODULE_ATTR; +REDISMODULE_API int (*RedisModule_AddPostNotificationJob)(RedisModuleCtx *ctx, RedisModulePostNotificationJobFunc callback, void *pd, void (*free_pd)(void*)) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_NotifyKeyspaceEvent)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key) REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)() REDISMODULE_ATTR; REDISMODULE_API int (*RedisModule_BlockedClientDisconnected)(RedisModuleCtx *ctx) REDISMODULE_ATTR; @@ -1477,6 +1479,7 @@ static void RedisModule_InitAPI(RedisModuleCtx *ctx) { REDISMODULE_GET_API(BlockedClientMeasureTimeEnd); REDISMODULE_GET_API(SetDisconnectCallback); REDISMODULE_GET_API(SubscribeToKeyspaceEvents); + REDISMODULE_GET_API(AddPostNotificationJob); REDISMODULE_GET_API(NotifyKeyspaceEvent); REDISMODULE_GET_API(GetNotifyKeyspaceEvents); REDISMODULE_GET_API(BlockedClientDisconnected); diff --git a/src/lib.rs b/src/lib.rs index ebf05809..e863c558 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ use strum_macros::AsRefStr; extern crate num_traits; pub mod alloc; +pub mod apierror; pub mod error; pub mod native_types; pub mod raw;