From 035aae7c27fc8739c50fbd6d286ef4167ce44926 Mon Sep 17 00:00:00 2001 From: Xinjing Hu Date: Thu, 8 Jun 2023 19:47:54 +0800 Subject: [PATCH] feat(binder): Support bind args in Ordered-Set Agg (#10193) Co-authored-by: stonepage <40830455+st1page@users.noreply.github.com> --- proto/expr.proto | 6 ++ src/batch/benches/hash_agg.rs | 1 + src/batch/src/executor/hash_agg.rs | 2 + src/batch/src/executor/sort_agg.rs | 5 + src/expr/benches/expr.rs | 1 + src/expr/src/agg/array_agg.rs | 3 + src/expr/src/agg/def.rs | 23 +++- src/expr/src/agg/general.rs | 2 + src/expr/src/agg/string_agg.rs | 3 + src/expr/src/expr/expr_literal.rs | 2 +- .../tests/testdata/input/agg.yaml | 20 ++++ .../tests/testdata/output/agg.yaml | 31 ++++++ src/frontend/src/binder/expr/function.rs | 100 ++++++++++++++---- src/frontend/src/expr/agg_call.rs | 27 ++++- src/frontend/src/expr/expr_rewriter.rs | 4 +- src/frontend/src/expr/mod.rs | 1 + src/frontend/src/expr/order_by_expr.rs | 4 +- .../src/optimizer/plan_node/generic/agg.rs | 18 +++- .../src/optimizer/plan_node/logical_agg.rs | 17 ++- .../optimizer/rule/min_max_on_index_rule.rs | 2 + .../rule/over_window_to_agg_and_join_rule.rs | 1 + src/meta/src/stream/test_fragmenter.rs | 1 + src/sqlparser/tests/testdata/select.yaml | 2 + src/stream/benches/stream_hash_agg.rs | 7 ++ .../src/executor/aggregation/distinct.rs | 1 + src/stream/src/executor/aggregation/minput.rs | 3 + src/stream/src/executor/aggregation/value.rs | 2 + src/stream/src/executor/integration_tests.rs | 5 + .../executor/over_window/state/aggregate.rs | 1 + src/stream/src/executor/simple_agg.rs | 4 + .../src/executor/stateless_simple_agg.rs | 4 + .../tests/integration_tests/hash_agg.rs | 8 ++ 32 files changed, 279 insertions(+), 32 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 4b6c5959fa3f..4ea5b076b13e 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -240,6 +240,11 @@ message InputRef { data.DataType type = 2; } +message Constant { + data.Datum datum = 1; + data.DataType type = 2; +} + // The items which can occur in the select list of `ProjectSet` operator. // // When there are table functions in the SQL query `SELECT ...`, it will be planned as `ProjectSet`. @@ -312,6 +317,7 @@ message AggCall { bool distinct = 4; repeated common.ColumnOrder order_by = 5; ExprNode filter = 6; + repeated Constant direct_args = 7; } message WindowFrame { diff --git a/src/batch/benches/hash_agg.rs b/src/batch/benches/hash_agg.rs index 791d8a940ae0..beaf6eb97b47 100644 --- a/src/batch/benches/hash_agg.rs +++ b/src/batch/benches/hash_agg.rs @@ -47,6 +47,7 @@ fn create_agg_call( distinct: false, order_by: vec![], filter: None, + direct_args: vec![], } } diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 0a3d70e29c7a..32b3ed4bb4ec 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -331,6 +331,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let agg_prost = HashAggNode { @@ -397,6 +398,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let agg_prost = HashAggNode { diff --git a/src/batch/src/executor/sort_agg.rs b/src/batch/src/executor/sort_agg.rs index 8d52fe969b4e..8ecca361ac71 100644 --- a/src/batch/src/executor/sort_agg.rs +++ b/src/batch/src/executor/sort_agg.rs @@ -412,6 +412,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let count_star = build_agg(AggCall::from_protobuf(&prost)?)?; @@ -504,6 +505,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let count_star = build_agg(AggCall::from_protobuf(&prost)?)?; @@ -618,6 +620,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?; @@ -701,6 +704,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?; @@ -810,6 +814,7 @@ mod tests { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], }; let sum_agg = build_agg(AggCall::from_protobuf(&prost)?)?; diff --git a/src/expr/benches/expr.rs b/src/expr/benches/expr.rs index f2913758e3cc..2928bf62c17c 100644 --- a/src/expr/benches/expr.rs +++ b/src/expr/benches/expr.rs @@ -292,6 +292,7 @@ fn bench_expr(c: &mut Criterion) { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }) { Ok(agg) => agg, Err(e) => { diff --git a/src/expr/src/agg/array_agg.rs b/src/expr/src/agg/array_agg.rs index 748a1ab41677..1548e0052b5c 100644 --- a/src/expr/src/agg/array_agg.rs +++ b/src/expr/src/agg/array_agg.rs @@ -66,6 +66,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = return_type.create_array_builder(0); agg.update_multi(&chunk, 0, chunk.cardinality()).await?; @@ -97,6 +98,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = return_type.create_array_builder(0); agg.output(&mut builder)?; @@ -147,6 +149,7 @@ mod tests { ], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = return_type.create_array_builder(0); agg.update_multi(&chunk, 0, chunk.cardinality()).await?; diff --git a/src/expr/src/agg/def.rs b/src/expr/src/agg/def.rs index 1f9bc6941e76..02678493fa5d 100644 --- a/src/expr/src/agg/def.rs +++ b/src/expr/src/agg/def.rs @@ -16,14 +16,16 @@ use std::sync::Arc; +use itertools::Itertools; use parse_display::{Display, FromStr}; use risingwave_common::bail; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, OrderType}; +use risingwave_common::util::value_encoding; use risingwave_pb::expr::agg_call::PbType; use risingwave_pb::expr::{PbAggCall, PbInputRef}; -use crate::expr::{build_from_prost, ExpressionRef}; +use crate::expr::{build_from_prost, ExpressionRef, LiteralExpression}; use crate::Result; /// Represents an aggregation function. @@ -44,6 +46,9 @@ pub struct AggCall { /// Should deduplicate the input before aggregation. pub distinct: bool, + + /// Constant arguments. + pub direct_args: Vec, } impl AggCall { @@ -63,6 +68,21 @@ impl AggCall { Some(ref pb_filter) => Some(Arc::from(build_from_prost(pb_filter)?)), None => None, }; + let direct_args = agg_call + .direct_args + .iter() + .map(|arg| { + let data_type = DataType::from(arg.get_type().unwrap()); + LiteralExpression::new( + data_type.clone(), + value_encoding::deserialize_datum( + arg.get_datum().unwrap().get_body().as_slice(), + &data_type, + ) + .unwrap(), + ) + }) + .collect_vec(); Ok(AggCall { kind: agg_kind, args, @@ -70,6 +90,7 @@ impl AggCall { column_orders, filter, distinct: agg_call.distinct, + direct_args, }) } } diff --git a/src/expr/src/agg/general.rs b/src/expr/src/agg/general.rs index 30f4b9fa3f4a..4f518135da1a 100644 --- a/src/expr/src/agg/general.rs +++ b/src/expr/src/agg/general.rs @@ -214,6 +214,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], })?; agg_state .update_multi(&input_chunk, 0, input_chunk.cardinality()) @@ -404,6 +405,7 @@ mod tests { column_orders: vec![], filter: None, distinct: true, + direct_args: vec![], })?; agg_state .update_multi(&input_chunk, 0, input_chunk.cardinality()) diff --git a/src/expr/src/agg/string_agg.rs b/src/expr/src/agg/string_agg.rs index c2179848147a..70effa22dda3 100644 --- a/src/expr/src/agg/string_agg.rs +++ b/src/expr/src/agg/string_agg.rs @@ -53,6 +53,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); agg.update_multi(&chunk, 0, chunk.cardinality()).await?; @@ -81,6 +82,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); agg.update_multi(&chunk, 0, chunk.cardinality()).await?; @@ -113,6 +115,7 @@ mod tests { ], filter: None, distinct: false, + direct_args: vec![], })?; let mut builder = ArrayBuilderImpl::Utf8(Utf8ArrayBuilder::new(0)); agg.update_multi(&chunk, 0, chunk.cardinality()).await?; diff --git a/src/expr/src/expr/expr_literal.rs b/src/expr/src/expr/expr_literal.rs index 07b287342ebd..40eb1fb82441 100644 --- a/src/expr/src/expr/expr_literal.rs +++ b/src/expr/src/expr/expr_literal.rs @@ -25,7 +25,7 @@ use crate::expr::Expression; use crate::{ExprError, Result}; /// A literal expression. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct LiteralExpression { return_type: DataType, literal: Datum, diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 05adc9fedf74..890aee5c210c 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -723,3 +723,23 @@ select a, count(*) cnt from t group by a order by a desc; expected_outputs: - batch_plan +- sql: | + create table t (x int, y int); + select percentile_cont(x) within group (order by y) from t; + expected_outputs: + - binder_error +- sql: | + create table t (x int, y int); + select percentile_cont('abc') within group (order by y) from t; + expected_outputs: + - binder_error +- sql: | + create table t (x int, y int); + select percentile_cont(0, 0) within group (order by y) from t; + expected_outputs: + - binder_error +- sql: | + create table t (x int, y int); + select percentile_cont(0) within group (order by y desc) from t; + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index bcccc87203f2..3bc154b2e5bc 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1311,3 +1311,34 @@ └─BatchHashAgg { group_key: [t.a], aggs: [count] } └─BatchExchange { order: [], dist: HashShard(t.a) } └─BatchScan { table: t, columns: [t.a], distribution: SomeShard } +- sql: | + create table t (x int, y int); + select percentile_cont(x) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(x) + + Caused by: + Invalid input syntax: arg in percentile_cont must be constant +- sql: | + create table t (x int, y int); + select percentile_cont('abc') within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont('abc') + + Caused by: + Invalid input syntax: arg in percentile_cont must be double precision +- sql: | + create table t (x int, y int); + select percentile_cont(0, 0) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(0, 0) + + Caused by: + Invalid input syntax: only one arg is expected in percentile_cont +- sql: | + create table t (x int, y int); + select percentile_cont(0) within group (order by y desc) from t; + batch_plan: | + BatchSimpleAgg { aggs: [percentile_cont(t.y order_by(t.y DESC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index b46037a12553..966a1cb2cfd1 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -21,7 +21,7 @@ use bk_tree::{metrics, BKTree}; use itertools::Itertools; use risingwave_common::array::ListValue; use risingwave_common::catalog::PG_CATALOG_SCHEMA_NAME; -use risingwave_common::error::{ErrorCode, Result}; +use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::session_config::USER_NAME_WILD_CARD; use risingwave_common::types::DataType; use risingwave_common::{GIT_SHA, RW_VERSION}; @@ -141,12 +141,18 @@ impl Binder { } pub(super) fn bind_agg(&mut self, mut f: Function, kind: AggKind) -> Result { - if f.within_group.is_some() - && !matches!( - kind, - AggKind::PercentileCont | AggKind::PercentileDisc | AggKind::Mode - ) - { + if matches!( + kind, + AggKind::PercentileCont | AggKind::PercentileDisc | AggKind::Mode + ) { + if f.within_group.is_none() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "within group is expected for the {}", + kind + )) + .into()); + } + } else if f.within_group.is_some() { return Err(ErrorCode::InvalidInputSyntax(format!( "within group is disallowed for the {}", kind @@ -154,12 +160,19 @@ impl Binder { .into()); } self.ensure_aggregate_allowed()?; - let inputs: Vec = f - .args - .into_iter() - .map(|arg| self.bind_function_arg(arg)) - .flatten_ok() - .try_collect()?; + let inputs: Vec = if f.within_group.is_some() { + f.within_group + .iter() + .map(|x| self.bind_function_expr_arg(FunctionArgExpr::Expr(x.expr.clone()))) + .flatten_ok() + .try_collect()? + } else { + f.args + .iter() + .map(|arg| self.bind_function_arg(arg.clone())) + .flatten_ok() + .try_collect()? + }; if f.distinct { match &kind { AggKind::Count if inputs.is_empty() => { @@ -230,14 +243,61 @@ impl Binder { ) .into()); } - let order_by = OrderBy::new( - f.order_by - .into_iter() - .map(|e| self.bind_order_by_expr(e)) - .try_collect()?, - ); + let order_by = if f.within_group.is_some() { + if !f.order_by.is_empty() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "order_by clause outside of within group is disallowed in {}", + kind + )) + .into()); + } + OrderBy::new( + f.within_group + .iter() + .map(|x| self.bind_order_by_expr(*x.clone())) + .try_collect()?, + ) + } else { + OrderBy::new( + f.order_by + .into_iter() + .map(|e| self.bind_order_by_expr(e)) + .try_collect()?, + ) + }; + let direct_args = if matches!(kind, AggKind::PercentileCont | AggKind::PercentileDisc) { + let args = + self.bind_function_arg(f.args.into_iter().exactly_one().map_err(|_| { + ErrorCode::InvalidInputSyntax(format!("only one arg is expected in {}", kind)) + })?)?; + if args.len() != 1 || args[0].clone().as_literal().is_none() { + Err( + ErrorCode::InvalidInputSyntax(format!("arg in {} must be constant", kind)) + .into(), + ) + } else if let Ok(casted) = args[0] + .clone() + .cast_implicit(DataType::Float64)? + .fold_const() + { + Ok::<_, RwError>(vec![Literal::new(casted, DataType::Float64)]) + } else { + Err(ErrorCode::InvalidInputSyntax(format!( + "arg in {} must be double precision", + kind + )) + .into()) + } + } else { + Ok(vec![]) + }?; Ok(ExprImpl::AggCall(Box::new(AggCall::new( - kind, inputs, f.distinct, order_by, filter, + kind, + inputs, + f.distinct, + order_by, + filter, + direct_args, )?))) } diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index 8ef51b7911d2..daf8d30859e8 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -18,7 +18,7 @@ use risingwave_common::types::DataType; use risingwave_expr::agg::AggKind; use risingwave_expr::sig::agg::AGG_FUNC_SIG_MAP; -use super::{Expr, ExprImpl, OrderBy}; +use super::{Expr, ExprImpl, Literal, OrderBy}; use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] @@ -29,6 +29,7 @@ pub struct AggCall { distinct: bool, order_by: OrderBy, filter: Condition, + direct_args: Vec, } impl std::fmt::Debug for AggCall { @@ -68,7 +69,14 @@ impl AggCall { // XXX: some special cases that can not be handled by signature map. // may return list or struct type - (AggKind::Min | AggKind::Max | AggKind::FirstValue, [input]) => input.clone(), + ( + AggKind::Min + | AggKind::Max + | AggKind::FirstValue + | AggKind::PercentileDisc + | AggKind::Mode, + [input], + ) => input.clone(), (AggKind::ArrayAgg, [input]) => List(Box::new(input.clone())), // functions that are rewritten in the frontend and don't exist in the expr crate (AggKind::Avg, [input]) => match input { @@ -85,6 +93,7 @@ impl AggCall { Float32 | Float64 | Int256 => Float64, _ => return Err(err()), }, + (AggKind::PercentileCont, _) => Float64, // other functions are handled by signature map _ => { @@ -105,6 +114,7 @@ impl AggCall { distinct: bool, order_by: OrderBy, filter: Condition, + direct_args: Vec, ) -> Result { let data_types = inputs.iter().map(ExprImpl::return_type).collect_vec(); let return_type = Self::infer_return_type(agg_kind, &data_types)?; @@ -115,16 +125,27 @@ impl AggCall { distinct, order_by, filter, + direct_args, }) } - pub fn decompose(self) -> (AggKind, Vec, bool, OrderBy, Condition) { + pub fn decompose( + self, + ) -> ( + AggKind, + Vec, + bool, + OrderBy, + Condition, + Vec, + ) { ( self.agg_kind, self.inputs, self.distinct, self.order_by, self.filter, + self.direct_args, ) } diff --git a/src/frontend/src/expr/expr_rewriter.rs b/src/frontend/src/expr/expr_rewriter.rs index 23456858bc73..b75e5ce7ba79 100644 --- a/src/frontend/src/expr/expr_rewriter.rs +++ b/src/frontend/src/expr/expr_rewriter.rs @@ -46,14 +46,14 @@ pub trait ExprRewriter { FunctionCall::new_unchecked(func_type, inputs, ret).into() } fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl { - let (func_type, inputs, distinct, order_by, filter) = agg_call.decompose(); + let (func_type, inputs, distinct, order_by, filter, direct_args) = agg_call.decompose(); let inputs = inputs .into_iter() .map(|expr| self.rewrite_expr(expr)) .collect(); let order_by = order_by.rewrite_expr(self); let filter = filter.rewrite_expr(self); - AggCall::new(func_type, inputs, distinct, order_by, filter) + AggCall::new(func_type, inputs, distinct, order_by, filter, direct_args) .unwrap() .into() } diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index e614be03fb2d..5015d8e257a3 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -181,6 +181,7 @@ impl ExprImpl { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .unwrap() .into() diff --git a/src/frontend/src/expr/order_by_expr.rs b/src/frontend/src/expr/order_by_expr.rs index 5bb76899086d..48790ceb3933 100644 --- a/src/frontend/src/expr/order_by_expr.rs +++ b/src/frontend/src/expr/order_by_expr.rs @@ -22,7 +22,7 @@ use crate::expr::{ExprImpl, ExprMutator, ExprRewriter, ExprVisitor}; /// A sort expression in the `ORDER BY` clause. /// /// See also [`bind_order_by_expr`](`crate::binder::Binder::bind_order_by_expr`). -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct OrderByExpr { pub expr: ExprImpl, pub order_type: OrderType, @@ -36,7 +36,7 @@ impl Display for OrderByExpr { } /// See [`OrderByExpr`]. -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq, Hash, Debug)] pub struct OrderBy { pub sort_exprs: Vec, } diff --git a/src/frontend/src/optimizer/plan_node/generic/agg.rs b/src/frontend/src/optimizer/plan_node/generic/agg.rs index 22eeedfc0f37..dfa373c9cf39 100644 --- a/src/frontend/src/optimizer/plan_node/generic/agg.rs +++ b/src/frontend/src/optimizer/plan_node/generic/agg.rs @@ -21,13 +21,15 @@ use pretty_xmlish::{Pretty, StrAssocArr}; use risingwave_common::catalog::{Field, FieldDisplay, Schema}; use risingwave_common::types::DataType; use risingwave_common::util::sort_util::{ColumnOrder, ColumnOrderDisplay, OrderType}; +use risingwave_common::util::value_encoding; use risingwave_expr::agg::AggKind; -use risingwave_pb::expr::PbAggCall; +use risingwave_pb::data::PbDatum; +use risingwave_pb::expr::{PbAggCall, PbConstant}; use risingwave_pb::stream_plan::{agg_call_state, AggCallState as AggCallStatePb}; use super::super::utils::TableCatalogBuilder; use super::{impl_distill_unit_from_fields, stream, GenericPlanNode, GenericPlanRef}; -use crate::expr::{Expr, ExprRewriter, InputRef, InputRefDisplay}; +use crate::expr::{Expr, ExprRewriter, InputRef, InputRefDisplay, Literal}; use crate::optimizer::optimizer_context::OptimizerContextRef; use crate::optimizer::plan_node::batch::BatchPlanRef; use crate::optimizer::property::{Distribution, FunctionalDependencySet, RequiredDist}; @@ -669,6 +671,7 @@ pub struct PlanAggCall { /// Selective aggregation: only the input rows for which /// `filter` evaluates to `true` will be fed to the aggregate function. pub filter: Condition, + pub direct_args: Vec, } impl fmt::Debug for PlanAggCall { @@ -729,6 +732,16 @@ impl PlanAggCall { distinct: self.distinct, order_by: self.order_by.iter().map(ColumnOrder::to_protobuf).collect(), filter: self.filter.as_expr_unless_true().map(|x| x.to_expr_proto()), + direct_args: self + .direct_args + .iter() + .map(|x| PbConstant { + datum: Some(PbDatum { + body: value_encoding::serialize_datum(x.get_data()), + }), + r#type: Some(x.return_type().to_protobuf()), + }) + .collect(), } } @@ -775,6 +788,7 @@ impl PlanAggCall { distinct: false, order_by: vec![], filter: Condition::true_cond(), + direct_args: vec![], } } diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 8aaa4fa6ab0f..6cbbe5dccd3c 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -419,7 +419,8 @@ impl LogicalAggBuilder { agg_call: AggCall, ) -> std::result::Result { let return_type = agg_call.return_type(); - let (agg_kind, inputs, mut distinct, mut order_by, filter) = agg_call.decompose(); + let (agg_kind, inputs, mut distinct, mut order_by, filter, direct_args) = + agg_call.decompose(); match &agg_kind { AggKind::Min | AggKind::Max => { distinct = false; @@ -487,6 +488,7 @@ impl LogicalAggBuilder { distinct, order_by: order_by.clone(), filter: filter.clone(), + direct_args: direct_args.clone(), }); let left = ExprImpl::from(left_ref).cast_explicit(return_type).unwrap(); @@ -499,6 +501,7 @@ impl LogicalAggBuilder { distinct, order_by, filter, + direct_args, }); Ok(ExprImpl::from( @@ -546,6 +549,7 @@ impl LogicalAggBuilder { distinct, order_by: order_by.clone(), filter: filter.clone(), + direct_args: direct_args.clone(), })) .cast_explicit(return_type.clone()) .unwrap(); @@ -561,6 +565,7 @@ impl LogicalAggBuilder { distinct, order_by: order_by.clone(), filter: filter.clone(), + direct_args: direct_args.clone(), })) .cast_explicit(return_type.clone()) .unwrap(); @@ -576,6 +581,7 @@ impl LogicalAggBuilder { distinct, order_by, filter, + direct_args, })); // we start with variance @@ -675,6 +681,7 @@ impl LogicalAggBuilder { distinct, order_by, filter, + direct_args, }) .into()), } @@ -1186,6 +1193,7 @@ mod tests { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .unwrap(); let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()]; @@ -1211,6 +1219,7 @@ mod tests { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .unwrap(); let max_v3 = AggCall::new( @@ -1219,6 +1228,7 @@ mod tests { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .unwrap(); let func_call = @@ -1259,6 +1269,7 @@ mod tests { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .unwrap(); let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()]; @@ -1293,6 +1304,7 @@ mod tests { distinct: false, order_by: vec![], filter: Condition::true_cond(), + direct_args: vec![], }; Agg::new( vec![agg_call], @@ -1417,6 +1429,7 @@ mod tests { distinct: false, order_by: vec![], filter: Condition::true_cond(), + direct_args: vec![], }; let agg: PlanRef = Agg::new( vec![agg_call], @@ -1487,6 +1500,7 @@ mod tests { distinct: false, order_by: vec![], filter: Condition::true_cond(), + direct_args: vec![], }, PlanAggCall { agg_kind: AggKind::Max, @@ -1495,6 +1509,7 @@ mod tests { distinct: false, order_by: vec![], filter: Condition::true_cond(), + direct_args: vec![], }, ]; let agg: PlanRef = diff --git a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs index 24667e8e46f4..29f4bad286fb 100644 --- a/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs +++ b/src/frontend/src/optimizer/rule/min_max_on_index_rule.rs @@ -124,6 +124,7 @@ impl MinMaxOnIndexRule { filter: Condition { conjunctions: vec![], }, + direct_args: vec![], }], FixedBitSet::new(), topn.into(), @@ -193,6 +194,7 @@ impl MinMaxOnIndexRule { filter: Condition { conjunctions: vec![], }, + direct_args: vec![], }], FixedBitSet::new(), topn.into(), diff --git a/src/frontend/src/optimizer/rule/over_window_to_agg_and_join_rule.rs b/src/frontend/src/optimizer/rule/over_window_to_agg_and_join_rule.rs index c3f4b8857824..48b3f67cce76 100644 --- a/src/frontend/src/optimizer/rule/over_window_to_agg_and_join_rule.rs +++ b/src/frontend/src/optimizer/rule/over_window_to_agg_and_join_rule.rs @@ -58,6 +58,7 @@ impl Rule for OverWindowToAggAndJoinRule { false, OrderBy::any(), Condition::true_cond(), + vec![], ) .ok()?; select_exprs.push(agg_call.into()); diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 6b721196b45f..46b8de238f20 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -72,6 +72,7 @@ fn make_sum_aggcall(idx: u32) -> AggCall { distinct: false, order_by: vec![], filter: None, + direct_args: vec![], } } diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 991adcfdd975..c8aed6d1d8ca 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -82,3 +82,5 @@ - input: select percentile_cont(0.3) within group (order by x desc) from unnest(array[1,2,4,5,10]) as x formatted_sql: SELECT percentile_cont(0.3) FROM unnest(ARRAY[1, 2, 4, 5, 10]) AS x formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))] }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' +- input: select percentile_cont(0.3) within group (order by x, y desc) from t + error_msg: 'sql parser error: only one arg in order by is expected here' diff --git a/src/stream/benches/stream_hash_agg.rs b/src/stream/benches/stream_hash_agg.rs index 7a56861f0199..be3927bb0659 100644 --- a/src/stream/benches/stream_hash_agg.rs +++ b/src/stream/benches/stream_hash_agg.rs @@ -92,6 +92,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Count, @@ -100,6 +101,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: Some(build_from_pretty("(less_than:boolean $2:int8 10000:int8)").into()), distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Count, @@ -108,6 +110,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: Some(build_from_pretty("(and:boolean (greater_than_or_equal:boolean $2:int8 10000:int8) (less_than:boolean $2:int8 100000:int8))").into()), distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Count, @@ -116,6 +119,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: Some(build_from_pretty("(greater_than_or_equal:boolean $2:int8 100000:int8)").into()), distinct: false, + direct_args: vec![], }, // FIXME(kwannoel): Can ignore for now, since it is low cost in q17 (blackhole). // It does not work can't diagnose root cause yet. @@ -152,6 +156,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, // avg (count) AggCall { @@ -161,6 +166,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -169,6 +175,7 @@ fn setup_bench_hash_agg(store: S) -> BoxedExecutor { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; diff --git a/src/stream/src/executor/aggregation/distinct.rs b/src/stream/src/executor/aggregation/distinct.rs index ea6081b3dcfe..bc91af436b88 100644 --- a/src/stream/src/executor/aggregation/distinct.rs +++ b/src/stream/src/executor/aggregation/distinct.rs @@ -335,6 +335,7 @@ mod tests { column_orders: vec![], filter: None, + direct_args: vec![], } } diff --git a/src/stream/src/executor/aggregation/minput.rs b/src/stream/src/executor/aggregation/minput.rs index c38965e2e8ec..fbb3bbd26393 100644 --- a/src/stream/src/executor/aggregation/minput.rs +++ b/src/stream/src/executor/aggregation/minput.rs @@ -297,6 +297,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], } } @@ -993,6 +994,7 @@ mod tests { ], filter: None, distinct: false, + direct_args: vec![], }; let group_key = None; @@ -1094,6 +1096,7 @@ mod tests { ], filter: None, distinct: false, + direct_args: vec![], }; let group_key = None; diff --git a/src/stream/src/executor/aggregation/value.rs b/src/stream/src/executor/aggregation/value.rs index aa1121eac7ba..327adf6c5c4a 100644 --- a/src/stream/src/executor/aggregation/value.rs +++ b/src/stream/src/executor/aggregation/value.rs @@ -95,6 +95,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], } } @@ -137,6 +138,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], } } diff --git a/src/stream/src/executor/integration_tests.rs b/src/stream/src/executor/integration_tests.rs index 57cf9fe15a6c..e43e3ec3c8ba 100644 --- a/src/stream/src/executor/integration_tests.rs +++ b/src/stream/src/executor/integration_tests.rs @@ -60,6 +60,7 @@ async fn test_merger_sum_aggr() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -68,6 +69,7 @@ async fn test_merger_sum_aggr() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ], vec![], @@ -157,6 +159,7 @@ async fn test_merger_sum_aggr() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -165,6 +168,7 @@ async fn test_merger_sum_aggr() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Count, // as row count, index: 2 @@ -173,6 +177,7 @@ async fn test_merger_sum_aggr() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ], 2, // row_count_index diff --git a/src/stream/src/executor/over_window/state/aggregate.rs b/src/stream/src/executor/over_window/state/aggregate.rs index 2f38fa9b932e..911ecf7062ac 100644 --- a/src/stream/src/executor/over_window/state/aggregate.rs +++ b/src/stream/src/executor/over_window/state/aggregate.rs @@ -56,6 +56,7 @@ impl AggregateState { filter: None, // TODO(rc): support distinct on window function call? PG doesn't support it either. distinct: false, + direct_args: vec![], }; Ok(Self { agg_call, diff --git a/src/stream/src/executor/simple_agg.rs b/src/stream/src/executor/simple_agg.rs index 92b5d0f0b3d6..c7e1dd7d2437 100644 --- a/src/stream/src/executor/simple_agg.rs +++ b/src/stream/src/executor/simple_agg.rs @@ -397,6 +397,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -405,6 +406,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -413,6 +415,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Min, @@ -421,6 +424,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; diff --git a/src/stream/src/executor/stateless_simple_agg.rs b/src/stream/src/executor/stateless_simple_agg.rs index 15340863b252..79db71c1ed9a 100644 --- a/src/stream/src/executor/stateless_simple_agg.rs +++ b/src/stream/src/executor/stateless_simple_agg.rs @@ -206,6 +206,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }]; let simple_agg = Box::new( @@ -263,6 +264,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -271,6 +273,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -279,6 +282,7 @@ mod tests { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; diff --git a/src/stream/tests/integration_tests/hash_agg.rs b/src/stream/tests/integration_tests/hash_agg.rs index ee70940436eb..202b64917ea5 100644 --- a/src/stream/tests/integration_tests/hash_agg.rs +++ b/src/stream/tests/integration_tests/hash_agg.rs @@ -39,6 +39,7 @@ async fn test_hash_agg_count_sum() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Sum, @@ -47,6 +48,7 @@ async fn test_hash_agg_count_sum() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, // This is local hash aggregation, so we add another sum state AggCall { @@ -56,6 +58,7 @@ async fn test_hash_agg_count_sum() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; @@ -140,6 +143,7 @@ async fn test_hash_agg_min() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Min, @@ -148,6 +152,7 @@ async fn test_hash_agg_min() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; @@ -229,6 +234,7 @@ async fn test_hash_agg_min_append_only() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, AggCall { kind: AggKind::Min, @@ -237,6 +243,7 @@ async fn test_hash_agg_min_append_only() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }, ]; @@ -319,6 +326,7 @@ async fn test_hash_agg_emit_on_window_close() { column_orders: vec![], filter: None, distinct: false, + direct_args: vec![], }]; let create_executor = || async {