Skip to content

Commit

Permalink
feat(optimizer): outer join simplify (risingwavelabs#2349)
Browse files Browse the repository at this point in the history
* add simplify outer join

* apply planner test

* q15 plan test

* use ExprType
  • Loading branch information
st1page authored May 6, 2022
1 parent e2229af commit afffdec
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 38 deletions.
53 changes: 53 additions & 0 deletions src/frontend/src/optimizer/rule/filter_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use risingwave_pb::plan_common::JoinType;

use super::super::plan_node::*;
use super::Rule;
use crate::expr::{ExprImpl, ExprType};
use crate::optimizer::rule::BoxedRule;
use crate::utils::{ColIndexMapping, Condition};

Expand Down Expand Up @@ -56,6 +57,8 @@ impl Rule for FilterJoinRule {

let mut new_filter_predicate = filter.predicate().clone();

let join_type = self.simplify_outer(filter.predicate(), left_col_num, join_type);

let (left_from_filter, right_from_filter, on) = self.push_down(
&mut new_filter_predicate,
left_col_num,
Expand Down Expand Up @@ -180,6 +183,56 @@ impl FilterJoinRule {
JoinType::Inner | JoinType::LeftOuter | JoinType::RightSemi
)
}

/// Try to simplify the outer join with the predicate on the top of the join
///
/// now it is just a naive implementation for comparison expression, we can give a more general
/// implementation with constant folding in future
fn simplify_outer(
&self,
predicate: &Condition,
left_col_num: usize,
join_type: JoinType,
) -> JoinType {
let (mut gen_null_in_left, mut gen_null_in_right) = match join_type {
JoinType::LeftOuter => (false, true),
JoinType::RightOuter => (true, false),
JoinType::FullOuter => (true, true),
_ => return join_type,
};

for expr in &predicate.conjunctions {
if let ExprImpl::FunctionCall(func) = expr {
match func.get_expr_type() {
ExprType::Equal
| ExprType::NotEqual
| ExprType::LessThan
| ExprType::LessThanOrEqual
| ExprType::GreaterThan
| ExprType::GreaterThanOrEqual => {
for input in func.inputs() {
if let ExprImpl::InputRef(input) = input {
let idx = input.index;
if idx < left_col_num {
gen_null_in_left = false;
} else {
gen_null_in_right = false;
}
}
}
}
_ => {}
};
}
}

match (gen_null_in_left, gen_null_in_right) {
(true, true) => JoinType::FullOuter,
(true, false) => JoinType::RightOuter,
(false, true) => JoinType::LeftOuter,
(false, false) => JoinType::Inner,
}
}
}

#[cfg(test)]
Expand Down
108 changes: 70 additions & 38 deletions src/frontend/test_runner/tests/testdata/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1017,29 +1017,28 @@
LogicalScan { table: nation, columns: [_row_id#0, n_nationkey, n_name, n_regionkey, n_comment] }
optimized_logical_plan: |
LogicalProject { exprs: [$0, $1], expr_alias: [ps_partkey, value] }
LogicalFilter { predicate: ($2 > $3) }
LogicalJoin { type: LeftOuter, on: always }
LogicalAgg { group_keys: [0], agg_calls: [sum($1), sum($1)] }
LogicalProject { exprs: [$0, ($2 * $1)], expr_alias: [ , ] }
LogicalJoin { type: Inner, on: ($3 = $4) }
LogicalProject { exprs: [$0, $2, $3, $5], expr_alias: [ , , , ] }
LogicalJoin { type: Inner, on: ($1 = $4) }
LogicalScan { table: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] }
LogicalJoin { type: Inner, on: ($2 > $3) }
LogicalAgg { group_keys: [0], agg_calls: [sum($1), sum($1)] }
LogicalProject { exprs: [$0, ($2 * $1)], expr_alias: [ , ] }
LogicalJoin { type: Inner, on: ($3 = $4) }
LogicalProject { exprs: [$0, $2, $3, $5], expr_alias: [ , , , ] }
LogicalJoin { type: Inner, on: ($1 = $4) }
LogicalScan { table: partsupp, columns: [ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] }
LogicalScan { table: supplier, columns: [s_suppkey, s_nationkey] }
LogicalProject { exprs: [$0], expr_alias: [ ] }
LogicalFilter { predicate: ($1 = 'ARGENTINA':Varchar) }
LogicalScan { table: nation, columns: [n_nationkey, n_name] }
LogicalProject { exprs: [($0 * 0.0001000000:Decimal)], expr_alias: [ ] }
LogicalAgg { group_keys: [], agg_calls: [sum($0)] }
LogicalProject { exprs: [($1 * $0)], expr_alias: [ ] }
LogicalJoin { type: Inner, on: ($2 = $3) }
LogicalProject { exprs: [$1, $2, $4], expr_alias: [ , , ] }
LogicalJoin { type: Inner, on: ($0 = $3) }
LogicalScan { table: partsupp, columns: [ps_suppkey, ps_availqty, ps_supplycost] }
LogicalScan { table: supplier, columns: [s_suppkey, s_nationkey] }
LogicalProject { exprs: [$0], expr_alias: [ ] }
LogicalFilter { predicate: ($1 = 'ARGENTINA':Varchar) }
LogicalScan { table: nation, columns: [n_nationkey, n_name] }
LogicalProject { exprs: [($0 * 0.0001000000:Decimal)], expr_alias: [ ] }
LogicalAgg { group_keys: [], agg_calls: [sum($0)] }
LogicalProject { exprs: [($1 * $0)], expr_alias: [ ] }
LogicalJoin { type: Inner, on: ($2 = $3) }
LogicalProject { exprs: [$1, $2, $4], expr_alias: [ , , ] }
LogicalJoin { type: Inner, on: ($0 = $3) }
LogicalScan { table: partsupp, columns: [ps_suppkey, ps_availqty, ps_supplycost] }
LogicalScan { table: supplier, columns: [s_suppkey, s_nationkey] }
LogicalProject { exprs: [$0], expr_alias: [ ] }
LogicalFilter { predicate: ($1 = 'ARGENTINA':Varchar) }
LogicalScan { table: nation, columns: [n_nationkey, n_name] }
- id: tpch_q12
before:
- create_tables
Expand Down Expand Up @@ -1237,25 +1236,58 @@
)
order by
s_suppkey;
optimized_logical_plan: |
LogicalProject { exprs: [$0, $1, $2, $3, $4], expr_alias: [s_suppkey, s_name, s_address, s_phone, total_revenue] }
LogicalFilter { predicate: ($4 = $5) }
LogicalJoin { type: LeftOuter, on: always }
LogicalProject { exprs: [$0, $1, $2, $3, $5], expr_alias: [ , , , , ] }
LogicalJoin { type: Inner, on: ($0 = $4) }
LogicalScan { table: supplier, columns: [s_suppkey, s_name, s_address, s_phone] }
LogicalProject { exprs: [$0, $1], expr_alias: [l_suppkey, total_revenue] }
LogicalAgg { group_keys: [0], agg_calls: [sum($1)] }
LogicalProject { exprs: [$0, ($1 * (1:Int32 - $2))], expr_alias: [ , ] }
LogicalFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
LogicalScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate] }
LogicalProject { exprs: [$0], expr_alias: [max_revenue] }
LogicalAgg { group_keys: [], agg_calls: [max($0)] }
LogicalProject { exprs: [$1], expr_alias: [ ] }
LogicalAgg { group_keys: [0], agg_calls: [sum($1)] }
LogicalProject { exprs: [$0, ($1 * (1:Int32 - $2))], expr_alias: [ , ] }
LogicalFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
LogicalScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate] }
batch_plan: |
BatchExchange { order: [$0 ASC], dist: Single }
BatchSort { order: [$0 ASC] }
BatchProject { exprs: [$0, $1, $2, $3, $4], expr_alias: [s_suppkey, s_name, s_address, s_phone, total_revenue] }
BatchHashJoin { type: Inner, predicate: $4 = $5 }
BatchProject { exprs: [$0, $1, $2, $3, $5], expr_alias: [ , , , , ] }
BatchExchange { order: [], dist: HashShard([5]) }
BatchHashJoin { type: Inner, predicate: $0 = $4 }
BatchExchange { order: [], dist: HashShard([0]) }
BatchScan { table: supplier, columns: [s_suppkey, s_name, s_address, s_phone] }
BatchProject { exprs: [$0, $1], expr_alias: [l_suppkey, total_revenue] }
BatchHashAgg { group_keys: [$0], aggs: [sum($1)] }
BatchProject { exprs: [$0, ($1 * (1:Int32 - $2))], expr_alias: [ , ] }
BatchExchange { order: [], dist: HashShard([0]) }
BatchFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
BatchScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate] }
BatchProject { exprs: [$0], expr_alias: [max_revenue] }
BatchExchange { order: [], dist: HashShard([0]) }
BatchSimpleAgg { aggs: [max($0)] }
BatchExchange { order: [], dist: Single }
BatchProject { exprs: [$1], expr_alias: [ ] }
BatchHashAgg { group_keys: [$0], aggs: [sum($1)] }
BatchProject { exprs: [$0, ($1 * (1:Int32 - $2))], expr_alias: [ , ] }
BatchExchange { order: [], dist: HashShard([0]) }
BatchFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
BatchScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate] }
stream_plan: |
StreamMaterialize { columns: [s_suppkey, s_name, s_address, s_phone, total_revenue, _row_id#0(hidden), l_suppkey(hidden), agg#0(hidden), max_revenue(hidden)], pk_columns: [_row_id#0, l_suppkey, agg#0, max_revenue], order_descs: [s_suppkey, _row_id#0, l_suppkey, agg#0, max_revenue] }
StreamExchange { dist: HashShard([5, 6, 7, 8]) }
StreamProject { exprs: [$0, $1, $2, $3, $4, $5, $6, $8, $7], expr_alias: [s_suppkey, s_name, s_address, s_phone, total_revenue, , , , ] }
StreamHashJoin { type: Inner, predicate: $4 = $7 }
StreamProject { exprs: [$0, $1, $2, $3, $6, $4, $5], expr_alias: [ , , , , , , ] }
StreamExchange { dist: HashShard([6]) }
StreamHashJoin { type: Inner, predicate: $0 = $5 }
StreamExchange { dist: HashShard([0]) }
StreamTableScan { table: supplier, columns: [s_suppkey, s_name, s_address, s_phone, _row_id#0], pk_indices: [4] }
StreamProject { exprs: [$0, $2], expr_alias: [l_suppkey, total_revenue] }
StreamHashAgg { group_keys: [$0], aggs: [count, sum($1)] }
StreamProject { exprs: [$0, ($1 * (1:Int32 - $2)), $4], expr_alias: [ , , ] }
StreamExchange { dist: HashShard([0]) }
StreamFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
StreamTableScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate, _row_id#0], pk_indices: [4] }
StreamProject { exprs: [$1, $0], expr_alias: [max_revenue, ] }
StreamExchange { dist: HashShard([1]) }
StreamSimpleAgg { aggs: [count, max($0)] }
StreamExchange { dist: Single }
StreamProject { exprs: [$2, $0], expr_alias: [ , ] }
StreamHashAgg { group_keys: [$0], aggs: [count, sum($1)] }
StreamProject { exprs: [$0, ($1 * (1:Int32 - $2)), $4], expr_alias: [ , , ] }
StreamExchange { dist: HashShard([0]) }
StreamFilter { predicate: ($3 >= '1993-01-01':Varchar::Date) AND ($3 < ('1993-01-01':Varchar::Date + '3 mons 00:00:00':Interval)) }
StreamTableScan { table: lineitem, columns: [l_suppkey, l_extendedprice, l_discount, l_shipdate, _row_id#0], pk_indices: [4] }
- id: tpch_q16
before:
- create_tables
Expand Down

0 comments on commit afffdec

Please sign in to comment.