From 30794414c25fbdf358d0e2c31b7cbeb61e71be2d Mon Sep 17 00:00:00 2001 From: xiangjinwu <17769960+xiangjinwu@users.noreply.github.com> Date: Tue, 21 Mar 2023 13:16:49 +0800 Subject: [PATCH] fix(optimizer): `PlanCorrelatedIdFinder` should be aware of agg filter (#8667) --- .../testdata/subquery_expr_correlated.yaml | 17 ++++++++++++++++ .../plan_visitor/plan_correlated_id_finder.rs | 20 ++++++++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/subquery_expr_correlated.yaml b/src/frontend/planner_test/tests/testdata/subquery_expr_correlated.yaml index 7de33fc12a1a..22386033deee 100644 --- a/src/frontend/planner_test/tests/testdata/subquery_expr_correlated.yaml +++ b/src/frontend/planner_test/tests/testdata/subquery_expr_correlated.yaml @@ -755,6 +755,23 @@ ├─LogicalAgg { group_key: [strings.v1], aggs: [] } | └─LogicalScan { table: strings, columns: [strings.v1] } └─LogicalScan { table: strings, columns: [strings.v1] } +- name: issue 7574 correlated input in agg filter in having + sql: | + CREATE TABLE strings(v1 VARCHAR); + SELECT (SELECT 1 FROM strings HAVING COUNT(v1) FILTER (WHERE t.v1 < 'b') > 2) FROM strings AS t; + optimized_logical_plan_for_batch: | + LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [1:Int32] } + ├─LogicalScan { table: strings, columns: [strings.v1] } + └─LogicalProject { exprs: [strings.v1, 1:Int32] } + └─LogicalFilter { predicate: (count(strings.v1) filter((strings.v1 < 'b':Varchar)) > 2:Int32) } + └─LogicalAgg { group_key: [strings.v1], aggs: [count(strings.v1) filter((strings.v1 < 'b':Varchar))] } + └─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [strings.v1, strings.v1] } + ├─LogicalAgg { group_key: [strings.v1], aggs: [] } + | └─LogicalScan { table: strings, columns: [strings.v1] } + └─LogicalJoin { type: Inner, on: true, output: all } + ├─LogicalAgg { group_key: [strings.v1], aggs: [] } + | └─LogicalScan { table: strings, columns: [strings.v1] } + └─LogicalScan { table: strings, columns: [strings.v1] } - name: Existential join on outer join with correlated condition sql: | create table t1(x int, y int); diff --git a/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs b/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs index 3929da18900d..5ed18cccedd7 100644 --- a/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs +++ b/src/frontend/src/optimizer/plan_visitor/plan_correlated_id_finder.rs @@ -15,7 +15,9 @@ use std::collections::HashSet; use crate::expr::{CorrelatedId, CorrelatedInputRef, ExprVisitor}; -use crate::optimizer::plan_node::{LogicalFilter, LogicalJoin, LogicalProject, PlanTreeNode}; +use crate::optimizer::plan_node::{ + LogicalAgg, LogicalFilter, LogicalJoin, LogicalProject, PlanTreeNode, +}; use crate::optimizer::plan_visitor::PlanVisitor; use crate::PlanRef; @@ -37,8 +39,8 @@ impl PlanCorrelatedIdFinder { } impl PlanVisitor<()> for PlanCorrelatedIdFinder { - /// `correlated_input_ref` can only appear in `LogicalProject`, `LogicalFilter` and - /// `LogicalJoin` now. + /// `correlated_input_ref` can only appear in `LogicalProject`, `LogicalFilter`, + /// `LogicalJoin` or the `filter` clause of `PlanAggCall` of `LogicalAgg` now. fn merge(_: (), _: ()) {} @@ -71,6 +73,18 @@ impl PlanVisitor<()> for PlanCorrelatedIdFinder { .into_iter() .for_each(|input| self.visit(input)); } + + fn visit_logical_agg(&mut self, plan: &LogicalAgg) { + let mut finder = ExprCorrelatedIdFinder::default(); + plan.agg_calls() + .iter() + .for_each(|agg_call| agg_call.filter.visit_expr(&mut finder)); + self.correlated_id_set.extend(finder.correlated_id_set); + + plan.inputs() + .into_iter() + .for_each(|input| self.visit(input)); + } } #[derive(Default)]