diff --git a/datafusion/src/optimizer/simplify_expressions.rs b/datafusion/src/optimizer/simplify_expressions.rs index ff2c05c76f18..6ab3cdd6f00d 100644 --- a/datafusion/src/optimizer/simplify_expressions.rs +++ b/datafusion/src/optimizer/simplify_expressions.rs @@ -45,33 +45,18 @@ use crate::{error::Result, logical_plan::Operator}; /// pub struct SimplifyExpressions {} -fn expr_contains(expr: &Expr, needle: &Expr) -> bool { +/// returns true if `needle` is found in a chain of search_op +/// expressions. Such as: (A AND B) AND C +fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { match expr { - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } => expr_contains(left, needle) || expr_contains(right, needle), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } => expr_contains(left, needle) || expr_contains(right, needle), + Expr::BinaryExpr { left, op, right } if *op == search_op => { + expr_contains(left, needle, search_op) + || expr_contains(right, needle, search_op) + } _ => expr == needle, } } -fn as_binary_expr(expr: &Expr) -> Option<&Expr> { - match expr { - Expr::BinaryExpr { .. } => Some(expr), - _ => None, - } -} - -fn operator_is_boolean(op: Operator) -> bool { - op == Operator::And || op == Operator::Or -} - fn is_one(s: &Expr) -> bool { match s { Expr::Literal(ScalarValue::Int8(Some(1))) @@ -95,6 +80,22 @@ fn is_true(expr: &Expr) -> bool { } } +/// returns true if expr is a +/// `Expr::Literal(ScalarValue::Boolean(v))` , false otherwise +fn is_bool_lit(expr: &Expr) -> bool { + matches!(expr, Expr::Literal(ScalarValue::Boolean(_))) +} + +/// Return a literal NULL value +fn lit_null() -> Expr { + Expr::Literal(ScalarValue::Boolean(None)) +} + +/// returns true if expr is a `Not(_)`, false otherwise +fn is_not(expr: &Expr) -> bool { + matches!(expr, Expr::Not(_)) +} + fn is_null(expr: &Expr) -> bool { match expr { Expr::Literal(v) => v.is_null(), @@ -109,160 +110,27 @@ fn is_false(expr: &Expr) -> bool { } } -fn simplify(expr: &Expr) -> Expr { - match expr { - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_true(left) || is_true(right) => lit(true), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_false(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if is_false(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if left == right => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_false(left) || is_false(right) => lit(false), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_true(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if is_true(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if left == right => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Multiply, - right, - } if is_one(left) => simplify(right), - Expr::BinaryExpr { - left, - op: Operator::Multiply, - right, - } if is_one(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if is_one(right) => simplify(left), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if left == right && is_null(left) => *left.clone(), - Expr::BinaryExpr { - left, - op: Operator::Divide, - right, - } if left == right => lit(1), +/// returns true if `haystack` looks like (needle OP X) or (X OP needle) +fn is_op_with(target_op: Operator, haystack: &Expr, needle: &Expr) -> bool { + match haystack { Expr::BinaryExpr { left, op, right } - if left == right && operator_is_boolean(*op) => + if op == &target_op + && (needle == left.as_ref() || needle == right.as_ref()) => { - simplify(left) + true } - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if expr_contains(left, right) => as_binary_expr(left) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&x.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&*right.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } if expr_contains(right, left) => as_binary_expr(right) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*right.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&*left.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if expr_contains(left, right) => as_binary_expr(left) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*right.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&x.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } if expr_contains(right, left) => as_binary_expr(right) - .map(|x| match x { - Expr::BinaryExpr { - left: _, - op: Operator::Or, - right: _, - } => simplify(&*left.clone()), - Expr::BinaryExpr { - left: _, - op: Operator::And, - right: _, - } => simplify(&x.clone()), - _ => expr.clone(), - }) - .unwrap_or_else(|| expr.clone()), - Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr { - left: Box::new(simplify(left)), - op: *op, - right: Box::new(simplify(right)), - }, - _ => expr.clone(), + _ => false, + } +} + +/// returns the contained boolean value in `expr` as +/// `Expr::Literal(ScalarValue::Boolean(v))`. +/// +/// panics if expr is not a literal boolean +fn as_bool_lit(expr: Expr) -> Option { + match expr { + Expr::Literal(ScalarValue::Boolean(v)) => v, + _ => panic!("Expected boolean literal, got {:?}", expr), } } @@ -281,11 +149,9 @@ impl OptimizerRule for SimplifyExpressions { // projected columns. With just the projected schema, it's not possible to infer types for // expressions that references non-projected columns within the same project plan or its // children plans. - let mut simplifier = - super::simplify_expressions::Simplifier::new(plan.all_schemas()); + let mut simplifier = Simplifier::new(plan.all_schemas()); - let mut const_evaluator = - super::simplify_expressions::ConstEvaluator::new(execution_props); + let mut const_evaluator = ConstEvaluator::new(execution_props); let new_inputs = plan .inputs() @@ -301,9 +167,6 @@ impl OptimizerRule for SimplifyExpressions { // Constant folding should not change expression name. let name = &e.name(plan.schema()); - // TODO combine simplify into Simplifier - let e = simplify(&e); - // TODO iterate until no changes are made // during rewrite (evaluating constants can // enable new simplifications and @@ -316,7 +179,6 @@ impl OptimizerRule for SimplifyExpressions { let new_name = &new_e.name(plan.schema()); - // TODO simplify this logic if let (Ok(expr_name), Ok(new_expr_name)) = (name, new_name) { if expr_name != new_expr_name { Ok(new_e.alias(expr_name)) @@ -554,212 +416,250 @@ impl<'a> Simplifier<'a> { false } - fn boolean_folding_for_or( - const_bool: &Option, - bool_expr: Box, - left_right_order: bool, - ) -> Expr { - // See if we can fold 'const_bool OR bool_expr' to a constant boolean - match const_bool { - // TRUE or expr (including NULL) = TRUE - Some(true) => Expr::Literal(ScalarValue::Boolean(Some(true))), - // FALSE or expr (including NULL) = expr - Some(false) => *bool_expr, - None => match *bool_expr { - // NULL or TRUE = TRUE - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(Some(true))) - } - // NULL or FALSE = NULL - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL or NULL = NULL - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL or expr can be either NULL or TRUE - // So let us not rewrite it - _ => { - let mut left = - Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); - let mut right = bool_expr; - if !left_right_order { - std::mem::swap(&mut left, &mut right); - } - - Expr::BinaryExpr { - left, - op: Operator::Or, - right, - } - } - }, + /// Returns true if expr is nullable + fn nullable(&self, expr: &Expr) -> Result { + for schema in &self.schemas { + if let Ok(res) = expr.nullable(schema.as_ref()) { + return Ok(res); + } + // expr may be from another input, so keep trying } - } - fn boolean_folding_for_and( - const_bool: &Option, - bool_expr: Box, - left_right_order: bool, - ) -> Expr { - // See if we can fold 'const_bool AND bool_expr' to a constant boolean - match const_bool { - // TRUE and expr (including NULL) = expr - Some(true) => *bool_expr, - // FALSE and expr (including NULL) = FALSE - Some(false) => Expr::Literal(ScalarValue::Boolean(Some(false))), - None => match *bool_expr { - // NULL and TRUE = NULL - Expr::Literal(ScalarValue::Boolean(Some(true))) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL and FALSE = FALSE - Expr::Literal(ScalarValue::Boolean(Some(false))) => { - Expr::Literal(ScalarValue::Boolean(Some(false))) - } - // NULL and NULL = NULL - Expr::Literal(ScalarValue::Boolean(None)) => { - Expr::Literal(ScalarValue::Boolean(None)) - } - // NULL and expr can either be NULL or FALSE - // So let us not rewrite it - _ => { - let mut left = - Box::new(Expr::Literal(ScalarValue::Boolean(*const_bool))); - let mut right = bool_expr; - if !left_right_order { - std::mem::swap(&mut left, &mut right); - } - - Expr::BinaryExpr { - left, - op: Operator::And, - right, - } - } - }, - } + // This means we weren't able to compule `Expr::nullable` with + // any input schemas, signalling a problem + Err(DataFusionError::Internal(format!( + "Could not find find columns in '{}' during simplify", + expr + ))) } } impl<'a> ExprRewriter for Simplifier<'a> { /// rewrite the expression simplifying any constant expressions fn mutate(&mut self, expr: Expr) -> Result { + use Expr::*; + use Operator::{And, Divide, Eq, Multiply, NotEq, Or}; + let new_expr = match expr { - Expr::BinaryExpr { left, op, right } => match op { - Operator::Eq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l == r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => *right, - Some(false) => Expr::Not(right), - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => *left, - Some(false) => Expr::Not(left), - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - _ => Expr::BinaryExpr { - left, - op: Operator::Eq, - right, - }, - }, - Operator::NotEq => match (left.as_ref(), right.as_ref()) { - ( - Expr::Literal(ScalarValue::Boolean(l)), - Expr::Literal(ScalarValue::Boolean(r)), - ) => match (l, r) { - (Some(l), Some(r)) => { - Expr::Literal(ScalarValue::Boolean(Some(l != r))) - } - _ => Expr::Literal(ScalarValue::Boolean(None)), - }, - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - match b { - Some(true) => Expr::Not(right), - Some(false) => *right, - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - match b { - Some(true) => Expr::Not(left), - Some(false) => *left, - None => Expr::Literal(ScalarValue::Boolean(None)), - } - } - _ => Expr::BinaryExpr { - left, - op: Operator::NotEq, - right, - }, - }, - Operator::Or => match (left.as_ref(), right.as_ref()) { - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - Self::boolean_folding_for_or(b, right, true) - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - Self::boolean_folding_for_or(b, left, false) - } - _ => Expr::BinaryExpr { - left, - op: Operator::Or, - right, - }, - }, - Operator::And => match (left.as_ref(), right.as_ref()) { - (Expr::Literal(ScalarValue::Boolean(b)), _) - if self.is_boolean_type(&right) => - { - Self::boolean_folding_for_and(b, right, true) - } - (_, Expr::Literal(ScalarValue::Boolean(b))) - if self.is_boolean_type(&left) => - { - Self::boolean_folding_for_and(b, left, false) - } - _ => Expr::BinaryExpr { - left, - op: Operator::And, - right, - }, - }, - _ => Expr::BinaryExpr { left, op, right }, - }, - // Not(Not(expr)) --> expr - Expr::Not(inner) => { - if let Expr::Not(negated_inner) = *inner { - *negated_inner - } else { - Expr::Not(inner) + // + // Rules for Eq + // + + // true = A --> A + // false = A --> !A + // null = A --> null + BinaryExpr { + left, + op: Eq, + right, + } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + match as_bool_lit(*left) { + Some(true) => *right, + Some(false) => Not(right), + None => lit_null(), + } + } + // A = true --> A + // A = false --> !A + // A = null --> null + BinaryExpr { + left, + op: Eq, + right, + } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + match as_bool_lit(*right) { + Some(true) => *left, + Some(false) => Not(left), + None => lit_null(), } } + + // + // Rules for NotEq + // + + // true != A --> !A + // false != A --> A + // null != A --> null + BinaryExpr { + left, + op: NotEq, + right, + } if is_bool_lit(&left) && self.is_boolean_type(&right) => { + match as_bool_lit(*left) { + Some(true) => Not(right), + Some(false) => *right, + None => lit_null(), + } + } + // A != true --> !A + // A != false --> A + // A != null --> null, + BinaryExpr { + left, + op: NotEq, + right, + } if is_bool_lit(&right) && self.is_boolean_type(&left) => { + match as_bool_lit(*right) { + Some(true) => Not(left), + Some(false) => *left, + None => lit_null(), + } + } + + // + // Rules for OR + // + + // true OR A --> true (even if A is null) + BinaryExpr { + left, + op: Or, + right: _, + } if is_true(&left) => *left, + // false OR A --> A + BinaryExpr { + left, + op: Or, + right, + } if is_false(&left) => *right, + // A OR true --> true (even if A is null) + BinaryExpr { + left: _, + op: Or, + right, + } if is_true(&right) => *right, + // A OR false --> A + BinaryExpr { + left, + op: Or, + right, + } if is_false(&right) => *left, + // (..A..) OR A --> (..A..) + BinaryExpr { + left, + op: Or, + right, + } if expr_contains(&left, &right, Or) => *left, + // A OR (..A..) --> (..A..) + BinaryExpr { + left, + op: Or, + right, + } if expr_contains(&right, &left, Or) => *right, + // A OR (A AND B) --> A (if B not null) + BinaryExpr { + left, + op: Or, + right, + } if !self.nullable(&right)? && is_op_with(And, &right, &left) => *left, + // (A AND B) OR A --> A (if B not null) + BinaryExpr { + left, + op: Or, + right, + } if !self.nullable(&left)? && is_op_with(And, &left, &right) => *right, + + // + // Rules for AND + // + + // true AND A --> A + BinaryExpr { + left, + op: And, + right, + } if is_true(&left) => *right, + // false AND A --> false (even if A is null) + BinaryExpr { + left, + op: And, + right: _, + } if is_false(&left) => *left, + // A AND true --> A + BinaryExpr { + left, + op: And, + right, + } if is_true(&right) => *left, + // A AND false --> false (even if A is null) + BinaryExpr { + left: _, + op: And, + right, + } if is_false(&right) => *right, + // (..A..) AND A --> (..A..) + BinaryExpr { + left, + op: And, + right, + } if expr_contains(&left, &right, And) => *left, + // A AND (..A..) --> (..A..) + BinaryExpr { + left, + op: And, + right, + } if expr_contains(&right, &left, And) => *right, + // A AND (A OR B) --> A (if B not null) + BinaryExpr { + left, + op: And, + right, + } if !self.nullable(&right)? && is_op_with(Or, &right, &left) => *left, + // (A OR B) AND A --> A (if B not null) + BinaryExpr { + left, + op: And, + right, + } if !self.nullable(&left)? && is_op_with(Or, &left, &right) => *right, + + // + // Rules for Multiply + // + BinaryExpr { + left, + op: Multiply, + right, + } if is_one(&right) => *left, + BinaryExpr { + left, + op: Multiply, + right, + } if is_one(&left) => *right, + + // + // Rules for Divide + // + + // A / 1 --> A + BinaryExpr { + left, + op: Divide, + right, + } if is_one(&right) => *left, + // A / null --> null + BinaryExpr { + left, + op: Divide, + right, + } if left == right && is_null(&left) => *left, + // A / A --> 1 (if a is not nullable) + BinaryExpr { + left, + op: Divide, + right, + } if !self.nullable(&left)? && left == right => lit(1), + + // + // Rules for Not + // + + // !(!A) --> A + Not(inner) if is_not(&inner) => match *inner { + Not(negated_inner) => *negated_inner, + _ => unreachable!(), + }, + expr => { // no additional rewrites possible expr @@ -791,8 +691,8 @@ mod tests { let expr_b = lit(true).or(col("c2")); let expected = lit(true); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -801,8 +701,8 @@ mod tests { let expr_b = col("c2").or(lit(false)); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -810,7 +710,7 @@ mod tests { let expr = col("c2").or(col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -819,8 +719,8 @@ mod tests { let expr_b = col("c2").and(lit(false)); let expected = lit(false); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -828,7 +728,7 @@ mod tests { let expr = col("c2").and(col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -837,8 +737,8 @@ mod tests { let expr_b = col("c2").and(lit(true)); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -847,8 +747,8 @@ mod tests { let expr_b = binary_expr(lit(1), Operator::Multiply, col("c2")); let expected = col("c2"); - assert_eq!(simplify(&expr_a), expected); - assert_eq!(simplify(&expr_b), expected); + assert_eq!(simplify(expr_a), expected); + assert_eq!(simplify(expr_b), expected); } #[test] @@ -856,15 +756,24 @@ mod tests { let expr = binary_expr(col("c2"), Operator::Divide, lit(1)); let expected = col("c2"); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_divide_by_same() { let expr = binary_expr(col("c2"), Operator::Divide, col("c2")); + // if c2 is null, c2 / c2 = null, so can't simplify + let expected = expr.clone(); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_divide_by_same_non_null() { + let expr = binary_expr(col("c2_non_null"), Operator::Divide, col("c2_non_null")); let expected = lit(1); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -873,21 +782,21 @@ mod tests { let expr = (col("c2").gt(lit(5))).and(col("c2").gt(lit(5))); let expected = col("c2").gt(lit(5)); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_composed_and() { - // ((c > 5) AND (d < 6)) AND (c > 5) + // ((c > 5) AND (c1 < 6)) AND (c > 5) let expr = binary_expr( - binary_expr(col("c2").gt(lit(5)), Operator::And, col("d").lt(lit(6))), + binary_expr(col("c2").gt(lit(5)), Operator::And, col("c1").lt(lit(6))), Operator::And, col("c2").gt(lit(5)), ); let expected = - binary_expr(col("c2").gt(lit(5)), Operator::And, col("d").lt(lit(6))); + binary_expr(col("c2").gt(lit(5)), Operator::And, col("c1").lt(lit(6))); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] @@ -900,20 +809,91 @@ mod tests { ); let expected = expr.clone(); - assert_eq!(simplify(&expr), expected); + assert_eq!(simplify(expr), expected); } #[test] fn test_simplify_or_and() { - // (c > 5) OR ((d < 6) AND (c > 5) -- can remove - let expr = binary_expr( - col("c2").gt(lit(5)), + let l = col("c2").gt(lit(5)); + let r = binary_expr(col("c1").lt(lit(6)), Operator::And, col("c2").gt(lit(5))); + + // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) + let expr = binary_expr(l.clone(), Operator::Or, r.clone()); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) + let expr = binary_expr(l, Operator::Or, r); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_or_and_non_null() { + let l = col("c2_non_null").gt(lit(5)); + let r = binary_expr( + col("c1_non_null").lt(lit(6)), + Operator::And, + col("c2_non_null").gt(lit(5)), + ); + + // (c2 > 5) OR ((c1 < 6) AND (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::Or, r.clone()); + + // This is only true if `c1 < 6` is not nullable / can not be null. + let expected = col("c2_non_null").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) AND (c2 > 5)) OR (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::Or, r); + + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_or() { + let l = col("c2").gt(lit(5)); + let r = binary_expr(col("c1").lt(lit(6)), Operator::Or, col("c2").gt(lit(5))); + + // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::And, r.clone()); + + // no rewrites if c1 can be null + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::And, r); + let expected = expr.clone(); + assert_eq!(simplify(expr), expected); + } + + #[test] + fn test_simplify_and_or_non_null() { + let l = col("c2_non_null").gt(lit(5)); + let r = binary_expr( + col("c1_non_null").lt(lit(6)), Operator::Or, - binary_expr(col("d").lt(lit(6)), Operator::And, col("c2").gt(lit(5))), + col("c2_non_null").gt(lit(5)), ); - let expected = col("c2").gt(lit(5)); - assert_eq!(simplify(&expr), expected); + // (c2 > 5) AND ((c1 < 6) OR (c2 > 5)) --> c2 > 5 + let expr = binary_expr(l.clone(), Operator::And, r.clone()); + + // This is only true if `c1 < 6` is not nullable / can not be null. + let expected = col("c2_non_null").gt(lit(5)); + + assert_eq!(simplify(expr), expected); + + // ((c1 < 6) OR (c2 > 5)) AND (c2 > 5) --> c2 > 5 + let expr = binary_expr(l, Operator::And, r); + + assert_eq!(simplify(expr), expected); } #[test] @@ -921,7 +901,7 @@ mod tests { let expr = binary_expr(lit_null(), Operator::And, lit(false)); let expr_eq = lit(false); - assert_eq!(simplify(&expr), expr_eq); + assert_eq!(simplify(expr), expr_eq); } #[test] @@ -930,16 +910,16 @@ mod tests { let expr_plus = binary_expr(null.clone(), Operator::Divide, null.clone()); let expr_eq = null; - assert_eq!(simplify(&expr_plus), expr_eq); + assert_eq!(simplify(expr_plus), expr_eq); } #[test] - fn test_simplify_do_not_simplify_arithmetic_expr() { + fn test_simplify_simplify_arithmetic_expr() { let expr_plus = binary_expr(lit(1), Operator::Plus, lit(1)); let expr_eq = binary_expr(lit(1), Operator::Eq, lit(1)); - assert_eq!(simplify(&expr_plus), expr_plus); - assert_eq!(simplify(&expr_eq), expr_eq); + assert_eq!(simplify(expr_plus), lit(2)); + assert_eq!(simplify(expr_eq), lit(true)); } // ------------------------------ @@ -1182,11 +1162,17 @@ mod tests { // ----- Simplifier tests ------- // ------------------------------ - // TODO rename to simplify - fn do_simplify(expr: Expr) -> Expr { + fn simplify(expr: Expr) -> Expr { let schema = expr_test_schema(); let mut rewriter = Simplifier::new(vec![&schema]); - expr.rewrite(&mut rewriter).expect("expected to simplify") + + let execution_props = ExecutionProps::new(); + let mut const_evaluator = ConstEvaluator::new(&execution_props); + + expr.rewrite(&mut rewriter) + .expect("expected to simplify") + .rewrite(&mut const_evaluator) + .expect("expected to const evaluate") } fn expr_test_schema() -> DFSchemaRef { @@ -1194,6 +1180,8 @@ mod tests { DFSchema::new(vec![ DFField::new(None, "c1", DataType::Utf8, true), DFField::new(None, "c2", DataType::Boolean, true), + DFField::new(None, "c1_non_null", DataType::Utf8, false), + DFField::new(None, "c2_non_null", DataType::Boolean, false), ]) .unwrap(), ) @@ -1201,20 +1189,20 @@ mod tests { #[test] fn simplify_expr_not_not() { - assert_eq!(do_simplify(col("c2").not().not().not()), col("c2").not(),); + assert_eq!(simplify(col("c2").not().not().not()), col("c2").not(),); } #[test] fn simplify_expr_null_comparison() { // x = null is always null assert_eq!( - do_simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))), + simplify(lit(true).eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null != null is always null assert_eq!( - do_simplify( + simplify( lit(ScalarValue::Boolean(None)).not_eq(lit(ScalarValue::Boolean(None))) ), lit(ScalarValue::Boolean(None)), @@ -1222,13 +1210,13 @@ mod tests { // x != null is always null assert_eq!( - do_simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), + simplify(col("c2").not_eq(lit(ScalarValue::Boolean(None)))), lit(ScalarValue::Boolean(None)), ); // null = x is always null assert_eq!( - do_simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))), + simplify(lit(ScalarValue::Boolean(None)).eq(col("c2"))), lit(ScalarValue::Boolean(None)), ); } @@ -1239,16 +1227,16 @@ mod tests { assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); // true = ture -> true - assert_eq!(do_simplify(lit(true).eq(lit(true))), lit(true)); + assert_eq!(simplify(lit(true).eq(lit(true))), lit(true)); // true = false -> false - assert_eq!(do_simplify(lit(true).eq(lit(false))), lit(false),); + assert_eq!(simplify(lit(true).eq(lit(false))), lit(false),); // c2 = true -> c2 - assert_eq!(do_simplify(col("c2").eq(lit(true))), col("c2")); + assert_eq!(simplify(col("c2").eq(lit(true))), col("c2")); // c2 = false => !c2 - assert_eq!(do_simplify(col("c2").eq(lit(false))), col("c2").not(),); + assert_eq!(simplify(col("c2").eq(lit(false))), col("c2").not(),); } #[test] @@ -1262,25 +1250,8 @@ mod tests { // Make sure c1 column to be used in tests is not boolean type assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); - // don't fold c1 = true - assert_eq!( - do_simplify(col("c1").eq(lit(true))), - col("c1").eq(lit(true)), - ); - - // don't fold c1 = false - assert_eq!( - do_simplify(col("c1").eq(lit(false))), - col("c1").eq(lit(false)), - ); - - // test constant operands - assert_eq!(do_simplify(lit(1).eq(lit(true))), lit(1).eq(lit(true)),); - - assert_eq!( - do_simplify(lit("a").eq(lit(false))), - lit("a").eq(lit(false)), - ); + // don't fold c1 = foo + assert_eq!(simplify(col("c1").eq(lit("foo"))), col("c1").eq(lit("foo")),); } #[test] @@ -1290,15 +1261,15 @@ mod tests { assert_eq!(col("c2").get_type(&schema).unwrap(), DataType::Boolean); // c2 != true -> !c2 - assert_eq!(do_simplify(col("c2").not_eq(lit(true))), col("c2").not(),); + assert_eq!(simplify(col("c2").not_eq(lit(true))), col("c2").not(),); // c2 != false -> c2 - assert_eq!(do_simplify(col("c2").not_eq(lit(false))), col("c2"),); + assert_eq!(simplify(col("c2").not_eq(lit(false))), col("c2"),); // test constant - assert_eq!(do_simplify(lit(true).not_eq(lit(true))), lit(false),); + assert_eq!(simplify(lit(true).not_eq(lit(true))), lit(false),); - assert_eq!(do_simplify(lit(true).not_eq(lit(false))), lit(true),); + assert_eq!(simplify(lit(true).not_eq(lit(false))), lit(true),); } #[test] @@ -1311,44 +1282,25 @@ mod tests { assert_eq!(col("c1").get_type(&schema).unwrap(), DataType::Utf8); assert_eq!( - do_simplify(col("c1").not_eq(lit(true))), - col("c1").not_eq(lit(true)), - ); - - assert_eq!( - do_simplify(col("c1").not_eq(lit(false))), - col("c1").not_eq(lit(false)), - ); - - // test constants - assert_eq!( - do_simplify(lit(1).not_eq(lit(true))), - lit(1).not_eq(lit(true)), - ); - - assert_eq!( - do_simplify(lit("a").not_eq(lit(false))), - lit("a").not_eq(lit(false)), + simplify(col("c1").not_eq(lit("foo"))), + col("c1").not_eq(lit("foo")), ); } #[test] fn simplify_expr_case_when_then_else() { assert_eq!( - do_simplify(Expr::Case { + simplify(Expr::Case { expr: None, when_then_expr: vec![( Box::new(col("c2").not_eq(lit(false))), - Box::new(lit("ok").eq(lit(true))), + Box::new(lit("ok").eq(lit("not_ok"))), )], else_expr: Some(Box::new(col("c2").eq(lit(true)))), }), Expr::Case { expr: None, - when_then_expr: vec![( - Box::new(col("c2")), - Box::new(lit("ok").eq(lit(true))) - )], + when_then_expr: vec![(Box::new(col("c2")), Box::new(lit(false)))], else_expr: Some(Box::new(col("c2"))), } ); @@ -1362,22 +1314,22 @@ mod tests { #[test] fn simplify_expr_bool_or() { // col || true is always true - assert_eq!(do_simplify(col("c2").or(lit(true))), lit(true),); + assert_eq!(simplify(col("c2").or(lit(true))), lit(true),); // col || false is always col - assert_eq!(do_simplify(col("c2").or(lit(false))), col("c2"),); + assert_eq!(simplify(col("c2").or(lit(false))), col("c2"),); // true || null is always true - assert_eq!(do_simplify(lit(true).or(lit_null())), lit(true),); + assert_eq!(simplify(lit(true).or(lit_null())), lit(true),); // null || true is always true - assert_eq!(do_simplify(lit_null().or(lit(true))), lit(true),); + assert_eq!(simplify(lit_null().or(lit(true))), lit(true),); // false || null is always null - assert_eq!(do_simplify(lit(false).or(lit_null())), lit_null(),); + assert_eq!(simplify(lit(false).or(lit_null())), lit_null(),); // null || false is always null - assert_eq!(do_simplify(lit_null().or(lit(false))), lit_null(),); + assert_eq!(simplify(lit_null().or(lit(false))), lit_null(),); // ( c1 BETWEEN Int32(0) AND Int32(10) ) OR Boolean(NULL) // it can be either NULL or TRUE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10)` @@ -1389,28 +1341,28 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.or(lit_null()); - let result = do_simplify(expr.clone()); + let result = simplify(expr.clone()); assert_eq!(expr, result); } #[test] fn simplify_expr_bool_and() { // col & true is always col - assert_eq!(do_simplify(col("c2").and(lit(true))), col("c2"),); + assert_eq!(simplify(col("c2").and(lit(true))), col("c2"),); // col & false is always false - assert_eq!(do_simplify(col("c2").and(lit(false))), lit(false),); + assert_eq!(simplify(col("c2").and(lit(false))), lit(false),); // true && null is always null - assert_eq!(do_simplify(lit(true).and(lit_null())), lit_null(),); + assert_eq!(simplify(lit(true).and(lit_null())), lit_null(),); // null && true is always null - assert_eq!(do_simplify(lit_null().and(lit(true))), lit_null(),); + assert_eq!(simplify(lit_null().and(lit(true))), lit_null(),); // false && null is always false - assert_eq!(do_simplify(lit(false).and(lit_null())), lit(false),); + assert_eq!(simplify(lit(false).and(lit_null())), lit(false),); // null && false is always false - assert_eq!(do_simplify(lit_null().and(lit(false))), lit(false),); + assert_eq!(simplify(lit_null().and(lit(false))), lit(false),); // c1 BETWEEN Int32(0) AND Int32(10) AND Boolean(NULL) // it can be either NULL or FALSE depending on the value of `c1 BETWEEN Int32(0) AND Int32(10` @@ -1422,7 +1374,7 @@ mod tests { high: Box::new(lit(10)), }; let expr = expr.and(lit_null()); - let result = do_simplify(expr.clone()); + let result = simplify(expr.clone()); assert_eq!(expr, result); } @@ -1473,12 +1425,12 @@ mod tests { ); } - // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) #[test] fn test_simplify_optimized_plan_with_composed_and() { let table_scan = test_table_scan(); + // ((c > 5) AND (d < 6)) AND (c > 5) --> (c > 5) AND (d < 6) let plan = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")]) + .project(vec![col("a"), col("b")]) .unwrap() .filter(and( and(col("a").gt(lit(5)), col("b").lt(lit(6))), @@ -1492,7 +1444,7 @@ mod tests { &plan, "\ Filter: #test.a > Int32(5) AND #test.b < Int32(6) AS test.a > Int32(5) AND test.b < Int32(6) AND test.a > Int32(5)\ - \n Projection: #test.a\ + \n Projection: #test.a, #test.b\ \n TableScan: test projection=None", ); }