From a93418501fc269a2400ccff18c6290c6c032d2c4 Mon Sep 17 00:00:00 2001 From: Dylan Date: Thu, 14 Sep 2023 17:31:04 +0800 Subject: [PATCH] fix(optimizer): relax scan predicate pull up mapping inverse restriction (#12308) --- .../testdata/input/lateral_subquery.yaml | 12 +++++++ .../testdata/output/lateral_subquery.yaml | 35 +++++++++++++++++++ .../src/optimizer/plan_node/logical_scan.rs | 21 +++++++---- 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/frontend/planner_test/tests/testdata/input/lateral_subquery.yaml b/src/frontend/planner_test/tests/testdata/input/lateral_subquery.yaml index e1215c828cc7f..d17cb92dc8577 100644 --- a/src/frontend/planner_test/tests/testdata/input/lateral_subquery.yaml +++ b/src/frontend/planner_test/tests/testdata/input/lateral_subquery.yaml @@ -92,3 +92,15 @@ expected_outputs: - batch_plan - stream_plan +- name: https://github.com/risingwavelabs/risingwave/issues/12298 + sql: | + create table t1(c varchar, n varchar, id varchar, d varchar); + create table t2(c varchar, p varchar, id varchar, d varchar); + select array_agg(t1.n order by path_idx) from t1 + join t2 + on t1.c = 'abc' + and t2.c = 'abc' + cross join unnest((case when t2.p <> '' then (string_to_array(trim(t2.p, ','), ',') || t2.d) else ARRAY[t2.d] end)) WITH ORDINALITY AS path_cols(path_val, path_idx) + where path_val = t1.id; + expected_outputs: + - stream_plan diff --git a/src/frontend/planner_test/tests/testdata/output/lateral_subquery.yaml b/src/frontend/planner_test/tests/testdata/output/lateral_subquery.yaml index b67362ea5da4f..2f72bc6d4f4a2 100644 --- a/src/frontend/planner_test/tests/testdata/output/lateral_subquery.yaml +++ b/src/frontend/planner_test/tests/testdata/output/lateral_subquery.yaml @@ -173,3 +173,38 @@ └─StreamHashAgg { group_key: [t.arr], aggs: [count] } └─StreamExchange { dist: HashShard(t.arr) } └─StreamTableScan { table: t, columns: [t.arr, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) } +- name: https://github.com/risingwavelabs/risingwave/issues/12298 + sql: | + create table t1(c varchar, n varchar, id varchar, d varchar); + create table t2(c varchar, p varchar, id varchar, d varchar); + select array_agg(t1.n order by path_idx) from t1 + join t2 + on t1.c = 'abc' + and t2.c = 'abc' + cross join unnest((case when t2.p <> '' then (string_to_array(trim(t2.p, ','), ',') || t2.d) else ARRAY[t2.d] end)) WITH ORDINALITY AS path_cols(path_val, path_idx) + where path_val = t1.id; + stream_plan: |- + StreamMaterialize { columns: [array_agg], stream_key: [], pk_columns: [], pk_conflict: NoCheck } + └─StreamProject { exprs: [array_agg(t1.n order_by($expr1 ASC))] } + └─StreamSimpleAgg { aggs: [array_agg(t1.n order_by($expr1 ASC)), count] } + └─StreamExchange { dist: Single } + └─StreamProject { exprs: [t1.n, (projected_row_id + 1:Int64) as $expr1, t1._row_id, t2.p, t2.p, t2.d, t2.d, projected_row_id, t1.id, t2._row_id] } + └─StreamHashJoin { type: Inner, predicate: t2.p IS NOT DISTINCT FROM t2.p AND t2.p IS NOT DISTINCT FROM t2.p AND t2.d IS NOT DISTINCT FROM t2.d AND t2.d IS NOT DISTINCT FROM t2.d, output: [t1.n, t1.id, projected_row_id, t2.p, t2.p, t2.d, t2.d, Unnest(Case(($1 <> '':Varchar), ArrayAppend(StringToArray(Trim($1, ',':Varchar), ',':Varchar), $3), Array($3))), t2.p, t2.d, t1._row_id, t2._row_id] } + ├─StreamExchange { dist: HashShard(t2.p, t2.d) } + │ └─StreamHashJoin { type: Inner, predicate: t1.id = Unnest(Case(($1 <> '':Varchar), ArrayAppend(StringToArray(Trim($1, ',':Varchar), ',':Varchar), $3), Array($3))), output: [t1.n, t1.id, projected_row_id, t2.p, t2.p, t2.d, t2.d, Unnest(Case(($1 <> '':Varchar), ArrayAppend(StringToArray(Trim($1, ',':Varchar), ',':Varchar), $3), Array($3))), t1._row_id] } + │ ├─StreamExchange { dist: HashShard(t1.id) } + │ │ └─StreamProject { exprs: [t1.n, t1.id, t1._row_id] } + │ │ └─StreamFilter { predicate: (t1.c = 'abc':Varchar) } + │ │ └─StreamTableScan { table: t1, columns: [t1.n, t1.id, t1._row_id, t1.c], pk: [t1._row_id], dist: UpstreamHashShard(t1._row_id) } + │ └─StreamExchange { dist: HashShard(Unnest(Case(($1 <> '':Varchar), ArrayAppend(StringToArray(Trim($1, ',':Varchar), ',':Varchar), $3), Array($3)))) } + │ └─StreamProjectSet { select_list: [$0, $1, $2, $3, Unnest(Case(($1 <> '':Varchar), ArrayAppend(StringToArray(Trim($1, ',':Varchar), ',':Varchar), $3), Array($3)))] } + │ └─StreamProject { exprs: [t2.p, t2.p, t2.d, t2.d] } + │ └─StreamHashAgg { group_key: [t2.p, t2.p, t2.d, t2.d], aggs: [count] } + │ └─StreamExchange { dist: HashShard(t2.p, t2.p, t2.d, t2.d) } + │ └─StreamProject { exprs: [t2.p, t2.p, t2.d, t2.d, t2._row_id] } + │ └─StreamFilter { predicate: (t2.c = 'abc':Varchar) } + │ └─StreamTableScan { table: t2, columns: [t2.p, t2.p, t2.d, t2.d, t2._row_id, t2.c], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } + └─StreamExchange { dist: HashShard(t2.p, t2.d) } + └─StreamProject { exprs: [t2.p, t2.d, t2._row_id] } + └─StreamFilter { predicate: (t2.c = 'abc':Varchar) } + └─StreamTableScan { table: t2, columns: [t2.p, t2.d, t2._row_id, t2.c], pk: [t2._row_id], dist: UpstreamHashShard(t2._row_id) } diff --git a/src/frontend/src/optimizer/plan_node/logical_scan.rs b/src/frontend/src/optimizer/plan_node/logical_scan.rs index d7574abed7b29..e671f7412c661 100644 --- a/src/frontend/src/optimizer/plan_node/logical_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_scan.rs @@ -232,13 +232,20 @@ impl LogicalScan { return (self.core.clone(), Condition::true_cond(), None); } - let mut mapping = ColIndexMapping::with_target_size( - self.required_col_idx().iter().map(|i| Some(*i)).collect(), - self.table_desc().columns.len(), - ) - .inverse() - .expect("must be invertible"); - predicate = predicate.rewrite_expr(&mut mapping); + let mut inverse_mapping = { + let mapping = ColIndexMapping::with_target_size( + self.required_col_idx().iter().map(|i| Some(*i)).collect(), + self.table_desc().columns.len(), + ); + // Since `required_col_idx` mapping is not invertible, we need to inverse manually. + let mut inverse_map = vec![None; mapping.target_size()]; + for (src, dst) in mapping.mapping_pairs() { + inverse_map[dst] = Some(src); + } + ColIndexMapping::with_target_size(inverse_map, mapping.source_size()) + }; + + predicate = predicate.rewrite_expr(&mut inverse_mapping); let scan_without_predicate = generic::Scan::new( self.table_name().to_string(),