From 5cf94c934698a906711a1f421ef5a785c5e1b377 Mon Sep 17 00:00:00 2001 From: xxchan Date: Wed, 14 Jun 2023 16:18:53 +0200 Subject: [PATCH] feat: support scalar function in FROM clause (#10317) --- Makefile.toml | 2 +- e2e_test/batch/functions/func_in_from.part | 14 +++ e2e_test/streaming/values.slt | 18 +++ .../tests/testdata/input/expr.yaml | 27 ++++ .../tests/testdata/output/expr.yaml | 32 +++++ src/frontend/src/binder/bind_context.rs | 1 + src/frontend/src/binder/expr/function.rs | 32 ++--- src/frontend/src/binder/mod.rs | 11 +- src/frontend/src/binder/relation/mod.rs | 111 +---------------- .../src/binder/relation/table_function.rs | 117 ++++++++++++++++++ src/frontend/src/planner/relation.rs | 21 +++- 11 files changed, 252 insertions(+), 134 deletions(-) create mode 100644 e2e_test/batch/functions/func_in_from.part create mode 100644 src/frontend/src/binder/relation/table_function.rs diff --git a/Makefile.toml b/Makefile.toml index 9d5ba0e71d828..9780bda9d2873 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -910,7 +910,7 @@ if [ $# -gt 0 ]; then ARGS=("$@") echo "Applying clippy --fix for $@ (including dirty and staged files)" - cargo clippy ${ARGS[@]/#/--package risingwave_} ${RISINGWAVE_FEATURE_FLAGS} --fix --allow-dirty --allow-staged + cargo clippy ${ARGS[@]/#/--package risingwave_} --fix --allow-dirty --allow-staged else echo "Applying clippy --fix for all targets to all files (including dirty and staged files)" echo "Tip: run $(tput setaf 4)./risedev cf {package_names}$(tput sgr0) to only check-fix those packages (e.g. frontend, meta)." diff --git a/e2e_test/batch/functions/func_in_from.part b/e2e_test/batch/functions/func_in_from.part new file mode 100644 index 0000000000000..a8c85180468f6 --- /dev/null +++ b/e2e_test/batch/functions/func_in_from.part @@ -0,0 +1,14 @@ +query I +select abs.abs from abs(-1); +---- +1 + +query I +select alias.alias from abs(-1) alias; +---- +1 + +query I +select alias.col from abs(-1) alias(col); +---- +1 diff --git a/e2e_test/streaming/values.slt b/e2e_test/streaming/values.slt index c07ec20edc826..74b07f8d4ccad 100644 --- a/e2e_test/streaming/values.slt +++ b/e2e_test/streaming/values.slt @@ -35,3 +35,21 @@ drop materialized view mv; statement ok drop table t; + +statement ok +create materialized view mv as select * from abs(-1); + +# TODO: support this +statement error not yet implemented: LogicalTableFunction::logical_rewrite_for_stream +create materialized view mv2 as select * from range(1,2); + +statement ok +flush; + +query IR +select * from mv; +---- +1 + +statement ok +drop materialized view mv; diff --git a/src/frontend/planner_test/tests/testdata/input/expr.yaml b/src/frontend/planner_test/tests/testdata/input/expr.yaml index de237ade68d8e..be7531bc2344e 100644 --- a/src/frontend/planner_test/tests/testdata/input/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/input/expr.yaml @@ -423,3 +423,30 @@ sql: select 1 / 0 t1; expected_outputs: - batch_error +# functions in FROM clause +- sql: | + select * from abs(-1); + expected_outputs: + - batch_plan + - stream_plan +- sql: | + select * from range(1,2); + expected_outputs: + - batch_plan + # TODO: support this + - stream_error +- sql: | + select * from max(); + expected_outputs: + - binder_error +- name: Grafana issue-10134 + sql: | + SELECT * FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + expected_outputs: + - batch_plan + - stream_error \ No newline at end of file diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index 1a74f9767bbc1..c388e40264562 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -619,3 +619,35 @@ - name: const_eval of division by 0 error sql: select 1 / 0 t1; batch_error: 'Expr error: Division by zero' +- sql: | + select * from abs(-1); + batch_plan: | + BatchValues { rows: [[1:Int32]] } + stream_plan: | + StreamMaterialize { columns: [abs, _row_id(hidden)], stream_key: [_row_id], pk_columns: [_row_id], pk_conflict: "NoCheck", watermark_columns: [abs] } + └─StreamValues { rows: [[Abs(-1:Int32), 0:Int64]] } +- sql: | + select * from range(1,2); + batch_plan: | + BatchTableFunction { Range(1:Int32, 2:Int32) } + stream_error: |- + Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream + No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml +- sql: | + select * from max(); + binder_error: 'Invalid input syntax: aggregate functions are not allowed in FROM' +- name: Grafana issue-10134 + sql: | + SELECT * FROM + generate_series( + array_lower(string_to_array(current_setting('search_path'),','),1), + array_upper(string_to_array(current_setting('search_path'),','),1) + ) as i, + string_to_array(current_setting('search_path'),',') s + batch_plan: | + BatchNestedLoopJoin { type: Inner, predicate: true, output: all } + ├─BatchTableFunction { GenerateSeries(1:Int32, 2:Int32) } + └─BatchValues { rows: [[ARRAY["$user", public]:List(Varchar)]] } + stream_error: |- + Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream + No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml diff --git a/src/frontend/src/binder/bind_context.rs b/src/frontend/src/binder/bind_context.rs index 838012c8ecac0..5101a1b2f0d08 100644 --- a/src/frontend/src/binder/bind_context.rs +++ b/src/frontend/src/binder/bind_context.rs @@ -52,6 +52,7 @@ pub enum Clause { GroupBy, Having, Filter, + From, } /// A `BindContext` that is only visible if the `LATERAL` keyword diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c1b089560bef4..309b18d912c54 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -43,7 +43,7 @@ use crate::expr::{ use crate::utils::Condition; impl Binder { - pub(super) fn bind_function(&mut self, f: Function) -> Result { + pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result { let function_name = match f.name.0.as_slice() { [name] => name.real_value(), [schema, name] => { @@ -114,18 +114,10 @@ impl Binder { // user defined function // TODO: resolve schema name - if let Some(func) = self - .catalog - .first_valid_schema( - &self.db_name, - &self.search_path, - &self.auth_context.user_name, - )? - .get_function_by_name_args( - &function_name, - &inputs.iter().map(|arg| arg.return_type()).collect_vec(), - ) - { + if let Some(func) = self.first_valid_schema()?.get_function_by_name_args( + &function_name, + &inputs.iter().map(|arg| arg.return_type()).collect_vec(), + ) { use crate::catalog::function_catalog::FunctionKind::*; match &func.kind { Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()), @@ -676,12 +668,7 @@ impl Binder { }))), ("current_schema", guard_by_len(0, raw(|binder, _inputs| { return Ok(binder - .catalog - .first_valid_schema( - &binder.db_name, - &binder.search_path, - &binder.auth_context.user_name, - ) + .first_valid_schema() .map(|schema| ExprImpl::literal_varchar(schema.name())) .unwrap_or_else(|_| ExprImpl::literal_null(DataType::Varchar))); }))), @@ -909,7 +896,8 @@ impl Binder { | Clause::Values | Clause::GroupBy | Clause::Having - | Clause::Filter => { + | Clause::Filter + | Clause::From => { return Err(ErrorCode::InvalidInputSyntax(format!( "window functions are not allowed in {}", clause @@ -950,7 +938,7 @@ impl Binder { fn ensure_aggregate_allowed(&self) -> Result<()> { if let Some(clause) = self.context.clause { match clause { - Clause::Where | Clause::Values => { + Clause::Where | Clause::Values | Clause::From => { return Err(ErrorCode::InvalidInputSyntax(format!( "aggregate functions are not allowed in {}", clause @@ -973,7 +961,7 @@ impl Binder { )) .into()); } - Clause::GroupBy | Clause::Having | Clause::Filter => {} + Clause::GroupBy | Clause::Having | Clause::Filter | Clause::From => {} } } Ok(()) diff --git a/src/frontend/src/binder/mod.rs b/src/frontend/src/binder/mod.rs index f560981d5d569..2b831b97f9e87 100644 --- a/src/frontend/src/binder/mod.rs +++ b/src/frontend/src/binder/mod.rs @@ -55,7 +55,8 @@ pub use update::BoundUpdate; pub use values::BoundValues; use crate::catalog::catalog_service::CatalogReadGuard; -use crate::catalog::{TableId, ViewId}; +use crate::catalog::schema_catalog::SchemaCatalog; +use crate::catalog::{CatalogResult, TableId, ViewId}; use crate::session::{AuthContext, SessionImpl}; pub type ShareId = usize; @@ -350,6 +351,14 @@ impl Binder { self.next_share_id += 1; id } + + fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> { + self.catalog.first_valid_schema( + &self.db_name, + &self.search_path, + &self.auth_context.user_name, + ) + } } #[cfg(test)] diff --git a/src/frontend/src/binder/relation/mod.rs b/src/frontend/src/binder/relation/mod.rs index e0c8dc46a40a0..c0ca60712b69d 100644 --- a/src/frontend/src/binder/relation/mod.rs +++ b/src/frontend/src/binder/relation/mod.rs @@ -14,32 +14,22 @@ use std::collections::hash_map::Entry; use std::ops::Deref; -use std::str::FromStr; -use itertools::Itertools; -use risingwave_common::catalog::{ - Field, Schema, TableId, DEFAULT_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME, - RW_INTERNAL_TABLE_FUNCTION_NAME, -}; +use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME}; use risingwave_common::error::{internal_error, ErrorCode, Result, RwError}; -use risingwave_common::types::DataType; use risingwave_sqlparser::ast::{ Expr as ParserExpr, FunctionArg, FunctionArgExpr, Ident, ObjectName, TableAlias, TableFactor, }; -use self::watermark::is_watermark_func; use super::bind_context::ColumnBinding; use super::statement::RewriteExprsRecursive; use crate::binder::Binder; -use crate::catalog::function_catalog::FunctionKind; -use crate::catalog::system_catalog::pg_catalog::{ - PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME, -}; -use crate::expr::{Expr, ExprImpl, InputRef, TableFunction, TableFunctionType}; +use crate::expr::{ExprImpl, InputRef}; mod join; mod share; mod subquery; +mod table_function; mod table_or_source; mod watermark; mod window_table_function; @@ -63,7 +53,7 @@ pub enum Relation { Subquery(Box), Join(Box), WindowTableFunction(Box), - TableFunction(Box), + TableFunction(ExprImpl), Watermark(Box), Share(Box), } @@ -76,13 +66,7 @@ impl RewriteExprsRecursive for Relation { Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter), Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter), Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter), - Relation::TableFunction(inner) => { - let new_args = std::mem::take(&mut inner.args) - .into_iter() - .map(|expr| rewriter.rewrite_expr(expr)) - .collect(); - inner.args = new_args; - } + Relation::TableFunction(inner) => *inner = rewriter.rewrite_expr(inner.take()), _ => {} } } @@ -405,90 +389,7 @@ impl Binder { for_system_time_as_of_proctime, } => self.bind_relation_by_name(name, alias, for_system_time_as_of_proctime), TableFactor::TableFunction { name, alias, args } => { - let func_name = &name.0[0].real_value(); - if func_name.eq_ignore_ascii_case(RW_INTERNAL_TABLE_FUNCTION_NAME) { - return self.bind_internal_table(args, alias); - } - if func_name.eq_ignore_ascii_case(PG_GET_KEYWORDS_FUNC_NAME) - || name.real_value().eq_ignore_ascii_case( - format!("{}.{}", PG_CATALOG_SCHEMA_NAME, PG_GET_KEYWORDS_FUNC_NAME) - .as_str(), - ) - { - return self.bind_relation_by_name_inner( - Some(PG_CATALOG_SCHEMA_NAME), - PG_KEYWORDS_TABLE_NAME, - alias, - false, - ); - } - if let Ok(kind) = WindowTableFunctionKind::from_str(func_name) { - return Ok(Relation::WindowTableFunction(Box::new( - self.bind_window_table_function(alias, kind, args)?, - ))); - } - if is_watermark_func(func_name) { - return Ok(Relation::Watermark(Box::new( - self.bind_watermark(alias, args)?, - ))); - }; - - let args: Vec = args - .into_iter() - .map(|arg| self.bind_function_arg(arg)) - .flatten_ok() - .try_collect()?; - let tf = if let Some(func) = self - .catalog - .first_valid_schema( - &self.db_name, - &self.search_path, - &self.auth_context.user_name, - )? - .get_function_by_name_args( - func_name, - &args.iter().map(|arg| arg.return_type()).collect_vec(), - ) - && matches!(func.kind, FunctionKind::Table { .. }) - { - TableFunction::new_user_defined(func.clone(), args) - } else if let Ok(table_function_type) = TableFunctionType::from_str(func_name) { - TableFunction::new(table_function_type, args)? - } else { - return Err(ErrorCode::NotImplemented( - format!("unknown table function: {}", func_name), - 1191.into(), - ) - .into()); - }; - let columns = if let DataType::Struct(s) = tf.return_type() { - // If the table function returns a struct, it's fields can be accessed just - // like a table's columns. - let schema = Schema::from(&s); - schema.fields.into_iter().map(|f| (false, f)).collect_vec() - } else { - // If there is an table alias, we should use the alias as the table function's - // column name. If column aliases are also provided, they - // are handled in bind_table_to_context. - // - // Note: named return value should take precedence over table alias. - // But we don't support it yet. - // e.g., - // ``` - // > create function foo(ret out int) language sql as 'select 1'; - // > select t.ret from foo() as t; - // ``` - let col_name = if let Some(alias) = &alias { - alias.name.real_value() - } else { - tf.name() - }; - vec![(false, Field::with_name(tf.return_type(), col_name))] - }; - - self.bind_table_to_context(columns, tf.name(), alias)?; - - Ok(Relation::TableFunction(Box::new(tf))) + self.bind_table_function(name, alias, args) } TableFactor::Derived { lateral, diff --git a/src/frontend/src/binder/relation/table_function.rs b/src/frontend/src/binder/relation/table_function.rs new file mode 100644 index 0000000000000..1be11687fb1c8 --- /dev/null +++ b/src/frontend/src/binder/relation/table_function.rs @@ -0,0 +1,117 @@ +// Copyright 2023 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::str::FromStr; + +use itertools::Itertools; +use risingwave_common::catalog::{ + Field, Schema, PG_CATALOG_SCHEMA_NAME, RW_INTERNAL_TABLE_FUNCTION_NAME, +}; +use risingwave_common::types::DataType; +use risingwave_sqlparser::ast::{Function, FunctionArg, ObjectName, TableAlias}; + +use super::watermark::is_watermark_func; +use super::{Binder, Relation, Result, WindowTableFunctionKind}; +use crate::binder::bind_context::Clause; +use crate::catalog::system_catalog::pg_catalog::{ + PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME, +}; +use crate::expr::Expr; + +impl Binder { + /// Binds a table function AST, which is a function call in a relation position. + /// + /// Besides [`TableFunction`] expr, it can also be other things like window table functions, or + /// scalar functions. + pub(super) fn bind_table_function( + &mut self, + name: ObjectName, + alias: Option, + args: Vec, + ) -> Result { + let func_name = &name.0[0].real_value(); + // internal/system table functions + { + if func_name.eq_ignore_ascii_case(RW_INTERNAL_TABLE_FUNCTION_NAME) { + return self.bind_internal_table(args, alias); + } + if func_name.eq_ignore_ascii_case(PG_GET_KEYWORDS_FUNC_NAME) + || name.real_value().eq_ignore_ascii_case( + format!("{}.{}", PG_CATALOG_SCHEMA_NAME, PG_GET_KEYWORDS_FUNC_NAME).as_str(), + ) + { + return self.bind_relation_by_name_inner( + Some(PG_CATALOG_SCHEMA_NAME), + PG_KEYWORDS_TABLE_NAME, + alias, + false, + ); + } + } + // window table functions (tumble/hop) + if let Ok(kind) = WindowTableFunctionKind::from_str(func_name) { + return Ok(Relation::WindowTableFunction(Box::new( + self.bind_window_table_function(alias, kind, args)?, + ))); + } + // watermark + if is_watermark_func(func_name) { + return Ok(Relation::Watermark(Box::new( + self.bind_watermark(alias, args)?, + ))); + }; + + let mut clause = Some(Clause::From); + std::mem::swap(&mut self.context.clause, &mut clause); + let func = self.bind_function(Function { + name, + args, + over: None, + distinct: false, + order_by: vec![], + filter: None, + within_group: None, + })?; + self.context.clause = clause; + + let columns = if let DataType::Struct(s) = func.return_type() { + // If the table function returns a struct, it's fields can be accessed just + // like a table's columns. + let schema = Schema::from(&s); + schema.fields.into_iter().map(|f| (false, f)).collect_vec() + } else { + // If there is an table alias, we should use the alias as the table function's + // column name. If column aliases are also provided, they + // are handled in bind_table_to_context. + // + // Note: named return value should take precedence over table alias. + // But we don't support it yet. + // e.g., + // ``` + // > create function foo(ret out int) language sql as 'select 1'; + // > select t.ret from foo() as t; + // ``` + let col_name = if let Some(alias) = &alias { + alias.name.real_value() + } else { + func_name.clone() + }; + vec![(false, Field::with_name(func.return_type(), col_name))] + }; + + self.bind_table_to_context(columns, func_name.clone(), alias)?; + + Ok(Relation::TableFunction(func)) + } +} diff --git a/src/frontend/src/planner/relation.rs b/src/frontend/src/planner/relation.rs index 2d23e74f8b74d..ad3a6279d76ae 100644 --- a/src/frontend/src/planner/relation.rs +++ b/src/frontend/src/planner/relation.rs @@ -15,6 +15,7 @@ use std::rc::Rc; use itertools::Itertools; +use risingwave_common::catalog::{Field, Schema}; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::{DataType, Interval, ScalarImpl}; @@ -22,10 +23,10 @@ use crate::binder::{ BoundBaseTable, BoundJoin, BoundShare, BoundSource, BoundSystemTable, BoundWatermark, BoundWindowTableFunction, Relation, WindowTableFunctionKind, }; -use crate::expr::{ExprImpl, ExprType, FunctionCall, InputRef, TableFunction}; +use crate::expr::{Expr, ExprImpl, ExprType, FunctionCall, InputRef}; use crate::optimizer::plan_node::{ LogicalHopWindow, LogicalJoin, LogicalProject, LogicalScan, LogicalShare, LogicalSource, - LogicalTableFunction, PlanRef, + LogicalTableFunction, LogicalValues, PlanRef, }; use crate::planner::Planner; @@ -42,7 +43,7 @@ impl Planner { Relation::Join(join) => self.plan_join(*join), Relation::WindowTableFunction(tf) => self.plan_window_table_function(*tf), Relation::Source(s) => self.plan_source(*s), - Relation::TableFunction(tf) => self.plan_table_function(*tf), + Relation::TableFunction(tf) => self.plan_table_function(tf), Relation::Watermark(tf) => self.plan_watermark(*tf), Relation::Share(share) => self.plan_share(*share), } @@ -115,8 +116,18 @@ impl Planner { } } - pub(super) fn plan_table_function(&mut self, table_function: TableFunction) -> Result { - Ok(LogicalTableFunction::new(table_function, self.ctx()).into()) + pub(super) fn plan_table_function(&mut self, table_function: ExprImpl) -> Result { + // TODO: maybe we can unify LogicalTableFunction with LogicalValues + match table_function { + ExprImpl::TableFunction(tf) => Ok(LogicalTableFunction::new(*tf, self.ctx()).into()), + expr => { + let schema = Schema { + // TODO: should be named + fields: vec![Field::unnamed(expr.return_type())], + }; + Ok(LogicalValues::create(vec![vec![expr]], schema, self.ctx())) + } + } } pub(super) fn plan_share(&mut self, share: BoundShare) -> Result {