Skip to content

Commit

Permalink
feat(optimizer): Add StreamProjectMergeRule (risingwavelabs#8753)
Browse files Browse the repository at this point in the history
Co-authored-by: lmatz <lmatz823@gmail.com>
  • Loading branch information
kwannoel and lmatz authored Mar 24, 2023
1 parent 380e104 commit 98ea76a
Show file tree
Hide file tree
Showing 10 changed files with 405 additions and 286 deletions.
41 changes: 38 additions & 3 deletions src/frontend/planner_test/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
mod resolve_id;

use std::collections::BTreeMap;
use std::collections::{BTreeMap, HashMap};
use std::path::Path;
use std::sync::Arc;

Expand All @@ -33,7 +33,7 @@ use risingwave_frontend::session::SessionImpl;
use risingwave_frontend::test_utils::{create_proto_file, get_explain_output, LocalFrontend};
use risingwave_frontend::{
build_graph, explain_stream_graph, Binder, Explain, FrontendOpts, OptimizerContext,
OptimizerContextRef, PlanRef, Planner,
OptimizerContextRef, PlanRef, Planner, WithOptions,
};
use risingwave_sqlparser::ast::{ExplainOptions, ObjectName, Statement};
use risingwave_sqlparser::parser::Parser;
Expand Down Expand Up @@ -83,6 +83,10 @@ pub struct TestCase {
/// Batch plan for local execution `.gen_batch_local_plan()`
pub batch_local_plan: Option<String>,

/// Create sink plan (assumes blackhole sink)
/// TODO: Other sinks
pub sink_plan: Option<String>,

/// Create MV plan `.gen_create_mv_plan()`
pub stream_plan: Option<String>,

Expand Down Expand Up @@ -152,6 +156,9 @@ pub struct TestCaseResult {
/// Batch plan for local execution `.gen_batch_local_plan()`
pub batch_local_plan: Option<String>,

/// Generate sink plan
pub sink_plan: Option<String>,

/// Create MV plan `.gen_create_mv_plan()`
pub stream_plan: Option<String>,

Expand All @@ -176,6 +183,9 @@ pub struct TestCaseResult {
/// Error of `.gen_stream_plan()`
pub stream_error: Option<String>,

/// Error of `.gen_sink_plan()`
pub sink_error: Option<String>,

/// The result of an `EXPLAIN` statement.
///
/// This field is used when `sql` is an `EXPLAIN` statement.
Expand Down Expand Up @@ -209,6 +219,7 @@ impl TestCaseResult {
batch_plan: self.batch_plan,
batch_local_plan: self.batch_local_plan,
stream_plan: self.stream_plan,
sink_plan: self.sink_plan,
batch_plan_proto: self.batch_plan_proto,
planner_error: self.planner_error,
optimizer_error: self.optimizer_error,
Expand Down Expand Up @@ -640,6 +651,30 @@ impl TestCase {
}
}

'sink: {
if self.sink_plan.is_some() {
let sink_name = "sink_test";
let mut options = HashMap::new();
options.insert("connector".to_string(), "blackhole".to_string());
options.insert("type".to_string(), "append-only".to_string());
let options = WithOptions::new(options);
match logical_plan.gen_sink_plan(
sink_name.to_string(),
format!("CREATE SINK {sink_name} AS {}", stmt),
options,
) {
Ok(sink_plan) => {
ret.sink_plan = Some(explain_plan(&sink_plan.into()));
break 'sink;
}
Err(err) => {
ret.sink_error = Some(err.to_string());
break 'sink;
}
}
}
}

Ok(ret)
}
}
Expand Down Expand Up @@ -696,7 +731,7 @@ fn check_result(expected: &TestCase, actual: &TestCaseResult) -> Result<()> {
&expected.explain_output,
&actual.explain_output,
)?;

check_option_plan_eq("sink_plan", &expected.sink_plan, &actual.sink_plan)?;
Ok(())
}

