Skip to content

Commit

Permalink
fix(optimizer): ApplyAggTransposeRule should handle `CorrelatedInpu…
Browse files Browse the repository at this point in the history
…tRef` in agg filter (#8650)
  • Loading branch information
xiangjinwu authored Mar 20, 2023
1 parent 32f4925 commit eddb2fc
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,22 @@
├─LogicalAgg { group_key: [strings.v1], aggs: [] }
| └─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalScan { table: strings, columns: [strings.v1] }
- name: issue 4762 correlated input in agg filter
sql: |
CREATE TABLE strings(v1 VARCHAR);
SELECT (SELECT STRING_AGG(v1, ',') FILTER (WHERE v1 < t.v1) FROM strings) FROM strings AS t;
optimized_logical_plan_for_batch: |
LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [string_agg(strings.v1, ',':Varchar) filter((strings.v1 < strings.v1))] }
├─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalAgg { group_key: [strings.v1], aggs: [string_agg(strings.v1, ',':Varchar) filter((strings.v1 < strings.v1))] }
└─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [strings.v1, strings.v1, ',':Varchar] }
├─LogicalAgg { group_key: [strings.v1], aggs: [] }
| └─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalProject { exprs: [strings.v1, strings.v1, ',':Varchar] }
└─LogicalJoin { type: Inner, on: true, output: all }
├─LogicalAgg { group_key: [strings.v1], aggs: [] }
| └─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalScan { table: strings, columns: [strings.v1] }
- name: Existential join on outer join with correlated condition
sql: |
create table t1(x int, y int);
Expand Down
14 changes: 7 additions & 7 deletions src/frontend/src/optimizer/rule/apply_agg_transpose_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ use risingwave_common::types::DataType;
use risingwave_expr::expr::AggKind;
use risingwave_pb::plan_common::JoinType;

use super::{BoxedRule, Rule};
use super::{ApplyOffsetRewriter, BoxedRule, Rule};
use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef};
use crate::optimizer::plan_node::{LogicalAgg, LogicalApply, LogicalFilter, LogicalProject};
use crate::optimizer::PlanRef;
use crate::utils::{ColIndexMapping, Condition};
use crate::utils::Condition;

