diff --git a/scripts/yaml-test-eval b/scripts/yaml-test-eval index 7bc185c2..5f6f9435 100755 --- a/scripts/yaml-test-eval +++ b/scripts/yaml-test-eval @@ -5,4 +5,4 @@ set -e yaml=$(realpath -e $1) -RUST_BACKTRACE=1 cargo test interpreter::one_yaml -- --include-ignored --nocapture "$yaml" +RUST_BACKTRACE=1 cargo test interpreter::one_yaml -- --include-ignored --nocapture "$yaml" $2 diff --git a/src/builtins/aggregates.rs b/src/builtins/aggregates.rs index 65422dac..98e9ab07 100644 --- a/src/builtins/aggregates.rs +++ b/src/builtins/aggregates.rs @@ -12,12 +12,12 @@ use std::collections::HashMap; use anyhow::{bail, Result}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("count", count); - m.insert("max", max); - m.insert("min", min); - m.insert("product", product); - m.insert("sort", sort); - m.insert("sum", sum); + m.insert("count", (count, 1)); + m.insert("max", (max, 1)); + m.insert("min", (min, 1)); + m.insert("product", (product, 1)); + m.insert("sort", (sort, 1)); + m.insert("sum", (sum, 1)); } fn count(span: &Span, params: &[Expr], args: &[Value]) -> Result { @@ -57,9 +57,9 @@ fn min(span: &Span, params: &[Expr], args: &[Value]) -> Result { Ok(match &args[0] { Value::Array(a) if a.is_empty() => Value::Undefined, - Value::Array(a) => a.iter().max().unwrap().clone(), + Value::Array(a) => a.iter().min().unwrap().clone(), Value::Set(a) if a.is_empty() => Value::Undefined, - Value::Set(a) => a.iter().max().unwrap().clone(), + Value::Set(a) => a.iter().min().unwrap().clone(), a => { let span = params[0].span(); bail!(span.error(format!("`min` requires array/set argument. Got `{a}`.").as_str())) @@ -68,7 +68,7 @@ fn min(span: &Span, params: &[Expr], args: &[Value]) -> Result { } fn product(span: &Span, params: &[Expr], args: &[Value]) -> Result { - ensure_args_count(span, "min", params, args, 1)?; + ensure_args_count(span, "product", params, args, 1)?; let mut v = 1 as Float; Ok(match &args[0] { diff --git a/src/builtins/arrays.rs b/src/builtins/arrays.rs index 4984a891..d9a8294b 100644 --- a/src/builtins/arrays.rs +++ b/src/builtins/arrays.rs @@ -13,9 +13,9 @@ use anyhow::Result; use std::rc::Rc; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("array.concat", concat); - m.insert("array.reverse", reverse); - m.insert("array.slice", slice); + m.insert("array.concat", (concat, 2)); + m.insert("array.reverse", (reverse, 1)); + m.insert("array.slice", (slice, 3)); } fn concat(span: &Span, params: &[Expr], args: &[Value]) -> Result { diff --git a/src/builtins/bitwise.rs b/src/builtins/bitwise.rs index ef1efcf9..c866e404 100644 --- a/src/builtins/bitwise.rs +++ b/src/builtins/bitwise.rs @@ -13,12 +13,12 @@ use std::collections::HashMap; use anyhow::Result; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("bits.and", and); - m.insert("bits.lsh", lsh); - m.insert("bits.negate", negate); - m.insert("bits.or", or); - m.insert("bits.rsh", rsh); - m.insert("bits.xor", xor); + m.insert("bits.and", (and, 2)); + m.insert("bits.lsh", (lsh, 2)); + m.insert("bits.negate", (negate, 1)); + m.insert("bits.or", (or, 2)); + m.insert("bits.rsh", (rsh, 2)); + m.insert("bits.xor", (xor, 2)); } fn and(span: &Span, params: &[Expr], args: &[Value]) -> Result { diff --git a/src/builtins/conversions.rs b/src/builtins/conversions.rs index 58e936c3..88aac5fd 100644 --- a/src/builtins/conversions.rs +++ b/src/builtins/conversions.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; use anyhow::{bail, Result}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("to_number", to_number); + m.insert("to_number", (to_number, 1)); } fn to_number(span: &Span, params: &[Expr], args: &[Value]) -> Result { diff --git a/src/builtins/debugging.rs b/src/builtins/debugging.rs index 34d06243..3c76fcfc 100644 --- a/src/builtins/debugging.rs +++ b/src/builtins/debugging.rs @@ -8,16 +8,23 @@ use crate::value::Value; use std::collections::HashMap; -use anyhow::Result; +use anyhow::{bail, Result}; + +// TODO: Should we avoid this limit? +const MAX_ARGS: u8 = std::u8::MAX; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("print", print); + m.insert("print", (print, MAX_ARGS)); } // Symbol analyzer must ensure that vars used by print are defined before // the print statement. Scheduler must ensure the above constraint. // Additionally interpreter must allow undefined inputs to print. fn print(span: &Span, _params: &[Expr], args: &[Value]) -> Result { + if args.len() > MAX_ARGS as usize { + bail!(span.error("print supports up to 100 arguments")); + } + let mut msg = String::default(); for a in args { match a { diff --git a/src/builtins/mod.rs b/src/builtins/mod.rs index e3289700..b040aad4 100644 --- a/src/builtins/mod.rs +++ b/src/builtins/mod.rs @@ -24,7 +24,7 @@ use std::collections::HashMap; use anyhow::Result; use lazy_static::lazy_static; -pub type BuiltinFcn = fn(&Span, &[Expr], &[Value]) -> Result; +pub type BuiltinFcn = (fn(&Span, &[Expr], &[Value]) -> Result, u8); #[rustfmt::skip] lazy_static! { diff --git a/src/builtins/numbers.rs b/src/builtins/numbers.rs index 08eff660..7f65c770 100644 --- a/src/builtins/numbers.rs +++ b/src/builtins/numbers.rs @@ -13,12 +13,12 @@ use anyhow::Result; use rand::{thread_rng, Rng}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("abs", abs); - m.insert("ceil", ceil); - m.insert("floor", floor); - m.insert("numbers.range", range); - m.insert("rand.intn", intn); - m.insert("round", round); + m.insert("abs", (abs, 1)); + m.insert("ceil", (ceil, 1)); + m.insert("floor", (floor, 1)); + m.insert("numbers.range", (range, 2)); + m.insert("rand.intn", (intn, 2)); + m.insert("round", (round, 1)); } pub fn arithmetic_operation( diff --git a/src/builtins/objects.rs b/src/builtins/objects.rs index 3bb89669..f88aadcc 100644 --- a/src/builtins/objects.rs +++ b/src/builtins/objects.rs @@ -14,12 +14,12 @@ use std::rc::Rc; use anyhow::{bail, Result}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("json.filter", json_filter); - // m.insert("json.patch", json_patch); - m.insert("object.filter", filter); - m.insert("object.get", get); - m.insert("object.keys", keys); - m.insert("object.remove", remove); + m.insert("json.filter", (json_filter, 2)); + // m.insert("json.patch", (json_patch)); + m.insert("object.filter", (filter, 2)); + m.insert("object.get", (get, 3)); + m.insert("object.keys", (keys, 1)); + m.insert("object.remove", (remove, 2)); } fn json_filter_impl(v: &Value, filter: &Value) -> Value { diff --git a/src/builtins/sets.rs b/src/builtins/sets.rs index de5e65be..d62fd448 100644 --- a/src/builtins/sets.rs +++ b/src/builtins/sets.rs @@ -12,8 +12,8 @@ use std::collections::{BTreeSet, HashMap}; use anyhow::{bail, Result}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("intersection", intersection_of_set_of_sets); - m.insert("union", union_of_set_of_sets); + m.insert("intersection", (intersection_of_set_of_sets, 1)); + m.insert("union", (union_of_set_of_sets, 1)); } pub fn intersection(expr1: &Expr, expr2: &Expr, v1: Value, v2: Value) -> Result { diff --git a/src/builtins/strings.rs b/src/builtins/strings.rs index 446698bb..0392768c 100644 --- a/src/builtins/strings.rs +++ b/src/builtins/strings.rs @@ -15,29 +15,29 @@ use std::collections::HashMap; use anyhow::{bail, Result}; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("concat", concat); - m.insert("contains", contains); - m.insert("endswith", endswith); - m.insert("format_int", format_int); - m.insert("indexof", indexof); - m.insert("indexof_n", indexof_n); - m.insert("lower", lower); - m.insert("replace", replace); - m.insert("split", split); - m.insert("sprintf", sprintf); - m.insert("startswith", startswith); - m.insert("strings.any_prefix_match", any_prefix_match); - m.insert("strings.any_suffix_match", any_suffix_match); - m.insert("strings.replace_n", replace_n); - m.insert("strings.reverse", reverse); - m.insert("strings.substring", substring); - m.insert("trim", trim); - m.insert("trim_left", trim_left); - m.insert("trim_prefix", trim_prefix); - m.insert("trim_right", trim_right); - m.insert("trim_space", trim_space); - m.insert("trim_suffix", trim_suffix); - m.insert("upper", upper); + m.insert("concat", (concat, 2)); + m.insert("contains", (contains, 2)); + m.insert("endswith", (endswith, 2)); + m.insert("format_int", (format_int, 2)); + m.insert("indexof", (indexof, 2)); + m.insert("indexof_n", (indexof_n, 2)); + m.insert("lower", (lower, 1)); + m.insert("replace", (replace, 3)); + m.insert("split", (split, 2)); + m.insert("sprintf", (sprintf, 2)); + m.insert("startswith", (startswith, 2)); + m.insert("strings.any_prefix_match", (any_prefix_match, 2)); + m.insert("strings.any_suffix_match", (any_suffix_match, 2)); + m.insert("strings.replace_n", (replace_n, 2)); + m.insert("strings.reverse", (reverse, 1)); + m.insert("strings.substring", (substring, 3)); + m.insert("trim", (trim, 2)); + m.insert("trim_left", (trim_left, 2)); + m.insert("trim_prefix", (trim_prefix, 2)); + m.insert("trim_right", (trim_right, 2)); + m.insert("trim_space", (trim_space, 1)); + m.insert("trim_suffix", (trim_suffix, 2)); + m.insert("upper", (upper, 1)); } fn concat(span: &Span, params: &[Expr], args: &[Value]) -> Result { diff --git a/src/builtins/tracing.rs b/src/builtins/tracing.rs index 7f68fb2b..d63cc302 100644 --- a/src/builtins/tracing.rs +++ b/src/builtins/tracing.rs @@ -12,7 +12,7 @@ use std::collections::HashMap; use anyhow::Result; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("trace", trace); + m.insert("trace", (trace, 1)); } // Symbol analyzer must ensure that vars used by trace are defined before diff --git a/src/builtins/types.rs b/src/builtins/types.rs index 20a79ff3..9b70f827 100644 --- a/src/builtins/types.rs +++ b/src/builtins/types.rs @@ -12,14 +12,14 @@ use std::collections::HashMap; use anyhow::Result; pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) { - m.insert("is_array", is_array); - m.insert("is_boolean", is_boolean); - m.insert("is_null", is_null); - m.insert("is_number", is_number); - m.insert("is_object", is_object); - m.insert("is_set", is_set); - m.insert("is_string", is_string); - m.insert("type_name", type_name); + m.insert("is_array", (is_array, 1)); + m.insert("is_boolean", (is_boolean, 1)); + m.insert("is_null", (is_null, 1)); + m.insert("is_number", (is_number, 1)); + m.insert("is_object", (is_object, 1)); + m.insert("is_set", (is_set, 1)); + m.insert("is_string", (is_string, 1)); + m.insert("type_name", (type_name, 1)); } fn is_array(span: &Span, params: &[Expr], args: &[Value]) -> Result { diff --git a/src/interpreter.rs b/src/interpreter.rs index 64497c6a..236773fb 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -6,6 +6,7 @@ use crate::builtins; use crate::lexer::Span; use crate::parser::Parser; use crate::scheduler::*; +use crate::utils::*; use crate::value::*; use anyhow::{anyhow, bail, Result}; @@ -164,6 +165,8 @@ impl<'source> Interpreter<'source> { if let Some(variable) = self.current_scope_mut()?.get_mut(name) { *variable = value.clone(); Ok(()) + } else if name == "_" { + Ok(()) } else { bail!("variable {} is undefined", name) } @@ -783,7 +786,17 @@ impl<'source> Interpreter<'source> { let r = Ok(match &stmt.literal { Literal::Expr { expr, .. } => { - let value = self.eval_expr(expr)?; + let value = match expr { + Expr::Call { span, fcn, params } => self.eval_call( + span, + fcn, + params, + get_extra_arg(expr, &HashMap::new()), + true, + )?, + _ => self.eval_expr(expr)?, + }; + if let Value::Bool(bool) = value { bool } else { @@ -793,6 +806,21 @@ impl<'source> Interpreter<'source> { value != Value::Undefined } } + Literal::NotExpr { expr, .. } => { + let value = match expr { + // Extra parameter is allowed; but a return argument is not allowed. + Expr::Call { span, fcn, params } => self.eval_call( + span, + fcn, + params, + get_extra_arg(expr, &HashMap::new()), + false, + )?, + _ => self.eval_expr(expr)?, + }; + // https://github.com/open-policy-agent/opa/issues/1622#issuecomment-520547385 + matches!(value, Value::Bool(false) | Value::Undefined) + } Literal::SomeVars { vars, .. } => { for var in vars { let name = var.text(); @@ -813,10 +841,6 @@ impl<'source> Interpreter<'source> { value, collection, } => self.eval_some_in(span, key, value, collection, stmts)?, - Literal::NotExpr { expr, .. } => { - // https://github.com/open-policy-agent/opa/issues/1622#issuecomment-520547385 - matches!(self.eval_expr(expr)?, Value::Bool(false) | Value::Undefined) - } Literal::Every { span, key, @@ -1278,7 +1302,7 @@ impl<'source> Interpreter<'source> { span: &'source Span<'source>, name: String, builtin: builtins::BuiltinFcn, - params: &'source Vec>, + params: &'source [Expr<'source>], ) -> Result { let mut args = vec![]; let allow_undefined = name == "print"; // TODO: with modifier @@ -1297,7 +1321,7 @@ impl<'source> Interpreter<'source> { } } - let v = builtin(span, ¶ms[..], &args[..])?; + let v = builtin.0(span, params, &args[..])?; // Handle trace function. // TODO: with modifier. @@ -1312,11 +1336,11 @@ impl<'source> Interpreter<'source> { Ok(v) } - fn eval_call( + fn eval_call_impl( &mut self, span: &'source Span<'source>, fcn: &'source Expr<'source>, - params: &'source Vec>, + params: &'source [Expr<'source>], ) -> Result { let fcns_rules = match self.lookup_function(fcn) { Ok(r) => r, @@ -1438,6 +1462,34 @@ impl<'source> Interpreter<'source> { Ok(results[0].clone()) } + fn eval_call( + &mut self, + span: &'source Span<'source>, + fcn: &'source Expr<'source>, + params: &'source Vec>, + extra_arg: Option<&'source Expr<'source>>, + allow_return_arg: bool, + ) -> Result { + // TODO: global var check; interop with `some var` + match extra_arg { + Some(Expr::Var(var)) + if allow_return_arg && self.lookup_local_var(var.text()).is_none() => + { + let value = self.eval_call_impl(span, fcn, ¶ms[..params.len() - 1])?; + if var.text() != "_" { + self.add_variable(var.text(), value)?; + } + Ok(Value::Bool(true)) + } + Some(expr) => { + let ret_value = self.eval_call_impl(span, fcn, ¶ms[..params.len() - 1])?; + let value = self.eval_expr(expr)?; + Ok(Value::Bool(ret_value == value)) + } + None => self.eval_call_impl(span, fcn, params), + } + } + fn lookup_local_var(&self, name: &str) -> Option { // Lookup local variables and arguments. for scope in self.scopes.iter().rev() { @@ -1566,7 +1618,7 @@ impl<'source> Interpreter<'source> { } => self.eval_object_compr(key, value, query), Expr::SetCompr { term, query, .. } => self.eval_set_compr(term, query), Expr::UnaryExpr { .. } => unimplemented!("unar expr is umplemented"), - Expr::Call { span, fcn, params } => self.eval_call(span, fcn, params), + Expr::Call { span, fcn, params } => self.eval_call(span, fcn, params, None, false), } } @@ -2150,8 +2202,21 @@ impl<'source> Interpreter<'source> { let prev_module = self.set_current_module(self.modules.last().copied())?; let value = self.eval_expr(snippet)?; // Pop the scope. + let scope = self.scopes.pop(); - let r = match snippet { + let r = match scope { + Some(scope) if !scope.is_empty() => { + let mut r = Value::new_object(); + let map = r.as_object_mut()?; + // Capture each binding. + for (name, v) in scope { + map.insert(Value::String(name), v); + } + Ok(r) + } + _ => Ok(value), + }; + /* let r = match snippet { Expr::AssignExpr { .. } => { if let Some(scope) = scope { let mut r = Value::new_object(); @@ -2166,7 +2231,7 @@ impl<'source> Interpreter<'source> { } } _ => Ok(value), - }; + };*/ self.set_current_module(prev_module)?; r } diff --git a/src/lib.rs b/src/lib.rs index a86cc639..f3787594 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,6 +7,7 @@ pub mod interpreter; pub mod lexer; pub mod parser; pub mod scheduler; +mod utils; pub mod value; pub use ast::*; diff --git a/src/parser.rs b/src/parser.rs index 8e127c37..9fecddc5 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -738,7 +738,7 @@ impl<'source> Parser<'source> { Ok(expr) } - fn parse_assign_expr(&mut self) -> Result> { + pub fn parse_assign_expr(&mut self) -> Result> { let state = self.clone(); let start = self.tok.1.start; let expr = self.parse_ref()?; diff --git a/src/scheduler.rs b/src/scheduler.rs index 1415b194..d8340d2e 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -3,10 +3,10 @@ use crate::ast::Expr::*; use crate::ast::*; -use crate::interpreter::Interpreter; use crate::lexer::Span; +use crate::utils; -use std::collections::{BTreeMap, BTreeSet, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashMap, VecDeque}; use std::string::String; use anyhow::{bail, Result}; @@ -187,7 +187,7 @@ pub fn schedule<'a>(infos: &mut [StmtInfo<'a>]) -> Result { } if order.len() != num_statements { - bail!("could not schedule all statements {order:?} {num_statements}"); + bail!("could not schedule all statements {order:?} {num_statements} {tmp:?}"); } // TODO: determine cycles. @@ -405,7 +405,7 @@ impl<'a> Analyzer<'a> { pub fn analyze(mut self, modules: &'a [Module<'a>]) -> Result { for m in modules { - let path = Interpreter::get_path_string(&m.package.refr, Some("data"))?; + let path = utils::get_path_string(&m.package.refr, Some("data"))?; let scope: &mut Scope = self.packages.entry(path).or_default(); for r in &m.policy { let var = match r { @@ -433,7 +433,7 @@ impl<'a> Analyzer<'a> { } fn analyze_module(&mut self, m: &'a Module<'a>) -> Result<()> { - let path = Interpreter::get_path_string(&m.package.refr, Some("data"))?; + let path = utils::get_path_string(&m.package.refr, Some("data"))?; let scope = match self.packages.get(&path) { Some(s) => s, _ => bail!("internal error: package scope missing"), @@ -932,7 +932,36 @@ impl<'a> Analyzer<'a> { definitions.push(Definition { var: "", used_vars }); // TODO: vars in compr } - Literal::Expr { expr, .. } | Literal::NotExpr { expr, .. } => { + Literal::Expr { expr, .. } => { + if let Some(Expr::Var(return_arg)) = utils::get_extra_arg(expr, &HashMap::new()) + { + let (mut used_vars, comprs) = Self::gather_used_vars_comprs_index_vars( + expr, + &mut scope, + &mut first_use, + &mut definitions, + )?; + let var = if return_arg.text() != "_" { + // The var in the return argument slot would have been processed as + // an used var. Remove it from used vars and add it as the variable being + // defined. + used_vars.pop(); + return_arg.text() + } else { + "" + }; + self.process_comprs( + &comprs[..], + &mut scope, + &mut first_use, + &mut used_vars, + )?; + definitions.push(Definition { var, used_vars }); + } else { + self.process_expr(expr, &mut scope, &mut first_use, &mut definitions)?; + } + } + Literal::NotExpr { expr, .. } => { self.process_expr(expr, &mut scope, &mut first_use, &mut definitions)?; } Literal::Every { diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 00000000..a554d4ee --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use crate::ast::*; +use crate::builtins::*; + +use std::collections::HashMap; + +use anyhow::{bail, Result}; + +pub fn get_path_string(refr: &Expr, document: Option<&str>) -> Result { + let mut comps = vec![]; + let mut expr = Some(refr); + while expr.is_some() { + match expr { + Some(Expr::RefDot { refr, field, .. }) => { + comps.push(field.text()); + expr = Some(refr); + } + Some(Expr::RefBrack { refr, index, .. }) + if matches!(index.as_ref(), Expr::String(_)) => + { + if let Expr::String(s) = index.as_ref() { + comps.push(s.text()); + expr = Some(refr); + } + } + Some(Expr::Var(v)) => { + comps.push(v.text()); + expr = None; + } + _ => bail!("internal error: not a simple ref"), + } + } + if let Some(d) = document { + comps.push(d); + }; + comps.reverse(); + Ok(comps.join(".")) +} + +pub fn get_extra_arg<'a>(expr: &'a Expr, arities: &HashMap) -> Option<&'a Expr<'a>> { + if let Expr::Call { fcn, params, .. } = expr { + if let Ok(path) = get_path_string(fcn, None) { + let n_args = if let Some(n_args) = arities.get(&path) { + *n_args + } else if let Some((_, n_args)) = BUILTINS.get(path.as_str()) { + *n_args + } else { + return None; + }; + if n_args as usize == params.len() - 1 { + return params.last(); + } + } + } + + None +} diff --git a/tests/interpreter/cases/builtins/aggregates/min.yaml b/tests/interpreter/cases/builtins/aggregates/min.yaml index 20a657d5..01557612 100644 --- a/tests/interpreter/cases/builtins/aggregates/min.yaml +++ b/tests/interpreter/cases/builtins/aggregates/min.yaml @@ -19,7 +19,7 @@ cases: u3 = min(set()) query: data.test want_result: - x: [ 3, 2 ] + x: [ -1, -1 ] - note: invalid-null data: {} diff --git a/tests/interpreter/cases/call/oldstyle.yaml b/tests/interpreter/cases/call/oldstyle.yaml new file mode 100644 index 00000000..a7a998af --- /dev/null +++ b/tests/interpreter/cases/call/oldstyle.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +cases: + - note: basic + data: {} + modules: + - | + package test + + a = [1, 2, 3, 4] + + p = x { + count(a, x) + } + + query: data.test.p = x + want_result: + x: 4 diff --git a/tests/interpreter/cases/compr/mod.rs b/tests/interpreter/cases/compr/mod.rs index f7cadbc5..6e0ba091 100644 --- a/tests/interpreter/cases/compr/mod.rs +++ b/tests/interpreter/cases/compr/mod.rs @@ -32,7 +32,7 @@ fn basic_array() -> Result<()> { array_compr_7 = [ 1 | [1, 2, 3][_]; [1, 2][_] >= 2 ] "#; - let expected = vec![Value::from_json_str( + let expected = [Value::from_json_str( r#" { "array": [1, 2, 3], "array_compr_0": [1], @@ -46,11 +46,10 @@ fn basic_array() -> Result<()> { }"#, )?]; - assert_match( - eval_file(&[rego.to_owned()], None, None, "data.test", false)?, - expected, - ); - Ok(()) + check_output( + &eval_file(&[rego.to_owned()], None, None, "data.test", false)?, + &expected, + ) } #[test] @@ -81,7 +80,7 @@ fn basic_set() -> Result<()> { set_compr_7 = { a | a = [1, 2, 3][_]; [1, 2][_] >= 2 } "#; - let expected = vec![Value::from_json_str( + let expected = [Value::from_json_str( r#" { "set": { "set!": [1, "string", [2, 3, 4], 567, false] @@ -115,9 +114,8 @@ fn basic_set() -> Result<()> { }"#, )?]; - assert_match( - eval_file(&[rego.to_owned()], None, None, "data.test", false)?, - expected, - ); - Ok(()) + check_output( + &eval_file(&[rego.to_owned()], None, None, "data.test", false)?, + &expected, + ) } diff --git a/tests/interpreter/cases/in/mod.rs b/tests/interpreter/cases/in/mod.rs index 1ac45aea..7be01a87 100644 --- a/tests/interpreter/cases/in/mod.rs +++ b/tests/interpreter/cases/in/mod.rs @@ -99,7 +99,7 @@ fn basic() -> Result<()> { } "#; - let expected = vec![Value::from_json_str( + let expected = [Value::from_json_str( r#" { "array": [1, 2, 3], "in_array_key_value": true, @@ -119,9 +119,8 @@ fn basic() -> Result<()> { }"#, )?]; - assert_match( - eval_file(&[rego.to_owned()], None, None, "data.test", false)?, - expected, - ); - Ok(()) + check_output( + &eval_file(&[rego.to_owned()], None, None, "data.test", false)?, + &expected, + ) } diff --git a/tests/interpreter/cases/input/mod.rs b/tests/interpreter/cases/input/mod.rs index f1e6943f..c8efccfb 100644 --- a/tests/interpreter/cases/input/mod.rs +++ b/tests/interpreter/cases/input/mod.rs @@ -25,7 +25,7 @@ fn basic() -> Result<()> { Value::from_json_str(r#"{"x": 6}"#)?, ]); - let expected = vec![ + let expected = [ Value::from_json_str( r#" { "y": {"set!": [6]}, @@ -40,9 +40,8 @@ fn basic() -> Result<()> { )?, ]; - assert_match( - eval_file_first_rule(&[rego.to_owned()], None, Some(input), "data.test", false)?, - expected, - ); - Ok(()) + check_output( + &eval_file_first_rule(&[rego.to_owned()], None, Some(input), "data.test", false)?, + &expected, + ) } diff --git a/tests/interpreter/cases/variables/mod.rs b/tests/interpreter/cases/variables/mod.rs index 592d9a3f..35d34acf 100644 --- a/tests/interpreter/cases/variables/mod.rs +++ b/tests/interpreter/cases/variables/mod.rs @@ -31,7 +31,7 @@ fn basic() -> Result<()> { set = {1, 2, 3} "#; - let expected = vec![Value::from_json_str( + let expected = [Value::from_json_str( r#" { "array": [1, 2, 3], "nested_array": [1, [2, 3, 4], 5, 6], @@ -46,9 +46,8 @@ fn basic() -> Result<()> { }"#, )?]; - assert_match( - eval_file(&[rego.to_owned()], None, None, "data.test", false)?, - expected, - ); - Ok(()) + check_output( + &eval_file(&[rego.to_owned()], None, None, "data.test", false)?, + &expected, + ) } diff --git a/tests/interpreter/mod.rs b/tests/interpreter/mod.rs index 8638c138..72201ac9 100644 --- a/tests/interpreter/mod.rs +++ b/tests/interpreter/mod.rs @@ -4,12 +4,13 @@ #![cfg(test)] use std::env; +use std::path::Path; use anyhow::{bail, Result}; use regorus::*; use serde::{ser::SerializeMap, Deserialize, Deserializer, Serialize, Serializer}; use test_generator::test_resources; -//use walkdir::WalkDir; +use walkdir::WalkDir; mod cases; @@ -154,28 +155,30 @@ fn match_values(computed: &Value, expected: &Value) -> Result<()> { } } -pub fn assert_match(computed_results: Vec, expected_results: Vec) { +pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> Result<()> { if computed_results.len() != expected_results.len() { - panic!( + bail!( "the number of computed results ({}) and expected results ({}) is not equal", computed_results.len(), expected_results.len() ); } - for (n, expected_result) in expected_results.into_iter().enumerate() { - let expected = match process_value(&expected_result) { + for (n, expected_result) in expected_results.iter().enumerate() { + let expected = match process_value(expected_result) { Ok(e) => e, - _ => panic!("unable to process value :\n {expected_result:?}"), + _ => bail!("unable to process value :\n {expected_result:?}"), }; if let Some(computed_result) = computed_results.get(n) { match match_values(computed_result, &expected) { Ok(()) => (), - Err(e) => panic!("{}", e), + Err(e) => bail!("{e}"), } } } + + Ok(()) } pub fn eval_file_first_rule( @@ -198,7 +201,7 @@ pub fn eval_file_first_rule( lines: query.split('\n').collect(), }; let mut parser = Parser::new(&source)?; - let expr = parser.parse_membership_expr()?; + let expr = parser.parse_assign_expr()?; for (idx, _) in regos.iter().enumerate() { files.push(format!("rego_{idx}")); @@ -278,7 +281,7 @@ pub fn eval_file( lines: query.split('\n').collect(), }; let mut parser = Parser::new(&source)?; - let expr = parser.parse_membership_expr()?; + let expr = parser.parse_assign_expr()?; for (idx, _) in regos.iter().enumerate() { files.push(format!("rego_{idx}")); @@ -435,7 +438,7 @@ struct YamlTest { cases: Vec, } -fn yaml_test_impl(file: &str) -> Result<()> { +fn yaml_test_impl(file: &str, is_opa_test: bool) -> Result<()> { let yaml_str = std::fs::read_to_string(file)?; let test: YamlTest = serde_yaml::from_str(&yaml_str)?; @@ -472,7 +475,14 @@ fn yaml_test_impl(file: &str) -> Result<()> { } } - assert_match(results, expected_results); + if is_opa_test { + // Convert value to json compatible representation. + let results = + Value::from_json_str(serde_json::to_string(&results)?.as_str())?; + match_values(&results, &expected_results[0])?; + } else { + check_output(&results, &expected_results)?; + } } _ => panic!("eval succeeded and did not produce any errors"), }, @@ -498,8 +508,8 @@ fn yaml_test_impl(file: &str) -> Result<()> { Ok(()) } -fn yaml_test(file: &str) -> Result<()> { - match yaml_test_impl(file) { +fn yaml_test(file: &str, is_opa_test: bool) -> Result<()> { + match yaml_test_impl(file, is_opa_test) { Ok(_) => Ok(()), Err(e) => { // If Err is returned, it doesn't always get printed by cargo test. @@ -511,28 +521,65 @@ fn yaml_test(file: &str) -> Result<()> { #[test] fn yaml_test_basic() -> Result<()> { - yaml_test("tests/interpreter/cases/basic_001.yaml") + yaml_test("tests/interpreter/cases/basic_001.yaml", false) } #[test] #[ignore = "intended for use by scripts/yaml-test-eval"] fn one_yaml() -> Result<()> { let mut file = String::default(); + let mut is_opa_test = false; + for a in env::args() { if a.ends_with(".yaml") { file = a; - break; + } else if a == "opa-test" { + is_opa_test = true; } } if file.is_empty() { - bail!("missing "); + bail!("missing "); } - yaml_test(file.as_str()) + yaml_test(file.as_str(), is_opa_test) } #[test_resources("tests/interpreter/**/*.yaml")] fn run(path: &str) { - yaml_test(path).unwrap() + yaml_test(path, false).unwrap() +} + +#[test] +#[ignore = "intended for running opa test suite"] +fn run_opa_tests() -> Result<()> { + let mut failures = vec![]; + for a in env::args() { + if !Path::new(&a).is_dir() { + continue; + } + for entry in WalkDir::new(a) + .sort_by_file_name() + .into_iter() + .filter_map(|e| e.ok()) + { + let path = entry.path().to_string_lossy().to_string(); + if Path::new(&path).is_dir() { + continue; + } + let yaml = path; + match yaml_test_impl(yaml.as_str(), true) { + Ok(_) => (), + Err(e) => { + failures.push((yaml, e)); + } + } + } + } + + if !failures.is_empty() { + dbg!(failures); + panic!("failed"); + } + Ok(()) }