Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve volatile expression handling in CommonSubexprEliminate #11265

Merged
merged 8 commits into from
Jul 8, 2024
13 changes: 10 additions & 3 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1413,12 +1413,19 @@ impl Expr {
.unwrap()
}

/// Returns true if the expression node is volatile, i.e. whether it can return
/// different results when evaluated multiple times with the same input.
/// Note: unlike [`Self::is_volatile`], this function does not consider inputs:
/// - `rand()` returns `true`,
/// - `a + rand()` returns `false`
pub fn is_volatile_node(&self) -> bool {
alamb marked this conversation as resolved.
Show resolved Hide resolved
matches!(self, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile)
}

/// Returns true if the expression is volatile, i.e. whether it can return different
/// results when evaluated multiple times with the same input.
pub fn is_volatile(&self) -> Result<bool> {
self.exists(|expr| {
Ok(matches!(expr, Expr::ScalarFunction(func) if func.func.signature().volatility == Volatility::Volatile ))
})
self.exists(|expr| Ok(expr.is_volatile_node()))
}

/// Recursively find all [`Expr::Placeholder`] expressions, and
Expand Down
196 changes: 158 additions & 38 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,24 +191,19 @@ impl CommonSubexprEliminate {
id_array: &mut IdArray<'n>,
expr_mask: ExprMask,
) -> Result<bool> {
// Don't consider volatile expressions for CSE.
Ok(if expr.is_volatile()? {
false
} else {
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;
let mut visitor = ExprIdentifierVisitor {
expr_stats,
id_array,
visit_stack: vec![],
down_index: 0,
up_index: 0,
expr_mask,
random_state: &self.random_state,
found_common: false,
};
expr.visit(&mut visitor)?;

visitor.found_common
})
Ok(visitor.found_common)
}