/// Transpose `LogicalApply` and `LogicalAgg`.
///
Expand Down Expand Up @@ -53,7 +53,6 @@ impl Rule for ApplyAggTransposeRule {
let agg: &LogicalAgg = right.as_logical_agg()?;
let (mut agg_calls, agg_group_key, agg_input) = agg.clone().decompose();
let is_scalar_agg = agg_group_key.is_empty();
let agg_input_len = agg_input.schema().len();
let apply_left_len = left.schema().len();

if !is_scalar_agg && max_one_row {
Expand Down Expand Up @@ -102,7 +101,7 @@ impl Rule for ApplyAggTransposeRule {
JoinType::LeftOuter,
Condition::true_cond(),
correlated_id,
correlated_indices,
correlated_indices.clone(),
false,
)
.translate_apply(left, eq_predicates)
Expand All @@ -113,7 +112,7 @@ impl Rule for ApplyAggTransposeRule {
JoinType::Inner,
Condition::true_cond(),
correlated_id,
correlated_indices,
correlated_indices.clone(),
false,
)
.into()
Expand All @@ -122,7 +121,8 @@ impl Rule for ApplyAggTransposeRule {
let group_agg = {
// shift index of agg_calls' `InputRef` with `apply_left_len`.
let offset = apply_left_len as isize;
let mut shift_index = ColIndexMapping::with_shift_offset(agg_input_len, offset);
let mut rewriter =
ApplyOffsetRewriter::new(apply_left_len, &correlated_indices, correlated_id);
agg_calls.iter_mut().for_each(|agg_call| {
agg_call.inputs.iter_mut().for_each(|input_ref| {
input_ref.shift_with_offset(offset);
Expand All @@ -131,7 +131,7 @@ impl Rule for ApplyAggTransposeRule {
.order_by
.iter_mut()
.for_each(|o| o.shift_with_offset(offset));
agg_call.filter = agg_call.filter.clone().rewrite_expr(&mut shift_index);
agg_call.filter = agg_call.filter.clone().rewrite_expr(&mut rewriter);
});
if is_scalar_agg {
// convert count(*) to count(1).
Expand Down
55 changes: 7 additions & 48 deletions src/frontend/src/optimizer/rule/apply_filter_transpose_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
use itertools::{Either, Itertools};
use risingwave_pb::plan_common::JoinType;

use super::{BoxedRule, Rule};
use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
use super::{ApplyOffsetRewriter, BoxedRule, Rule};
use crate::expr::ExprRewriter;
use crate::optimizer::plan_node::{LogicalApply, LogicalFilter, PlanTreeNodeUnary};
use crate::optimizer::PlanRef;
use crate::utils::{ColIndexMapping, Condition};
use crate::utils::Condition;

/// Transpose `LogicalApply` and `LogicalFilter`.
///
Expand Down Expand Up @@ -57,19 +57,8 @@ impl Rule for ApplyFilterTransposeRule {
let filter = right.as_logical_filter()?;
let input = filter.input();

let mut rewriter = Rewriter {
offset: left.schema().len(),
index_mapping: ColIndexMapping::new(
correlated_indices
.clone()
.into_iter()
.map(Some)
.collect_vec(),
)
.inverse(),
has_correlated_input_ref: false,
correlated_id,
};
let mut rewriter =
ApplyOffsetRewriter::new(left.schema().len(), &correlated_indices, correlated_id);
// Split predicates in LogicalFilter into correlated expressions and uncorrelated
// expressions.
let (cor_exprs, uncor_exprs) =
Expand All @@ -79,8 +68,8 @@ impl Rule for ApplyFilterTransposeRule {
.into_iter()
.partition_map(|expr| {
let expr = rewriter.rewrite_expr(expr);
if rewriter.has_correlated_input_ref {
rewriter.has_correlated_input_ref = false;
if rewriter.has_correlated_input_ref() {
rewriter.reset_state();
Either::Left(expr)
} else {
Either::Right(expr)
Expand Down Expand Up @@ -115,33 +104,3 @@ impl ApplyFilterTransposeRule {
Box::new(ApplyFilterTransposeRule {})
}
}

/// Convert `CorrelatedInputRef` to `InputRef` and shift `InputRef` with offset.
struct Rewriter {
offset: usize,
index_mapping: ColIndexMapping,
has_correlated_input_ref: bool,
correlated_id: CorrelatedId,
}
impl ExprRewriter for Rewriter {
fn rewrite_correlated_input_ref(
&mut self,
correlated_input_ref: CorrelatedInputRef,
) -> ExprImpl {
let found = correlated_input_ref.correlated_id() == self.correlated_id;
self.has_correlated_input_ref |= found;
if found {
InputRef::new(
self.index_mapping.map(correlated_input_ref.index()),
correlated_input_ref.return_type(),
)
.into()
} else {
correlated_input_ref.into()
}
}

fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
}
}
71 changes: 71 additions & 0 deletions src/frontend/src/optimizer/rule/apply_offset_rewriter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright 2023 RisingWave Labs
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use itertools::Itertools;

use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
use crate::utils::ColIndexMapping;

/// Convert `CorrelatedInputRef` to `InputRef` and shift `InputRef` with offset.
pub struct ApplyOffsetRewriter {
offset: usize,
index_mapping: ColIndexMapping,
has_correlated_input_ref: bool,
correlated_id: CorrelatedId,
}

impl ExprRewriter for ApplyOffsetRewriter {
fn rewrite_correlated_input_ref(
&mut self,
correlated_input_ref: CorrelatedInputRef,
) -> ExprImpl {
let found = correlated_input_ref.correlated_id() == self.correlated_id;
self.has_correlated_input_ref |= found;
if found {
InputRef::new(
self.index_mapping.map(correlated_input_ref.index()),
correlated_input_ref.return_type(),
)
.into()
} else {
correlated_input_ref.into()
}
}

fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
}
}

impl ApplyOffsetRewriter {
pub fn new(offset: usize, correlated_indices: &[usize], correlated_id: CorrelatedId) -> Self {
Self {
offset,
index_mapping: ColIndexMapping::new(
correlated_indices.iter().copied().map(Some).collect_vec(),
)
.inverse(),
has_correlated_input_ref: false,
correlated_id,
}
}

pub fn has_correlated_input_ref(&self) -> bool {
self.has_correlated_input_ref
}

pub fn reset_state(&mut self) {
self.has_correlated_input_ref = false;
}
}
46 changes: 4 additions & 42 deletions src/frontend/src/optimizer/rule/apply_project_transpose_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
use itertools::Itertools;
use risingwave_pb::plan_common::JoinType;

use super::{BoxedRule, Rule};
use crate::expr::{CorrelatedId, CorrelatedInputRef, Expr, ExprImpl, ExprRewriter, InputRef};
use super::{ApplyOffsetRewriter, BoxedRule, Rule};
use crate::expr::{ExprImpl, ExprRewriter, InputRef};
use crate::optimizer::plan_node::{LogicalApply, LogicalProject};
use crate::optimizer::PlanRef;
use crate::utils::ColIndexMapping;

/// Transpose `LogicalApply` and `LogicalProject`.
///
Expand Down Expand Up @@ -64,18 +63,8 @@ impl Rule for ApplyProjectTransposeRule {
let (proj_exprs, proj_input) = project.clone().decompose();

// replace correlated_input_ref in project exprs
let mut rewriter = Rewriter {
offset: left.schema().len(),
index_mapping: ColIndexMapping::new(
correlated_indices
.clone()
.into_iter()
.map(Some)
.collect_vec(),
)
.inverse(),
correlated_id,
};
let mut rewriter =
ApplyOffsetRewriter::new(left.schema().len(), &correlated_indices, correlated_id);

let new_proj_exprs: Vec<ExprImpl> = proj_exprs
.into_iter()
Expand Down Expand Up @@ -124,30 +113,3 @@ impl ExprRewriter for ApplyOnConditionRewriter {
}
}
}

/// Convert `CorrelatedInputRef` to `InputRef` and shift `InputRef` with offset.
struct Rewriter {
offset: usize,
index_mapping: ColIndexMapping,
correlated_id: CorrelatedId,
}
impl ExprRewriter for Rewriter {
fn rewrite_correlated_input_ref(
&mut self,
correlated_input_ref: CorrelatedInputRef,
) -> ExprImpl {
if correlated_input_ref.correlated_id() == self.correlated_id {
InputRef::new(
self.index_mapping.map(correlated_input_ref.index()),
correlated_input_ref.return_type(),
)
.into()
} else {
correlated_input_ref.into()
}
}

fn rewrite_input_ref(&mut self, input_ref: InputRef) -> ExprImpl {
InputRef::new(input_ref.index() + self.offset, input_ref.return_type()).into()
}
}
3 changes: 3 additions & 0 deletions src/frontend/src/optimizer/rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ pub use avoid_exchange_share_rule::*;
mod min_max_on_index_rule;
pub use min_max_on_index_rule::*;

mod apply_offset_rewriter;
use apply_offset_rewriter::ApplyOffsetRewriter;

#[macro_export]
macro_rules! for_all_rules {
($macro:ident) => {
Expand Down

0 comments on commit eddb2fc

Please sign in to comment.