Skip to content

Commit

Permalink
WIP: post notification API and API versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
MeirShpilraien committed Apr 10, 2023
1 parent 0fa0b24 commit af46fd4
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 4 deletions.
9 changes: 7 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,26 @@ strum_macros = "0.24"
backtrace = "0.3"
linkme = "0.3"
serde = { version = "1", features = ["derive"] }
cfg-if = "1"
redis-module-derive = { path = "./redismodule-rs-derive"}

[dev-dependencies]
anyhow = "1.0.38"
redis = "0.22.1"
lazy_static = "1.4.0"
redis-module-derive = { path = "./redismodule-rs-derive"}
redis-module = { path = "./", features = ["test", "experimental-api"]}

[build-dependencies]
bindgen = "0.64"
cc = "1.0"

[features]
default = ["experimental-api"]
default = ["experimental-api", "min-redis-compatibility-version-6-0"]
experimental-api = []
min-redis-compatibility-version-7-2 = []
min-redis-compatibility-version-7-0 = []
min-redis-compatibility-version-6-2 = []
min-redis-compatibility-version-6-0 = []

# Workaround to allow cfg(feature = "test") in dependencies:
# https://github.com/rust-lang/rust/issues/59168#issuecomment-472653680
Expand Down
2 changes: 2 additions & 0 deletions redismodule-rs-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ categories = ["database", "api-bindings"]
[dependencies]
syn = { version="1.0", features = ["full"]}
quote = "1.0"
lazy_static = "1.4.0"
proc-macro2 = "1.0.56"

[lib]
name = "redis_module_derive"
Expand Down
40 changes: 40 additions & 0 deletions redismodule-rs-derive/src/api_versions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
use std::collections::HashMap;

use proc_macro2::TokenStream;
use quote::quote;

lazy_static::lazy_static! {
pub(crate) static ref API_VERSION_MAPPING: HashMap<String, usize> = HashMap::from([
("RedisModule_AddPostNotificationJob".to_string(), 70200),
("RedisModule_SetCommandACLCategories".to_string(), 70200),
("RedisModule_GetOpenKeyModesAll".to_string(), 70200),
("RedisModule_CallReplyPromiseSetUnblockHandler".to_string(), 70200),
("RedisModule_CallReplyPromiseAbort".to_string(), 70200),
("RedisModule_Microseconds".to_string(), 70200),
("RedisModule_CachedMicroseconds".to_string(), 70200),
("RedisModule_RegisterAuthCallback".to_string(), 70200),
("RedisModule_BlockClientOnKeysWithFlags".to_string(), 70200),
("RedisModule_GetModuleOptionsAll".to_string(), 70200),
("RedisModule_BlockClientGetPrivateData".to_string(), 70200),
("RedisModule_BlockClientSetPrivateData".to_string(), 70200),
("RedisModule_BlockClientOnAuth".to_string(), 70200),
("RedisModule_ACLAddLogEntryByUserName".to_string(), 70200),
]);

pub(crate) static ref API_OLDEST_VERSION: usize = 60000;
pub(crate) static ref ALL_VERSIONS: Vec<(usize, String)> = vec![
(60000, "min-redis-compatibility-version-6-0".to_string()),
(60200, "min-redis-compatibility-version-6-2".to_string()),
(70000, "min-redis-compatibility-version-7-0".to_string()),
(70200, "min-redis-compatibility-version-7-2".to_string()),
];
}