/// Rewrites `exprs_list` with common sub-expressions replaced with a new
Expand Down Expand Up @@ -917,27 +912,50 @@ struct ExprIdentifierVisitor<'a, 'n> {

/// Record item that used when traversing an expression tree.
enum VisitRecord<'n> {
/// Contains the post-order index assigned in during the first, visiting traversal and
/// a boolean flag to indicate if the record marks an expression subtree (not just a
/// single node).
/// Marks the beginning of expression. It contains:
/// - The post-order index assigned during the first, visiting traversal.
/// - A boolean flag if the record marks an expression subtree (not just a single
/// node).
EnterMark(usize, bool),
/// Accumulated identifier of sub expression.
ExprItem(Identifier<'n>),

/// Marks an accumulated subexpression tree. It contains:
/// - The accumulated identifier of a subexpression.
/// - A boolean flag if the expression is valid for subexpression elimination.
/// The flag is propagated up from children to parent. (E.g. volatile expressions
/// are not valid and can't be extracted, but non-volatile children of volatile
/// expressions can be extracted.)
ExprItem(Identifier<'n>, bool),
}

impl<'n> ExprIdentifierVisitor<'_, 'n> {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem`
/// before it.
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>) {
/// Find the first `EnterMark` in the stack, and accumulates every `ExprItem` before
/// it. Returns a tuple that contains:
/// - The pre-order index of the expression we marked.
/// - A boolean flag if we marked an expression subtree (not just a single node).
/// If true we didn't recurse into the node's children, so we need to calculate the
/// hash of the marked expression tree (not just the node) and we need to validate
/// the expression tree (not just the node).
/// - The accumulated identifier of the children of the marked expression.
/// - An accumulated boolean flag from the children of the marked expression if all
/// children are valid for subexpression elimination (i.e. it is safe to extract the
/// expression as a common expression from its children POV).
/// (E.g. if any of the children of the marked expression is not valid (e.g. is
/// volatile) then the expression is also not valid, so we can propagate this
/// information up from children to parents via `visit_stack` during the first,
/// visiting traversal and no need to test the expression's validity beforehand with
/// an extra traversal).
fn pop_enter_mark(&mut self) -> (usize, bool, Option<Identifier<'n>>, bool) {
let mut expr_id = None;
let mut is_valid = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please document what "valid" means in this context? I think it means "is valid for CSE" as in "this sub expression could potentially be removed via CSE" but I am not quite sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a comment to the method in 79b2e02, I think it explains what "valid" means in this case, but let me know if we should document the variable as well.


while let Some(item) = self.visit_stack.pop() {
match item {
VisitRecord::EnterMark(down_index, tree) => {
return (down_index, tree, expr_id);
VisitRecord::EnterMark(down_index, is_tree) => {
return (down_index, is_tree, expr_id, is_valid);
}
VisitRecord::ExprItem(id) => {
expr_id = Some(id.combine(expr_id));
VisitRecord::ExprItem(sub_expr_id, sub_expr_is_valid) => {
expr_id = Some(sub_expr_id.combine(expr_id));
is_valid &= sub_expr_is_valid;
}
}
}
Expand All @@ -949,8 +967,6 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
type Node = Expr;

fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
// TODO: consider non-volatile sub-expressions for CSE

// If an expression can short circuit its children then don't consider its
// children for CSE (https://github.com/apache/arrow-datafusion/issues/8814).
// This means that we don't recurse into its children, but handle the expression
Expand All @@ -972,21 +988,31 @@ impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
}

fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
let (down_index, is_tree, sub_expr_id) = self.pop_enter_mark();
let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) = self.pop_enter_mark();

let expr_id =
Identifier::new(expr, is_tree, self.random_state).combine(sub_expr_id);
let (expr_id, is_valid) = if is_tree {
(
Identifier::new(expr, true, self.random_state),
!expr.is_volatile()?,
)
} else {
(
Identifier::new(expr, false, self.random_state).combine(sub_expr_id),
!expr.is_volatile_node() && sub_expr_is_valid,
)
};

self.id_array[down_index].0 = self.up_index;
if !self.expr_mask.ignores(expr) {
if is_valid && !self.expr_mask.ignores(expr) {
self.id_array[down_index].1 = Some(expr_id);
let count = self.expr_stats.entry(expr_id).or_insert(0);
*count += 1;
if *count > 1 {
self.found_common = true;
}
}
self.visit_stack.push(VisitRecord::ExprItem(expr_id));
self.visit_stack
.push(VisitRecord::ExprItem(expr_id, is_valid));
self.up_index += 1;

Ok(TreeNodeRecursion::Continue)
Expand Down Expand Up @@ -1101,15 +1127,17 @@ fn replace_common_expr<'n>(

#[cfg(test)]
mod test {
use std::any::Any;
use std::collections::HashSet;
use std::iter;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_expr::expr::AggregateFunction;
use datafusion_expr::logical_plan::{table_scan, JoinType};
use datafusion_expr::{
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr, Signature,
SimpleAggregateUDF, Volatility,
grouping_set, AccumulatorFactoryFunction, AggregateUDF, BinaryExpr,
ColumnarValue, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF,
Volatility,
};
use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder};

Expand Down Expand Up @@ -1838,4 +1866,96 @@ mod test {

Ok(())
}

#[test]
fn test_volatile() -> Result<()> {
let table_scan = test_table_scan()?;

let extracted_child = col("a") + col("b");
let rand = rand_func().call(vec![]);
let not_extracted_volatile = extracted_child + rand;
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile.clone().alias("c1"),
not_extracted_volatile.alias("c2"),
])?
.build()?;

let expected = "Projection: __common_expr_1 + random() AS c1, __common_expr_1 + random() AS c2\
\n Projection: test.a + test.b AS __common_expr_1, test.a, test.b, test.c\
\n TableScan: test";

assert_optimized_plan_eq(expected, plan, None);

Ok(())
}

#[test]
fn test_volatile_short_circuits() -> Result<()> {
let table_scan = test_table_scan()?;

let rand = rand_func().call(vec![]);
let not_extracted_volatile_short_circuit_2 =
rand.clone().eq(lit(0)).or(col("b").eq(lit(0)));
let not_extracted_volatile_short_circuit_1 =
col("a").eq(lit(0)).or(rand.eq(lit(0)));
let plan = LogicalPlanBuilder::from(table_scan.clone())
.project(vec![
not_extracted_volatile_short_circuit_1.clone().alias("c1"),
not_extracted_volatile_short_circuit_1.alias("c2"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I was thinking -- why can't we extract the a = 0 part of a = 0 OR random() = 0 out?

It seems like it would be ok to rewrite

SELECT 
  a = 0 OR random() = 0 as c1,
  a = 0 OR random() = 0 as c2,
 ..

to something like this:

SELECT 
  __subexpr_1 OR random() as c1,
  __subexpr_1 OR random() as c2,
...
FROM (
  SELECT 
   a=0 as  __subexpr_1 
  FROM 
...

Copy link
Contributor Author

@peter-toth peter-toth Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This question is similar to #11197 (comment). We can extract the surely evaluated expressions (1st legs) from those short circuiting Or, And and CaseWhen expressions, but we are not there yet with this PR.

My 3rd PR in the #11194 epic will implement that logic, but I want to add the improvements gradually as it will require to recurse into only certain children of a parent and that's not straightforward with the current treenode APIs (i.e. we need to stop recursion at the parent with a Jump and start a new recursion only on the interresting children).

not_extracted_volatile_short_circuit_2.clone().alias("c3"),
not_extracted_volatile_short_circuit_2.alias("c4"),
])?
.build()?;

let expected = "Projection: test.a = Int32(0) OR random() = Int32(0) AS c1, test.a = Int32(0) OR random() = Int32(0) AS c2, random() = Int32(0) OR test.b = Int32(0) AS c3, random() = Int32(0) OR test.b = Int32(0) AS c4\
\n TableScan: test";

assert_non_optimized_plan_eq(expected, plan, None);

Ok(())
}

/// returns a "random" function that is marked volatile (aka each invocation
/// returns a different value)
///
/// Does not use datafusion_functions::rand to avoid introducing a
/// dependency on that crate.
fn rand_func() -> ScalarUDF {
ScalarUDF::new_from_impl(RandomStub::new())
}

#[derive(Debug)]
struct RandomStub {
signature: Signature,
}

impl RandomStub {
fn new() -> Self {
Self {
signature: Signature::exact(vec![], Volatility::Volatile),
}
}
}
impl ScalarUDFImpl for RandomStub {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"random"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!()
}
}
}
Loading