Skip to content

Commit

Permalink
feat(binder): Support bind args in Ordered-Set Agg (#10193)
Browse files Browse the repository at this point in the history
Co-authored-by: stonepage <40830455+st1page@users.noreply.github.com>
  • Loading branch information
Honeta and st1page authored Jun 8, 2023
1 parent 64ec79f commit 035aae7
Show file tree
Hide file tree
Showing 32 changed files with 279 additions and 32 deletions.
6 changes: 6 additions & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ message InputRef {
data.DataType type = 2;
}

message Constant {
data.Datum datum = 1;
data.DataType type = 2;
}

// The items which can occur in the select list of `ProjectSet` operator.
//
// When there are table functions in the SQL query `SELECT ...`, it will be planned as `ProjectSet`.
Expand Down Expand Up @@ -312,6 +317,7 @@ message AggCall {
bool distinct = 4;
repeated common.ColumnOrder order_by = 5;
ExprNode filter = 6;
repeated Constant direct_args = 7;
}

message WindowFrame {
Expand Down
1 change: 1 addition & 0 deletions src/batch/benches/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ fn create_agg_call(
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/batch/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let agg_prost = HashAggNode {
Expand Down Expand Up @@ -397,6 +398,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let agg_prost = HashAggNode {
Expand Down
5 changes: 5 additions & 0 deletions src/batch/src/executor/sort_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let count_star = build_agg(AggCall::from_protobuf(&prost)?)?;
Expand Down Expand Up @@ -504,6 +505,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let count_star = build_agg(AggCall::from_protobuf(&prost)?)?;
Expand Down Expand Up @@ -618,6 +620,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?;
Expand Down Expand Up @@ -701,6 +704,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?;
Expand Down Expand Up @@ -810,6 +814,7 @@ mod tests {
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?;
Expand Down
1 change: 1 addition & 0 deletions src/expr/benches/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ fn bench_expr(c: &mut Criterion) {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
}) {
Ok(agg) => agg,
Err(e) => {
Expand Down
3 changes: 3 additions & 0 deletions src/expr/src/agg/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = return_type.create_array_builder(0);
agg.update_multi(&chunk, 0, chunk.cardinality()).await?;
Expand Down Expand Up @@ -97,6 +98,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = return_type.create_array_builder(0);
agg.output(&mut builder)?;
Expand Down Expand Up @@ -147,6 +149,7 @@ mod tests {
],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = return_type.create_array_builder(0);
agg.update_multi(&chunk, 0, chunk.cardinality()).await?;
Expand Down
23 changes: 22 additions & 1 deletion src/expr/src/agg/def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@

use std::sync::Arc;

use itertools::Itertools;
use parse_display::{Display, FromStr};
use risingwave_common::bail;
use risingwave_common::types::DataType;
use risingwave_common::util::sort_util::{ColumnOrder, OrderType};
use risingwave_common::util::value_encoding;
use risingwave_pb::expr::agg_call::PbType;
use risingwave_pb::expr::{PbAggCall, PbInputRef};

use crate::expr::{build_from_prost, ExpressionRef};
use crate::expr::{build_from_prost, ExpressionRef, LiteralExpression};
use crate::Result;

/// Represents an aggregation function.
Expand All @@ -44,6 +46,9 @@ pub struct AggCall {

/// Should deduplicate the input before aggregation.
pub distinct: bool,

/// Constant arguments.
pub direct_args: Vec<LiteralExpression>,
}

impl AggCall {
Expand All @@ -63,13 +68,29 @@ impl AggCall {
Some(ref pb_filter) => Some(Arc::from(build_from_prost(pb_filter)?)),
None => None,
};
let direct_args = agg_call
.direct_args
.iter()
.map(|arg| {
let data_type = DataType::from(arg.get_type().unwrap());
LiteralExpression::new(
data_type.clone(),
value_encoding::deserialize_datum(
arg.get_datum().unwrap().get_body().as_slice(),
&data_type,
)
.unwrap(),
)
})
.collect_vec();
Ok(AggCall {
kind: agg_kind,
args,
return_type: DataType::from(agg_call.get_return_type()?),
column_orders,
filter,
distinct: agg_call.distinct,
direct_args,
})
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/expr/src/agg/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
})?;
agg_state
.update_multi(&input_chunk, 0, input_chunk.cardinality())
Expand Down Expand Up @@ -404,6 +405,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: true,
direct_args: vec![],
})?;
agg_state
.update_multi(&input_chunk, 0, input_chunk.cardinality())
Expand Down
3 changes: 3 additions & 0 deletions src/expr/src/agg/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0));
agg.update_multi(&chunk, 0, chunk.cardinality()).await?;
Expand Down Expand Up @@ -81,6 +82,7 @@ mod tests {
column_orders: vec![],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0));
agg.update_multi(&chunk, 0, chunk.cardinality()).await?;
Expand Down Expand Up @@ -113,6 +115,7 @@ mod tests {
],
filter: None,
distinct: false,
direct_args: vec![],
})?;
let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0));
agg.update_multi(&chunk, 0, chunk.cardinality()).await?;
Expand Down
2 changes: 1 addition & 1 deletion src/expr/src/expr/expr_literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::expr::Expression;
use crate::{ExprError, Result};

/// A literal expression.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct LiteralExpression {
return_type: DataType,
literal: Datum,
Expand Down
20 changes: 20 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -723,3 +723,23 @@
select a, count(*) cnt from t group by a order by a desc;
expected_outputs:
- batch_plan
- sql: |
create table t (x int, y int);
select percentile_cont(x) within group (order by y) from t;
expected_outputs:
- binder_error
- sql: |
create table t (x int, y int);
select percentile_cont('abc') within group (order by y) from t;
expected_outputs:
- binder_error
- sql: |
create table t (x int, y int);
select percentile_cont(0, 0) within group (order by y) from t;
expected_outputs:
- binder_error
- sql: |
create table t (x int, y int);
select percentile_cont(0) within group (order by y desc) from t;
expected_outputs:
- batch_plan
31 changes: 31 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1311,3 +1311,34 @@
└─BatchHashAgg { group_key: [t.a], aggs: [count] }
└─BatchExchange { order: [], dist: HashShard(t.a) }
└─BatchScan { table: t, columns: [t.a], distribution: SomeShard }
- sql: |
create table t (x int, y int);
select percentile_cont(x) within group (order by y) from t;
binder_error: |-
Bind error: failed to bind expression: percentile_cont(x)
Caused by:
Invalid input syntax: arg in percentile_cont must be constant
- sql: |
create table t (x int, y int);
select percentile_cont('abc') within group (order by y) from t;
binder_error: |-
Bind error: failed to bind expression: percentile_cont('abc')
Caused by:
Invalid input syntax: arg in percentile_cont must be double precision
- sql: |
create table t (x int, y int);
select percentile_cont(0, 0) within group (order by y) from t;
binder_error: |-
Bind error: failed to bind expression: percentile_cont(0, 0)
Caused by:
Invalid input syntax: only one arg is expected in percentile_cont
- sql: |
create table t (x int, y int);
select percentile_cont(0) within group (order by y desc) from t;
batch_plan: |
BatchSimpleAgg { aggs: [percentile_cont(t.y order_by(t.y DESC))] }
└─BatchExchange { order: [], dist: Single }
└─BatchScan { table: t, columns: [t.y], distribution: SomeShard }
Loading

0 comments on commit 035aae7

Please sign in to comment.