diff --git a/src/builtins/numbers.rs b/src/builtins/numbers.rs index 7b48eae6..ad199de0 100644 --- a/src/builtins/numbers.rs +++ b/src/builtins/numbers.rs @@ -78,8 +78,7 @@ fn range(span: &Span, params: &[Ref], args: &[Value]) -> Result { } let incr = if v2 >= v1 { 1 } else { -1 } as Float; - let mut values = vec![]; - values.reserve((v2 - v1).abs() as usize + 1); + let mut values = Vec::with_capacity((v2 - v1).abs() as usize + 1); let mut v = v1; while v != v2 { diff --git a/src/engine.rs b/src/engine.rs index 623d8aca..ef8b9b8e 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -6,6 +6,7 @@ use crate::interpreter::*; use crate::lexer::*; use crate::parser::*; use crate::scheduler::*; +use crate::utils::gather_functions; use crate::value::*; use anyhow::Result; @@ -13,8 +14,8 @@ use anyhow::Result; #[derive(Clone)] pub struct Engine { modules: Vec>, - input: Value, - data: Value, + interpreter: Interpreter, + prepared: bool, } impl Default for Engine { @@ -27,8 +28,8 @@ impl Engine { pub fn new() -> Self { Self { modules: vec![], - input: Value::new_object(), - data: Value::new_object(), + interpreter: Interpreter::new(), + prepared: false, } } @@ -36,6 +37,8 @@ impl Engine { let source = Source::new(path, rego); let mut parser = Parser::new(&source)?; self.modules.push(Ref::new(parser.parse()?)); + // if policies change, interpreter needs to be prepared again + self.prepared = false; Ok(()) } @@ -43,36 +46,105 @@ impl Engine { let source = Source::from_file(path)?; let mut parser = Parser::new(&source)?; self.modules.push(Ref::new(parser.parse()?)); + self.prepared = false; Ok(()) } pub fn set_input(&mut self, input: Value) { - self.input = input; + self.interpreter.set_input(input); } pub fn clear_data(&mut self) { - self.data = Value::new_object(); + self.interpreter.set_data(Value::new_object()); + self.prepared = false; } pub fn add_data(&mut self, data: Value) -> Result<()> { - self.data.merge(data) + self.prepared = false; + self.interpreter.get_data_mut().merge(data) } - pub fn eval_query(&self, query: String, enable_tracing: bool) -> Result { - // Analyze the modules and determine how statements must be scheduled. - let analyzer = Analyzer::new(); - let schedule = analyzer.analyze(&self.modules)?; + pub fn get_modules(&mut self) -> &Vec> { + &self.modules + } + + fn prepare_for_eval(&mut self, enable_tracing: bool) -> Result<()> { + self.interpreter.set_traces(enable_tracing); + + // if the data/policies have changed or the interpreter has never been prepared + if !self.prepared { + // Analyze the modules and determine how statements must be scheduled. + let analyzer = Analyzer::new(); + let schedule = analyzer.analyze(&self.modules)?; + + self.interpreter.init_with_document()?; + self.interpreter.set_schedule(Some(schedule)); + self.interpreter.set_modules(&self.modules); + + self.interpreter.clear_builtins_cache(); + // when the interpreter is prepared the initial data is saved + // the data will be reset to init_data each time clean_internal_evaluation_state is called + let init_data = self.interpreter.get_data_mut().clone(); + self.interpreter.set_init_data(init_data); + + self.interpreter + .set_functions(gather_functions(&self.modules)?); + self.interpreter.gather_rules()?; + self.prepared = true; + } + + Ok(()) + } + + pub fn eval_rule( + &mut self, + module: &Ref, + rule: &Ref, + enable_tracing: bool, + ) -> Result { + self.prepare_for_eval(enable_tracing)?; + self.interpreter.clean_internal_evaluation_state(); - // Create interpreter object. - let mut interpreter = Interpreter::new(&self.modules)?; + self.interpreter.eval_rule(module, rule)?; + + Ok(self.interpreter.get_data_mut().clone()) + } + + pub fn eval_modules(&mut self, enable_tracing: bool) -> Result { + self.prepare_for_eval(enable_tracing)?; + self.interpreter.clean_internal_evaluation_state(); + + // Ensure that each module has an empty object + for m in &self.modules { + let path = Parser::get_path_ref_components(&m.package.refr)?; + let path: Vec<&str> = path.iter().map(|s| *s.text()).collect(); + let vref = + Interpreter::make_or_get_value_mut(self.interpreter.get_data_mut(), &path[..])?; + if *vref == Value::Undefined { + *vref = Value::new_object(); + } + } + + self.interpreter.check_default_rules()?; + for module in self.modules.clone() { + for rule in &module.policy { + self.interpreter.eval_rule(&module, rule)?; + } + } + // Defer the evaluation of the default rules to here + for module in self.modules.clone() { + let prev_module = self.interpreter.set_current_module(Some(module.clone()))?; + for rule in &module.policy { + self.interpreter.eval_default_rule(rule)?; + } + self.interpreter.set_current_module(prev_module)?; + } + + Ok(self.interpreter.get_data_mut().clone()) + } - // Evaluate all the modules. - interpreter.eval( - &Some(self.data.clone()), - &Some(self.input.clone()), - false, - Some(schedule), - )?; + pub fn eval_query(&mut self, query: String, enable_tracing: bool) -> Result { + self.eval_modules(false)?; // Parse the query. let query_len = query.len(); @@ -88,7 +160,9 @@ impl Engine { let query_node = Ref::new(parser.parse_query(query_span, "")?); let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?; - let results = interpreter.eval_user_query(&query_node, &query_schedule, enable_tracing)?; + let results = + self.interpreter + .eval_user_query(&query_node, &query_schedule, enable_tracing)?; Ok(results) } } diff --git a/src/interpreter.rs b/src/interpreter.rs index efcbd08f..7c7f7c3f 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -20,12 +20,12 @@ type Scope = BTreeMap; type DefaultRuleInfo = (Ref, Option); type ContextExprs = (Option>, Option>); +#[derive(Clone)] pub struct Interpreter { modules: Vec>, module: Option>, schedule: Option, current_module_path: String, - prepared: bool, input: Value, data: Value, init_data: Value, @@ -46,6 +46,12 @@ pub struct Interpreter { allow_deprecated: bool, } +impl Default for Interpreter { + fn default() -> Self { + Self::new() + } +} + #[derive(Debug, Clone, Serialize)] pub struct QueryResult { // Expressions is shown first to match OPA. @@ -88,20 +94,16 @@ struct LoopExpr { } impl Interpreter { - pub fn new(modules: &[Ref]) -> Result { - let mut with_document = Value::new_object(); - *Self::make_or_get_value_mut(&mut with_document, &["data"])? = Value::new_object(); - *Self::make_or_get_value_mut(&mut with_document, &["input"])? = Value::new_object(); - Ok(Interpreter { - modules: modules.to_vec(), + pub fn new() -> Interpreter { + Interpreter { + modules: vec![], module: None, schedule: None, current_module_path: String::default(), - prepared: false, input: Value::new_object(), data: Value::new_object(), init_data: Value::new_object(), - with_document, + with_document: Value::new_object(), with_functions: BTreeMap::new(), scopes: vec![Scope::new()], contexts: vec![], @@ -115,47 +117,68 @@ impl Interpreter { no_rules_lookup: false, traces: None, allow_deprecated: true, - }) + } + } + + pub fn set_schedule(&mut self, schedule: Option) { + self.schedule = schedule; + } + + pub fn set_functions(&mut self, functions: FunctionTable) { + self.functions = functions; } - pub fn get_modules(&mut self) -> &mut Vec> { + pub fn set_modules(&mut self, modules: &[Ref]) { + self.modules = modules.to_vec(); + } + + pub fn get_modules_mut(&mut self) -> &mut Vec> { &mut self.modules } + pub fn set_init_data(&mut self, init_data: Value) { + self.init_data = init_data; + } + pub fn set_data(&mut self, data: Value) { self.data = data; } - pub fn get_data(&mut self) -> &mut Value { + pub fn get_data_mut(&mut self) -> &mut Value { &mut self.data } - fn clean_internal_evaluation_state(&mut self) { - self.data = self.init_data.clone(); - self.processed.clear(); - self.loop_var_values.clear(); - self.scopes = vec![Scope::new()]; - self.contexts = vec![]; - } - - fn checks_for_eval(&mut self, input: &Option, enable_tracing: bool) -> Result<()> { - if !self.prepared { - bail!("prepare_for_eval should be called before eval_modules"); - } - + pub fn set_traces(&mut self, enable_tracing: bool) { self.traces = match enable_tracing { true => Some(vec![]), false => None, }; + } - if let Some(input) = input { - self.input = input.clone(); - info!("input: {:#?}", self.input); - } + pub fn set_input(&mut self, input: Value) { + self.input = input; + info!("input: {:#?}", self.input); + } + + pub fn init_with_document(&mut self) -> Result<()> { + *Self::make_or_get_value_mut(&mut self.with_document, &["data"])? = Value::new_object(); + *Self::make_or_get_value_mut(&mut self.with_document, &["input"])? = Value::new_object(); Ok(()) } + pub fn clear_builtins_cache(&mut self) { + self.builtins_cache.clear(); + } + + pub fn clean_internal_evaluation_state(&mut self) { + self.data = self.init_data.clone(); + self.processed.clear(); + self.loop_var_values.clear(); + self.scopes = vec![Scope::new()]; + self.contexts = vec![]; + } + fn current_module(&self) -> Result> { self.module .clone() @@ -2039,7 +2062,7 @@ impl Interpreter { } #[inline] - fn make_or_get_value_mut<'a>(obj: &'a mut Value, paths: &[&str]) -> Result<&'a mut Value> { + pub fn make_or_get_value_mut<'a>(obj: &'a mut Value, paths: &[&str]) -> Result<&'a mut Value> { if paths.is_empty() { return Ok(obj); } @@ -2107,7 +2130,10 @@ impl Interpreter { Ok(comps.join(".")) } - fn set_current_module(&mut self, module: Option>) -> Result>> { + pub fn set_current_module( + &mut self, + module: Option>, + ) -> Result>> { let m = self.module.clone(); if let Some(m) = &module { self.current_module_path = Self::get_path_string(&m.package.refr, Some("data"))?; @@ -2172,7 +2198,7 @@ impl Interpreter { Err(span.error(format!("invalid `{kind}` in default value").as_str())) } - fn check_default_rules(&self) -> Result<()> { + pub fn check_default_rules(&self) -> Result<()> { for module in &self.modules { for rule in &module.policy { if let Rule::Default { value, .. } = rule.as_ref() { @@ -2183,7 +2209,7 @@ impl Interpreter { Ok(()) } - fn eval_default_rule(&mut self, rule: &Ref) -> Result<()> { + pub fn eval_default_rule(&mut self, rule: &Ref) -> Result<()> { // Skip reprocessing rule. if self.processed.contains(rule) { return Ok(()); @@ -2264,7 +2290,7 @@ impl Interpreter { } } - fn eval_rule(&mut self, module: &Ref, rule: &Ref) -> Result<()> { + pub fn eval_rule(&mut self, module: &Ref, rule: &Ref) -> Result<()> { // Skip reprocessing rule if self.processed.contains(rule) { return Ok(()); @@ -2357,85 +2383,6 @@ impl Interpreter { } } - pub fn eval_rule_with_input( - &mut self, - module: &Ref, - rule: &Ref, - input: &Option, - enable_tracing: bool, - ) -> Result { - self.checks_for_eval(input, enable_tracing)?; - self.clean_internal_evaluation_state(); - - self.eval_rule(module, rule)?; - - Ok(self.data.clone()) - } - - pub fn prepare_for_eval( - &mut self, - schedule: Option, - data: &Option, - ) -> Result<()> { - self.schedule = schedule; - self.builtins_cache.clear(); - - if let Some(data) = data { - self.data = data.clone(); - self.init_data = data.clone(); - } - - self.functions = gather_functions(&self.modules)?; - - self.gather_rules()?; - self.prepared = true; - - Ok(()) - } - - pub fn eval_modules(&mut self, input: &Option, enable_tracing: bool) -> Result { - self.checks_for_eval(input, enable_tracing)?; - self.clean_internal_evaluation_state(); - - // Ensure that each module has an empty object - for m in &self.modules { - let path = Parser::get_path_ref_components(&m.package.refr)?; - let path: Vec<&str> = path.iter().map(|s| *s.text()).collect(); - let vref = Self::make_or_get_value_mut(&mut self.data, &path[..])?; - if *vref == Value::Undefined { - *vref = Value::new_object(); - } - } - - self.check_default_rules()?; - for module in self.modules.clone() { - for rule in &module.policy { - self.eval_rule(&module, rule)?; - } - } - // Defer the evaluation of the default rules to here - for module in self.modules.clone() { - let prev_module = self.set_current_module(Some(module.clone()))?; - for rule in &module.policy { - self.eval_default_rule(rule)?; - } - self.set_current_module(prev_module)?; - } - - Ok(self.data.clone()) - } - - pub fn eval( - &mut self, - data: &Option, - input: &Option, - enable_tracing: bool, - schedule: Option, - ) -> Result { - self.prepare_for_eval(schedule, data)?; - self.eval_modules(input, enable_tracing) - } - pub fn eval_user_query( &mut self, query: &Ref, @@ -2503,7 +2450,7 @@ impl Interpreter { } } - fn gather_rules(&mut self) -> Result<()> { + pub fn gather_rules(&mut self) -> Result<()> { for module in self.modules.clone() { let prev_module = self.set_current_module(Some(module.clone()))?; for rule in &module.policy { diff --git a/src/scheduler.rs b/src/scheduler.rs index 7f70c4cd..0875bfd4 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -58,8 +58,7 @@ pub fn schedule( } // Order of execution for statements. - let mut order = vec![]; - order.reserve(infos.len()); + let mut order = Vec::with_capacity(infos.len()); // Keep track of whether a var has been defined or not. let mut defined_vars = BTreeSet::new(); diff --git a/tests/interpreter/mod.rs b/tests/interpreter/mod.rs index bf108382..57b6a7f4 100644 --- a/tests/interpreter/mod.rs +++ b/tests/interpreter/mod.rs @@ -192,10 +192,10 @@ pub fn eval_file( query: &str, enable_tracing: bool, ) -> Result> { + let mut engine: Engine = engine::Engine::new(); + let mut results = vec![]; let mut files = vec![]; - let mut sources = vec![]; - let mut modules = vec![]; for (idx, _) in regos.iter().enumerate() { files.push(format!("rego_{idx}")); @@ -203,35 +203,15 @@ pub fn eval_file( for (idx, file) in files.iter().enumerate() { let contents = regos[idx].as_str(); - sources.push(Source::new(file.to_string(), contents.to_string())); + engine.add_policy(file.to_string(), contents.to_string())?; } - for source in &sources { - let mut parser = Parser::new(source)?; - modules.push(Ref::new(parser.parse()?)); + if let Some(data) = data_opt { + engine.add_data(data)?; } - let query_source = regorus::Source::new(" inputs.push(single_input), @@ -239,21 +219,20 @@ pub fn eval_file( } for input in inputs { - interpreter.eval_modules(&Some(input), enable_tracing)?; + engine.set_input(input); + engine.eval_modules(enable_tracing)?; // Now eval the query. push_query_results( - interpreter.eval_user_query(&query_node, &query_schedule, enable_tracing)?, + engine.eval_query(query.to_string(), enable_tracing)?, &mut results, ); } } else { // it no input is defined then one evaluation of all modules is performed - interpreter.eval(&data_opt, &None, enable_tracing, Some(schedule))?; - // Now eval the query. push_query_results( - interpreter.eval_user_query(&query_node, &query_schedule, enable_tracing)?, + engine.eval_query(query.to_string(), enable_tracing)?, &mut results, ); } diff --git a/tests/opa.rs b/tests/opa.rs index 3d8e2f95..8ea2a6da 100644 --- a/tests/opa.rs +++ b/tests/opa.rs @@ -136,7 +136,7 @@ fn run_opa_tests(opa_tests_dir: String, folders: &[String]) -> Result<()> { let path = Path::new("target/opa/failures").join(path_dir); std::fs::create_dir_all(path.clone())?; - let mut cmd = "cargo run --example dregorus eval".to_string(); + let mut cmd = "cargo run --example regorus eval".to_string(); if let Some(data) = &case.data { let json_path = path.join(format!("data{n}.json")); cmd += format!(" -d {}", json_path.display()).as_str();