From 90ee868d9a168f7d38b190796d89cbd40159a988 Mon Sep 17 00:00:00 2001 From: Xinjing Hu Date: Wed, 14 Jun 2023 19:00:27 +0800 Subject: [PATCH 1/4] feat(expr, agg): support `PERCENTILE_CONT`, `PERCENTILE_DISC` and `MODE` aggregation (#10252) Signed-off-by: Richard Chien Co-authored-by: Richard Chien Co-authored-by: Noel Kwan <47273164+kwannoel@users.noreply.github.com> --- .../batch/aggregate/ordered_set_agg.slt.part | 33 ++++ src/expr/src/agg/mod.rs | 3 + src/expr/src/agg/mode.rs | 126 +++++++++++++++ src/expr/src/agg/percentile_cont.rs | 132 ++++++++++++++++ src/expr/src/agg/percentile_disc.rs | 143 ++++++++++++++++++ .../tests/testdata/input/agg.yaml | 25 +++ .../tests/testdata/output/agg.yaml | 43 +++++- src/frontend/src/binder/expr/function.rs | 39 ++++- src/frontend/src/expr/agg_call.rs | 16 +- .../src/optimizer/plan_node/logical_agg.rs | 18 +++ src/tests/sqlsmith/src/sql_gen/types.rs | 3 + 11 files changed, 563 insertions(+), 18 deletions(-) create mode 100644 e2e_test/batch/aggregate/ordered_set_agg.slt.part create mode 100644 src/expr/src/agg/mode.rs create mode 100644 src/expr/src/agg/percentile_cont.rs create mode 100644 src/expr/src/agg/percentile_disc.rs diff --git a/e2e_test/batch/aggregate/ordered_set_agg.slt.part b/e2e_test/batch/aggregate/ordered_set_agg.slt.part new file mode 100644 index 000000000000..6cf42db84313 --- /dev/null +++ b/e2e_test/batch/aggregate/ordered_set_agg.slt.part @@ -0,0 +1,33 @@ +statement error +select p, percentile_cont(p) within group (order by x::float8) +from generate_series(1,5) x, + (values (0::float8),(0.1),(0.25),(0.4),(0.5),(0.6),(0.75),(0.9),(1)) v(p) +group by p order by p; + +statement error +select percentile_cont(array[0,1,0.25,0.75,0.5,1,0.3,0.32,0.35,0.38,0.4]) within group (order by x) +from generate_series(1,6) x; + +statement error +select percentile_disc(array[0.25,0.5,0.75]) within group (order by x) +from unnest('{fred,jim,fred,jack,jill,fred,jill,jim,jim,sheila,jim,sheila}'::text[]) u(x); + +statement error +select pg_collation_for(percentile_disc(1) within group (order by x collate "POSIX")) + from (values ('fred'),('jim')) v(x); + +query RR +select + percentile_cont(0.5) within group (order by a), + percentile_disc(0.5) within group (order by a) +from (values(1::float8),(3),(5),(7)) t(a); +---- +4 3 + +query RR +select + percentile_cont(0.25) within group (order by a), + percentile_disc(0.5) within group (order by a) +from (values(1::float8),(3),(5),(7)) t(a); +---- +2.5 3 diff --git a/src/expr/src/agg/mod.rs b/src/expr/src/agg/mod.rs index c03aa597a906..b0d7722d240e 100644 --- a/src/expr/src/agg/mod.rs +++ b/src/expr/src/agg/mod.rs @@ -28,6 +28,9 @@ mod array_agg; mod count_star; mod general; mod jsonb_agg; +mod mode; +mod percentile_cont; +mod percentile_disc; mod string_agg; // wrappers diff --git a/src/expr/src/agg/mode.rs b/src/expr/src/agg/mode.rs new file mode 100644 index 000000000000..297e169dc93c --- /dev/null +++ b/src/expr/src/agg/mode.rs @@ -0,0 +1,126 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +#[build_aggregate("mode(*) -> *")] +fn build(agg: AggCall) -> Result> { + Ok(Box::new(Mode::new(agg.return_type))) +} + +/// Computes the mode, the most frequent value of the aggregated argument (arbitrarily choosing the +/// first one if there are multiple equally-frequent values). The aggregated argument must be of a +/// sortable type. +/// +/// ```slt +/// query I +/// select mode() within group (order by unnest) from unnest(array[1]); +/// ---- +/// 1 +/// +/// query I +/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4]); +/// ---- +/// 4 +/// +/// query R +/// select mode() within group (order by unnest) from unnest(array[0.1,0.2,0.2,0.4,0.4,0.3,0.3,0.4]); +/// ---- +/// 0.4 +/// +/// query R +/// select mode() within group (order by unnest) from unnest(array[1,2,2,3,3,4,4,4,3]); +/// ---- +/// 3 +/// +/// query T +/// select mode() within group (order by unnest) from unnest(array['1','2','2','3','3','4','4','4','3']); +/// ---- +/// 3 +/// +/// query I +/// select mode() within group (order by unnest) from unnest(array[]::int[]); +/// ---- +/// NULL +/// ``` +#[derive(Clone, EstimateSize)] +pub struct Mode { + return_type: DataType, + cur_mode: Datum, + cur_mode_freq: usize, + cur_item: Datum, + cur_item_freq: usize, +} + +impl Mode { + pub fn new(return_type: DataType) -> Self { + Self { + return_type, + cur_mode: None, + cur_mode_freq: 0, + cur_item: None, + cur_item_freq: 0, + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + let datum = datum_ref.to_owned_datum(); + if datum.is_some() && self.cur_item == datum { + self.cur_item_freq += 1; + } else if datum.is_some() { + self.cur_item = datum; + self.cur_item_freq = 1; + } + if self.cur_item_freq > self.cur_mode_freq { + self.cur_mode = self.cur_item.clone(); + self.cur_mode_freq = self.cur_item_freq; + } + } +} + +#[async_trait::async_trait] +impl Aggregator for Mode { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + builder.append(self.cur_mode.clone()); + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/expr/src/agg/percentile_cont.rs b/src/expr/src/agg/percentile_cont.rs new file mode 100644 index 000000000000..1e88557712c5 --- /dev/null +++ b/src/expr/src/agg/percentile_cont.rs @@ -0,0 +1,132 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +/// Computes the continuous percentile, a value corresponding to the specified fraction within the +/// ordered set of aggregated argument values. This will interpolate between adjacent input items if +/// needed. +/// +/// ```slt +/// statement ok +/// create table t(x int, y bigint, z real, w double, v varchar); +/// +/// statement ok +/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000'); +/// +/// query R +/// select percentile_cont(0.45) within group (order by x desc) from t; +/// ---- +/// 2.1 +/// +/// query R +/// select percentile_cont(0.45) within group (order by y desc) from t; +/// ---- +/// 21 +/// +/// query R +/// select percentile_cont(0.45) within group (order by z desc) from t; +/// ---- +/// 210 +/// +/// query R +/// select percentile_cont(0.45) within group (order by w desc) from t; +/// ---- +/// 2100 +/// +/// query R +/// select percentile_cont(NULL) within group (order by w desc) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[build_aggregate("percentile_cont(float64) -> float64")] +fn build(agg: AggCall) -> Result> { + let fraction: Option = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(PercentileCont::new(fraction))) +} + +#[derive(Clone, EstimateSize)] +pub struct PercentileCont { + fractions: Option, + data: Vec, +} + +impl PercentileCont { + pub fn new(fractions: Option) -> Self { + Self { + fractions, + data: vec![], + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + if let Some(datum) = datum_ref.to_owned_datum() { + self.data.push((*datum.as_float64()).into()); + } + } +} + +#[async_trait::async_trait] +impl Aggregator for PercentileCont { + fn return_type(&self) -> DataType { + DataType::Float64 + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + if let Some(fractions) = self.fractions && !self.data.is_empty() { + let rn = fractions * (self.data.len() - 1) as f64; + let crn = f64::ceil(rn); + let frn = f64::floor(rn); + let result = if crn == frn { + self.data[crn as usize] + } else { + (crn - rn) * self.data[frn as usize] + + (rn - frn) * self.data[crn as usize] + }; + builder.append(Some(ScalarImpl::Float64(result.into()))); + } else { + builder.append(Datum::None); + } + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/expr/src/agg/percentile_disc.rs b/src/expr/src/agg/percentile_disc.rs new file mode 100644 index 000000000000..a8ab7ccb0fe6 --- /dev/null +++ b/src/expr/src/agg/percentile_disc.rs @@ -0,0 +1,143 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use risingwave_common::array::*; +use risingwave_common::estimate_size::EstimateSize; +use risingwave_common::types::*; +use risingwave_expr_macro::build_aggregate; + +use super::Aggregator; +use crate::agg::AggCall; +use crate::Result; + +/// Computes the discrete percentile, the first value within the ordered set of aggregated argument +/// values whose position in the ordering equals or exceeds the specified fraction. The aggregated +/// argument must be of a sortable type. +/// +/// ```slt +/// statement ok +/// create table t(x int, y bigint, z real, w double, v varchar); +/// +/// statement ok +/// insert into t values(1,10,100,1000,'10000'),(2,20,200,2000,'20000'),(3,30,300,3000,'30000'); +/// +/// query R +/// select percentile_disc(0) within group (order by x) from t; +/// ---- +/// 1 +/// +/// query R +/// select percentile_disc(0.33) within group (order by y) from t; +/// ---- +/// 10 +/// +/// query R +/// select percentile_disc(0.34) within group (order by z) from t; +/// ---- +/// 200 +/// +/// query R +/// select percentile_disc(0.67) within group (order by w) from t +/// ---- +/// 3000 +/// +/// query R +/// select percentile_disc(1) within group (order by v) from t; +/// ---- +/// 30000 +/// +/// query R +/// select percentile_disc(NULL) within group (order by w) from t; +/// ---- +/// NULL +/// +/// statement ok +/// drop table t; +/// ``` +#[build_aggregate("percentile_disc(*) -> *")] +fn build(agg: AggCall) -> Result> { + let fraction: Option = agg.direct_args[0] + .literal() + .map(|x| (*x.as_float64()).into()); + Ok(Box::new(PercentileDisc::new(fraction, agg.return_type))) +} + +#[derive(Clone)] +pub struct PercentileDisc { + fractions: Option, + return_type: DataType, + data: Vec, +} + +impl EstimateSize for PercentileDisc { + fn estimated_heap_size(&self) -> usize { + self.data + .iter() + .fold(0, |acc, x| acc + x.estimated_heap_size()) + } +} + +impl PercentileDisc { + pub fn new(fractions: Option, return_type: DataType) -> Self { + Self { + fractions, + return_type, + data: vec![], + } + } + + fn add_datum(&mut self, datum_ref: DatumRef<'_>) { + if let Some(datum) = datum_ref.to_owned_datum() { + self.data.push(datum); + } + } +} + +#[async_trait::async_trait] +impl Aggregator for PercentileDisc { + fn return_type(&self) -> DataType { + self.return_type.clone() + } + + async fn update_multi( + &mut self, + input: &DataChunk, + start_row_id: usize, + end_row_id: usize, + ) -> Result<()> { + let array = input.column_at(0); + for row_id in start_row_id..end_row_id { + self.add_datum(array.value_at(row_id)); + } + Ok(()) + } + + fn output(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + if let Some(fractions) = self.fractions && !self.data.is_empty() { + let rn = fractions * self.data.len() as f64; + if fractions == 0.0 { + builder.append(Some(self.data[0].clone())); + } else { + builder.append(Some(self.data[f64::ceil(rn) as usize - 1].clone())); + } + } else { + builder.append(Datum::None); + } + Ok(()) + } + + fn estimated_size(&self) -> usize { + EstimateSize::estimated_size(self) + } +} diff --git a/src/frontend/planner_test/tests/testdata/input/agg.yaml b/src/frontend/planner_test/tests/testdata/input/agg.yaml index 6836385aa9ad..71f50323179c 100644 --- a/src/frontend/planner_test/tests/testdata/input/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/input/agg.yaml @@ -815,13 +815,38 @@ select percentile_cont('abc') within group (order by y) from t; expected_outputs: - binder_error +- sql: | + create table t (x int, y int); + select percentile_cont(1.3) within group (order by y) from t; + expected_outputs: + - binder_error - sql: | create table t (x int, y int); select percentile_cont(0, 0) within group (order by y) from t; expected_outputs: - binder_error +- sql: | + create table t (x int, y varchar); + select percentile_cont(0) within group (order by y) from t; + expected_outputs: + - binder_error - sql: | create table t (x int, y int); select percentile_cont(0) within group (order by y desc) from t; expected_outputs: - batch_plan +- sql: | + create table t (x int, y varchar); + select percentile_disc(1) within group (order by y desc) from t; + expected_outputs: + - batch_plan +- sql: | + create table t (x int, y varchar); + select mode() within group (order by y desc) from t; + expected_outputs: + - batch_plan +- sql: | + create table t (x int, y varchar); + select mode(1) within group (order by y desc) from t; + expected_outputs: + - binder_error \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/agg.yaml b/src/frontend/planner_test/tests/testdata/output/agg.yaml index 0f7c52f181a3..a0777d5f906b 100644 --- a/src/frontend/planner_test/tests/testdata/output/agg.yaml +++ b/src/frontend/planner_test/tests/testdata/output/agg.yaml @@ -1428,7 +1428,15 @@ Bind error: failed to bind expression: percentile_cont('abc') Caused by: - Invalid input syntax: arg in percentile_cont must be double precision + Invalid input syntax: arg in percentile_cont must be float64 +- sql: | + create table t (x int, y int); + select percentile_cont(1.3) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(1.3) + + Caused by: + Invalid input syntax: arg in percentile_cont must between 0 and 1 - sql: | create table t (x int, y int); select percentile_cont(0, 0) within group (order by y) from t; @@ -1437,10 +1445,41 @@ Caused by: Invalid input syntax: only one arg is expected in percentile_cont +- sql: | + create table t (x int, y varchar); + select percentile_cont(0) within group (order by y) from t; + binder_error: |- + Bind error: failed to bind expression: percentile_cont(0) + + Caused by: + Bind error: cannot cast type "varchar" to "double precision" in Implicit context - sql: | create table t (x int, y int); select percentile_cont(0) within group (order by y desc) from t; batch_plan: | - BatchSimpleAgg { aggs: [percentile_cont(t.y order_by(t.y DESC))] } + BatchSimpleAgg { aggs: [percentile_cont($expr1 order_by(t.y DESC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [t.y::Float64 as $expr1, t.y] } + └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select percentile_disc(1) within group (order by y desc) from t; + batch_plan: | + BatchSimpleAgg { aggs: [percentile_disc(t.y order_by(t.y DESC))] } + └─BatchExchange { order: [], dist: Single } + └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select mode() within group (order by y desc) from t; + batch_plan: | + BatchSimpleAgg { aggs: [mode(t.y order_by(t.y DESC))] } └─BatchExchange { order: [], dist: Single } └─BatchScan { table: t, columns: [t.y], distribution: SomeShard } +- sql: | + create table t (x int, y varchar); + select mode(1) within group (order by y desc) from t; + binder_error: |- + Bind error: failed to bind expression: mode(1) + + Caused by: + Invalid input syntax: no arguments are expected in mode agg diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 121f2f99fbc4..c1b089560bef 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -159,8 +159,14 @@ impl Binder { )) .into()); } + if kind == AggKind::Mode && !f.args.is_empty() { + return Err(ErrorCode::InvalidInputSyntax( + "no arguments are expected in mode agg".to_string(), + ) + .into()); + } self.ensure_aggregate_allowed()?; - let inputs: Vec = if f.within_group.is_some() { + let mut inputs: Vec = if f.within_group.is_some() { f.within_group .iter() .map(|x| self.bind_function_expr_arg(FunctionArgExpr::Expr(x.expr.clone()))) @@ -173,6 +179,15 @@ impl Binder { .flatten_ok() .try_collect()? }; + if kind == AggKind::PercentileCont { + inputs[0] = inputs + .iter() + .exactly_one() + .unwrap() + .clone() + .cast_implicit(DataType::Float64)?; + } + if f.distinct { match &kind { AggKind::Count if inputs.is_empty() => { @@ -280,13 +295,23 @@ impl Binder { .cast_implicit(DataType::Float64)? .fold_const() { - Ok::<_, RwError>(vec![Literal::new(casted, DataType::Float64)]) + if casted + .clone() + .is_some_and(|x| !(0.0..=1.0).contains(&Into::::into(*x.as_float64()))) + { + Err(ErrorCode::InvalidInputSyntax(format!( + "arg in {} must between 0 and 1", + kind + )) + .into()) + } else { + Ok::<_, RwError>(vec![Literal::new(casted, DataType::Float64)]) + } } else { - Err(ErrorCode::InvalidInputSyntax(format!( - "arg in {} must be double precision", - kind - )) - .into()) + Err( + ErrorCode::InvalidInputSyntax(format!("arg in {} must be float64", kind)) + .into(), + ) } } else { Ok(vec![]) diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index daf8d30859e8..6760376de1d8 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -69,14 +69,7 @@ impl AggCall { // XXX: some special cases that can not be handled by signature map. // may return list or struct type - ( - AggKind::Min - | AggKind::Max - | AggKind::FirstValue - | AggKind::PercentileDisc - | AggKind::Mode, - [input], - ) => input.clone(), + (AggKind::Min | AggKind::Max | AggKind::FirstValue, [input]) => input.clone(), (AggKind::ArrayAgg, [input]) => List(Box::new(input.clone())), // functions that are rewritten in the frontend and don't exist in the expr crate (AggKind::Avg, [input]) => match input { @@ -93,7 +86,12 @@ impl AggCall { Float32 | Float64 | Int256 => Float64, _ => return Err(err()), }, - (AggKind::PercentileCont, _) => Float64, + // Ordered-Set Aggregation + (AggKind::PercentileCont, [input]) => match input { + Float64 => Float64, + _ => return Err(err()), + }, + (AggKind::PercentileDisc | AggKind::Mode, [input]) => input.clone(), // other functions are handled by signature map _ => { diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 6cbbe5dccd3c..c1a563310af3 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -1062,6 +1062,24 @@ fn new_stream_hash_agg(logical: Agg, vnode_col_idx: Option) -> S impl ToStream for LogicalAgg { fn to_stream(&self, ctx: &mut ToStreamContext) -> Result { + for agg_call in self.agg_calls() { + if matches!( + agg_call.agg_kind, + AggKind::BitAnd + | AggKind::BitOr + | AggKind::BoolAnd + | AggKind::BoolOr + | AggKind::PercentileCont + | AggKind::PercentileDisc + | AggKind::Mode + ) { + return Err(ErrorCode::NotImplemented( + format!("{} aggregation in materialized view", agg_call.agg_kind), + None.into(), + ) + .into()); + } + } let eowc = ctx.emit_on_window_close(); let stream_input = self.input().to_stream(ctx)?; diff --git a/src/tests/sqlsmith/src/sql_gen/types.rs b/src/tests/sqlsmith/src/sql_gen/types.rs index b53209e6e243..9f0af5976569 100644 --- a/src/tests/sqlsmith/src/sql_gen/types.rs +++ b/src/tests/sqlsmith/src/sql_gen/types.rs @@ -219,6 +219,9 @@ pub(crate) static AGG_FUNC_TABLE: LazyLock>> = AggKind::BitOr, AggKind::BoolAnd, AggKind::BoolOr, + AggKind::PercentileCont, + AggKind::PercentileDisc, + AggKind::Mode, ] .contains(&func.func) }) From 5b382397b114d42010696a2a592c1096614d38c4 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 14 Jun 2023 13:20:21 +0200 Subject: [PATCH 2/4] fix: replace ouroboros with self_cell (#10316) --- Cargo.lock | 43 ++++---------------- src/expr/Cargo.toml | 2 +- src/expr/src/expr/expr_to_char_const_tmpl.rs | 4 +- src/expr/src/vector_op/to_char.rs | 27 ++++++------ src/expr/src/vector_op/to_timestamp.rs | 2 +- 5 files changed, 24 insertions(+), 54 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f66e42c71d30..861b3bfdab7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,12 +2,6 @@ # It is not intended for manual editing. version = 3 -[[package]] -name = "Inflector" -version = "0.11.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" - [[package]] name = "addr2line" version = "0.19.0" @@ -62,12 +56,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "aliasable" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" - [[package]] name = "android_system_properties" version = "0.1.5" @@ -4340,29 +4328,6 @@ version = "6.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ceedf44fb00f2d1984b0bc98102627ce622e083e49a5bacdb3e514fa4238e267" -[[package]] -name = "ouroboros" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1358bd1558bd2a083fed428ffeda486fbfb323e698cdda7794259d592ca72db" -dependencies = [ - "aliasable", - "ouroboros_macro", -] - -[[package]] -name = "ouroboros_macro" -version = "0.15.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f7d21ccd03305a674437ee1248f3ab5d4b1db095cf1caf49f1713ddf61956b7" -dependencies = [ - "Inflector", - "proc-macro-error", - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "output_vt100" version = "0.1.3" @@ -6072,7 +6037,6 @@ dependencies = [ "madsim-tokio", "md5", "num-traits", - "ouroboros", "parse-display", "paste", "regex", @@ -6081,6 +6045,7 @@ dependencies = [ "risingwave_pb", "risingwave_udf", "rust_decimal", + "self_cell", "serde_json", "sha1", "sha2", @@ -7010,6 +6975,12 @@ dependencies = [ "libc", ] +[[package]] +name = "self_cell" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c309e515543e67811222dbc9e3dd7e1056279b782e1dacffe4242b718734fb6" + [[package]] name = "semver" version = "1.0.17" diff --git a/src/expr/Cargo.toml b/src/expr/Cargo.toml index 9bd5fa2ecaf6..5dbe7d3d77ea 100644 --- a/src/expr/Cargo.toml +++ b/src/expr/Cargo.toml @@ -34,7 +34,6 @@ hex = "0.4.3" itertools = "0.10" md5 = "0.7.0" num-traits = "0.2" -ouroboros = "0.15" parse-display = "0.6" paste = "1" regex = "1" @@ -43,6 +42,7 @@ risingwave_expr_macro = { path = "macro" } risingwave_pb = { path = "../prost" } risingwave_udf = { path = "../udf" } rust_decimal = { version = "1", features = ["db-postgres", "maths"] } +self_cell = "1.0.0" serde_json = "1" sha1 = "0.10.5" sha2 = "0.10.6" diff --git a/src/expr/src/expr/expr_to_char_const_tmpl.rs b/src/expr/src/expr/expr_to_char_const_tmpl.rs index 5b4165ff84b6..4164dfed4405 100644 --- a/src/expr/src/expr/expr_to_char_const_tmpl.rs +++ b/src/expr/src/expr/expr_to_char_const_tmpl.rs @@ -56,7 +56,7 @@ impl Expression for ExprToCharConstTmpl { let mut writer = output.writer().begin(); let fmt = data .0 - .format_with_items(self.ctx.chrono_pattern.borrow_items().iter()); + .format_with_items(self.ctx.chrono_pattern.borrow_dependent().iter()); write!(writer, "{fmt}").unwrap(); writer.finish(); } else { @@ -72,7 +72,7 @@ impl Expression for ExprToCharConstTmpl { Ok(if let Some(ScalarImpl::Timestamp(data)) = data { Some( data.0 - .format_with_items(self.ctx.chrono_pattern.borrow_items().iter()) + .format_with_items(self.ctx.chrono_pattern.borrow_dependent().iter()) .to_string() .into(), ) diff --git a/src/expr/src/vector_op/to_char.rs b/src/expr/src/vector_op/to_char.rs index d6dccec50945..fd332cc7a767 100644 --- a/src/expr/src/vector_op/to_char.rs +++ b/src/expr/src/vector_op/to_char.rs @@ -17,22 +17,23 @@ use std::sync::LazyLock; use aho_corasick::{AhoCorasick, AhoCorasickBuilder}; use chrono::format::StrftimeItems; -use ouroboros::self_referencing; use risingwave_common::types::Timestamp; use static_assertions::const_assert_eq; -#[self_referencing] -pub struct ChronoPattern { - pub(crate) tmpl: String, - #[borrows(tmpl)] - #[covariant] - pub(crate) items: Vec>, +type Pattern<'a> = Vec>; + +self_cell::self_cell! { + pub struct ChronoPattern { + owner: String, + #[covariant] + dependent: Pattern, + } } impl Debug for ChronoPattern { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("ChronoPattern") - .field("tmpl", self.borrow_tmpl()) + .field("tmpl", self.borrow_owner()) .finish() } } @@ -67,16 +68,14 @@ pub fn compile_pattern_to_chrono(tmpl: &str) -> ChronoPattern { true }); tracing::debug!(tmpl, chrono_tmpl, "compile_pattern_to_chrono"); - ChronoPatternBuilder { - tmpl: chrono_tmpl, - items_builder: |tmpl| StrftimeItems::new(tmpl).collect::>(), - } - .build() + ChronoPattern::new(chrono_tmpl, |tmpl| { + StrftimeItems::new(tmpl).collect::>() + }) } // #[function("to_char(timestamp, varchar) -> varchar")] pub fn to_char_timestamp(data: Timestamp, tmpl: &str, writer: &mut dyn Write) { let pattern = compile_pattern_to_chrono(tmpl); - let format = data.0.format_with_items(pattern.borrow_items().iter()); + let format = data.0.format_with_items(pattern.borrow_dependent().iter()); write!(writer, "{}", format).unwrap(); } diff --git a/src/expr/src/vector_op/to_timestamp.rs b/src/expr/src/vector_op/to_timestamp.rs index 732d28ce5a7f..c9a9f0a4a30b 100644 --- a/src/expr/src/vector_op/to_timestamp.rs +++ b/src/expr/src/vector_op/to_timestamp.rs @@ -22,7 +22,7 @@ use crate::Result; #[inline(always)] pub fn to_timestamp_const_tmpl(s: &str, tmpl: &ChronoPattern) -> Result { let mut parsed = Parsed::new(); - chrono::format::parse(&mut parsed, s, tmpl.borrow_items().iter())?; + chrono::format::parse(&mut parsed, s, tmpl.borrow_dependent().iter())?; // chrono will only assign the default value for seconds/nanoseconds fields, and raise an error // for other ones. We should specify the default value manually. From 9593d1b61c765f0085532edc1ef23467bab84c30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tesla=20Zhang=E2=80=AE?= Date: Wed, 14 Jun 2023 08:40:29 -0400 Subject: [PATCH 3/4] refactor(plan_node_fmt): 4 more impls for Distill (#10296) --- .../src/optimizer/plan_node/batch_seq_scan.rs | 15 +---- .../src/optimizer/plan_node/generic/scan.rs | 23 +++++++ .../src/optimizer/plan_node/logical_now.rs | 18 ++++- .../plan_node/logical_project_set.rs | 18 +---- .../src/optimizer/plan_node/logical_scan.rs | 67 +++++++++++++++---- .../src/optimizer/plan_node/stream_now.rs | 18 ++++- 6 files changed, 115 insertions(+), 44 deletions(-) diff --git a/src/frontend/src/optimizer/plan_node/batch_seq_scan.rs b/src/frontend/src/optimizer/plan_node/batch_seq_scan.rs index 64810ab5b27d..27de0411b55a 100644 --- a/src/frontend/src/optimizer/plan_node/batch_seq_scan.rs +++ b/src/frontend/src/optimizer/plan_node/batch_seq_scan.rs @@ -183,19 +183,8 @@ impl Distill for BatchSeqScan { fn distill<'a>(&self) -> Pretty<'a> { let verbose = self.base.ctx.is_explain_verbose(); let mut vec = Vec::with_capacity(4); - vec.push(("table", Pretty::display(&self.logical.table_name))); - vec.push(( - "columns", - Pretty::Array( - match verbose { - true => self.logical.column_names_with_table_prefix(), - false => self.logical.column_names(), - } - .into_iter() - .map(Pretty::from) - .collect(), - ), - )); + vec.push(("table", Pretty::from(self.logical.table_name.clone()))); + vec.push(("columns", self.logical.columns_pretty(verbose))); if !self.scan_ranges.is_empty() { let range_strs = self.scan_ranges_as_strs(verbose); diff --git a/src/frontend/src/optimizer/plan_node/generic/scan.rs b/src/frontend/src/optimizer/plan_node/generic/scan.rs index e9cf440c4f3c..930bf4dc35ca 100644 --- a/src/frontend/src/optimizer/plan_node/generic/scan.rs +++ b/src/frontend/src/optimizer/plan_node/generic/scan.rs @@ -17,6 +17,7 @@ use std::rc::Rc; use educe::Educe; use fixedbitset::FixedBitSet; +use pretty_xmlish::Pretty; use risingwave_common::catalog::{ColumnDesc, Field, Schema, TableDesc}; use risingwave_common::util::column_index_mapping::ColIndexMapping; use risingwave_common::util::sort_util::ColumnOrder; @@ -264,6 +265,28 @@ impl Scan { ctx, } } + + pub(crate) fn columns_pretty<'a>(&self, verbose: bool) -> Pretty<'a> { + Pretty::Array( + match verbose { + true => self.column_names_with_table_prefix(), + false => self.column_names(), + } + .into_iter() + .map(Pretty::from) + .collect(), + ) + } + + pub(crate) fn fields_pretty_schema(&self) -> Schema { + let fields = self + .table_desc + .columns + .iter() + .map(|col| Field::from_with_table_name_prefix(col, &self.table_name)) + .collect(); + Schema { fields } + } } impl GenericPlanNode for Scan { diff --git a/src/frontend/src/optimizer/plan_node/logical_now.rs b/src/frontend/src/optimizer/plan_node/logical_now.rs index 9314da6461b7..983788ca0544 100644 --- a/src/frontend/src/optimizer/plan_node/logical_now.rs +++ b/src/frontend/src/optimizer/plan_node/logical_now.rs @@ -15,12 +15,13 @@ use std::fmt; use itertools::Itertools; +use pretty_xmlish::Pretty; use risingwave_common::bail; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::error::Result; use risingwave_common::types::DataType; -use super::utils::IndicesDisplay; +use super::utils::{Distill, IndicesDisplay}; use super::{ ColPrunable, ColumnPruningContext, ExprRewritable, LogicalFilter, PlanBase, PlanRef, PredicatePushdown, RewriteStreamContext, StreamNow, ToBatch, ToStream, ToStreamContext, @@ -47,6 +48,21 @@ impl LogicalNow { } } +impl Distill for LogicalNow { + fn distill<'a>(&self) -> Pretty<'a> { + let vec = if self.base.ctx.is_explain_verbose() { + let disp = Pretty::debug(&IndicesDisplay { + indices: &(0..self.schema().fields.len()).collect_vec(), + input_schema: self.schema(), + }); + vec![("output", disp)] + } else { + vec![] + }; + + Pretty::childless_record("LogicalNow", vec) + } +} impl fmt::Display for LogicalNow { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let verbose = self.base.ctx.is_explain_verbose(); diff --git a/src/frontend/src/optimizer/plan_node/logical_project_set.rs b/src/frontend/src/optimizer/plan_node/logical_project_set.rs index 8379f56478ae..7b0e51843c8a 100644 --- a/src/frontend/src/optimizer/plan_node/logical_project_set.rs +++ b/src/frontend/src/optimizer/plan_node/logical_project_set.rs @@ -12,13 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::fmt; - use fixedbitset::FixedBitSet; use itertools::Itertools; use risingwave_common::error::Result; use risingwave_common::types::DataType; +use super::utils::impl_distill_by_unit; use super::{ gen_filter_and_pushdown, generic, BatchProjectSet, ColPrunable, ExprRewritable, LogicalProject, PlanBase, PlanRef, PlanTreeNodeUnary, PredicatePushdown, StreamProjectSet, ToBatch, ToStream, @@ -187,13 +186,6 @@ impl LogicalProjectSet { pub fn select_list(&self) -> &Vec { &self.core.select_list } - - pub(super) fn fmt_with_name(&self, f: &mut fmt::Formatter<'_>, name: &str) -> fmt::Result { - let _verbose = self.base.ctx.is_explain_verbose(); - // TODO: add verbose display like Project - - self.core.fmt_with_name(f, name) - } } impl PlanTreeNodeUnary for LogicalProjectSet { @@ -225,12 +217,8 @@ impl PlanTreeNodeUnary for LogicalProjectSet { } impl_plan_tree_node_for_unary! {LogicalProjectSet} - -impl fmt::Display for LogicalProjectSet { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.fmt_with_name(f, "LogicalProjectSet") - } -} +impl_distill_by_unit!(LogicalProjectSet, core, "LogicalProjectSet"); +// TODO: add verbose display like Project impl ColPrunable for LogicalProjectSet { fn prune_col(&self, required_cols: &[usize], ctx: &mut ColumnPruningContext) -> PlanRef { diff --git a/src/frontend/src/optimizer/plan_node/logical_scan.rs b/src/frontend/src/optimizer/plan_node/logical_scan.rs index 7e64d3a1e2a7..729ac16f823e 100644 --- a/src/frontend/src/optimizer/plan_node/logical_scan.rs +++ b/src/frontend/src/optimizer/plan_node/logical_scan.rs @@ -18,11 +18,13 @@ use std::rc::Rc; use fixedbitset::FixedBitSet; use itertools::Itertools; -use risingwave_common::catalog::{ColumnDesc, Field, Schema, TableDesc}; +use pretty_xmlish::Pretty; +use risingwave_common::catalog::{ColumnDesc, TableDesc}; use risingwave_common::error::{ErrorCode, Result, RwError}; use risingwave_common::util::sort_util::ColumnOrder; use super::generic::{GenericPlanNode, GenericPlanRef}; +use super::utils::Distill; use super::{ generic, BatchFilter, BatchProject, ColPrunable, ExprRewritable, PlanBase, PlanRef, PredicatePushdown, StreamTableScan, ToBatch, ToStream, @@ -287,6 +289,52 @@ impl LogicalScan { impl_plan_tree_node_for_leaf! {LogicalScan} +impl Distill for LogicalScan { + fn distill<'a>(&self) -> Pretty<'a> { + let verbose = self.base.ctx.is_explain_verbose(); + let mut vec = Vec::with_capacity(5); + vec.push(("table", Pretty::from(self.table_name().to_owned()))); + let key_is_columns = + self.predicate().always_true() || self.output_col_idx() == self.required_col_idx(); + let key = if key_is_columns { + "columns" + } else { + "output_columns" + }; + vec.push((key, self.core.columns_pretty(verbose))); + if !key_is_columns { + vec.push(( + "required_columns", + Pretty::Array( + self.required_col_idx() + .iter() + .map(|i| { + let col_name = &self.table_desc().columns[*i].name; + Pretty::from(if verbose { + format!("{}.{}", self.table_name(), col_name) + } else { + col_name.to_string() + }) + }) + .collect(), + ), + )); + } + + if !self.predicate().always_true() { + let input_schema = self.core.fields_pretty_schema(); + vec.push(( + "predicate", + Pretty::display(&ConditionDisplay { + condition: self.predicate(), + input_schema: &input_schema, + }), + )) + } + + Pretty::childless_record("LogicalScan", vec) + } +} impl fmt::Display for LogicalScan { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let verbose = self.base.ctx.is_explain_verbose(); @@ -314,26 +362,17 @@ impl fmt::Display for LogicalScan { ", output_columns: [{}], required_columns: [{}]", output_col_names, self.required_col_idx().iter().format_with(", ", |i, f| { + let col_name = &self.table_desc().columns[*i].name; if verbose { - f(&format_args!( - "{}.{}", - self.table_name(), - self.table_desc().columns[*i].name - )) + f(&format_args!("{}.{}", self.table_name(), col_name)) } else { - f(&format_args!("{}", self.table_desc().columns[*i].name)) + f(&format_args!("{}", col_name)) } }) )?; } - let fields = self - .table_desc() - .columns - .iter() - .map(|col| Field::from_with_table_name_prefix(col, self.table_name())) - .collect_vec(); - let input_schema = Schema { fields }; + let input_schema = self.core.fields_pretty_schema(); write!( f, ", predicate: {} }}", diff --git a/src/frontend/src/optimizer/plan_node/stream_now.rs b/src/frontend/src/optimizer/plan_node/stream_now.rs index de0c0aee8501..1c4ba6b8ae93 100644 --- a/src/frontend/src/optimizer/plan_node/stream_now.rs +++ b/src/frontend/src/optimizer/plan_node/stream_now.rs @@ -16,6 +16,7 @@ use std::fmt; use fixedbitset::FixedBitSet; use itertools::Itertools; +use pretty_xmlish::Pretty; use risingwave_common::catalog::{Field, Schema}; use risingwave_common::types::DataType; use risingwave_pb::stream_plan::stream_node::NodeBody; @@ -23,7 +24,7 @@ use risingwave_pb::stream_plan::NowNode; use super::generic::GenericPlanRef; use super::stream::StreamPlanRef; -use super::utils::{formatter_debug_plan_node, IndicesDisplay, TableCatalogBuilder}; +use super::utils::{formatter_debug_plan_node, Distill, IndicesDisplay, TableCatalogBuilder}; use super::{ExprRewritable, LogicalNow, PlanBase, StreamNode}; use crate::optimizer::property::{Distribution, FunctionalDependencySet}; use crate::stream_fragmenter::BuildFragmentGraphState; @@ -58,6 +59,21 @@ impl StreamNow { } } +impl Distill for StreamNow { + fn distill<'a>(&self) -> Pretty<'a> { + let vec = if self.base.ctx.is_explain_verbose() { + let disp = Pretty::debug(&IndicesDisplay { + indices: &(0..self.schema().fields.len()).collect_vec(), + input_schema: self.schema(), + }); + vec![("output", disp)] + } else { + vec![] + }; + + Pretty::childless_record("StreamNow", vec) + } +} impl fmt::Display for StreamNow { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let verbose = self.base.ctx.is_explain_verbose(); From 5cf94c934698a906711a1f421ef5a785c5e1b377 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 14 Jun 2023 16:18:53 +0200 Subject: [PATCH 4/4] feat: support scalar function in FROM clause (#10317) --- Makefile.toml | 2 +- e2e_test/batch/functions/func_in_from.part | 14 +++ e2e_test/streaming/values.slt | 18 +++ .../tests/testdata/input/expr.yaml | 27 ++++ .../tests/testdata/output/expr.yaml | 32 +++++ src/frontend/src/binder/bind_context.rs | 1 + src/frontend/src/binder/expr/function.rs | 32 ++--- src/frontend/src/binder/mod.rs | 11 +- src/frontend/src/binder/relation/mod.rs | 111 +---------------- .../src/binder/relation/table_function.rs | 117 ++++++++++++++++++ src/frontend/src/planner/relation.rs | 21 +++- 11 files changed, 252 insertions(+), 134 deletions(-) create mode 100644 e2e_test/batch/functions/func_in_from.part create mode 100644 src/frontend/src/binder/relation/table_function.rs diff --git a/Makefile.toml b/Makefile.toml index 9d5ba0e71d82..9780bda9d287 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -910,7 +910,7 @@ if [ $# -gt 0 ]; then ARGS=("$@") echo "Applying clippy --fix for $@ (including dirty and staged files)" - cargo clippy ${ARGS[@]/#/--package risingwave_} ${RISINGWAVE_FEATURE_FLAGS} --fix --allow-dirty --allow-staged + cargo clippy ${ARGS[@]/#/--package risingwave_} --fix --allow-dirty --allow-staged else echo "Applying clippy --fix for all targets to all files (including dirty and staged files)" echo "Tip: run $(tput setaf 4)./risedev cf {package_names}$(tput sgr0) to only check-fix those packages (e.g. frontend, meta)." diff --git a/e2e_test/batch/functions/func_in_from.part b/e2e_test/batch/functions/func_in_from.part new file mode 100644 index 000000000000..a8c85180468f --- /dev/null +++ b/e2e_test/batch/functions/func_in_from.part @@ -0,0 +1,14 @@ +query I +select abs.abs from abs(-1); +---- +1 + +query I +select alias.alias from abs(-1) alias; +---- +1 + +query I +select alias.col from abs(-1) alias(col); +---- +1 diff --git a/e2e_test/streaming/values.slt b/e2e_test/streaming/values.slt index c07ec20edc82..74b07f8d4cca 100644 --- a/e2e_test/streaming/values.slt +++ b/e2e_test/streaming/values.slt @@ -35,3 +35,21 @@ drop materialized view mv; statement ok drop table t; + +statement ok +create materialized view mv as select * from abs(-1); + +# TODO: support this +statement error not yet implemented: LogicalTableFunction::logical_rewrite_for_stream +create materialized view mv2 as select * from range(1,2); + +statement ok +flush; + +query IR +select * from mv; +---- +1 + +statement ok +drop materialized view mv; diff --git a/src/frontend/planner_test/tests/testdata/input/expr.yaml b/src/frontend/planner_test/tests/testdata/input/expr.yaml index de237ade68d8..be7531bc2344 100644 --- a/src/frontend/planner_test/tests/testdata/input/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/input/expr.yaml @@ -423,3 +423,30 @@ sql: select 1 / 0 t1; expected_outputs: - batch_error +# functions in FROM clause +- sql: | + select * from abs(-1); + expected_outputs: + - batch_plan + - stream_plan +- sql: | + select * from range(1,2); + expected_outputs: + - batch_plan + # TODO: support this + - stream_error +- sql: | + select * from max(); + expected_outputs: + - binder_error +- name: Grafana issue-10134 + sql: | + SELECT * FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + expected_outputs: + - batch_plan + - stream_error \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index 1a74f9767bbc..c388e4026456 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -619,3 +619,35 @@ - name: const_eval of division by 0 error sql: select 1 / 0 t1; batch_error: 'Expr error: Division by zero' +- sql: | + select * from abs(-1); + batch_plan: | + BatchValues { rows: [[1:Int32]] } + stream_plan: | + StreamMaterialize { columns: [abs, _row_id(hidden)], stream_key: [_row_id], pk_columns: [_row_id], pk_conflict: "NoCheck", watermark_columns: [abs] } + └─StreamValues { rows: [[Abs(-1:Int32), 0:Int64]] } +- sql: | + select * from range(1,2); + batch_plan: | + BatchTableFunction { Range(1:Int32, 2:Int32) } + stream_error: |- + Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream + No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml +- sql: | + select * from max(); + binder_error: 'Invalid input syntax: aggregate functions are not allowed in FROM' +- name: Grafana issue-10134 + sql: | + SELECT * FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + batch_plan: | + BatchNestedLoopJoin { type: Inner, predicate: true, output: all } + ├─BatchTableFunction { GenerateSeries(1:Int32, 2:Int32) } + └─BatchValues { rows: [[ARRAY["$user", public]:List(Varchar)]] } + stream_error: |- + Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream + No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml diff --git a/src/frontend/src/binder/bind_context.rs b/src/frontend/src/binder/bind_context.rs index 838012c8ecac..5101a1b2f0d0 100644 --- a/src/frontend/src/binder/bind_context.rs +++ b/src/frontend/src/binder/bind_context.rs @@ -52,6 +52,7 @@ pub enum Clause { GroupBy, Having, Filter, + From, } /// A `BindContext` that is only visible if the `LATERAL` keyword diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c1b089560bef..309b18d912c5 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -43,7 +43,7 @@ use crate::expr::{ use crate::utils::Condition; impl Binder { - pub(super) fn bind_function(&mut self, f: Function) -> Result { + pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { let function_name = match f.name.0.as_slice() { [name] => name.real_value(), [schema, name] => { @@ -114,18 +114,10 @@ impl Binder { // user defined function // TODO: resolve schema name - if let Some(func) = self - .catalog - .first_valid_schema( - &self.db_name, - &self.search_path, - &self.auth_context.user_name, - )? - .get_function_by_name_args( - &function_name, - &inputs.iter().map(|arg| arg.return_type()).collect_vec(), - ) - { + if let Some(func) = self.first_valid_schema()?.get_function_by_name_args( + &function_name, + &inputs.iter().map(|arg| arg.return_type()).collect_vec(), + ) { use crate::catalog::function_catalog::FunctionKind::*; match &func.kind { Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()), @@ -676,12 +668,7 @@ impl Binder { }))), ("current_schema", guard_by_len(0, raw(|binder, _inputs| { return Ok(binder - .catalog - .first_valid_schema( - &binder.db_name, - &binder.search_path, - &binder.auth_context.user_name, - ) + .first_valid_schema() .map(|schema| ExprImpl::literal_varchar(schema.name())) .unwrap_or_else(|_| ExprImpl::literal_null(DataType::Varchar))); }))), @@ -909,7 +896,8 @@ impl Binder { | Clause::Values | Clause::GroupBy | Clause::Having - | Clause::Filter => { + | Clause::Filter + | Clause::From => { return Err(ErrorCode::InvalidInputSyntax(format!( "window functions are not allowed in {}", clause @@ -950,7 +938,7 @@ impl Binder { fn ensure_aggregate_allowed(&self) -> Result<()> { if let Some(clause) = self.context.clause { match clause { - Clause::Where | Clause::Values => { + Clause::Where | Clause::Values | Clause::From => { return Err(ErrorCode::InvalidInputSyntax(format!( "aggregate functions are not allowed in {}", clause @@ -973,7 +961,7 @@ impl Binder { )) .into()); } - Clause::GroupBy | Clause::Having | Clause::Filter => {} + Clause::GroupBy | Clause::Having | Clause::Filter | Clause::From => {} } } Ok(()) diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index f560981d5d56..2b831b97f9e8 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -55,7 +55,8 @@ pub use update::BoundUpdate; pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; -use crate::catalog::{TableId, ViewId}; +use crate::catalog::schema_catalog::SchemaCatalog; +use crate::catalog::{CatalogResult, TableId, ViewId}; use crate::session::{AuthContext, SessionImpl}; pub type ShareId = usize; @@ -350,6 +351,14 @@ impl Binder { self.next_share_id += 1; id } + + fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> { + self.catalog.first_valid_schema( + &self.db_name, + &self.search_path, + &self.auth_context.user_name, + ) + } } #[cfg(test)] diff --git a/src/frontend/src/binder/relation/mod.rs b/src/frontend/src/binder/relation/mod.rs index e0c8dc46a40a..c0ca60712b69 100644 --- a/src/frontend/src/binder/relation/mod.rs +++ b/src/frontend/src/binder/relation/mod.rs @@ -14,32 +14,22 @@ use std::collections::hash_map::Entry; use std::ops::Deref; -use std::str::FromStr; -use itertools::Itertools; -use risingwave_common::catalog::{ - Field, Schema, TableId, DEFAULT_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME, - RW_INTERNAL_TABLE_FUNCTION_NAME, -}; +use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME}; use risingwave_common::error::{internal_error, ErrorCode, Result, RwError}; -use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{ Expr as ParserExpr, FunctionArg, FunctionArgExpr, Ident, ObjectName, TableAlias, TableFactor, }; -use self::watermark::is_watermark_func; use super::bind_context::ColumnBinding; use super::statement::RewriteExprsRecursive; use crate::binder::Binder; -use crate::catalog::function_catalog::FunctionKind; -use crate::catalog::system_catalog::pg_catalog::{ - PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME, -}; -use crate::expr::{Expr, ExprImpl, InputRef, TableFunction, TableFunctionType}; +use crate::expr::{ExprImpl, InputRef}; mod join; mod share; mod subquery; +mod table_function; mod table_or_source; mod watermark; mod window_table_function; @@ -63,7 +53,7 @@ pub enum Relation { Subquery(Box), Join(Box), WindowTableFunction(Box), - TableFunction(Box), + TableFunction(ExprImpl), Watermark(Box), Share(Box), } @@ -76,13 +66,7 @@ impl RewriteExprsRecursive for Relation { Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter), Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter), Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter), - Relation::TableFunction(inner) => { - let new_args = std::mem::take(&mut inner.args) - .into_iter() - .map(|expr| rewriter.rewrite_expr(expr)) - .collect(); - inner.args = new_args; - } + Relation::TableFunction(inner) => *inner = rewriter.rewrite_expr(inner.take()), _ => {} } } @@ -405,90 +389,7 @@ impl Binder { for_system_time_as_of_proctime, } => self.bind_relation_by_name(name, alias, for_system_time_as_of_proctime), TableFactor::TableFunction { name, alias, args } => { - let func_name = &name.0[0].real_value(); - if func_name.eq_ignore_ascii_case(RW_INTERNAL_TABLE_FUNCTION_NAME) { - return self.bind_internal_table(args, alias); - } - if func_name.eq_ignore_ascii_case(PG_GET_KEYWORDS_FUNC_NAME) - || name.real_value().eq_ignore_ascii_case( - format!("{}.{}", PG_CATALOG_SCHEMA_NAME, PG_GET_KEYWORDS_FUNC_NAME) - .as_str(), - ) - { - return self.bind_relation_by_name_inner( - Some(PG_CATALOG_SCHEMA_NAME), - PG_KEYWORDS_TABLE_NAME, - alias, - false, - ); - } - if let Ok(kind) = WindowTableFunctionKind::from_str(func_name) { - return Ok(Relation::WindowTableFunction(Box::new( - self.bind_window_table_function(alias, kind, args)?, - ))); - } - if is_watermark_func(func_name) { - return Ok(Relation::Watermark(Box::new( - self.bind_watermark(alias, args)?, - ))); - }; - - let args: Vec = args - .into_iter() - .map(|arg| self.bind_function_arg(arg)) - .flatten_ok() - .try_collect()?; - let tf = if let Some(func) = self - .catalog - .first_valid_schema( - &self.db_name, - &self.search_path, - &self.auth_context.user_name, - )? - .get_function_by_name_args( - func_name, - &args.iter().map(|arg| arg.return_type()).collect_vec(), - ) - && matches!(func.kind, FunctionKind::Table { .. }) - { - TableFunction::new_user_defined(func.clone(), args) - } else if let Ok(table_function_type) = TableFunctionType::from_str(func_name) { - TableFunction::new(table_function_type, args)? - } else { - return Err(ErrorCode::NotImplemented( - format!("unknown table function: {}", func_name), - 1191.into(), - ) - .into()); - }; - let columns = if let DataType::Struct(s) = tf.return_type() { - // If the table function returns a struct, it's fields can be accessed just - // like a table's columns. - let schema = Schema::from(&s); - schema.fields.into_iter().map(|f| (false, f)).collect_vec() - } else { - // If there is an table alias, we should use the alias as the table function's - // column name. If column aliases are also provided, they - // are handled in bind_table_to_context. - // - // Note: named return value should take precedence over table alias. - // But we don't support it yet. - // e.g., - // ``` - // > create function foo(ret out int) language sql as 'select 1'; - // > select t.ret from foo() as t; - // ``` - let col_name = if let Some(alias) = &alias { - alias.name.real_value() - } else { - tf.name() - }; - vec![(false, Field::with_name(tf.return_type(), col_name))] - }; - - self.bind_table_to_context(columns, tf.name(), alias)?; - - Ok(Relation::TableFunction(Box::new(tf))) + self.bind_table_function(name, alias, args) } TableFactor::Derived { lateral, diff --git a/src/frontend/src/binder/relation/table_function.rs b/src/frontend/src/binder/relation/table_function.rs new file mode 100644 index 000000000000..1be11687fb1c --- /dev/null +++ b/src/frontend/src/binder/relation/table_function.rs @@ -0,0 +1,117 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; + +use itertools::Itertools; +use risingwave_common::catalog::{ + Field, Schema, PG_CATALOG_SCHEMA_NAME, RW_INTERNAL_TABLE_FUNCTION_NAME, +}; +use risingwave_common::types::DataType; +use risingwave_sqlparser::ast::{Function, FunctionArg, ObjectName, TableAlias}; + +use super::watermark::is_watermark_func; +use super::{Binder, Relation, Result, WindowTableFunctionKind}; +use crate::binder::bind_context::Clause; +use crate::catalog::system_catalog::pg_catalog::{ + PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME, +}; +use crate::expr::Expr; + +impl Binder { + /// Binds a table function AST, which is a function call in a relation position. + /// + /// Besides [`TableFunction`] expr, it can also be other things like window table functions, or + /// scalar functions. + pub(super) fn bind_table_function( + &mut self, + name: ObjectName, + alias: Option, + args: Vec, + ) -> Result { + let func_name = &name.0[0].real_value(); + // internal/system table functions + { + if func_name.eq_ignore_ascii_case(RW_INTERNAL_TABLE_FUNCTION_NAME) { + return self.bind_internal_table(args, alias); + } + if func_name.eq_ignore_ascii_case(PG_GET_KEYWORDS_FUNC_NAME) + || name.real_value().eq_ignore_ascii_case( + format!("{}.{}", PG_CATALOG_SCHEMA_NAME, PG_GET_KEYWORDS_FUNC_NAME).as_str(), + ) + { + return self.bind_relation_by_name_inner( + Some(PG_CATALOG_SCHEMA_NAME), + PG_KEYWORDS_TABLE_NAME, + alias, + false, + ); + } + } + // window table functions (tumble/hop) + if let Ok(kind) = WindowTableFunctionKind::from_str(func_name) { + return Ok(Relation::WindowTableFunction(Box::new( + self.bind_window_table_function(alias, kind, args)?, + ))); + } + // watermark + if is_watermark_func(func_name) { + return Ok(Relation::Watermark(Box::new( + self.bind_watermark(alias, args)?, + ))); + }; + + let mut clause = Some(Clause::From); + std::mem::swap(&mut self.context.clause, &mut clause); + let func = self.bind_function(Function { + name, + args, + over: None, + distinct: false, + order_by: vec![], + filter: None, + within_group: None, + })?; + self.context.clause = clause; + + let columns = if let DataType::Struct(s) = func.return_type() { + // If the table function returns a struct, it's fields can be accessed just + // like a table's columns. + let schema = Schema::from(&s); + schema.fields.into_iter().map(|f| (false, f)).collect_vec() + } else { + // If there is an table alias, we should use the alias as the table function's + // column name. If column aliases are also provided, they + // are handled in bind_table_to_context. + // + // Note: named return value should take precedence over table alias. + // But we don't support it yet. + // e.g., + // ``` + // > create function foo(ret out int) language sql as 'select 1'; + // > select t.ret from foo() as t; + // ``` + let col_name = if let Some(alias) = &alias { + alias.name.real_value() + } else { + func_name.clone() + }; + vec![(false, Field::with_name(func.return_type(), col_name))] + }; + + self.bind_table_to_context(columns, func_name.clone(), alias)?; + + Ok(Relation::TableFunction(func)) + } +} diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 2d23e74f8b74..ad3a6279d76a 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -15,6 +15,7 @@ use std::rc::Rc; use itertools::Itertools; +use risingwave_common::catalog::{Field, Schema}; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::{DataType, Interval, ScalarImpl}; @@ -22,10 +23,10 @@ use crate::binder::{ BoundBaseTable, BoundJoin, BoundShare, BoundSource, BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation, WindowTableFunctionKind, }; -use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, TableFunction}; +use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef}; use crate::optimizer::plan_node::{ LogicalHopWindow, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, LogicalSource, - LogicalTableFunction, PlanRef, + LogicalTableFunction, LogicalValues, PlanRef, }; use crate::planner::Planner; @@ -42,7 +43,7 @@ impl Planner { Relation::Join(join) => self.plan_join(*join), Relation::WindowTableFunction(tf) => self.plan_window_table_function(*tf), Relation::Source(s) => self.plan_source(*s), - Relation::TableFunction(tf) => self.plan_table_function(*tf), + Relation::TableFunction(tf) => self.plan_table_function(tf), Relation::Watermark(tf) => self.plan_watermark(*tf), Relation::Share(share) => self.plan_share(*share), } @@ -115,8 +116,18 @@ impl Planner { } } - pub(super) fn plan_table_function(&mut self, table_function: TableFunction) -> Result { - Ok(LogicalTableFunction::new(table_function, self.ctx()).into()) + pub(super) fn plan_table_function(&mut self, table_function: ExprImpl) -> Result { + // TODO: maybe we can unify LogicalTableFunction with LogicalValues + match table_function { + ExprImpl::TableFunction(tf) => Ok(LogicalTableFunction::new(*tf, self.ctx()).into()), + expr => { + let schema = Schema { + // TODO: should be named + fields: vec![Field::unnamed(expr.return_type())], + }; + Ok(LogicalValues::create(vec![vec![expr]], schema, self.ctx())) + } + } } pub(super) fn plan_share(&mut self, share: BoundShare) -> Result {