From 43dc7b32b8db7adf55b8fab29b123cbdcb946dc9 Mon Sep 17 00:00:00 2001 From: stonepage <40830455+st1page@users.noreply.github.com> Date: Wed, 30 Mar 2022 20:43:43 +0800 Subject: [PATCH] feat(optimizer): filter elim (#1430) * add filter if need * filter only if need * fix * fix * rename * clippy fix --- .../src/optimizer/plan_node/logical_filter.rs | 14 +++++++++++--- rust/frontend/src/optimizer/rule/filter_agg.rs | 11 ++++------- rust/frontend/src/optimizer/rule/filter_join.rs | 10 +++------- rust/frontend/src/optimizer/rule/filter_project.rs | 4 ++-- rust/frontend/src/planner/delete.rs | 2 +- rust/frontend/src/planner/select.rs | 7 +++---- .../test_runner/tests/testdata/column_pruning.yaml | 3 +-- rust/frontend/test_runner/tests/testdata/tpch.yaml | 12 ++++-------- 8 files changed, 29 insertions(+), 34 deletions(-) diff --git a/rust/frontend/src/optimizer/plan_node/logical_filter.rs b/rust/frontend/src/optimizer/plan_node/logical_filter.rs index 1c429db902daf..ff1823a5dfe90 100644 --- a/rust/frontend/src/optimizer/plan_node/logical_filter.rs +++ b/rust/frontend/src/optimizer/plan_node/logical_filter.rs @@ -15,7 +15,6 @@ use std::fmt; use fixedbitset::FixedBitSet; -use risingwave_common::error::Result; use super::{ ColPrunable, CollectInputRef, LogicalProject, PlanBase, PlanNode, PlanRef, PlanTreeNodeUnary, @@ -52,10 +51,19 @@ impl LogicalFilter { } } + /// Create a `LogicalFilter` unless the predicate is always true + pub fn create(input: PlanRef, predicate: Condition) -> PlanRef { + if predicate.always_true() { + input + } else { + LogicalFilter::new(input, predicate).into() + } + } + /// the function will check if the predicate is bool expression - pub fn create(input: PlanRef, predicate: ExprImpl) -> Result { + pub fn create_with_expr(input: PlanRef, predicate: ExprImpl) -> PlanRef { let predicate = Condition::with_expr(predicate); - Ok(Self::new(input, predicate).into()) + Self::new(input, predicate).into() } /// Get the predicate of the logical join. diff --git a/rust/frontend/src/optimizer/rule/filter_agg.rs b/rust/frontend/src/optimizer/rule/filter_agg.rs index 86af3c6974329..2ab79ff969f0c 100644 --- a/rust/frontend/src/optimizer/rule/filter_agg.rs +++ b/rust/frontend/src/optimizer/rule/filter_agg.rs @@ -53,13 +53,10 @@ impl Rule for FilterAggRule { let pushed_predicate = pushed_predicate.rewrite_expr(&mut subst); let input = agg.input(); - let pushed_filter = LogicalFilter::new(input, pushed_predicate); - let new_agg = agg.clone_with_input(pushed_filter.into()).into(); - if agg_call_pred.always_true() { - Some(new_agg) - } else { - Some(LogicalFilter::new(new_agg, agg_call_pred).into()) - } + let pushed_filter = LogicalFilter::create(input, pushed_predicate); + let new_agg = agg.clone_with_input(pushed_filter).into(); + + Some(LogicalFilter::create(new_agg, agg_call_pred)) } } diff --git a/rust/frontend/src/optimizer/rule/filter_join.rs b/rust/frontend/src/optimizer/rule/filter_join.rs index efd6efa11e971..73e351ba205f1 100644 --- a/rust/frontend/src/optimizer/rule/filter_join.rs +++ b/rust/frontend/src/optimizer/rule/filter_join.rs @@ -84,22 +84,18 @@ impl Rule for FilterJoinRule { let right_predicate = right_from_filter.and_then(|c1| right_from_on.map(|c2| c1.and(c2))); let new_left: PlanRef = if let Some(predicate) = left_predicate { - LogicalFilter::new(join.left(), predicate).into() + LogicalFilter::create(join.left(), predicate) } else { join.left() }; let new_right: PlanRef = if let Some(predicate) = right_predicate { - LogicalFilter::new(join.right(), predicate).into() + LogicalFilter::create(join.right(), predicate) } else { join.right() }; let new_join = LogicalJoin::new(new_left, new_right, join_type, new_on); - if new_filter_predicate.always_true() { - Some(new_join.into()) - } else { - Some(LogicalFilter::new(new_join.into(), new_filter_predicate).into()) - } + Some(LogicalFilter::create(new_join.into(), new_filter_predicate)) } } diff --git a/rust/frontend/src/optimizer/rule/filter_project.rs b/rust/frontend/src/optimizer/rule/filter_project.rs index 2a5d2f63f1989..691349495a52c 100644 --- a/rust/frontend/src/optimizer/rule/filter_project.rs +++ b/rust/frontend/src/optimizer/rule/filter_project.rs @@ -31,8 +31,8 @@ impl Rule for FilterProjectRule { let predicate = filter.predicate().clone().rewrite_expr(&mut subst); let input = project.input(); - let pushed_filter = LogicalFilter::new(input, predicate); - Some(project.clone_with_input(pushed_filter.into()).into()) + let pushed_filter = LogicalFilter::create(input, predicate); + Some(project.clone_with_input(pushed_filter).into()) } } diff --git a/rust/frontend/src/planner/delete.rs b/rust/frontend/src/planner/delete.rs index 6854ccf86f2c6..1d0b3558f167a 100644 --- a/rust/frontend/src/planner/delete.rs +++ b/rust/frontend/src/planner/delete.rs @@ -27,7 +27,7 @@ impl Planner { let source_id = delete.table_source.source_id; let scan = self.plan_base_table(delete.table)?; let input = if let Some(expr) = delete.selection { - LogicalFilter::create(scan, expr)? + LogicalFilter::create_with_expr(scan, expr) } else { scan }; diff --git a/rust/frontend/src/planner/select.rs b/rust/frontend/src/planner/select.rs index aca171aae5fde..a0742c4418a57 100644 --- a/rust/frontend/src/planner/select.rs +++ b/rust/frontend/src/planner/select.rs @@ -95,7 +95,7 @@ impl Planner { /// [`LogicalJoin`] using [`substitute_subqueries`]. fn plan_where(&mut self, mut input: PlanRef, where_clause: ExprImpl) -> Result { if !where_clause.has_subquery() { - return LogicalFilter::create(input, where_clause); + return Ok(LogicalFilter::create_with_expr(input, where_clause)); } let (subquery_conjunctions, not_subquery_conjunctions, others) = @@ -154,13 +154,12 @@ impl Planner { Ok(input) } else { let (input, others) = self.substitute_subqueries(input, others.conjunctions)?; - Ok(LogicalFilter::new( + Ok(LogicalFilter::create( input, Condition { conjunctions: others, }, - ) - .into()) + )) } } diff --git a/rust/frontend/test_runner/tests/testdata/column_pruning.yaml b/rust/frontend/test_runner/tests/testdata/column_pruning.yaml index 825644b83788f..46184f82bf778 100644 --- a/rust/frontend/test_runner/tests/testdata/column_pruning.yaml +++ b/rust/frontend/test_runner/tests/testdata/column_pruning.yaml @@ -115,8 +115,7 @@ LogicalProject { exprs: [$0, $1], expr_alias: [ , ] } LogicalFilter { predicate: ($2 < 1:Int32) } LogicalScan { table: t1, columns: [v1, v2, v3] } - LogicalFilter { predicate: } - LogicalScan { table: t2, columns: [v1, v2] } + LogicalScan { table: t2, columns: [v1, v2] } - sql: | /* mixed */ create table t (v1 bigint, v2 double precision, v3 int); diff --git a/rust/frontend/test_runner/tests/testdata/tpch.yaml b/rust/frontend/test_runner/tests/testdata/tpch.yaml index bf46ae911f4e3..a3ab7f6f66170 100644 --- a/rust/frontend/test_runner/tests/testdata/tpch.yaml +++ b/rust/frontend/test_runner/tests/testdata/tpch.yaml @@ -125,11 +125,9 @@ BatchFilter { predicate: ($1 = "FURNITURE":Varchar) AND true:Boolean AND true:Boolean } BatchScan { table: customer, columns: [c_custkey, c_mktsegment] } BatchExchange { order: [], dist: HashShard([1]) } - BatchFilter { predicate: } - BatchScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority] } + BatchScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority] } BatchExchange { order: [], dist: HashShard([0]) } - BatchFilter { predicate: } - BatchScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount] } + BatchScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount] } stream_plan: | StreamMaterialize { columns: [l_orderkey, revenue, o_orderdate, o_shippriority], pk_columns: [revenue, o_orderdate, l_orderkey, o_shippriority] } StreamProject { exprs: [$0, $4, $1, $2], expr_alias: [l_orderkey, revenue, o_orderdate, o_shippriority] } @@ -145,11 +143,9 @@ StreamFilter { predicate: ($1 = "FURNITURE":Varchar) AND true:Boolean AND true:Boolean } StreamTableScan { table: customer, columns: [c_custkey, c_mktsegment, _row_id#0], pk_indices: [2] } StreamExchange { dist: HashShard([1]) } - StreamFilter { predicate: } - StreamTableScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority, _row_id#0], pk_indices: [4] } + StreamTableScan { table: orders, columns: [o_orderkey, o_custkey, o_orderdate, o_shippriority, _row_id#0], pk_indices: [4] } StreamExchange { dist: HashShard([0]) } - StreamFilter { predicate: } - StreamTableScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount, _row_id#0], pk_indices: [3] } + StreamTableScan { table: lineitem, columns: [l_orderkey, l_extendedprice, l_discount, _row_id#0], pk_indices: [3] } - id: tpch_q6 before: - create_tables