Expand Down
25 changes: 11 additions & 14 deletions src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,8 @@
└─StreamGlobalSimpleAgg { aggs: [max(max($expr1) filter((t.a < t.b) AND ((t.a + t.b) < 100:Int32) AND ((t.a * t.b) <> ((t.a + t.b) - 1:Int32)))), count] }
└─StreamExchange { dist: Single }
└─StreamHashAgg { group_key: [$expr2], aggs: [max($expr1) filter((t.a < t.b) AND ((t.a + t.b) < 100:Int32) AND ((t.a * t.b) <> ((t.a + t.b) - 1:Int32))), count] }
└─StreamProject { exprs: [t.a, t.b, $expr1, t._row_id, Vnode(t._row_id) as $expr2] }
└─StreamProject { exprs: [t.a, t.b, (t.a * t.b) as $expr1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamProject { exprs: [t.a, t.b, (t.a * t.b) as $expr1, t._row_id, Vnode(t._row_id) as $expr2] }
└─StreamTableScan { table: t, columns: [t.a, t.b, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: avg filter clause + group by
sql: |
create table t(a int, b int);
Expand Down Expand Up @@ -1139,9 +1138,8 @@
└─StreamExchange { dist: HashShard(lineitem.l_commitdate) }
└─StreamHashAgg { group_key: [lineitem.l_commitdate, $expr1], aggs: [max(lineitem.l_commitdate), count] }
└─StreamProject { exprs: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate, Vnode(lineitem.l_orderkey) as $expr1] }
└─StreamProject { exprs: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate] }
└─StreamHashAgg { group_key: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate], aggs: [count] }
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_tax, lineitem.l_commitdate, lineitem.l_shipinstruct], pk: [lineitem.l_orderkey], dist: UpstreamHashShard(lineitem.l_orderkey) }
└─StreamHashAgg { group_key: [lineitem.l_tax, lineitem.l_shipinstruct, lineitem.l_orderkey, lineitem.l_commitdate], aggs: [count] }
└─StreamTableScan { table: lineitem, columns: [lineitem.l_orderkey, lineitem.l_tax, lineitem.l_commitdate, lineitem.l_shipinstruct], pk: [lineitem.l_orderkey], dist: UpstreamHashShard(lineitem.l_orderkey) }
- name: two phase agg on hop window input should use two phase agg
sql: |
SET QUERY_MODE TO DISTRIBUTED;
Expand Down Expand Up @@ -1180,14 +1178,13 @@
└─StreamExchange { dist: HashShard(window_start) }
└─StreamHashAgg { group_key: [window_start, $expr2], aggs: [max(sum0(count)), count] }
└─StreamProject { exprs: [bid.auction, window_start, sum0(count), Vnode(bid.auction, window_start) as $expr2] }
└─StreamProject { exprs: [bid.auction, window_start, sum0(count)] }
└─StreamHashAgg { group_key: [bid.auction, window_start], aggs: [sum0(count), count] }
└─StreamExchange { dist: HashShard(bid.auction, window_start) }
└─StreamHashAgg { group_key: [bid.auction, window_start, $expr1], aggs: [count] }
└─StreamProject { exprs: [bid.auction, window_start, bid._row_id, Vnode(bid._row_id) as $expr1] }
└─StreamHopWindow { time_col: bid.date_time, slide: 00:00:02, size: 00:00:10, output: [bid.auction, window_start, bid._row_id] }
└─StreamFilter { predicate: IsNotNull(bid.date_time) }
└─StreamTableScan { table: bid, columns: [bid.date_time, bid.auction, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) }
└─StreamHashAgg { group_key: [bid.auction, window_start], aggs: [sum0(count), count] }
└─StreamExchange { dist: HashShard(bid.auction, window_start) }
└─StreamHashAgg { group_key: [bid.auction, window_start, $expr1], aggs: [count] }
└─StreamProject { exprs: [bid.auction, window_start, bid._row_id, Vnode(bid._row_id) as $expr1] }
└─StreamHopWindow { time_col: bid.date_time, slide: 00:00:02, size: 00:00:10, output: [bid.auction, window_start, bid._row_id] }
└─StreamFilter { predicate: IsNotNull(bid.date_time) }
└─StreamTableScan { table: bid, columns: [bid.date_time, bid.auction, bid._row_id], pk: [bid._row_id], dist: UpstreamHashShard(bid._row_id) }
- name: two phase agg with stream SomeShard (via index) but pk satisfies output dist should use shuffle agg
sql: |
SET QUERY_MODE TO DISTRIBUTED;
Expand Down
Loading

0 comments on commit 98ea76a

Please sign in to comment.