pub(crate) fn get_feature_flags(min_required_version: usize) -> (Vec<TokenStream>, Vec<TokenStream>) {
let all_lower_versions: Vec<&str> = ALL_VERSIONS.iter().filter_map(|(v, s)| if *v < min_required_version {Some(s.as_str())} else {None}).collect();
let all_upper_versions: Vec<&str> = ALL_VERSIONS.iter().filter_map(|(v, s)| if *v >= min_required_version {Some(s.as_str())} else {None}).collect();
(
all_lower_versions.into_iter().map(|s| quote!(feature = #s).into()).collect(),
all_upper_versions.into_iter().map(|s| quote!(feature = #s).into()).collect(),
)
}
108 changes: 106 additions & 2 deletions redismodule-rs-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
extern crate proc_macro;
use proc_macro::TokenStream;

mod api_versions;

use proc_macro::{TokenStream};
use quote::quote;
use syn;
use syn::parse::{Parse, ParseStream, Result};
use syn::punctuated::Punctuated;
use syn::token::{RArrow, Paren};
use syn::{self, parse_macro_input, Token, Type, ReturnType, TypeTuple};
use syn::ItemFn;
use syn::Ident;
use api_versions::{API_VERSION_MAPPING, API_OLDEST_VERSION, get_feature_flags};

#[proc_macro_attribute]
pub fn role_changed_event_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -55,3 +63,99 @@ pub fn module_changed_event_handler(_attr: TokenStream, item: TokenStream) -> To
};
gen.into()
}

#[derive(Debug)]
struct Args{
requested_apis: Vec<Ident>
}

impl Parse for Args{
fn parse(input: ParseStream) -> Result<Self> {
// parses a,b,c, or a,b,c where a,b and c are Indent
let vars = Punctuated::<Ident, Token![,]>::parse_terminated(input)?;
Ok(Args {
requested_apis: vars.into_iter().collect(),
})
}
}

#[proc_macro_attribute]
pub fn redismodule_api(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as Args);
let original_func = item.clone();
let mut original_func = parse_macro_input!(original_func as ItemFn);
let original_func_name = original_func.sig.ident.clone();
let original_func_name: Ident = Ident::new(&format!("{}_inner", original_func_name.to_string()), original_func_name.span());
original_func.sig.ident = original_func_name.clone();

let mut use_self = false;
let input_names:Vec<Ident> = original_func.sig.inputs.clone().into_iter().filter_map(|v| {
match v {
syn::FnArg::Receiver(_) => use_self = true,
syn::FnArg::Typed(pat_type) => {
if let syn::Pat::Ident(pat_ident) = *pat_type.pat.clone() {
return Some(pat_ident.ident)
}
}
}
None
}).collect();
let func = parse_macro_input!(item as ItemFn);

let minimum_require_version = args.requested_apis.iter().fold(*API_OLDEST_VERSION, |min_api_version, item|{
// if we do not have a version mapping, we assume the API exists and return the minimum version.
let api_version = API_VERSION_MAPPING.get(&item.to_string()).map(|v| *v).unwrap_or(*API_OLDEST_VERSION);
api_version.max(min_api_version)
});

if *API_OLDEST_VERSION == minimum_require_version {
// all API exists on the older version supported so we can just return the function as is.
return quote!(#original_func).into();
}

let requested_apis = args.requested_apis;
let requested_apis_str: Vec<String> = requested_apis.iter().map(|e| e.to_string()).collect();
let vis = func.vis;
let inner_return_return_type = match func.sig.output.clone() {
ReturnType::Default => Box::new(Type::Tuple(TypeTuple{paren_token: Paren::default(), elems: Punctuated::new()})),
ReturnType::Type(_, t) => t,
};
let new_return_return_type = Type::Path(syn::parse(quote!(
crate::apierror::APIResult<#inner_return_return_type>
).into()).unwrap());
let mut sig = func.sig;
sig.output = ReturnType::Type(RArrow::default(), Box::new(new_return_return_type));

let original_function_call = if use_self {
quote!(self.#original_func_name(#(#input_names, )*))
} else {
quote!(#original_func_name(#(#input_names, )*))
};

let new_func = quote!(
#original_func

#vis #sig {
#(
unsafe{crate::raw::#requested_apis.ok_or(concat!(#requested_apis_str, " does not exists"))?};
)*

Ok(#original_function_call)
}
);

let (all_lower_features, all_upper_features) = get_feature_flags(minimum_require_version);

let gen = quote! {
cfg_if::cfg_if! {
if #[cfg(any(#(#all_lower_features, )*))] {
#new_func
} else if #[cfg(any(#(#all_upper_features, )*))] {
#original_func
} else {
compile_error!("min-redis-compatibility-version is not set correctly")
}
}
};
gen.into()
}
2 changes: 2 additions & 0 deletions src/apierror.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub type APIError = String;
pub type APIResult<T> = Result<T, APIError>;
28 changes: 28 additions & 0 deletions src/context/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use bitflags::bitflags;
use redis_module_derive::redismodule_api;
use std::ffi::CString;
use std::os::raw::c_void;
use std::os::raw::{c_char, c_int, c_long, c_longlong};
use std::ptr::{self, NonNull};
use std::sync::atomic::{AtomicPtr, Ordering};
Expand Down Expand Up @@ -671,6 +673,32 @@ impl Context {
let acl_permission_result: Result<(), &str> = acl_permission_result.into();
acl_permission_result.map_err(|_e| RedisError::Str("User does not have permissions on key"))
}

#[redismodule_api(RedisModule_AddPostNotificationJob)]
pub fn add_post_notification_job<F: Fn(&Context)>(self, callback: F) {
let callback = Box::into_raw(Box::new(callback));
unsafe {
raw::RedisModule_AddPostNotificationJob.unwrap()(
self.ctx,
Some(post_notification_job::<F>),
callback as *mut c_void,
Some(post_notification_job_free_callback::<F>),
);
}
}
}

extern "C" fn post_notification_job_free_callback<F: Fn(&Context)>(pd: *mut c_void) {
unsafe { Box::from_raw(pd as *mut F) };
}

extern "C" fn post_notification_job<F: Fn(&Context)>(
ctx: *mut raw::RedisModuleCtx,
pd: *mut c_void,
) {
let callback = unsafe { &*(pd as *mut F) };
let ctx = Context::new(ctx);
callback(&ctx);
}

unsafe impl RedisLockIndicator for Context {}
Expand Down
3 changes: 3 additions & 0 deletions src/include/redismodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ typedef struct RedisModuleKeyOptCtx RedisModuleKeyOptCtx;
typedef int (*RedisModuleCmdFunc)(RedisModuleCtx *ctx, RedisModuleString **argv, int argc);
typedef void (*RedisModuleDisconnectFunc)(RedisModuleCtx *ctx, RedisModuleBlockedClient *bc);
typedef int (*RedisModuleNotificationFunc)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key);
typedef void (*RedisModulePostNotificationJobFunc) (RedisModuleCtx *ctx, void *pd);
typedef void *(*RedisModuleTypeLoadFunc)(RedisModuleIO *rdb, int encver);
typedef void (*RedisModuleTypeSaveFunc)(RedisModuleIO *rdb, void *value);
typedef int (*RedisModuleTypeAuxLoadFunc)(RedisModuleIO *rdb, int encver, int when);
Expand Down Expand Up @@ -1135,6 +1136,7 @@ REDISMODULE_API void (*RedisModule_ThreadSafeContextLock)(RedisModuleCtx *ctx) R
REDISMODULE_API int (*RedisModule_ThreadSafeContextTryLock)(RedisModuleCtx *ctx) REDISMODULE_ATTR;
REDISMODULE_API void (*RedisModule_ThreadSafeContextUnlock)(RedisModuleCtx *ctx) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_SubscribeToKeyspaceEvents)(RedisModuleCtx *ctx, int types, RedisModuleNotificationFunc cb) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_AddPostNotificationJob)(RedisModuleCtx *ctx, RedisModulePostNotificationJobFunc callback, void *pd, void (*free_pd)(void*)) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_NotifyKeyspaceEvent)(RedisModuleCtx *ctx, int type, const char *event, RedisModuleString *key) REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_GetNotifyKeyspaceEvents)() REDISMODULE_ATTR;
REDISMODULE_API int (*RedisModule_BlockedClientDisconnected)(RedisModuleCtx *ctx) REDISMODULE_ATTR;
Expand Down Expand Up @@ -1477,6 +1479,7 @@ static void RedisModule_InitAPI(RedisModuleCtx *ctx) {
REDISMODULE_GET_API(BlockedClientMeasureTimeEnd);
REDISMODULE_GET_API(SetDisconnectCallback);
REDISMODULE_GET_API(SubscribeToKeyspaceEvents);
REDISMODULE_GET_API(AddPostNotificationJob);
REDISMODULE_GET_API(NotifyKeyspaceEvent);
REDISMODULE_GET_API(GetNotifyKeyspaceEvents);
REDISMODULE_GET_API(BlockedClientDisconnected);
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use strum_macros::AsRefStr;
extern crate num_traits;

pub mod alloc;
pub mod apierror;
pub mod error;
pub mod native_types;
pub mod raw;
Expand Down

0 comments on commit af46fd4

Please sign in to comment.