diff --git a/Cargo.toml b/Cargo.toml index 6d5f3240..f8bab903 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,8 +18,9 @@ keywords = ["interpreter", "opa", "policy-as-code", "rego"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["full-opa"] +default = ["full-opa", "arc"] +arc = ["scientific/arc"] base64 = ["dep:data-encoding"] base64url = ["dep:data-encoding"] crypto = ["dep:constant_time_eq", "dep:hmac", "dep:hex", "dep:md-5", "dep:sha1", "dep:sha2"] diff --git a/src/ast.rs b/src/ast.rs index 40cb219b..4b425ba8 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -2,6 +2,7 @@ // Licensed under the MIT License. use crate::lexer::*; +use crate::Rc; use std::ops::Deref; @@ -37,7 +38,7 @@ pub enum AssignOp { } pub struct NodeRef { - r: std::rc::Rc, + r: Rc, } impl Clone for NodeRef { @@ -54,7 +55,7 @@ impl std::fmt::Debug for NodeRef { impl std::cmp::PartialEq for NodeRef { fn eq(&self, other: &Self) -> bool { - std::rc::Rc::as_ptr(&self.r).eq(&std::rc::Rc::as_ptr(&other.r)) + Rc::as_ptr(&self.r).eq(&Rc::as_ptr(&other.r)) } } @@ -62,7 +63,7 @@ impl std::cmp::Eq for NodeRef {} impl std::cmp::Ord for NodeRef { fn cmp(&self, other: &Self) -> std::cmp::Ordering { - std::rc::Rc::as_ptr(&self.r).cmp(&std::rc::Rc::as_ptr(&other.r)) + Rc::as_ptr(&self.r).cmp(&Rc::as_ptr(&other.r)) } } @@ -88,9 +89,7 @@ impl AsRef for NodeRef { impl NodeRef { pub fn new(t: T) -> Self { - Self { - r: std::rc::Rc::new(t), - } + Self { r: Rc::new(t) } } } diff --git a/src/builtins/arrays.rs b/src/builtins/arrays.rs index e1dea3ca..bd80d7b2 100644 --- a/src/builtins/arrays.rs +++ b/src/builtins/arrays.rs @@ -5,7 +5,8 @@ use crate::ast::{Expr, Ref}; use crate::builtins; use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_numeric}; use crate::lexer::Span; -use crate::value::{Rc, Value}; +use crate::Rc; +use crate::Value; use std::collections::HashMap; diff --git a/src/builtins/objects.rs b/src/builtins/objects.rs index 3e942862..995a4ae6 100644 --- a/src/builtins/objects.rs +++ b/src/builtins/objects.rs @@ -5,7 +5,8 @@ use crate::ast::{Expr, Ref}; use crate::builtins; use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_object}; use crate::lexer::Span; -use crate::value::{Rc, Value}; +use crate::Rc; +use crate::Value; use std::collections::{BTreeMap, BTreeSet, HashMap}; use std::iter::Iterator; diff --git a/src/builtins/utils.rs b/src/builtins/utils.rs index a1200a52..4996e410 100644 --- a/src/builtins/utils.rs +++ b/src/builtins/utils.rs @@ -4,7 +4,8 @@ use crate::ast::{Expr, Ref}; use crate::lexer::Span; use crate::number::Number; -use crate::value::{Rc, Value}; +use crate::Rc; +use crate::Value; use std::collections::{BTreeMap, BTreeSet}; diff --git a/src/engine.rs b/src/engine.rs index f1d64c6e..2de5af94 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -17,7 +17,7 @@ use anyhow::{bail, Result}; /// The Rego evaluation engine. /// -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Engine { modules: Vec>, interpreter: Interpreter, @@ -120,6 +120,11 @@ impl Engine { self.interpreter.set_input(input); } + pub fn set_input_json(&mut self, input_json: &str) -> Result<()> { + self.set_input(Value::from_json_str(input_json)?); + Ok(()) + } + /// Clear the data document. /// /// The data document will be reset to an empty object. @@ -182,6 +187,10 @@ impl Engine { self.interpreter.get_data_mut().merge(data) } + pub fn add_data_json(&mut self, data_json: &str) -> Result<()> { + self.add_data(Value::from_json_str(data_json)?) + } + /// Set whether builtins should raise errors strictly or not. /// /// Regorus differs from OPA in that by default builtins will @@ -256,6 +265,14 @@ impl Engine { ) } + pub fn eval_bool_query(&mut self, query: String, enable_tracing: bool) -> Result { + let results = self.eval_query(query, enable_tracing)?; + if results.result.len() != 1 || results.result[0].expressions.len() != 1 { + bail!("query did not produce exactly one value"); + } + results.result[0].expressions[0].value.as_bool().copied() + } + #[doc(hidden)] fn prepare_for_eval(&mut self, enable_tracing: bool) -> Result<()> { self.interpreter.set_traces(enable_tracing); diff --git a/src/interpreter.rs b/src/interpreter.rs index 9646c231..8bf7cb2d 100644 --- a/src/interpreter.rs +++ b/src/interpreter.rs @@ -9,6 +9,7 @@ use crate::parser::Parser; use crate::scheduler::*; use crate::utils::*; use crate::value::*; +use crate::Rc; use crate::{Expression, Extension, Location, QueryResult, QueryResults}; use anyhow::{anyhow, bail, Result}; @@ -37,7 +38,7 @@ enum FunctionModifier { Value(Value), } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Interpreter { modules: Vec>, module: Option>, @@ -64,7 +65,7 @@ pub struct Interpreter { allow_deprecated: bool, strict_builtin_errors: bool, imports: BTreeMap>, - extensions: HashMap)>, + extensions: HashMap>)>, } impl Default for Interpreter { @@ -2165,7 +2166,7 @@ impl Interpreter { if param_values.len() != *nargs as usize { bail!(span.error("incorrect number of parameters supplied to extension")); } - let r = ext(param_values); + let r = Rc::make_mut(ext)(param_values); // Restore with_functions. if let Some(with_functions) = with_functions_saved { self.with_functions = with_functions; @@ -3458,7 +3459,7 @@ impl Interpreter { extension: Box, ) -> Result<()> { if let std::collections::hash_map::Entry::Vacant(v) = self.extensions.entry(path) { - v.insert((nargs, extension)); + v.insert((nargs, Rc::new(extension))); Ok(()) } else { bail!("extension already added"); diff --git a/src/lexer.rs b/src/lexer.rs index 8d8dc6f4..67aec674 100644 --- a/src/lexer.rs +++ b/src/lexer.rs @@ -8,7 +8,9 @@ use core::str::CharIndices; use std::convert::AsRef; use std::path::Path; -use crate::value::Value; +use crate::Rc; +use crate::Value; + use anyhow::{anyhow, bail, Result}; #[derive(Clone)] @@ -20,7 +22,7 @@ struct SourceInternal { #[derive(Clone)] pub struct Source { - src: std::rc::Rc, + src: Rc, } #[derive(Clone)] @@ -108,7 +110,7 @@ impl Source { lines.push((s, s)); } Self { - src: std::rc::Rc::new(SourceInternal { + src: Rc::new(SourceInternal { file, contents, lines, diff --git a/src/lib.rs b/src/lib.rs index 28082199..3d91ff89 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,12 @@ mod value; pub use engine::Engine; pub use value::Value; +#[cfg(feature = "arc")] +use std::sync::Arc as Rc; + +#[cfg(not(feature = "arc"))] +use std::rc::Rc; + /// Location of an [`Expression`] in a Rego query. /// /// ``` @@ -68,7 +74,7 @@ pub struct Expression { pub value: Value, /// The Rego expression. - pub text: std::rc::Rc, + pub text: Rc, /// Location of the expression in the query string. pub location: Location, @@ -263,7 +269,7 @@ pub struct QueryResults { /// A user defined builtin function implementation. /// /// It is not necessary to implement this trait directly. -pub trait Extension: FnMut(Vec) -> anyhow::Result { +pub trait Extension: FnMut(Vec) -> anyhow::Result + Send + Sync { /// Fn, FnMut etc are not sized and cannot be cloned in their boxed form. /// clone_box exists to overcome that. fn clone_box<'a>(&self) -> Box @@ -274,7 +280,7 @@ pub trait Extension: FnMut(Vec) -> anyhow::Result { /// Automatically make matching closures a valid [`Extension`]. impl Extension for F where - F: FnMut(Vec) -> anyhow::Result + Clone, + F: FnMut(Vec) -> anyhow::Result + Clone + Send + Sync, { fn clone_box<'a>(&self) -> Box where @@ -291,6 +297,12 @@ impl<'a> Clone for Box { } } +impl std::fmt::Debug for dyn Extension { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { + f.write_fmt(format_args!("")) + } +} + /// Items in `unstable` are likely to change. #[doc(hidden)] pub mod unstable { diff --git a/src/number.rs b/src/number.rs index 8b7eac4f..735695dd 100644 --- a/src/number.rs +++ b/src/number.rs @@ -3,7 +3,6 @@ use core::fmt::{Debug, Formatter}; use std::cmp::{Ord, Ordering}; -use std::rc::Rc; use std::str::FromStr; use anyhow::{anyhow, bail, Result}; @@ -11,6 +10,8 @@ use anyhow::{anyhow, bail, Result}; use serde::ser::Serializer; use serde::Serialize; +use crate::Rc; + pub type BigInt = i128; type BigFloat = scientific::Scientific; diff --git a/src/scheduler.rs b/src/scheduler.rs index 77c7be3e..0eb79693 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -380,7 +380,7 @@ pub struct Analyzer { current_module_path: String, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Schedule { pub scopes: BTreeMap, Scope>, pub order: BTreeMap, Vec>, diff --git a/src/value.rs b/src/value.rs index f6ccb7e8..5283feeb 100644 --- a/src/value.rs +++ b/src/value.rs @@ -15,7 +15,7 @@ use serde::de::{self, Deserializer, MapAccess, SeqAccess, Visitor}; use serde::ser::{SerializeMap, Serializer}; use serde::{Deserialize, Serialize}; -pub type Rc = compact_rc::Rc16; +use crate::Rc; /// A value in a Rego document. /// diff --git a/tests/arc.rs b/tests/arc.rs new file mode 100644 index 00000000..528671b7 --- /dev/null +++ b/tests/arc.rs @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +use lazy_static::lazy_static; +use std::sync::Mutex; + +use regorus::*; + +// Ensure that types can be s +lazy_static! { + static ref VALUE: Value = Value::Null; + static ref ENGINE: Mutex = Mutex::new(Engine::new()); +// static ref ENGINE: Engine = Engine::new(); +} + +#[test] +fn shared_engine() -> anyhow::Result<()> { + let e_guard = ENGINE.lock(); + let mut engine = e_guard.expect("failed to lock engine"); + + engine.add_policy( + "hello.rego".to_string(), + r#" +package test +allow = true +"# + .to_string(), + )?; + + let results = engine.eval_query("data.test.allow".to_string(), false)?; + assert_eq!(results.result[0].expressions[0].value, Value::from(true)); + Ok(()) +} diff --git a/tests/mod.rs b/tests/mod.rs index 4782498e..7e739bf3 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -5,3 +5,6 @@ mod engine; mod lexer; mod parser; mod value; + +#[cfg(feature = "arc")] +mod arc;