Skip to content

Commit

Permalink
feat: support scalar function in FROM clause (#10317)
Browse files Browse the repository at this point in the history
  • Loading branch information
xxchan authored Jun 14, 2023
1 parent 9593d1b commit 5cf94c9
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 134 deletions.
2 changes: 1 addition & 1 deletion Makefile.toml
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,7 @@ if [ $# -gt 0 ]; then
ARGS=("$@")
echo "Applying clippy --fix for $@ (including dirty and staged files)"
cargo clippy ${ARGS[@]/#/--package risingwave_} ${RISINGWAVE_FEATURE_FLAGS} --fix --allow-dirty --allow-staged
cargo clippy ${ARGS[@]/#/--package risingwave_} --fix --allow-dirty --allow-staged
else
echo "Applying clippy --fix for all targets to all files (including dirty and staged files)"
echo "Tip: run $(tput setaf 4)./risedev cf {package_names}$(tput sgr0) to only check-fix those packages (e.g. frontend, meta)."
Expand Down
14 changes: 14 additions & 0 deletions e2e_test/batch/functions/func_in_from.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
query I
select abs.abs from abs(-1);
----
1

query I
select alias.alias from abs(-1) alias;
----
1

query I
select alias.col from abs(-1) alias(col);
----
1
18 changes: 18 additions & 0 deletions e2e_test/streaming/values.slt
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,21 @@ drop materialized view mv;

statement ok
drop table t;

statement ok
create materialized view mv as select * from abs(-1);

# TODO: support this
statement error not yet implemented: LogicalTableFunction::logical_rewrite_for_stream
create materialized view mv2 as select * from range(1,2);

statement ok
flush;

query IR
select * from mv;
----
1

statement ok
drop materialized view mv;
27 changes: 27 additions & 0 deletions src/frontend/planner_test/tests/testdata/input/expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,30 @@
sql: select 1 / 0 t1;
expected_outputs:
- batch_error
# functions in FROM clause
- sql: |
select * from abs(-1);
expected_outputs:
- batch_plan
- stream_plan
- sql: |
select * from range(1,2);
expected_outputs:
- batch_plan
# TODO: support this
- stream_error
- sql: |
select * from max();
expected_outputs:
- binder_error
- name: Grafana issue-10134
sql: |
SELECT * FROM
generate_series(
array_lower(string_to_array(current_setting('search_path'),','),1),
array_upper(string_to_array(current_setting('search_path'),','),1)
) as i,
string_to_array(current_setting('search_path'),',') s
expected_outputs:
- batch_plan
- stream_error
32 changes: 32 additions & 0 deletions src/frontend/planner_test/tests/testdata/output/expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -619,3 +619,35 @@
- name: const_eval of division by 0 error
sql: select 1 / 0 t1;
batch_error: 'Expr error: Division by zero'
- sql: |
select * from abs(-1);
batch_plan: |
BatchValues { rows: [[1:Int32]] }
stream_plan: |
StreamMaterialize { columns: [abs, _row_id(hidden)], stream_key: [_row_id], pk_columns: [_row_id], pk_conflict: "NoCheck", watermark_columns: [abs] }
└─StreamValues { rows: [[Abs(-1:Int32), 0:Int64]] }
- sql: |
select * from range(1,2);
batch_plan: |
BatchTableFunction { Range(1:Int32, 2:Int32) }
stream_error: |-
Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
- sql: |
select * from max();
binder_error: 'Invalid input syntax: aggregate functions are not allowed in FROM'
- name: Grafana issue-10134
sql: |
SELECT * FROM
generate_series(
array_lower(string_to_array(current_setting('search_path'),','),1),
array_upper(string_to_array(current_setting('search_path'),','),1)
) as i,
string_to_array(current_setting('search_path'),',') s
batch_plan: |
BatchNestedLoopJoin { type: Inner, predicate: true, output: all }
├─BatchTableFunction { GenerateSeries(1:Int32, 2:Int32) }
└─BatchValues { rows: [[ARRAY["$user", public]:List(Varchar)]] }
stream_error: |-
Feature is not yet implemented: LogicalTableFunction::logical_rewrite_for_stream
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
1 change: 1 addition & 0 deletions src/frontend/src/binder/bind_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub enum Clause {
GroupBy,
Having,
Filter,
From,
}

/// A `BindContext` that is only visible if the `LATERAL` keyword
Expand Down
32 changes: 10 additions & 22 deletions src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::expr::{
use crate::utils::Condition;

impl Binder {
pub(super) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
pub(in crate::binder) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
let function_name = match f.name.0.as_slice() {
[name] => name.real_value(),
[schema, name] => {
Expand Down Expand Up @@ -114,18 +114,10 @@ impl Binder {

// user defined function
// TODO: resolve schema name
if let Some(func) = self
.catalog
.first_valid_schema(
&self.db_name,
&self.search_path,
&self.auth_context.user_name,
)?
.get_function_by_name_args(
&function_name,
&inputs.iter().map(|arg| arg.return_type()).collect_vec(),
)
{
if let Some(func) = self.first_valid_schema()?.get_function_by_name_args(
&function_name,
&inputs.iter().map(|arg| arg.return_type()).collect_vec(),
) {
use crate::catalog::function_catalog::FunctionKind::*;
match &func.kind {
Scalar { .. } => return Ok(UserDefinedFunction::new(func.clone(), inputs).into()),
Expand Down Expand Up @@ -676,12 +668,7 @@ impl Binder {
}))),
("current_schema", guard_by_len(0, raw(|binder, _inputs| {
return Ok(binder
.catalog
.first_valid_schema(
&binder.db_name,
&binder.search_path,
&binder.auth_context.user_name,
)
.first_valid_schema()
.map(|schema| ExprImpl::literal_varchar(schema.name()))
.unwrap_or_else(|_| ExprImpl::literal_null(DataType::Varchar)));
}))),
Expand Down Expand Up @@ -909,7 +896,8 @@ impl Binder {
| Clause::Values
| Clause::GroupBy
| Clause::Having
| Clause::Filter => {
| Clause::Filter
| Clause::From => {
return Err(ErrorCode::InvalidInputSyntax(format!(
"window functions are not allowed in {}",
clause
Expand Down Expand Up @@ -950,7 +938,7 @@ impl Binder {
fn ensure_aggregate_allowed(&self) -> Result<()> {
if let Some(clause) = self.context.clause {
match clause {
Clause::Where | Clause::Values => {
Clause::Where | Clause::Values | Clause::From => {
return Err(ErrorCode::InvalidInputSyntax(format!(
"aggregate functions are not allowed in {}",
clause
Expand All @@ -973,7 +961,7 @@ impl Binder {
))
.into());
}
Clause::GroupBy | Clause::Having | Clause::Filter => {}
Clause::GroupBy | Clause::Having | Clause::Filter | Clause::From => {}
}
}
Ok(())
Expand Down
11 changes: 10 additions & 1 deletion src/frontend/src/binder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ pub use update::BoundUpdate;
pub use values::BoundValues;

use crate::catalog::catalog_service::CatalogReadGuard;
use crate::catalog::{TableId, ViewId};
use crate::catalog::schema_catalog::SchemaCatalog;
use crate::catalog::{CatalogResult, TableId, ViewId};
use crate::session::{AuthContext, SessionImpl};

pub type ShareId = usize;
Expand Down Expand Up @@ -350,6 +351,14 @@ impl Binder {
self.next_share_id += 1;
id
}

fn first_valid_schema(&self) -> CatalogResult<&SchemaCatalog> {
self.catalog.first_valid_schema(
&self.db_name,
&self.search_path,
&self.auth_context.user_name,
)
}
}

#[cfg(test)]
Expand Down
111 changes: 6 additions & 105 deletions src/frontend/src/binder/relation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,22 @@

use std::collections::hash_map::Entry;
use std::ops::Deref;
use std::str::FromStr;

use itertools::Itertools;
use risingwave_common::catalog::{
Field, Schema, TableId, DEFAULT_SCHEMA_NAME, PG_CATALOG_SCHEMA_NAME,
RW_INTERNAL_TABLE_FUNCTION_NAME,
};
use risingwave_common::catalog::{Field, TableId, DEFAULT_SCHEMA_NAME};
use risingwave_common::error::{internal_error, ErrorCode, Result, RwError};
use risingwave_common::types::DataType;
use risingwave_sqlparser::ast::{
Expr as ParserExpr, FunctionArg, FunctionArgExpr, Ident, ObjectName, TableAlias, TableFactor,
};

use self::watermark::is_watermark_func;
use super::bind_context::ColumnBinding;
use super::statement::RewriteExprsRecursive;
use crate::binder::Binder;
use crate::catalog::function_catalog::FunctionKind;
use crate::catalog::system_catalog::pg_catalog::{
PG_GET_KEYWORDS_FUNC_NAME, PG_KEYWORDS_TABLE_NAME,
};
use crate::expr::{Expr, ExprImpl, InputRef, TableFunction, TableFunctionType};
use crate::expr::{ExprImpl, InputRef};

mod join;
mod share;
mod subquery;
mod table_function;
mod table_or_source;
mod watermark;
mod window_table_function;
Expand All @@ -63,7 +53,7 @@ pub enum Relation {
Subquery(Box<BoundSubquery>),
Join(Box<BoundJoin>),
WindowTableFunction(Box<BoundWindowTableFunction>),
TableFunction(Box<TableFunction>),
TableFunction(ExprImpl),
Watermark(Box<BoundWatermark>),
Share(Box<BoundShare>),
}
Expand All @@ -76,13 +66,7 @@ impl RewriteExprsRecursive for Relation {
Relation::WindowTableFunction(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Watermark(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::Share(inner) => inner.rewrite_exprs_recursive(rewriter),
Relation::TableFunction(inner) => {
let new_args = std::mem::take(&mut inner.args)
.into_iter()
.map(|expr| rewriter.rewrite_expr(expr))
.collect();
inner.args = new_args;
}
Relation::TableFunction(inner) => *inner = rewriter.rewrite_expr(inner.take()),
_ => {}
}
}
Expand Down Expand Up @@ -405,90 +389,7 @@ impl Binder {
for_system_time_as_of_proctime,
} => self.bind_relation_by_name(name, alias, for_system_time_as_of_proctime),
TableFactor::TableFunction { name, alias, args } => {
let func_name = &name.0[0].real_value();
if func_name.eq_ignore_ascii_case(RW_INTERNAL_TABLE_FUNCTION_NAME) {
return self.bind_internal_table(args, alias);
}
if func_name.eq_ignore_ascii_case(PG_GET_KEYWORDS_FUNC_NAME)
|| name.real_value().eq_ignore_ascii_case(
format!("{}.{}", PG_CATALOG_SCHEMA_NAME, PG_GET_KEYWORDS_FUNC_NAME)
.as_str(),
)
{
return self.bind_relation_by_name_inner(
Some(PG_CATALOG_SCHEMA_NAME),
PG_KEYWORDS_TABLE_NAME,
alias,
false,
);
}
if let Ok(kind) = WindowTableFunctionKind::from_str(func_name) {
return Ok(Relation::WindowTableFunction(Box::new(
self.bind_window_table_function(alias, kind, args)?,
)));
}
if is_watermark_func(func_name) {
return Ok(Relation::Watermark(Box::new(
self.bind_watermark(alias, args)?,
)));
};

let args: Vec<ExprImpl> = args
.into_iter()
.map(|arg| self.bind_function_arg(arg))
.flatten_ok()
.try_collect()?;
let tf = if let Some(func) = self
.catalog
.first_valid_schema(
&self.db_name,
&self.search_path,
&self.auth_context.user_name,
)?
.get_function_by_name_args(
func_name,
&args.iter().map(|arg| arg.return_type()).collect_vec(),
)
&& matches!(func.kind, FunctionKind::Table { .. })
{
TableFunction::new_user_defined(func.clone(), args)
} else if let Ok(table_function_type) = TableFunctionType::from_str(func_name) {
TableFunction::new(table_function_type, args)?
} else {
return Err(ErrorCode::NotImplemented(
format!("unknown table function: {}", func_name),
1191.into(),
)
.into());
};
let columns = if let DataType::Struct(s) = tf.return_type() {
// If the table function returns a struct, it's fields can be accessed just
// like a table's columns.
let schema = Schema::from(&s);
schema.fields.into_iter().map(|f| (false, f)).collect_vec()
} else {
// If there is an table alias, we should use the alias as the table function's
// column name. If column aliases are also provided, they
// are handled in bind_table_to_context.
//
// Note: named return value should take precedence over table alias.
// But we don't support it yet.
// e.g.,
// ```
// > create function foo(ret out int) language sql as 'select 1';
// > select t.ret from foo() as t;
// ```
let col_name = if let Some(alias) = &alias {
alias.name.real_value()
} else {
tf.name()
};
vec![(false, Field::with_name(tf.return_type(), col_name))]
};

self.bind_table_to_context(columns, tf.name(), alias)?;

Ok(Relation::TableFunction(Box::new(tf)))
self.bind_table_function(name, alias, args)
}
TableFactor::Derived {
lateral,
Expand Down
Loading

0 comments on commit 5cf94c9

Please sign in to comment.