Skip to content

Commit

Permalink
Cherry-pick is not null pushdown (#268)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dandandan authored Sep 6, 2024
1 parent f49d6a8 commit f490e9e
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 49 deletions.
54 changes: 39 additions & 15 deletions datafusion/optimizer/src/filter_null_join_keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
//! [`FilterNullJoinKeys`] adds filters to join inputs when input isn't nullable

use crate::optimizer::ApplyOrder;
use crate::push_down_filter::on_lr_is_preserved;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::tree_node::Transformed;
use datafusion_common::{internal_err, Result};
use datafusion_expr::utils::conjunction;
use datafusion_expr::{
logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan,
};
use datafusion_expr::{logical_plan::Filter, Expr, ExprSchemable, LogicalPlan};
use std::sync::Arc;

/// The FilterNullJoinKeys rule will identify inner joins with equi-join conditions
/// where the join key is nullable on one side and non-nullable on the other side
/// and then insert an `IsNotNull` filter on the nullable side since null values
/// The FilterNullJoinKeys rule will identify joins with equi-join conditions
/// where the join key is nullable and then insert an `IsNotNull` filter on the nullable side since null values
/// can never match.
#[derive(Default)]
pub struct FilterNullJoinKeys {}
Expand Down Expand Up @@ -59,21 +57,23 @@ impl OptimizerRule for FilterNullJoinKeys {
if !config.options().optimizer.filter_null_join_keys {
return Ok(Transformed::no(plan));
}

match plan {
LogicalPlan::Join(mut join) if join.join_type == JoinType::Inner => {
LogicalPlan::Join(mut join) if !join.on.is_empty() => {
let (left_preserved, right_preserved) =
on_lr_is_preserved(join.join_type);

let left_schema = join.left.schema();
let right_schema = join.right.schema();

let mut left_filters = vec![];
let mut right_filters = vec![];

for (l, r) in &join.on {
if l.nullable(left_schema)? {
if left_preserved && l.nullable(left_schema)? {
left_filters.push(l.clone());
}

if r.nullable(right_schema)? {
if right_preserved && r.nullable(right_schema)? {
right_filters.push(r.clone());
}
}
Expand Down Expand Up @@ -117,7 +117,7 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::Column;
use datafusion_expr::logical_plan::table_scan;
use datafusion_expr::{col, lit, LogicalPlanBuilder};
use datafusion_expr::{col, lit, JoinType, LogicalPlanBuilder};

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected)
Expand All @@ -126,18 +126,41 @@ mod tests {
#[test]
fn left_nullable() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
let expected = "Inner Join: t1.optional_id = t2.id\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
\n TableScan: t2";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_left_join() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Left)?;
let expected = "Left Join: t1.optional_id = t2.id\
\n TableScan: t1\
\n TableScan: t2";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_left_join_reordered() -> Result<()> {
let (t_left, t_right) = test_tables()?;
// Note: order of tables is reversed
let plan =
build_plan(t_right, t_left, "t2.id", "t1.optional_id", JoinType::Left)?;
let expected = "Left Join: t2.id = t1.optional_id\
\n TableScan: t2\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1";
assert_optimized_plan_equal(plan, expected)
}

#[test]
fn left_nullable_on_condition_reversed() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id")?;
let plan = build_plan(t1, t2, "t2.id", "t1.optional_id", JoinType::Inner)?;
let expected = "Inner Join: t1.optional_id = t2.id\
\n Filter: t1.optional_id IS NOT NULL\
\n TableScan: t1\
Expand All @@ -148,7 +171,7 @@ mod tests {
#[test]
fn nested_join_multiple_filter_expr() -> Result<()> {
let (t1, t2) = test_tables()?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id")?;
let plan = build_plan(t1, t2, "t1.optional_id", "t2.id", JoinType::Inner)?;
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new("t1_id", DataType::UInt32, true),
Expand Down Expand Up @@ -252,11 +275,12 @@ mod tests {
right_table: LogicalPlan,
left_key: &str,
right_key: &str,
join_type: JoinType,
) -> Result<LogicalPlan> {
LogicalPlanBuilder::from(left_table)
.join(
right_table,
JoinType::Inner,
join_type,
(
vec![Column::from_qualified_name(left_key)],
vec![Column::from_qualified_name(right_key)],
Expand Down
55 changes: 30 additions & 25 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,38 +146,43 @@ pub struct PushDownFilter {}
/// there may be rows in the output that don't directly map to a row in the
/// right input (due to nulls filling where there is no match on the right).
///
/// This is important because we can always push down post-join filters to a preserved
/// side of the join, assuming the filter only references columns from that side. For the
/// non-preserved side it can be more tricky.
///
/// Returns a tuple of booleans - (left_preserved, right_preserved).
fn lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
/// - In a left join, the left side is preserved (we can push predicates) but
/// the right is not, because there may be rows in the output that don't
/// directly map to a row in the right input (due to nulls filling where there
/// is no match on the right).
pub(crate) fn lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
JoinType::Inner => (true, true),
JoinType::Left => (true, false),
JoinType::Right => (false, true),
JoinType::Full => (false, false),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
JoinType::LeftSemi | JoinType::LeftAnti => (true, false),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
JoinType::RightSemi | JoinType::RightAnti => (false, true),
}
}

/// For a given JOIN logical plan, determine whether each side of the join is preserved
/// in terms on join filtering.
/// Predicates from join filter can only be pushed to preserved join side.
fn on_lr_is_preserved(join_type: JoinType) -> Result<(bool, bool)> {
/// For a given JOIN type, determine whether each input of the join is preserved
/// for the join condition (`ON` clause filters).
///
/// It is only correct to push filters below a join for preserved inputs.
///
/// # Return Value
/// A tuple of booleans - (left_preserved, right_preserved).
///
/// See [`lr_is_preserved`] for a definition of "preserved".
pub(crate) fn on_lr_is_preserved(join_type: JoinType) -> (bool, bool) {
match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
JoinType::LeftAnti => Ok((false, true)),
JoinType::RightAnti => Ok((true, false)),
JoinType::Inner => (true, true),
JoinType::Left => (false, true),
JoinType::Right => (true, false),
JoinType::Full => (false, false),
JoinType::LeftSemi | JoinType::RightSemi => (true, true),
JoinType::LeftAnti => (false, true),
JoinType::RightAnti => (true, false),
}
}

Expand Down Expand Up @@ -395,7 +400,7 @@ fn push_down_all_join(
) -> Result<Transformed<LogicalPlan>> {
let is_inner_join = join.join_type == JoinType::Inner;
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type)?;
let (left_preserved, right_preserved) = lr_is_preserved(join.join_type);

// The predicates can be divided to three categories:
// 1) can push through join to its children(left or right)
Expand Down Expand Up @@ -435,7 +440,7 @@ fn push_down_all_join(
}

if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type)?;
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join.join_type);
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? {
left_push.push(on)
Expand Down
28 changes: 19 additions & 9 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,12 @@ fn semi_join_with_join_filter() -> Result<()> {
let plan = test_sql(sql)?;
let expected = "Projection: test.col_utf8\
\n LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32 Filter: test.col_uint32 != __correlated_sq_1.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: __correlated_sq_1\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -141,7 +143,8 @@ fn anti_join_with_join_filter() -> Result<()> {
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: __correlated_sq_1\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -152,11 +155,13 @@ fn where_exists_distinct() -> Result<()> {
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
let plan = test_sql(sql)?;
let expected = "LeftSemi Join: test.col_int32 = __correlated_sq_1.col_int32\
\n TableScan: test projection=[col_int32]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: __correlated_sq_1\
\n Aggregate: groupBy=[[t2.col_int32]], aggr=[[]]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand All @@ -172,9 +177,12 @@ fn intersect() -> Result<()> {
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
\n LeftSemi Join: test.col_int32 = test.col_int32, test.col_utf8 = test.col_utf8\
\n Aggregate: groupBy=[[test.col_int32, test.col_utf8]], aggr=[[]]\
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]\
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]\
\n TableScan: test projection=[col_int32, col_utf8]";
\n Filter: test.col_int32 IS NOT NULL AND test.col_utf8 IS NOT NULL\
\n TableScan: test projection=[col_int32, col_utf8]";
assert_eq!(expected, format!("{plan:?}"));
Ok(())
}
Expand Down Expand Up @@ -270,9 +278,11 @@ fn test_same_name_but_not_ambiguous() {
let expected = "LeftSemi Join: t1.col_int32 = t2.col_int32\
\n Aggregate: groupBy=[[t1.col_int32]], aggr=[[]]\
\n SubqueryAlias: t1\
\n TableScan: test projection=[col_int32]\
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32]";
\n Filter: test.col_int32 IS NOT NULL\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan:?}"));
}

Expand Down

0 comments on commit f490e9e

Please sign in to comment.