Skip to content

Commit

Permalink
fix(planner): stddev/var rewriter panic due to mix use of pre/pos…
Browse files Browse the repository at this point in the history
…t project index (risingwavelabs#9081)
  • Loading branch information
xiangjinwu authored Apr 10, 2023
1 parent a3fc14a commit fd38f3d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 1 deletion.
10 changes: 10 additions & 0 deletions e2e_test/batch/aggregate/stddev_and_variance.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,13 @@ select stddev_pop(v), stddev_samp(v), var_pop(v), var_samp(v) from t

statement ok
drop table t

statement ok
create table t(v int, w float);

query R
select stddev_samp(v) from t group by w;
----

statement ok
drop table t;
17 changes: 17 additions & 0 deletions src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,23 @@
└─StreamStatelessLocalSimpleAgg { aggs: [sum($expr1), sum(t.v1), count(t.v1)] }
└─StreamProject { exprs: [t.v1, (t.v1 * t.v1) as $expr1, t._row_id] }
└─StreamTableScan { table: t, columns: [t.v1, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: stddev_samp with other columns
sql: |
select count(''), stddev_samp(1);
logical_plan: |
LogicalProject { exprs: [count('':Varchar), Case((count(1:Int32) <= 1:Int64), null:Decimal::Float64, Pow(((sum($expr1)::Decimal - ((sum(1:Int32)::Decimal * sum(1:Int32)::Decimal) / count(1:Int32))) / (count(1:Int32) - 1:Int64))::Float64, 0.5:Float64)) as $expr2] }
└─LogicalAgg { aggs: [count('':Varchar), sum($expr1), sum(1:Int32), count(1:Int32)] }
└─LogicalProject { exprs: ['':Varchar, 1:Int32, (1:Int32 * 1:Int32) as $expr1] }
└─LogicalValues { rows: [[]], schema: Schema { fields: [] } }
- name: stddev_samp with group
sql: |
create table t(v int, w float);
select stddev_samp(v) from t group by w;
logical_plan: |
LogicalProject { exprs: [Case((count(t.v) <= 1:Int64), null:Decimal::Float64, Pow(((sum($expr1)::Decimal - ((sum(t.v)::Decimal * sum(t.v)::Decimal) / count(t.v))) / (count(t.v) - 1:Int64))::Float64, 0.5:Float64)) as $expr2] }
└─LogicalAgg { group_key: [t.w], aggs: [sum($expr1), sum(t.v), count(t.v)] }
└─LogicalProject { exprs: [t.w, t.v, (t.v * t.v) as $expr1] }
└─LogicalScan { table: t, columns: [t.v, t.w, t._row_id] }
- name: force two phase aggregation should succeed with UpstreamHashShard and SomeShard (batch only).
sql: |
SET QUERY_MODE TO DISTRIBUTED;
Expand Down
4 changes: 4 additions & 0 deletions src/frontend/src/optimizer/plan_node/generic/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ impl ProjectBuilder {
}
}

pub fn get_expr(&self, index: usize) -> Option<&ExprImpl> {
self.exprs.get(index)
}

pub fn expr_index(&self, expr: &ExprImpl) -> Option<usize> {
check_expr_type(expr).ok()?;
self.exprs_index.get(expr).copied()
Expand Down
3 changes: 2 additions & 1 deletion src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,12 +617,13 @@ impl LogicalAggBuilder {
// use pow(x, 0.5) to simulate
AggKind::StddevPop | AggKind::StddevSamp | AggKind::VarPop | AggKind::VarSamp => {
let input = inputs.iter().exactly_one().unwrap();
let pre_proj_input = self.input_proj_builder.get_expr(input.index).unwrap();

// first, we compute sum of squared as sum_sq
let squared_input_expr = ExprImpl::from(
FunctionCall::new(
ExprType::Multiply,
vec![ExprImpl::from(input.clone()), ExprImpl::from(input.clone())],
vec![pre_proj_input.clone(), pre_proj_input.clone()],
)
.unwrap(),
);
Expand Down

0 comments on commit fd38f3d

Please sign in to comment.