diff --git a/src/cached_enforcer.rs b/src/cached_enforcer.rs index 94442fd7..0da85503 100644 --- a/src/cached_enforcer.rs +++ b/src/cached_enforcer.rs @@ -8,7 +8,7 @@ use crate::{ emitter::{clear_cache, Event, EventData, EventEmitter}, enforcer::EnforceContext, enforcer::Enforcer, - model::Model, + model::{Model, OperatorFunction}, rbac::RoleManager, Result, }; @@ -27,7 +27,7 @@ use crate::{error::ModelError, get_or_err}; use async_trait::async_trait; use parking_lot::RwLock; -use rhai::{Dynamic, ImmutableString}; +use rhai::Dynamic; use std::{collections::HashMap, sync::Arc}; @@ -123,11 +123,7 @@ impl CoreApi for CachedEnforcer { } #[inline] - fn add_function( - &mut self, - fname: &str, - f: fn(ImmutableString, ImmutableString) -> bool, - ) { + fn add_function(&mut self, fname: &str, f: OperatorFunction) { self.enforcer.add_function(fname, f); } diff --git a/src/core_api.rs b/src/core_api.rs index 131f1676..352bf432 100644 --- a/src/core_api.rs +++ b/src/core_api.rs @@ -1,7 +1,7 @@ use crate::{ - enforcer::EnforceContext, Adapter, Effector, EnforceArgs, Event, - EventEmitter, Filter, Model, Result, RoleManager, TryIntoAdapter, - TryIntoModel, + enforcer::EnforceContext, model::OperatorFunction, Adapter, Effector, + EnforceArgs, Event, EventEmitter, Filter, Model, Result, RoleManager, + TryIntoAdapter, TryIntoModel, }; #[cfg(feature = "watcher")] @@ -15,7 +15,6 @@ use crate::emitter::EventData; use async_trait::async_trait; use parking_lot::RwLock; -use rhai::ImmutableString; use std::sync::Arc; @@ -33,11 +32,7 @@ pub trait CoreApi: Send + Sync { ) -> Result where Self: Sized; - fn add_function( - &mut self, - fname: &str, - f: fn(ImmutableString, ImmutableString) -> bool, - ); + fn add_function(&mut self, fname: &str, f: OperatorFunction); fn get_model(&self) -> &dyn Model; fn get_mut_model(&mut self) -> &mut dyn Model; fn get_adapter(&self) -> &dyn Adapter; diff --git a/src/enforcer.rs b/src/enforcer.rs index ddcbd589..7f87eae1 100644 --- a/src/enforcer.rs +++ b/src/enforcer.rs @@ -7,7 +7,7 @@ use crate::{ error::{ModelError, PolicyError, RequestError}, get_or_err, get_or_err_with_context, management_api::MgmtApi, - model::{FunctionMap, Model}, + model::{FunctionMap, Model, OperatorFunction}, rbac::{DefaultRoleManager, RoleManager}, register_g_function, util::{escape_assertion, escape_eval}, @@ -351,6 +351,23 @@ impl Enforcer { })) } + fn register_function(engine: &mut Engine, key: &str, f: OperatorFunction) { + match f { + OperatorFunction::Arg0(func) => { + engine.register_fn(key, func); + } + OperatorFunction::Arg1(func) => { + engine.register_fn(key, func); + } + OperatorFunction::Arg2(func) => { + engine.register_fn(key, func); + } + OperatorFunction::Arg3(func) => { + engine.register_fn(key, func); + } + } + } + pub(crate) fn register_g_functions(&mut self) -> Result<()> { if let Some(ast_map) = self.model.get_model().get("g") { for (fname, ast) in ast_map { @@ -380,7 +397,7 @@ impl CoreApi for Enforcer { engine.register_global_module(CASBIN_PACKAGE.as_shared_module()); for (key, &func) in fm.get_functions() { - engine.register_fn(key, func); + Self::register_function(&mut engine, key, func); } let mut e = Self { @@ -425,13 +442,9 @@ impl CoreApi for Enforcer { } #[inline] - fn add_function( - &mut self, - fname: &str, - f: fn(ImmutableString, ImmutableString) -> bool, - ) { + fn add_function(&mut self, fname: &str, f: OperatorFunction) { self.fm.add_function(fname, f); - self.engine.register_fn(fname, f); + Self::register_function(&mut self.engine, fname, f); } #[inline] @@ -1340,7 +1353,11 @@ mod tests { e.add_function( "keyMatchCustom", - |s1: ImmutableString, s2: ImmutableString| key_match(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + key_match(&s1, &s2).into() + }, + ), ); assert_eq!( diff --git a/src/model/function_map.rs b/src/model/function_map.rs index 01285e0d..bc9faddf 100644 --- a/src/model/function_map.rs +++ b/src/model/function_map.rs @@ -10,51 +10,107 @@ use globset::GlobBuilder; use ip_network::IpNetwork; use once_cell::sync::Lazy; use regex::Regex; -use rhai::ImmutableString; +use rhai::{Dynamic, ImmutableString}; static MAT_B: Lazy = Lazy::new(|| Regex::new(r":[^/]*").unwrap()); static MAT_P: Lazy = Lazy::new(|| Regex::new(r"\{[^/]*\}").unwrap()); use std::{borrow::Cow, collections::HashMap}; +#[derive(Clone, Copy)] +pub enum OperatorFunction { + Arg0(fn() -> Dynamic), + Arg1(fn(ImmutableString) -> Dynamic), + Arg2(fn(ImmutableString, ImmutableString) -> Dynamic), + Arg3(fn(ImmutableString, ImmutableString, ImmutableString) -> Dynamic), +} + pub struct FunctionMap { - pub(crate) fm: - HashMap bool>, + pub(crate) fm: HashMap, } impl Default for FunctionMap { fn default() -> FunctionMap { - let mut fm: HashMap< - String, - fn(ImmutableString, ImmutableString) -> bool, - > = HashMap::new(); + let mut fm: HashMap = HashMap::new(); fm.insert( "keyMatch".to_owned(), - |s1: ImmutableString, s2: ImmutableString| key_match(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + key_match(&s1, &s2).into() + }, + ), + ); + fm.insert( + "keyGet".to_owned(), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + key_get(&s1, &s2).into() + }, + ), ); fm.insert( "keyMatch2".to_owned(), - |s1: ImmutableString, s2: ImmutableString| key_match2(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + key_match2(&s1, &s2).into() + }, + ), + ); + fm.insert( + "keyGet2".to_owned(), + OperatorFunction::Arg3( + |s1: ImmutableString, + s2: ImmutableString, + s3: ImmutableString| { + key_get2(&s1, &s2, &s3).into() + }, + ), ); fm.insert( "keyMatch3".to_owned(), - |s1: ImmutableString, s2: ImmutableString| key_match3(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + key_match3(&s1, &s2).into() + }, + ), + ); + fm.insert( + "keyGet3".to_owned(), + OperatorFunction::Arg3( + |s1: ImmutableString, + s2: ImmutableString, + s3: ImmutableString| { + key_get3(&s1, &s2, &s3).into() + }, + ), ); fm.insert( "regexMatch".to_owned(), - |s1: ImmutableString, s2: ImmutableString| regex_match(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + regex_match(&s1, &s2).into() + }, + ), ); #[cfg(feature = "glob")] fm.insert( "globMatch".to_owned(), - |s1: ImmutableString, s2: ImmutableString| glob_match(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + glob_match(&s1, &s2).into() + }, + ), ); #[cfg(feature = "ip")] fm.insert( "ipMatch".to_owned(), - |s1: ImmutableString, s2: ImmutableString| ip_match(&s1, &s2), + OperatorFunction::Arg2( + |s1: ImmutableString, s2: ImmutableString| { + ip_match(&s1, &s2).into() + }, + ), ); FunctionMap { fm } @@ -63,19 +119,14 @@ impl Default for FunctionMap { impl FunctionMap { #[inline] - pub fn add_function( - &mut self, - fname: &str, - f: fn(ImmutableString, ImmutableString) -> bool, - ) { + pub fn add_function(&mut self, fname: &str, f: OperatorFunction) { self.fm.insert(fname.to_owned(), f); } #[inline] pub fn get_functions( &self, - ) -> impl Iterator bool)> - { + ) -> impl Iterator { self.fm.iter() } } @@ -93,6 +144,18 @@ pub fn key_match(key1: &str, key2: &str) -> bool { } } +// key_get returns the matched part +// For example, "/foo/bar/foo" matches "/foo/*" +// "bar/foo" will be returned. +pub fn key_get(key1: &str, key2: &str) -> String { + if let Some(i) = key2.find('*') { + if key1.len() > i && key1[..i] == key2[..i] { + return key1[i..].to_string(); + } + } + "".to_string() +} + // key_match2 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a * // For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/:resource" pub fn key_match2(key1: &str, key2: &str) -> bool { @@ -107,6 +170,35 @@ pub fn key_match2(key1: &str, key2: &str) -> bool { regex_match(key1, &format!("^{}$", key2)) } +// key_get2 returns value matched pattern +// For example, "/resource1" matches "/:resource" +// if the pathVar == "resource", then "resource1" will be returned. +pub fn key_get2(key1: &str, key2: &str, path_var: &str) -> String { + let key2: Cow = if key2.contains("/*") { + key2.replace("/*", "/.*").into() + } else { + key2.into() + }; + + let re = Regex::new(r":[^/]+").unwrap(); + let keys: Vec<_> = re.find_iter(&key2).collect(); + let key2 = re.replace_all(&key2, "([^/]+)").to_string(); + let key2 = format!("^{}$", key2); + + if let Ok(re2) = Regex::new(&key2) { + if let Some(caps) = re2.captures(key1) { + for (i, key) in keys.iter().enumerate() { + if path_var == &key.as_str()[1..] { + return caps + .get(i + 1) + .map_or("".to_string(), |m| m.as_str().to_string()); + } + } + } + } + "".to_string() +} + // key_match3 determines whether key1 matches the pattern of key2 (similar to RESTful path), key2 can contain a * // For example, "/foo/bar" matches "/foo/*", "/resource1" matches "/{resource}" pub fn key_match3(key1: &str, key2: &str) -> bool { @@ -121,6 +213,38 @@ pub fn key_match3(key1: &str, key2: &str) -> bool { regex_match(key1, &format!("^{}$", key2)) } +// key_get3 returns value matched pattern +// For example, "project/proj_project1_admin/" matches "project/proj_{project}_admin/" +// if the pathVar == "project", then "project1" will be returned. +pub fn key_get3(key1: &str, key2: &str, path_var: &str) -> String { + let key2: Cow = if key2.contains("/*") { + key2.replace("/*", "/.*").into() + } else { + key2.into() + }; + + let re = Regex::new(r"\{[^/]+?\}").unwrap(); + let keys: Vec<_> = re.find_iter(&key2).collect(); + let key2 = re.replace_all(&key2, "([^/]+?)").to_string(); + let key2 = Regex::new(r"\{") + .unwrap() + .replace_all(&key2, "\\{") + .to_string(); + let key2 = format!("^{}$", key2); + + let re2 = Regex::new(&key2).unwrap(); + if let Some(caps) = re2.captures(key1) { + for (i, key) in keys.iter().enumerate() { + if path_var == &key.as_str()[1..key.as_str().len() - 1] { + return caps + .get(i + 1) + .map_or("".to_string(), |m| m.as_str().to_string()); + } + } + } + "".to_string() +} + // regex_match determines whether key1 matches the pattern of key2 in regular expression. pub fn regex_match(key1: &str, key2: &str) -> bool { Regex::new(key2).unwrap().is_match(key1) @@ -184,6 +308,19 @@ mod tests { assert!(key_match("/bar", "/ba*")); } + #[test] + fn test_key_get() { + assert_eq!(key_get("/foo", "/foo"), ""); + assert_eq!(key_get("/foo", "/foo*"), ""); + assert_eq!(key_get("/foo", "/foo/*"), ""); + assert_eq!(key_get("/foo/bar", "/foo"), ""); + assert_eq!(key_get("/foo/bar", "/foo*"), "/bar"); + assert_eq!(key_get("/foo/bar", "/foo/*"), "bar"); + assert_eq!(key_get("/foobar", "/foo"), ""); + assert_eq!(key_get("/foobar", "/foo*"), "bar"); + assert_eq!(key_get("/foobar", "/foo/*"), ""); + } + #[test] fn test_key_match2() { assert!(key_match2("/foo/bar", "/foo/*")); @@ -200,6 +337,58 @@ mod tests { assert!(!key_match2("/foo/bar", "/foo/:/baz")); } + #[test] + fn test_key_get2() { + assert_eq!(key_get2("/foo", "/foo", "id"), ""); + assert_eq!(key_get2("/foo", "/foo*", "id"), ""); + assert_eq!(key_get2("/foo", "/foo/*", "id"), ""); + assert_eq!(key_get2("/foo/bar", "/foo", "id"), ""); + assert_eq!(key_get2("/foo/bar", "/foo*", "id"), ""); + assert_eq!(key_get2("/foo/bar", "/foo/*", "id"), ""); + assert_eq!(key_get2("/foobar", "/foo", "id"), ""); + assert_eq!(key_get2("/foobar", "/foo*", "id"), ""); + assert_eq!(key_get2("/foobar", "/foo/*", "id"), ""); + + assert_eq!(key_get2("/", "/:resource", "resource"), ""); + assert_eq!( + key_get2("/resource1", "/:resource", "resource"), + "resource1" + ); + assert_eq!(key_get2("/myid", "/:id/using/:resId", "id"), ""); + assert_eq!( + key_get2("/myid/using/myresid", "/:id/using/:resId", "id"), + "myid" + ); + assert_eq!( + key_get2("/myid/using/myresid", "/:id/using/:resId", "resId"), + "myresid" + ); + + assert_eq!(key_get2("/proxy/myid", "/proxy/:id/*", "id"), ""); + assert_eq!(key_get2("/proxy/myid/", "/proxy/:id/*", "id"), "myid"); + assert_eq!(key_get2("/proxy/myid/res", "/proxy/:id/*", "id"), "myid"); + assert_eq!( + key_get2("/proxy/myid/res/res2", "/proxy/:id/*", "id"), + "myid" + ); + assert_eq!( + key_get2("/proxy/myid/res/res2/res3", "/proxy/:id/*", "id"), + "myid" + ); + assert_eq!( + key_get2("/proxy/myid/res/res2/res3", "/proxy/:id/res/*", "id"), + "myid" + ); + assert_eq!(key_get2("/proxy/", "/proxy/:id/*", "id"), ""); + + assert_eq!(key_get2("/alice", "/:id", "id"), "alice"); + assert_eq!(key_get2("/alice/all", "/:id/all", "id"), "alice"); + assert_eq!(key_get2("/alice", "/:id/all", "id"), ""); + assert_eq!(key_get2("/alice/all", "/:id", "id"), ""); + + assert_eq!(key_get2("/alice/all", "/:/all", ""), ""); + } + #[test] fn test_regex_match() { assert!(regex_match("foobar", "^foo*")); @@ -222,6 +411,109 @@ mod tests { assert!(key_match3("/foo/bar/baz", "/foo/{}/baz")); } + #[test] + fn test_key_get3() { + assert_eq!(key_get3("/foo", "/foo", "id"), ""); + assert_eq!(key_get3("/foo", "/foo*", "id"), ""); + assert_eq!(key_get3("/foo", "/foo/*", "id"), ""); + assert_eq!(key_get3("/foo/bar", "/foo", "id"), ""); + assert_eq!(key_get3("/foo/bar", "/foo*", "id"), ""); + assert_eq!(key_get3("/foo/bar", "/foo/*", "id"), ""); + assert_eq!(key_get3("/foobar", "/foo", "id"), ""); + assert_eq!(key_get3("/foobar", "/foo*", "id"), ""); + assert_eq!(key_get3("/foobar", "/foo/*", "id"), ""); + + assert_eq!(key_get3("/", "/{resource}", "resource"), ""); + assert_eq!( + key_get3("/resource1", "/{resource}", "resource"), + "resource1" + ); + assert_eq!(key_get3("/myid", "/{id}/using/{resId}", "id"), ""); + assert_eq!( + key_get3("/myid/using/myresid", "/{id}/using/{resId}", "id"), + "myid" + ); + assert_eq!( + key_get3("/myid/using/myresid", "/{id}/using/{resId}", "resId"), + "myresid" + ); + + assert_eq!(key_get3("/proxy/myid", "/proxy/{id}/*", "id"), ""); + assert_eq!(key_get3("/proxy/myid/", "/proxy/{id}/*", "id"), "myid"); + assert_eq!(key_get3("/proxy/myid/res", "/proxy/{id}/*", "id"), "myid"); + assert_eq!( + key_get3("/proxy/myid/res/res2", "/proxy/{id}/*", "id"), + "myid" + ); + assert_eq!( + key_get3("/proxy/myid/res/res2/res3", "/proxy/{id}/*", "id"), + "myid" + ); + assert_eq!( + key_get3("/proxy/myid/res/res2/res3", "/proxy/{id}/res/*", "id"), + "myid" + ); + assert_eq!(key_get3("/proxy/", "/proxy/{id}/*", "id"), ""); + + assert_eq!( + key_get3( + "/api/group1_group_name/project1_admin/info", + "/api/{proj}_admin/info", + "proj" + ), + "" + ); + assert_eq!( + key_get3("/{id/using/myresid", "/{id/using/{resId}", "resId"), + "myresid" + ); + assert_eq!( + key_get3( + "/{id/using/myresid/status}", + "/{id/using/{resId}/status}", + "resId" + ), + "myresid" + ); + + assert_eq!( + key_get3("/proxy/myid/res/res2/res3", "/proxy/{id}/*/{res}", "res"), + "res3" + ); + assert_eq!( + key_get3( + "/api/project1_admin/info", + "/api/{proj}_admin/info", + "proj" + ), + "project1" + ); + assert_eq!( + key_get3( + "/api/group1_group_name/project1_admin/info", + "/api/{g}_{gn}/{proj}_admin/info", + "g" + ), + "group1" + ); + assert_eq!( + key_get3( + "/api/group1_group_name/project1_admin/info", + "/api/{g}_{gn}/{proj}_admin/info", + "gn" + ), + "group_name" + ); + assert_eq!( + key_get3( + "/api/group1_group_name/project1_admin/info", + "/api/{g}_{gn}/{proj}_admin/info", + "proj" + ), + "project1" + ); + } + #[cfg(feature = "ip")] #[test] fn test_ip_match() {