Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Top-down evaluation #177

Merged
merged 1 commit into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/regorus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ fn rego_eval(

// Evaluate query.
let results = engine.eval_query(query, enable_tracing)?;

println!("{}", serde_json::to_string_pretty(&results)?);

#[cfg(feature = "coverage")]
Expand Down Expand Up @@ -147,11 +148,11 @@ enum RegorusCommand {
#[arg(long, short)]
trace: bool,

// Non strict execution
/// Perform non-strict evaluation. (default behavior of OPA).
#[arg(long, short)]
non_strict: bool,

// Display coverage information
/// Display coverage information
#[cfg(feature = "coverage")]
#[arg(long, short)]
coverage: bool,
Expand Down
45 changes: 42 additions & 3 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,12 @@ impl Engine {
/// assert_eq!(results.result[0].expressions[0].value, Value::from(true));
/// # Ok(())
/// # }
/// ```
pub fn eval_query(&mut self, query: String, enable_tracing: bool) -> Result<QueryResults> {
self.eval_modules(enable_tracing)?;
self.prepare_for_eval(enable_tracing)?;
self.interpreter.clean_internal_evaluation_state();

self.interpreter.create_rule_prefixes()?;
let query_module = {
let source = Source::new(
"<query_module.rego>".to_owned(),
Expand All @@ -256,6 +259,9 @@ impl Engine {
let query_source = Source::new("<query.rego>".to_string(), query);
let mut parser = Parser::new(&query_source)?;
let query_node = parser.parse_user_query()?;
if query_node.span.text() == "data" {
self.eval_modules(enable_tracing)?;
}
let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?;
self.interpreter.eval_user_query(
&query_module,
Expand Down Expand Up @@ -290,6 +296,7 @@ impl Engine {
/// assert!(engine.eval_bool_query("true; false; true".to_string(), enable_tracing).is_err());
/// # Ok(())
/// # }
/// ```
pub fn eval_bool_query(&mut self, query: String, enable_tracing: bool) -> Result<bool> {
let results = self.eval_query(query, enable_tracing)?;
match results.result.len() {
Expand Down Expand Up @@ -346,6 +353,38 @@ impl Engine {
!matches!(self.eval_bool_query(query, enable_tracing), Ok(false))
}

#[doc(hidden)]
/// Evaluate the given query and all the rules in the supplied policies.
///
/// This is mainly used for testing Regorus itself.
pub fn eval_query_and_all_rules(
&mut self,
query: String,
enable_tracing: bool,
) -> Result<QueryResults> {
self.eval_modules(enable_tracing)?;

let query_module = {
let source = Source::new(
"<query_module.rego>".to_owned(),
"package __internal_query_module".to_owned(),
);
Ref::new(Parser::new(&source)?.parse()?)
};

// Parse the query.
let query_source = Source::new("<query.rego>".to_string(), query);
let mut parser = Parser::new(&query_source)?;
let query_node = parser.parse_user_query()?;
let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?;
self.interpreter.eval_user_query(
&query_module,
&query_node,
&query_schedule,
enable_tracing,
)
}

#[doc(hidden)]
fn prepare_for_eval(&mut self, enable_tracing: bool) -> Result<()> {
self.interpreter.set_traces(enable_tracing);
Expand Down Expand Up @@ -513,8 +552,8 @@ impl Engine {
/// "#.to_string()
/// )?;
///
/// // Evaluation fails since y is not defined.
/// assert!(engine.eval_query("data.invalid.y".to_string(), false).is_err());
/// // Evaluation fails since rule x calls an extension with out parameter.
/// assert!(engine.eval_query("data.invalid.x".to_string(), false).is_err());
/// # Ok(())
/// # }
/// ```
Expand Down
138 changes: 83 additions & 55 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2437,14 +2437,11 @@ impl Interpreter {
|| &module_path[path.len()..path.len() + 1] == ".")
{
// Ensure that the module is created.
{
let path = Parser::get_path_ref_components(&module.package.refr)?;
let path: Vec<&str> = path.iter().map(|s| s.text()).collect();
let vref = Self::make_or_get_value_mut(&mut self.data, &path[..])?;
if *vref == Value::Undefined {
*vref = Value::new_object();
}
self.mark_processed(&path)?;
let path = Parser::get_path_ref_components(&module.package.refr)?;
let path: Vec<&str> = path.iter().map(|s| s.text()).collect();
let vref = Self::make_or_get_value_mut(&mut self.data, &path[..])?;
if *vref == Value::Undefined {
*vref = Value::new_object();
}

for rule in &module.policy {
Expand All @@ -2460,14 +2457,17 @@ impl Interpreter {
}
}
self.set_current_module(prev_module)?;
self.mark_processed(&path)?;
}
}

Ok(())
}

fn ensure_rule_evaluated(&mut self, path: String) -> Result<()> {
let mut matched = false;
if let Some(rules) = self.rules.get(&path) {
matched = true;
for r in rules.clone() {
if !self.processed.contains(&r) {
let module = self.get_rule_module(&r)?;
Expand All @@ -2478,6 +2478,7 @@ impl Interpreter {

// Evaluate the associated default rules after non-default rules
if let Some(rules) = self.default_rules.get(&path) {
matched = true;
for (r, _) in rules.clone() {
if !self.processed.contains(&r) {
let module = self.get_rule_module(&r)?;
Expand All @@ -2488,8 +2489,11 @@ impl Interpreter {
}
}

let comps: Vec<&str> = path.split('.').collect();
self.mark_processed(&comps[1..])
if matched {
let comps: Vec<&str> = path.split('.').collect();
self.mark_processed(&comps[1..])?;
}
Ok(())
}

fn is_processed(&self, path: &[&str]) -> Result<bool> {
Expand Down Expand Up @@ -2547,6 +2551,15 @@ impl Interpreter {
return Ok(Self::get_value_chained(self.data.clone(), fields));
}

// If "data" is used in a query, without any fields, then evaluate all the modules.
if fields.is_empty() && self.active_rules.is_empty() {
for module in self.modules.clone() {
for rule in &module.policy {
self.eval_rule(&module, rule)?;
}
}
}

// With modifiers may be used to specify part of a module that that not yet been
// evaluated. Therefore ensure that module is evaluated first.
let path = "data.".to_owned() + &fields.join(".");
Expand Down Expand Up @@ -3381,31 +3394,29 @@ impl Interpreter {
debug!("processing module {module_path:?}");

for rule in &module.policy {
let mut rule_refr = Self::get_rule_refr(rule);
debug!("rule refr: {}", rule_refr.span().text());
debug!("rule : {:?}", rule);
if let Rule::Spec {
head:
RuleHead::Set {
refr, key: None, ..
},
..
} = rule.as_ref()
{
rule_refr = match refr.as_ref() {
Expr::RefDot { refr, .. } => refr,

_ => refr,
let rule_refr = Self::get_rule_refr(rule);
let mut prefix_path = module_path.clone();
let mut components = Self::get_rule_path_components(rule_refr)?;
let is_old_set = matches!(
rule.as_ref(),
Rule::Spec {
head: RuleHead::Set { key: None, .. },
..
}
);

if components.len() >= 2 && is_old_set {
components.pop();
}

let mut prefix_path = module_path.clone();
prefix_path.append(&mut Self::get_rule_path_components(rule_refr)?);
let prefix_path: Vec<&str> = prefix_path[0..prefix_path.len() - 1]
.iter()
.map(|s| s.as_ref())
.collect();
if components.len() > 1 {
components.pop();
} else {
continue;
}

prefix_path.append(&mut components);
let prefix_path: Vec<&str> = prefix_path.iter().map(|s| s.as_ref()).collect();
if Self::get_value_chained(self.data.clone(), &prefix_path) == Value::Undefined {
self.update_data(
rule_refr.span(),
Expand Down Expand Up @@ -3439,6 +3450,42 @@ impl Interpreter {
Ok(())
}

fn record_default_rule(
&mut self,
refr: &Ref<Expr>,
rule: &Ref<Rule>,
index: Option<String>,
) -> Result<()> {
let comps = Parser::get_path_ref_components(refr)?;
let comps: Vec<&str> = comps.iter().map(|s| s.text()).collect();
for (idx, c) in (0..comps.len()).enumerate() {
let path = self.current_module_path.clone() + "." + &comps[0..c + 1].join(".");
match self.default_rules.entry(path) {
Entry::Occupied(o) => {
if idx + 1 == comps.len() {
for (_, i) in o.get() {
if index.is_some() && i.is_some() {
let old = i.as_ref().unwrap();
let new = index.as_ref().unwrap();
if old == new {
bail!(refr.span().error("multiple default rules for the variable with the same index"));
}
} else if index.is_some() || i.is_some() {
bail!(refr.span().error("conflict type with the default rules"));
}
}
}
o.into_mut().push((rule.clone(), index.clone()));
}
Entry::Vacant(v) => {
v.insert(vec![(rule.clone(), index.clone())]);
}
}
}

Ok(())
}

pub fn process_imports(&mut self) -> Result<()> {
for module in &self.modules {
let module_path = get_path_string(&module.package.refr, Some("data"))?;
Expand All @@ -3455,7 +3502,10 @@ impl Interpreter {
// Warn redundant import of input. Ignore it.
eprintln!(
"{}",
import.refr.span().error("redundant import of `input`")
import
.refr
.span()
.message("warning", "redundant import of `input`")
);
continue;
}
Expand Down Expand Up @@ -3513,29 +3563,7 @@ impl Interpreter {
_ => (refr, None),
};

let path = Self::get_path_string(refr, None)?;
let path = self.current_module_path.clone() + "." + &path;
match self.default_rules.entry(path) {
Entry::Occupied(o) => {
for (_, i) in o.get() {
if index.is_some() && i.is_some() {
let old = i.as_ref().unwrap();
let new = index.as_ref().unwrap();
if old == new {
bail!(refr.span().error("multiple default rules for the variable with the same index"));
}
} else if index.is_some() || i.is_some() {
bail!(refr
.span()
.error("conflict type with the default rules"));
}
}
o.into_mut().push((rule.clone(), index));
}
Entry::Vacant(v) => {
v.insert(vec![(rule.clone(), index)]);
}
}
self.record_default_rule(refr, rule, index)?;
}
}
self.set_current_module(prev_module)?;
Expand Down
8 changes: 4 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use std::rc::Rc;
/// # }
/// ````
/// See also [`QueryResult`].
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
pub struct Location {
/// Line number. Starts at 1.
pub row: u16,
Expand All @@ -69,7 +69,7 @@ pub struct Location {
/// # }
/// ```
/// See also [`QueryResult`].
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
pub struct Expression {
/// Computed value of the expression.
pub value: Value,
Expand Down Expand Up @@ -157,7 +157,7 @@ pub struct Expression {
/// # Ok(())
/// # }
/// ```
#[derive(Debug, Clone, Serialize)]
#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
pub struct QueryResult {
/// Expressions in the query.
///
Expand Down Expand Up @@ -296,7 +296,7 @@ impl Default for QueryResult {
/// ```
///
/// See [QueryResult] for examples of different kinds of results.
#[derive(Debug, Clone, Default, Serialize)]
#[derive(Debug, Clone, Default, Serialize, Eq, PartialEq)]
pub struct QueryResults {
/// Collection of results of evaluting a query.
#[serde(skip_serializing_if = "Vec::is_empty")]
Expand Down
Loading
Loading