From f490e9e46b035d7ec6ba76ae7a83cc252c5f1cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20Heres?= Date: Fri, 6 Sep 2024 12:26:48 +0200 Subject: [PATCH] Cherry-pick is not null pushdown (#268) --- .../optimizer/src/filter_null_join_keys.rs | 54 +++++++++++++----- datafusion/optimizer/src/push_down_filter.rs | 55 ++++++++++--------- .../optimizer/tests/optimizer_integration.rs | 28 +++++++--- 3 files changed, 88 insertions(+), 49 deletions(-) diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index ecd1901abe58..7b75d8e6c26d 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -18,18 +18,16 @@ //! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable use crate::optimizer::ApplyOrder; +use crate::push_down_filter::on_lr_is_preserved; use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::tree_node::Transformed; use datafusion_common::{internal_err, Result}; use datafusion_expr::utils::conjunction; -use datafusion_expr::{ - logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan, -}; +use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan}; use std::sync::Arc; -/// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions -/// where the join key is nullable on one side and non-nullable on the other side -/// and then insert an `IsNotNull` filter on the nullable side since null values +/// The FilterNullJoinKeys rule will identify joins with equi-join conditions +/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values /// can never match. #[derive(Default)] pub struct FilterNullJoinKeys {} @@ -59,9 +57,11 @@ impl OptimizerRule for FilterNullJoinKeys { if !config.options().optimizer.filter_null_join_keys { return Ok(Transformed::no(plan)); } - match plan { - LogicalPlan::Join(mut join) if join.join_type == JoinType::Inner => { + LogicalPlan::Join(mut join) if !join.on.is_empty() => { + let (left_preserved, right_preserved) = + on_lr_is_preserved(join.join_type); + let left_schema = join.left.schema(); let right_schema = join.right.schema(); @@ -69,11 +69,11 @@ impl OptimizerRule for FilterNullJoinKeys { let mut right_filters = vec![]; for (l, r) in &join.on { - if l.nullable(left_schema)? { + if left_preserved && l.nullable(left_schema)? { left_filters.push(l.clone()); } - if r.nullable(right_schema)? { + if right_preserved && r.nullable(right_schema)? { right_filters.push(r.clone()); } } @@ -117,7 +117,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Column; use datafusion_expr::logical_plan::table_scan; - use datafusion_expr::{col, lit, LogicalPlanBuilder}; + use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder}; fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) @@ -126,7 +126,7 @@ mod tests { #[test] fn left_nullable() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; let expected = "Inner Join: t1.optional_id = t2.id\ \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ @@ -134,10 +134,33 @@ mod tests { assert_optimized_plan_equal(plan, expected) } + #[test] + fn left_nullable_left_join() -> Result<()> { + let (t1, t2) = test_tables()?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?; + let expected = "Left Join: t1.optional_id = t2.id\ + \n TableScan: t1\ + \n TableScan: t2"; + assert_optimized_plan_equal(plan, expected) + } + + #[test] + fn left_nullable_left_join_reordered() -> Result<()> { + let (t_left, t_right) = test_tables()?; + // Note: order of tables is reversed + let plan = + build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?; + let expected = "Left Join: t2.id = t1.optional_id\ + \n TableScan: t2\ + \n Filter: t1.optional_id IS NOT NULL\ + \n TableScan: t1"; + assert_optimized_plan_equal(plan, expected) + } + #[test] fn left_nullable_on_condition_reversed() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t2.id", "t1.optional_id")?; + let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?; let expected = "Inner Join: t1.optional_id = t2.id\ \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ @@ -148,7 +171,7 @@ mod tests { #[test] fn nested_join_multiple_filter_expr() -> Result<()> { let (t1, t2) = test_tables()?; - let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?; + let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?; let schema = Schema::new(vec![ Field::new("id", DataType::UInt32, false), Field::new("t1_id", DataType::UInt32, true), @@ -252,11 +275,12 @@ mod tests { right_table: LogicalPlan, left_key: &str, right_key: &str, + join_type: JoinType, ) -> Result { LogicalPlanBuilder::from(left_table) .join( right_table, - JoinType::Inner, + join_type, ( vec![Column::from_qualified_name(left_key)], vec![Column::from_qualified_name(right_key)], diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index e56bfd051fe2..4ccb01ec4d14 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -146,38 +146,43 @@ pub struct PushDownFilter {} /// there may be rows in the output that don't directly map to a row in the /// right input (due to nulls filling where there is no match on the right). /// -/// This is important because we can always push down post-join filters to a preserved -/// side of the join, assuming the filter only references columns from that side. For the -/// non-preserved side it can be more tricky. -/// -/// Returns a tuple of booleans - (left_preserved, right_preserved). -fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { +/// - In a left join, the left side is preserved (we can push predicates) but +/// the right is not, because there may be rows in the output that don't +/// directly map to a row in the right input (due to nulls filling where there +/// is no match on the right). +pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) { match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((true, false)), - JoinType::Right => Ok((false, true)), - JoinType::Full => Ok((false, false)), + JoinType::Inner => (true, true), + JoinType::Left => (true, false), + JoinType::Right => (false, true), + JoinType::Full => (false, false), // No columns from the right side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)), + JoinType::LeftSemi | JoinType::LeftAnti => (true, false), // No columns from the left side of the join can be referenced in output // predicates for semi/anti joins, so whether we specify t/f doesn't matter. - JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)), + JoinType::RightSemi | JoinType::RightAnti => (false, true), } } -/// For a given JOIN logical plan, determine whether each side of the join is preserved -/// in terms on join filtering. -/// Predicates from join filter can only be pushed to preserved join side. -fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> { +/// For a given JOIN type, determine whether each input of the join is preserved +/// for the join condition (`ON` clause filters). +/// +/// It is only correct to push filters below a join for preserved inputs. +/// +/// # Return Value +/// A tuple of booleans - (left_preserved, right_preserved). +/// +/// See [`lr_is_preserved`] for a definition of "preserved". +pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) { match join_type { - JoinType::Inner => Ok((true, true)), - JoinType::Left => Ok((false, true)), - JoinType::Right => Ok((true, false)), - JoinType::Full => Ok((false, false)), - JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)), - JoinType::LeftAnti => Ok((false, true)), - JoinType::RightAnti => Ok((true, false)), + JoinType::Inner => (true, true), + JoinType::Left => (false, true), + JoinType::Right => (true, false), + JoinType::Full => (false, false), + JoinType::LeftSemi | JoinType::RightSemi => (true, true), + JoinType::LeftAnti => (false, true), + JoinType::RightAnti => (true, false), } } @@ -395,7 +400,7 @@ fn push_down_all_join( ) -> Result> { let is_inner_join = join.join_type == JoinType::Inner; // Get pushable predicates from current optimizer state - let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?; + let (left_preserved, right_preserved) = lr_is_preserved(join.join_type); // The predicates can be divided to three categories: // 1) can push through join to its children(left or right) @@ -435,7 +440,7 @@ fn push_down_all_join( } if !on_filter.is_empty() { - let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?; + let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type); for on in on_filter { if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? { left_push.push(on) diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index ae3feafbb753..49c2377e0657 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -121,10 +121,12 @@ fn semi_join_with_join_filter() -> Result<()> { let plan = test_sql(sql)?; let expected = "Projection: test.col_utf8\ \n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\ - \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n SubqueryAlias: __correlated_sq_1\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -141,7 +143,8 @@ fn anti_join_with_join_filter() -> Result<()> { \n TableScan: test projection=[col_int32, col_uint32, col_utf8]\ \n SubqueryAlias: __correlated_sq_1\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32, col_uint32]"; + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_uint32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -152,11 +155,13 @@ fn where_exists_distinct() -> Result<()> { SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)"; let plan = test_sql(sql)?; let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\ - \n TableScan: test projection=[col_int32]\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: __correlated_sq_1\ \n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -172,9 +177,12 @@ fn intersect() -> Result<()> { \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ \n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\ \n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\ + \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ \n TableScan: test projection=[col_int32, col_utf8]\ - \n TableScan: test projection=[col_int32, col_utf8]\ - \n TableScan: test projection=[col_int32, col_utf8]"; + \n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } @@ -270,9 +278,11 @@ fn test_same_name_but_not_ambiguous() { let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\ \n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\ \n SubqueryAlias: t1\ - \n TableScan: test projection=[col_int32]\ + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]\ \n SubqueryAlias: t2\ - \n TableScan: test projection=[col_int32]"; + \n Filter: test.col_int32 IS NOT NULL\ + \n TableScan: test projection=[col_int32]"; assert_eq!(expected, format!("{plan:?}")); }