From ea9367f371884d8650bca0890dcb87b638b23ae8 Mon Sep 17 00:00:00 2001 From: discord9 <55937128+discord9@users.noreply.github.com> Date: Tue, 9 Apr 2024 10:32:02 +0800 Subject: [PATCH] refactor(flow): func spec api&use Error not EvalError in mfp (#3657) * refactor: func's specialization& use Error not EvalError * docs: some pub item * chore: typo * docs: add comments for every pub item * chore: per review * chore: per reveiw&derive Copy * chore: per review&test for binary fn spec * docs: comment explain how binary func spec works * chore: minor style change * fix: Error not EvalError --- Cargo.lock | 1 + src/flow/Cargo.toml | 1 + src/flow/clippy.toml | 3 + src/flow/src/adapter/error.rs | 27 +- src/flow/src/compute/render.rs | 2 +- src/flow/src/expr.rs | 1 + src/flow/src/expr/error.rs | 5 +- src/flow/src/expr/func.rs | 436 +++++++++++++++++++++++++++- src/flow/src/expr/id.rs | 2 + src/flow/src/expr/linear.rs | 44 +-- src/flow/src/expr/relation.rs | 5 +- src/flow/src/expr/relation/accum.rs | 20 +- src/flow/src/expr/relation/func.rs | 350 ++++++++-------------- src/flow/src/expr/scalar.rs | 55 +++- src/flow/src/expr/signature.rs | 67 +++++ src/flow/src/lib.rs | 6 + src/flow/src/plan.rs | 3 + src/flow/src/plan/reduce.rs | 3 + src/flow/src/repr.rs | 20 ++ src/flow/src/repr/relation.rs | 2 + src/flow/src/utils.rs | 13 + 21 files changed, 785 insertions(+), 281 deletions(-) create mode 100644 src/flow/clippy.toml create mode 100644 src/flow/src/expr/signature.rs diff --git a/Cargo.lock b/Cargo.lock index 95b98d170add..a33fd9623359 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3421,6 +3421,7 @@ dependencies = [ "servers", "smallvec", "snafu", + "strum 0.25.0", "tokio", "tonic 0.10.2", ] diff --git a/src/flow/Cargo.toml b/src/flow/Cargo.toml index d0eed66643bc..611d06b934ba 100644 --- a/src/flow/Cargo.toml +++ b/src/flow/Cargo.toml @@ -25,6 +25,7 @@ serde.workspace = true servers.workspace = true smallvec.workspace = true snafu.workspace = true +strum.workspace = true tokio.workspace = true tonic.workspace = true diff --git a/src/flow/clippy.toml b/src/flow/clippy.toml new file mode 100644 index 000000000000..5a9ebd2a5bc2 --- /dev/null +++ b/src/flow/clippy.toml @@ -0,0 +1,3 @@ +# Whether to only check for missing documentation in items visible within the current crate. For example, pub(crate) items. (default: false) +# This is a config for clippy::missing_docs_in_private_items +missing-docs-in-crate-items = true diff --git a/src/flow/src/adapter/error.rs b/src/flow/src/adapter/error.rs index b99c8a7007d2..ea5ea39f1356 100644 --- a/src/flow/src/adapter/error.rs +++ b/src/flow/src/adapter/error.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Error definition for flow module + use std::any::Any; use common_macro::stack_trace_debug; @@ -25,6 +27,7 @@ use snafu::{Location, Snafu}; use crate::expr::EvalError; +/// This error is used to represent all possible errors that can occur in the flow module. #[derive(Snafu)] #[snafu(visibility(pub))] #[stack_trace_debug] @@ -54,8 +57,25 @@ pub enum Error { #[snafu(display("No protobuf type for value: {value}"))] NoProtoType { value: Value, location: Location }, + + #[snafu(display("Not implement in flow: {reason}"))] + NotImplemented { reason: String, location: Location }, + + #[snafu(display("Flow plan error: {reason}"))] + Plan { reason: String, location: Location }, + + #[snafu(display("Unsupported temporal filter: {reason}"))] + UnsupportedTemporalFilter { reason: String, location: Location }, + + #[snafu(display("Datatypes error: {source} with extra message: {extra}"))] + Datatypes { + source: datatypes::Error, + extra: String, + location: Location, + }, } +/// Result type for flow module pub type Result = std::result::Result; impl ErrorExt for Error { @@ -64,8 +84,13 @@ impl ErrorExt for Error { Self::Eval { .. } | &Self::JoinTask { .. } => StatusCode::Internal, &Self::TableAlreadyExist { .. } => StatusCode::TableAlreadyExists, Self::TableNotFound { .. } => StatusCode::TableNotFound, - &Self::InvalidQuery { .. } => StatusCode::PlanQuery, + &Self::InvalidQuery { .. } | &Self::Plan { .. } | &Self::Datatypes { .. } => { + StatusCode::PlanQuery + } Self::NoProtoType { .. } => StatusCode::Unexpected, + &Self::NotImplemented { .. } | Self::UnsupportedTemporalFilter { .. } => { + StatusCode::Unsupported + } } } diff --git a/src/flow/src/compute/render.rs b/src/flow/src/compute/render.rs index 2effabad5cf7..708297d56f54 100644 --- a/src/flow/src/compute/render.rs +++ b/src/flow/src/compute/render.rs @@ -185,7 +185,7 @@ impl<'referred, 'df> Context<'referred, 'df> { let arrange_handler_inner = ArrangeHandler::from(arrange); // This closure capture following variables: - let mfp_plan = MfpPlan::create_from(mfp).context(EvalSnafu)?; + let mfp_plan = MfpPlan::create_from(mfp)?; let now = self.compute_state.current_time_ref(); let err_collector = self.err_collector.clone(); diff --git a/src/flow/src/expr.rs b/src/flow/src/expr.rs index d54dfa4b9f69..4550234b4e2e 100644 --- a/src/flow/src/expr.rs +++ b/src/flow/src/expr.rs @@ -20,6 +20,7 @@ mod id; mod linear; mod relation; mod scalar; +mod signature; pub(crate) use error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu}; pub(crate) use func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; diff --git a/src/flow/src/expr/error.rs b/src/flow/src/expr/error.rs index 9de189231670..5c4480e749c5 100644 --- a/src/flow/src/expr/error.rs +++ b/src/flow/src/expr/error.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Error handling for expression evaluation. + use std::any::Any; use common_macro::stack_trace_debug; @@ -59,9 +61,6 @@ pub enum EvalError { #[snafu(display("Optimize error: {reason}"))] Optimize { reason: String, location: Location }, - #[snafu(display("Unsupported temporal filter: {reason}"))] - UnsupportedTemporalFilter { reason: String, location: Location }, - #[snafu(display("Overflowed during evaluation"))] Overflow { location: Location }, } diff --git a/src/flow/src/expr/func.rs b/src/flow/src/expr/func.rs index b7f0ba698ae2..92fc20bc9f88 100644 --- a/src/flow/src/expr/func.rs +++ b/src/flow/src/expr/func.rs @@ -12,19 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! This module contains the definition of functions that can be used in expressions. + +use std::collections::HashMap; +use std::sync::OnceLock; + use common_time::DateTime; use datatypes::data_type::ConcreteDataType; use datatypes::types::cast; use datatypes::types::cast::CastOption; use datatypes::value::Value; -use hydroflow::bincode::Error; use serde::{Deserialize, Serialize}; -use snafu::ResultExt; +use smallvec::smallvec; +use snafu::{ensure, OptionExt, ResultExt}; +use strum::{EnumIter, IntoEnumIterator}; +use crate::adapter::error::{Error, InvalidQuerySnafu, PlanSnafu}; use crate::expr::error::{ CastValueSnafu, DivisionByZeroSnafu, EvalError, InternalSnafu, TryFromValueSnafu, TypeMismatchSnafu, }; +use crate::expr::signature::{GenericFn, Signature}; use crate::expr::{InvalidArgumentSnafu, ScalarExpr}; use crate::repr::{value_to_internal_ts, Row}; @@ -36,6 +44,38 @@ pub enum UnmaterializableFunc { CurrentSchema, } +impl UnmaterializableFunc { + /// Return the signature of the function + pub fn signature(&self) -> Signature { + match self { + Self::Now => Signature { + input: smallvec![], + output: ConcreteDataType::datetime_datatype(), + generic_fn: GenericFn::Now, + }, + Self::CurrentSchema => Signature { + input: smallvec![], + output: ConcreteDataType::string_datatype(), + generic_fn: GenericFn::CurrentSchema, + }, + } + } + + /// Create a UnmaterializableFunc from a string of the function name + pub fn from_str(name: &str) -> Result { + match name { + "now" => Ok(Self::Now), + "current_schema" => Ok(Self::CurrentSchema), + _ => InvalidQuerySnafu { + reason: format!("Unknown unmaterializable function: {}", name), + } + .fail(), + } + } +} + +/// UnaryFunc is a function that takes one argument. Also notice this enum doesn't contain function arguments, +/// because the arguments are stored in the expression. (except `cast` function, which requires a type argument) #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub enum UnaryFunc { Not, @@ -47,6 +87,68 @@ pub enum UnaryFunc { } impl UnaryFunc { + /// Return the signature of the function + pub fn signature(&self) -> Signature { + match self { + Self::IsNull => Signature { + input: smallvec![ConcreteDataType::null_datatype()], + output: ConcreteDataType::boolean_datatype(), + generic_fn: GenericFn::IsNull, + }, + Self::Not | Self::IsTrue | Self::IsFalse => Signature { + input: smallvec![ConcreteDataType::boolean_datatype()], + output: ConcreteDataType::boolean_datatype(), + generic_fn: match self { + Self::Not => GenericFn::Not, + Self::IsTrue => GenericFn::IsTrue, + Self::IsFalse => GenericFn::IsFalse, + _ => unreachable!(), + }, + }, + Self::StepTimestamp => Signature { + input: smallvec![ConcreteDataType::datetime_datatype()], + output: ConcreteDataType::datetime_datatype(), + generic_fn: GenericFn::StepTimestamp, + }, + Self::Cast(to) => Signature { + input: smallvec![ConcreteDataType::null_datatype()], + output: to.clone(), + generic_fn: GenericFn::Cast, + }, + } + } + + /// Create a UnaryFunc from a string of the function name and given argument type(optional) + pub fn from_str_and_type( + name: &str, + arg_type: Option, + ) -> Result { + match name { + "not" => Ok(Self::Not), + "is_null" => Ok(Self::IsNull), + "is_true" => Ok(Self::IsTrue), + "is_false" => Ok(Self::IsFalse), + "step_timestamp" => Ok(Self::StepTimestamp), + "cast" => { + let arg_type = arg_type.with_context(|| InvalidQuerySnafu { + reason: "cast function requires a type argument".to_string(), + })?; + Ok(UnaryFunc::Cast(arg_type)) + } + _ => InvalidQuerySnafu { + reason: format!("Unknown unary function: {}", name), + } + .fail(), + } + } + + /// Evaluate the function with given values and expression + /// + /// # Arguments + /// + /// - `values`: The values to be used in the evaluation + /// + /// - `expr`: The expression to be evaluated and use as argument, will extract the value from the `values` and evaluate the expression pub fn eval(&self, values: &[Value], expr: &ScalarExpr) -> Result { let arg = expr.eval(values)?; match self { @@ -109,8 +211,13 @@ impl UnaryFunc { } } +/// BinaryFunc is a function that takes two arguments. +/// Also notice this enum doesn't contain function arguments, since the arguments are stored in the expression. +/// /// TODO(discord9): support more binary functions for more types -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash, EnumIter, +)] pub enum BinaryFunc { Eq, NotEq, @@ -158,7 +265,223 @@ pub enum BinaryFunc { ModUInt64, } +/// Generate binary function signature based on the function and the input types +/// The user can provide custom signature for some functions in the form of a regular match arm, +/// and the rest will be generated according to the provided list of functions like this: +/// ```ignore +/// AddInt16=>(int16_datatype,Add), +/// ``` +/// which expand to: +/// ```ignore, rust +/// Self::AddInt16 => Signature { +/// input: smallvec![ +/// ConcreteDataType::int16_datatype(), +/// ConcreteDataType::int16_datatype(), +/// ], +/// output: ConcreteDataType::int16_datatype(), +/// generic_fn: GenericFn::Add, +/// }, +/// ```` +macro_rules! generate_binary_signature { + ($value:ident, { $($user_arm:tt)* }, + [ $( + $auto_arm:ident=>($con_type:ident,$generic:ident) + ),* + ]) => { + match $value { + $($user_arm)*, + $( + Self::$auto_arm => Signature { + input: smallvec![ + ConcreteDataType::$con_type(), + ConcreteDataType::$con_type(), + ], + output: ConcreteDataType::$con_type(), + generic_fn: GenericFn::$generic, + }, + )* + } + }; +} + +static SPECIALIZATION: OnceLock> = + OnceLock::new(); + impl BinaryFunc { + /// Use null type to ref to any type + pub fn signature(&self) -> Signature { + generate_binary_signature!(self, { + Self::Eq | Self::NotEq | Self::Lt | Self::Lte | Self::Gt | Self::Gte => Signature { + input: smallvec![ + ConcreteDataType::null_datatype(), + ConcreteDataType::null_datatype() + ], + output: ConcreteDataType::null_datatype(), + generic_fn: match self { + Self::Eq => GenericFn::Eq, + Self::NotEq => GenericFn::NotEq, + Self::Lt => GenericFn::Lt, + Self::Lte => GenericFn::Lte, + Self::Gt => GenericFn::Gt, + Self::Gte => GenericFn::Gte, + _ => unreachable!(), + }, + } + }, + [ + AddInt16=>(int16_datatype,Add), + AddInt32=>(int32_datatype,Add), + AddInt64=>(int64_datatype,Add), + AddUInt16=>(uint16_datatype,Add), + AddUInt32=>(uint32_datatype,Add), + AddUInt64=>(uint64_datatype,Add), + AddFloat32=>(float32_datatype,Add), + AddFloat64=>(float64_datatype,Add), + SubInt16=>(int16_datatype,Sub), + SubInt32=>(int32_datatype,Sub), + SubInt64=>(int64_datatype,Sub), + SubUInt16=>(uint16_datatype,Sub), + SubUInt32=>(uint32_datatype,Sub), + SubUInt64=>(uint64_datatype,Sub), + SubFloat32=>(float32_datatype,Sub), + SubFloat64=>(float64_datatype,Sub), + MulInt16=>(int16_datatype,Mul), + MulInt32=>(int32_datatype,Mul), + MulInt64=>(int64_datatype,Mul), + MulUInt16=>(uint16_datatype,Mul), + MulUInt32=>(uint32_datatype,Mul), + MulUInt64=>(uint64_datatype,Mul), + MulFloat32=>(float32_datatype,Mul), + MulFloat64=>(float64_datatype,Mul), + DivInt16=>(int16_datatype,Div), + DivInt32=>(int32_datatype,Div), + DivInt64=>(int64_datatype,Div), + DivUInt16=>(uint16_datatype,Div), + DivUInt32=>(uint32_datatype,Div), + DivUInt64=>(uint64_datatype,Div), + DivFloat32=>(float32_datatype,Div), + DivFloat64=>(float64_datatype,Div), + ModInt16=>(int16_datatype,Mod), + ModInt32=>(int32_datatype,Mod), + ModInt64=>(int64_datatype,Mod), + ModUInt16=>(uint16_datatype,Mod), + ModUInt32=>(uint32_datatype,Mod), + ModUInt64=>(uint64_datatype,Mod) + ] + ) + } + + /// Get the specialization of the binary function based on the generic function and the input type + pub fn specialization(generic: GenericFn, input_type: ConcreteDataType) -> Result { + let rule = SPECIALIZATION.get_or_init(|| { + let mut spec = HashMap::new(); + for func in BinaryFunc::iter() { + let sig = func.signature(); + spec.insert((sig.generic_fn, sig.input[0].clone()), func); + } + spec + }); + rule.get(&(generic, input_type.clone())) + .cloned() + .with_context(|| InvalidQuerySnafu { + reason: format!( + "No specialization found for binary function {:?} with input type {:?}", + generic, input_type + ), + }) + } + + /// choose the appropriate specialization based on the input types + /// + /// will try it best to extract from `arg_types` and `arg_exprs` to get the input types + /// if `arg_types` is not enough, it will try to extract from `arg_exprs` if `arg_exprs` is literal with known type + pub fn from_str_expr_and_type( + name: &str, + arg_exprs: &[ScalarExpr], + arg_types: &[Option], + ) -> Result { + // get first arg type and make sure if both is some, they are the same + let generic_fn = { + match name { + "eq" => GenericFn::Eq, + "not_eq" => GenericFn::NotEq, + "lt" => GenericFn::Lt, + "lte" => GenericFn::Lte, + "gt" => GenericFn::Gt, + "gte" => GenericFn::Gte, + "add" => GenericFn::Add, + "sub" => GenericFn::Sub, + "mul" => GenericFn::Mul, + "div" => GenericFn::Div, + "mod" => GenericFn::Mod, + _ => { + return InvalidQuerySnafu { + reason: format!("Unknown binary function: {}", name), + } + .fail(); + } + } + }; + let need_type = matches!( + generic_fn, + GenericFn::Add | GenericFn::Sub | GenericFn::Mul | GenericFn::Div | GenericFn::Mod + ); + + ensure!( + arg_exprs.len() == 2 && arg_types.len() == 2, + PlanSnafu { + reason: "Binary function requires exactly 2 arguments".to_string() + } + ); + + let arg_type = match (arg_types[0].as_ref(), arg_types[1].as_ref()) { + (Some(t1), Some(t2)) => { + ensure!( + t1 == t2, + InvalidQuerySnafu { + reason: format!( + "Binary function {} requires both arguments to have the same type", + name + ), + } + ); + Some(t1.clone()) + } + (Some(t), None) | (None, Some(t)) => Some(t.clone()), + _ => arg_exprs[0] + .as_literal() + .map(|lit| lit.data_type()) + .or_else(|| arg_exprs[1].as_literal().map(|lit| lit.data_type())), + }; + + ensure!( + !need_type || arg_type.is_some(), + InvalidQuerySnafu { + reason: format!( + "Binary function {} requires at least one argument with known type", + name + ), + } + ); + + let spec_fn = Self::specialization( + generic_fn, + arg_type + .clone() + .unwrap_or(ConcreteDataType::null_datatype()), + )?; + Ok(spec_fn) + } + + /// Evaluate the function with given values and expression + /// + /// # Arguments + /// + /// - `values`: The values to be used in the evaluation + /// + /// - `expr1`: The first arg to be evaluated, will extract the value from the `values` and evaluate the expression + /// + /// - `expr2`: The second arg to be evaluated pub fn eval( &self, values: &[Value], @@ -222,7 +545,7 @@ impl BinaryFunc { /// Reverse the comparison operator, i.e. `a < b` becomes `b > a`, /// equal and not equal are unchanged. - pub fn reverse_compare(&self) -> Result { + pub fn reverse_compare(&self) -> Result { let ret = match &self { BinaryFunc::Eq => BinaryFunc::Eq, BinaryFunc::NotEq => BinaryFunc::NotEq, @@ -231,7 +554,7 @@ impl BinaryFunc { BinaryFunc::Gt => BinaryFunc::Lt, BinaryFunc::Gte => BinaryFunc::Lte, _ => { - return InternalSnafu { + return InvalidQuerySnafu { reason: format!("Expect a comparison operator, found {:?}", self), } .fail(); @@ -241,13 +564,44 @@ impl BinaryFunc { } } -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] +/// VariadicFunc is a function that takes a variable number of arguments. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] pub enum VariadicFunc { And, Or, } impl VariadicFunc { + /// Return the signature of the function + pub fn signature(&self) -> Signature { + Signature { + input: smallvec![ConcreteDataType::boolean_datatype()], + output: ConcreteDataType::boolean_datatype(), + generic_fn: match self { + Self::And => GenericFn::And, + Self::Or => GenericFn::Or, + }, + } + } + + /// Create a VariadicFunc from a string of the function name and given argument types(optional) + pub fn from_str_and_types( + name: &str, + arg_types: &[Option], + ) -> Result { + // TODO: future variadic funcs to be added might need to check arg_types + let _ = arg_types; + match name { + "and" => Ok(Self::And), + "or" => Ok(Self::Or), + _ => InvalidQuerySnafu { + reason: format!("Unknown variadic function: {}", name), + } + .fail(), + } + } + + /// Evaluate the function with given values and expressions pub fn eval(&self, values: &[Value], exprs: &[ScalarExpr]) -> Result { match self { VariadicFunc::And => and(values, exprs), @@ -387,3 +741,73 @@ fn test_num_ops() { let res = or(&values, &exprs).unwrap(); assert_eq!(res, Value::from(true)); } + +/// test if the binary function specialization works +/// whether from direct type or from the expression that is literal +#[test] +fn test_binary_func_spec() { + assert_eq!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ScalarExpr::Column(0), ScalarExpr::Column(0)], + &[ + Some(ConcreteDataType::int32_datatype()), + Some(ConcreteDataType::int32_datatype()) + ] + ) + .unwrap(), + BinaryFunc::AddInt32 + ); + + assert_eq!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ScalarExpr::Column(0), ScalarExpr::Column(0)], + &[Some(ConcreteDataType::int32_datatype()), None] + ) + .unwrap(), + BinaryFunc::AddInt32 + ); + + assert_eq!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ScalarExpr::Column(0), ScalarExpr::Column(0)], + &[Some(ConcreteDataType::int32_datatype()), None] + ) + .unwrap(), + BinaryFunc::AddInt32 + ); + + assert_eq!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ScalarExpr::Column(0), ScalarExpr::Column(0)], + &[Some(ConcreteDataType::int32_datatype()), None] + ) + .unwrap(), + BinaryFunc::AddInt32 + ); + + assert_eq!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ + ScalarExpr::Literal(Value::from(1i32), ConcreteDataType::int32_datatype()), + ScalarExpr::Column(0) + ], + &[None, None] + ) + .unwrap(), + BinaryFunc::AddInt32 + ); + + matches!( + BinaryFunc::from_str_expr_and_type( + "add", + &[ScalarExpr::Column(0), ScalarExpr::Column(0)], + &[None, None] + ), + Err(Error::InvalidQuery { .. }) + ); +} diff --git a/src/flow/src/expr/id.rs b/src/flow/src/expr/id.rs index 6a3c50e9b7c6..9b098f05b333 100644 --- a/src/flow/src/expr/id.rs +++ b/src/flow/src/expr/id.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! `Id` is used to identify a dataflow component in plan like `Plan::Get{id: Id}`, this could be a source of data for an arrangement. + use serde::{Deserialize, Serialize}; /// Global id's scope is in Current Worker, and is cross-dataflow diff --git a/src/flow/src/expr/linear.rs b/src/flow/src/expr/linear.rs index d4a0ef5eda89..37e6d1df8720 100644 --- a/src/flow/src/expr/linear.rs +++ b/src/flow/src/expr/linear.rs @@ -12,12 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeMap, BTreeSet}; +//! define MapFilterProject which is a compound operator that can be applied row-by-row. + +use std::collections::{BTreeMap, BTreeSet, VecDeque}; use datatypes::value::Value; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use snafu::{ensure, OptionExt}; +use crate::adapter::error::{Error, InvalidQuerySnafu}; use crate::expr::error::EvalError; use crate::expr::{Id, InvalidArgumentSnafu, LocalId, ScalarExpr}; use crate::repr::{self, value_to_internal_ts, Diff, Row}; @@ -89,7 +93,7 @@ impl MapFilterProject { /// followed by the other. /// Note that the arguments are in the opposite order /// from how function composition is usually written in mathematics. - pub fn compose(before: Self, after: Self) -> Result { + pub fn compose(before: Self, after: Self) -> Result { let (m, f, p) = after.into_map_filter_project(); before.map(m)?.filter(f)?.project(p) } @@ -131,7 +135,7 @@ impl MapFilterProject { /// new_project -->|0| col-2 /// new_project -->|1| col-1 /// ``` - pub fn project(mut self, columns: I) -> Result + pub fn project(mut self, columns: I) -> Result where I: IntoIterator + std::fmt::Debug, { @@ -140,7 +144,7 @@ impl MapFilterProject { .map(|c| self.projection.get(c).cloned().ok_or(c)) .collect::, _>>() .map_err(|c| { - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "column index {} out of range, expected at most {} columns", c, @@ -178,7 +182,7 @@ impl MapFilterProject { /// filter -->|0| col-1 /// filter --> |1| col-2 /// ``` - pub fn filter(mut self, predicates: I) -> Result + pub fn filter(mut self, predicates: I) -> Result where I: IntoIterator, { @@ -193,7 +197,7 @@ impl MapFilterProject { let cur_row_len = self.input_arity + self.expressions.len(); ensure!( *c < cur_row_len, - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "column index {} out of range, expected at most {} columns", c, cur_row_len @@ -250,7 +254,7 @@ impl MapFilterProject { /// map -->|1|col-2 /// map -->|2|col-0 /// ``` - pub fn map(mut self, expressions: I) -> Result + pub fn map(mut self, expressions: I) -> Result where I: IntoIterator, { @@ -264,7 +268,7 @@ impl MapFilterProject { let current_row_len = self.input_arity + self.expressions.len(); ensure!( c < current_row_len, - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "column index {} out of range, expected at most {} columns", c, current_row_len @@ -303,6 +307,12 @@ impl MapFilterProject { } impl MapFilterProject { + /// Convert the `MapFilterProject` into a safe evaluation plan. Marking it safe to evaluate. + pub fn into_safe(self) -> SafeMfpPlan { + SafeMfpPlan { mfp: self } + } + + /// Optimize the `MapFilterProject` in place. pub fn optimize(&mut self) { // TODO(discord9): optimize } @@ -311,7 +321,7 @@ impl MapFilterProject { /// /// The main behavior is extract temporal predicates, which cannot be evaluated /// using the standard machinery. - pub fn into_plan(self) -> Result { + pub fn into_plan(self) -> Result { MfpPlan::create_from(self) } @@ -354,13 +364,13 @@ impl MapFilterProject { &mut self, mut shuffle: BTreeMap, new_input_arity: usize, - ) -> Result<(), EvalError> { + ) -> Result<(), Error> { // check shuffle is valid let demand = self.demand(); for d in demand { ensure!( shuffle.contains_key(&d), - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "Demanded column {} is not in shuffle's keys: {:?}", d, @@ -371,7 +381,7 @@ impl MapFilterProject { } ensure!( shuffle.len() <= new_input_arity, - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "shuffle's length {} is greater than new_input_arity {}", shuffle.len(), @@ -397,7 +407,7 @@ impl MapFilterProject { for proj in project.iter_mut() { ensure!( shuffle[proj] < new_row_len, - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "shuffled column index {} out of range, expected at most {} columns", shuffle[proj], new_row_len @@ -422,11 +432,7 @@ pub struct SafeMfpPlan { impl SafeMfpPlan { /// See [`MapFilterProject::permute`]. - pub fn permute( - &mut self, - map: BTreeMap, - new_arity: usize, - ) -> Result<(), EvalError> { + pub fn permute(&mut self, map: BTreeMap, new_arity: usize) -> Result<(), Error> { self.mfp.permute(map, new_arity) } @@ -544,7 +550,7 @@ pub struct MfpPlan { impl MfpPlan { /// find `now` in `predicates` and put them into lower/upper temporal bounds for temporal filter to use - pub fn create_from(mut mfp: MapFilterProject) -> Result { + pub fn create_from(mut mfp: MapFilterProject) -> Result { let mut lower_bounds = Vec::new(); let mut upper_bounds = Vec::new(); diff --git a/src/flow/src/expr/relation.rs b/src/flow/src/expr/relation.rs index db82c75425f4..a873c267b1a5 100644 --- a/src/flow/src/expr/relation.rs +++ b/src/flow/src/expr/relation.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Describes an aggregation function and it's input expression. + pub(crate) use func::AggregateFunc; use serde::{Deserialize, Serialize}; @@ -26,7 +28,8 @@ pub struct AggregateExpr { /// Names the aggregation function. pub func: AggregateFunc, /// An expression which extracts from each row the input to `func`. - /// TODO(discord9): currently unused, it only used in generate KeyValPlan from AggregateExpr + /// TODO(discord9): currently unused in render phase(because AccumulablePlan remember each Aggr Expr's input/output column), + /// so it only used in generate KeyValPlan from AggregateExpr pub expr: ScalarExpr, /// Should the aggregation be applied only to distinct results in each group. #[serde(default)] diff --git a/src/flow/src/expr/relation/accum.rs b/src/flow/src/expr/relation/accum.rs index 06df89eb8ee9..ed898d0c4d41 100644 --- a/src/flow/src/expr/relation/accum.rs +++ b/src/flow/src/expr/relation/accum.rs @@ -31,7 +31,7 @@ use serde::{Deserialize, Serialize}; use snafu::ensure; use crate::expr::error::{InternalSnafu, OverflowSnafu, TryFromValueSnafu, TypeMismatchSnafu}; -use crate::expr::relation::func::GenericFn; +use crate::expr::signature::GenericFn; use crate::expr::{AggregateFunc, EvalError}; use crate::repr::Diff; @@ -221,7 +221,7 @@ impl Accumulator for SimpleNumber { (f, v) => { let expected_datatype = f.signature().input; return Err(TypeMismatchSnafu { - expected: expected_datatype, + expected: expected_datatype[0].clone(), actual: v.data_type(), } .build())?; @@ -258,7 +258,6 @@ impl Accumulator for SimpleNumber { } /// Accumulates float values for sum over floating numbers. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] - pub struct Float { /// Accumulates non-special float values, i.e. not NaN, +inf, -inf. /// accum will be set to zero if `non_nulls` is zero. @@ -341,7 +340,7 @@ impl Accumulator for Float { (f, v) => { let expected_datatype = f.signature().input; return Err(TypeMismatchSnafu { - expected: expected_datatype, + expected: expected_datatype[0].clone(), actual: v.data_type(), } .build())?; @@ -445,25 +444,27 @@ impl Accumulator for OrdValue { // if aggr_fn is count, the incoming value type doesn't matter in type checking // otherwise, type need to be the same or value can be null let check_type_aggr_fn_and_arg_value = - ty_eq_without_precision(value.data_type(), aggr_fn.signature().input) + ty_eq_without_precision(value.data_type(), aggr_fn.signature().input[0].clone()) || matches!(aggr_fn, AggregateFunc::Count) || value.is_null(); let check_type_aggr_fn_and_self_val = self .val .as_ref() - .map(|zelf| ty_eq_without_precision(zelf.data_type(), aggr_fn.signature().input)) + .map(|zelf| { + ty_eq_without_precision(zelf.data_type(), aggr_fn.signature().input[0].clone()) + }) .unwrap_or(true) || matches!(aggr_fn, AggregateFunc::Count); if !check_type_aggr_fn_and_arg_value { return Err(TypeMismatchSnafu { - expected: aggr_fn.signature().input, + expected: aggr_fn.signature().input[0].clone(), actual: value.data_type(), } .build()); } else if !check_type_aggr_fn_and_self_val { return Err(TypeMismatchSnafu { - expected: aggr_fn.signature().input, + expected: aggr_fn.signature().input[0].clone(), actual: self .val .as_ref() @@ -548,6 +549,7 @@ pub enum Accum { } impl Accum { + /// create a new accumulator from given aggregate function pub fn new_accum(aggr_fn: &AggregateFunc) -> Result { Ok(match aggr_fn { AggregateFunc::Any @@ -590,6 +592,8 @@ impl Accum { } }) } + + /// try to convert a vector of value into given aggregate function's accumulator pub fn try_into_accum(aggr_fn: &AggregateFunc, state: Vec) -> Result { match aggr_fn { AggregateFunc::Any diff --git a/src/flow/src/expr/relation/func.rs b/src/flow/src/expr/relation/func.rs index 91fa686428e0..a0b049b1a2b3 100644 --- a/src/flow/src/expr/relation/func.rs +++ b/src/flow/src/expr/relation/func.rs @@ -12,13 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; +use std::sync::OnceLock; + use common_time::{Date, DateTime}; use datatypes::prelude::ConcreteDataType; use datatypes::value::{OrderedF32, OrderedF64, Value}; use serde::{Deserialize, Serialize}; +use smallvec::smallvec; +use snafu::OptionExt; +use strum::{EnumIter, IntoEnumIterator}; +use crate::adapter::error::{Error, InvalidQuerySnafu}; use crate::expr::error::{EvalError, TryFromValueSnafu, TypeMismatchSnafu}; use crate::expr::relation::accum::{Accum, Accumulator}; +use crate::expr::signature::{GenericFn, Signature}; use crate::repr::Diff; /// Aggregate functions that can be applied to a group of rows. @@ -32,7 +40,7 @@ use crate::repr::Diff; /// `count()->i64` /// /// `min/max(T)->T` -#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)] +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash, EnumIter)] pub enum AggregateFunc { MaxInt16, MaxInt32, @@ -83,14 +91,17 @@ pub enum AggregateFunc { } impl AggregateFunc { + /// if this function is a `max` pub fn is_max(&self) -> bool { self.signature().generic_fn == GenericFn::Max } + /// if this function is a `min` pub fn is_min(&self) -> bool { self.signature().generic_fn == GenericFn::Min } + /// if this function is a `sum` pub fn is_sum(&self) -> bool { self.signature().generic_fn == GenericFn::Sum } @@ -119,242 +130,125 @@ impl AggregateFunc { } } -pub struct Signature { - pub input: ConcreteDataType, - pub output: ConcreteDataType, - pub generic_fn: GenericFn, +macro_rules! generate_signature { + ($value:ident, { $($user_arm:tt)* }, + [ $( + $auto_arm:ident=>($con_type:ident,$generic:ident) + ),* + ]) => { + match $value { + $($user_arm)*, + $( + Self::$auto_arm => Signature { + input: smallvec![ + ConcreteDataType::$con_type(), + ConcreteDataType::$con_type(), + ], + output: ConcreteDataType::$con_type(), + generic_fn: GenericFn::$generic, + }, + )* + } + }; } -#[derive(Debug, PartialEq, Eq)] -pub enum GenericFn { - Max, - Min, - Sum, - Count, - Any, - All, -} +static SPECIALIZATION: OnceLock> = + OnceLock::new(); impl AggregateFunc { + /// Create a `AggregateFunc` from a string of the function name and given argument type(optional) + /// given an None type will be treated as null type, + /// which in turn for AggregateFunc like `Count` will be treated as any type + pub fn from_str_and_type( + name: &str, + arg_type: Option, + ) -> Result { + let rule = SPECIALIZATION.get_or_init(|| { + let mut spec = HashMap::new(); + for func in Self::iter() { + let sig = func.signature(); + spec.insert((sig.generic_fn, sig.input[0].clone()), func); + } + spec + }); + + let generic_fn = match name { + "max" => GenericFn::Max, + "min" => GenericFn::Min, + "sum" => GenericFn::Sum, + "count" => GenericFn::Count, + "any" => GenericFn::Any, + "all" => GenericFn::All, + _ => { + return InvalidQuerySnafu { + reason: format!("Unknown binary function: {}", name), + } + .fail(); + } + }; + let input_type = arg_type.unwrap_or_else(ConcreteDataType::null_datatype); + rule.get(&(generic_fn, input_type.clone())) + .cloned() + .with_context(|| InvalidQuerySnafu { + reason: format!( + "No specialization found for binary function {:?} with input type {:?}", + generic_fn, input_type + ), + }) + } + /// all concrete datatypes with precision types will be returned with largest possible variant /// as a exception, count have a signature of `null -> i64`, but it's actually `anytype -> i64` pub fn signature(&self) -> Signature { - match self { - AggregateFunc::MaxInt16 => Signature { - input: ConcreteDataType::int16_datatype(), - output: ConcreteDataType::int16_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxInt32 => Signature { - input: ConcreteDataType::int32_datatype(), - output: ConcreteDataType::int32_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxInt64 => Signature { - input: ConcreteDataType::int64_datatype(), - output: ConcreteDataType::int64_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxUInt16 => Signature { - input: ConcreteDataType::uint16_datatype(), - output: ConcreteDataType::uint16_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxUInt32 => Signature { - input: ConcreteDataType::uint32_datatype(), - output: ConcreteDataType::uint32_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxUInt64 => Signature { - input: ConcreteDataType::uint64_datatype(), - output: ConcreteDataType::uint64_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxFloat32 => Signature { - input: ConcreteDataType::float32_datatype(), - output: ConcreteDataType::float32_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxFloat64 => Signature { - input: ConcreteDataType::float64_datatype(), - output: ConcreteDataType::float64_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxBool => Signature { - input: ConcreteDataType::boolean_datatype(), - output: ConcreteDataType::boolean_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxString => Signature { - input: ConcreteDataType::string_datatype(), - output: ConcreteDataType::string_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxDate => Signature { - input: ConcreteDataType::date_datatype(), - output: ConcreteDataType::date_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxDateTime => Signature { - input: ConcreteDataType::datetime_datatype(), - output: ConcreteDataType::datetime_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxTimestamp => Signature { - input: ConcreteDataType::timestamp_second_datatype(), - output: ConcreteDataType::timestamp_second_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxTime => Signature { - input: ConcreteDataType::time_second_datatype(), - output: ConcreteDataType::time_second_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxDuration => Signature { - input: ConcreteDataType::duration_second_datatype(), - output: ConcreteDataType::duration_second_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MaxInterval => Signature { - input: ConcreteDataType::interval_year_month_datatype(), - output: ConcreteDataType::interval_year_month_datatype(), - generic_fn: GenericFn::Max, - }, - AggregateFunc::MinInt16 => Signature { - input: ConcreteDataType::int16_datatype(), - output: ConcreteDataType::int16_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinInt32 => Signature { - input: ConcreteDataType::int32_datatype(), - output: ConcreteDataType::int32_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinInt64 => Signature { - input: ConcreteDataType::int64_datatype(), - output: ConcreteDataType::int64_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinUInt16 => Signature { - input: ConcreteDataType::uint16_datatype(), - output: ConcreteDataType::uint16_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinUInt32 => Signature { - input: ConcreteDataType::uint32_datatype(), - output: ConcreteDataType::uint32_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinUInt64 => Signature { - input: ConcreteDataType::uint64_datatype(), - output: ConcreteDataType::uint64_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinFloat32 => Signature { - input: ConcreteDataType::float32_datatype(), - output: ConcreteDataType::float32_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinFloat64 => Signature { - input: ConcreteDataType::float64_datatype(), - output: ConcreteDataType::float64_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinBool => Signature { - input: ConcreteDataType::boolean_datatype(), - output: ConcreteDataType::boolean_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinString => Signature { - input: ConcreteDataType::string_datatype(), - output: ConcreteDataType::string_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinDate => Signature { - input: ConcreteDataType::date_datatype(), - output: ConcreteDataType::date_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinDateTime => Signature { - input: ConcreteDataType::datetime_datatype(), - output: ConcreteDataType::datetime_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinTimestamp => Signature { - input: ConcreteDataType::timestamp_second_datatype(), - output: ConcreteDataType::timestamp_second_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinTime => Signature { - input: ConcreteDataType::time_second_datatype(), - output: ConcreteDataType::time_second_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinDuration => Signature { - input: ConcreteDataType::duration_second_datatype(), - output: ConcreteDataType::duration_second_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::MinInterval => Signature { - input: ConcreteDataType::interval_year_month_datatype(), - output: ConcreteDataType::interval_year_month_datatype(), - generic_fn: GenericFn::Min, - }, - AggregateFunc::SumInt16 => Signature { - input: ConcreteDataType::int16_datatype(), - output: ConcreteDataType::int16_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumInt32 => Signature { - input: ConcreteDataType::int32_datatype(), - output: ConcreteDataType::int32_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumInt64 => Signature { - input: ConcreteDataType::int64_datatype(), - output: ConcreteDataType::int64_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumUInt16 => Signature { - input: ConcreteDataType::uint16_datatype(), - output: ConcreteDataType::uint16_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumUInt32 => Signature { - input: ConcreteDataType::uint32_datatype(), - output: ConcreteDataType::uint32_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumUInt64 => Signature { - input: ConcreteDataType::uint64_datatype(), - output: ConcreteDataType::uint64_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumFloat32 => Signature { - input: ConcreteDataType::float32_datatype(), - output: ConcreteDataType::float32_datatype(), - generic_fn: GenericFn::Sum, - }, - AggregateFunc::SumFloat64 => Signature { - input: ConcreteDataType::float64_datatype(), - output: ConcreteDataType::float64_datatype(), - generic_fn: GenericFn::Sum, - }, + generate_signature!(self, { AggregateFunc::Count => Signature { - input: ConcreteDataType::null_datatype(), + input: smallvec![ConcreteDataType::null_datatype()], output: ConcreteDataType::int64_datatype(), generic_fn: GenericFn::Count, - }, - AggregateFunc::Any => Signature { - input: ConcreteDataType::boolean_datatype(), - output: ConcreteDataType::boolean_datatype(), - generic_fn: GenericFn::Any, - }, - AggregateFunc::All => Signature { - input: ConcreteDataType::boolean_datatype(), - output: ConcreteDataType::boolean_datatype(), - generic_fn: GenericFn::All, - }, - } + } + },[ + MaxInt16 => (int16_datatype, Max), + MaxInt32 => (int32_datatype, Max), + MaxInt64 => (int64_datatype, Max), + MaxUInt16 => (uint16_datatype, Max), + MaxUInt32 => (uint32_datatype, Max), + MaxUInt64 => (uint64_datatype, Max), + MaxFloat32 => (float32_datatype, Max), + MaxFloat64 => (float64_datatype, Max), + MaxBool => (boolean_datatype, Max), + MaxString => (string_datatype, Max), + MaxDate => (date_datatype, Max), + MaxDateTime => (datetime_datatype, Max), + MaxTimestamp => (timestamp_second_datatype, Max), + MaxTime => (time_second_datatype, Max), + MaxDuration => (duration_second_datatype, Max), + MaxInterval => (interval_year_month_datatype, Max), + MinInt16 => (int16_datatype, Min), + MinInt32 => (int32_datatype, Min), + MinInt64 => (int64_datatype, Min), + MinUInt16 => (uint16_datatype, Min), + MinUInt32 => (uint32_datatype, Min), + MinUInt64 => (uint64_datatype, Min), + MinFloat32 => (float32_datatype, Min), + MinFloat64 => (float64_datatype, Min), + MinBool => (boolean_datatype, Min), + MinString => (string_datatype, Min), + MinDate => (date_datatype, Min), + MinDateTime => (datetime_datatype, Min), + MinTimestamp => (timestamp_second_datatype, Min), + MinTime => (time_second_datatype, Min), + MinDuration => (duration_second_datatype, Min), + MinInterval => (interval_year_month_datatype, Min), + SumInt16 => (int16_datatype, Sum), + SumInt32 => (int32_datatype, Sum), + SumInt64 => (int64_datatype, Sum), + SumUInt16 => (uint16_datatype, Sum), + SumUInt32 => (uint32_datatype, Sum), + SumUInt64 => (uint64_datatype, Sum), + SumFloat32 => (float32_datatype, Sum), + SumFloat64 => (float64_datatype, Sum), + Any => (boolean_datatype, Any), + All => (boolean_datatype, All) + ]) } } diff --git a/src/flow/src/expr/scalar.rs b/src/flow/src/expr/scalar.rs index 1bffdebd71f2..772bb06a4a90 100644 --- a/src/flow/src/expr/scalar.rs +++ b/src/flow/src/expr/scalar.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Scalar expressions. + use std::collections::{BTreeMap, BTreeSet}; use datatypes::prelude::ConcreteDataType; @@ -19,9 +21,8 @@ use datatypes::value::Value; use serde::{Deserialize, Serialize}; use snafu::ensure; -use crate::expr::error::{ - EvalError, InvalidArgumentSnafu, OptimizeSnafu, UnsupportedTemporalFilterSnafu, -}; +use crate::adapter::error::{Error, InvalidQuerySnafu, UnsupportedTemporalFilterSnafu}; +use crate::expr::error::{EvalError, InvalidArgumentSnafu, OptimizeSnafu}; use crate::expr::func::{BinaryFunc, UnaryFunc, UnmaterializableFunc, VariadicFunc}; /// A scalar expression, which can be evaluated to a value. @@ -64,6 +65,7 @@ pub enum ScalarExpr { } impl ScalarExpr { + /// Call a unary function on this expression. pub fn call_unary(self, func: UnaryFunc) -> Self { ScalarExpr::CallUnary { func, @@ -71,6 +73,7 @@ impl ScalarExpr { } } + /// Call a binary function on this expression and another. pub fn call_binary(self, other: Self, func: BinaryFunc) -> Self { ScalarExpr::CallBinary { func, @@ -79,6 +82,7 @@ impl ScalarExpr { } } + /// Eval this expression with the given values. pub fn eval(&self, values: &[Value]) -> Result { match self { ScalarExpr::Column(index) => Ok(values[*index].clone()), @@ -106,13 +110,13 @@ impl ScalarExpr { /// This method is applicable even when `permutation` is not a /// strict permutation, and it only needs to have entries for /// each column referenced in `self`. - pub fn permute(&mut self, permutation: &[usize]) -> Result<(), EvalError> { + pub fn permute(&mut self, permutation: &[usize]) -> Result<(), Error> { // check first so that we don't end up with a partially permuted expression ensure!( self.get_all_ref_columns() .into_iter() .all(|i| i < permutation.len()), - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "permutation {:?} is not a valid permutation for expression {:?}", permutation, self @@ -134,12 +138,12 @@ impl ScalarExpr { /// This method is applicable even when `permutation` is not a /// strict permutation, and it only needs to have entries for /// each column referenced in `self`. - pub fn permute_map(&mut self, permutation: &BTreeMap) -> Result<(), EvalError> { + pub fn permute_map(&mut self, permutation: &BTreeMap) -> Result<(), Error> { // check first so that we don't end up with a partially permuted expression ensure!( self.get_all_ref_columns() .is_subset(&permutation.keys().cloned().collect()), - InvalidArgumentSnafu { + InvalidQuerySnafu { reason: format!( "permutation {:?} is not a valid permutation for expression {:?}", permutation, self @@ -168,6 +172,21 @@ impl ScalarExpr { support } + /// Return true if the expression is a column reference. + pub fn is_column(&self) -> bool { + matches!(self, ScalarExpr::Column(_)) + } + + /// Cast the expression to a column reference if it is one. + pub fn as_column(&self) -> Option { + if let ScalarExpr::Column(i) = self { + Some(*i) + } else { + None + } + } + + /// Cast the expression to a literal if it is one. pub fn as_literal(&self) -> Option { if let ScalarExpr::Literal(lit, _column_type) = self { Some(lit.clone()) @@ -176,34 +195,42 @@ impl ScalarExpr { } } + /// Return true if the expression is a literal. pub fn is_literal(&self) -> bool { matches!(self, ScalarExpr::Literal(..)) } + /// Return true if the expression is a literal true. pub fn is_literal_true(&self) -> bool { Some(Value::Boolean(true)) == self.as_literal() } + /// Return true if the expression is a literal false. pub fn is_literal_false(&self) -> bool { Some(Value::Boolean(false)) == self.as_literal() } + /// Return true if the expression is a literal null. pub fn is_literal_null(&self) -> bool { Some(Value::Null) == self.as_literal() } + /// Build a literal null pub fn literal_null() -> Self { ScalarExpr::Literal(Value::Null, ConcreteDataType::null_datatype()) } + /// Build a literal from value and type pub fn literal(res: Value, typ: ConcreteDataType) -> Self { ScalarExpr::Literal(res, typ) } + /// Build a literal false pub fn literal_false() -> Self { ScalarExpr::Literal(Value::Boolean(false), ConcreteDataType::boolean_datatype()) } + /// Build a literal true pub fn literal_true() -> Self { ScalarExpr::Literal(Value::Boolean(true), ConcreteDataType::boolean_datatype()) } @@ -246,17 +273,17 @@ impl ScalarExpr { } } - fn visit_mut_post_nolimit(&mut self, f: &mut F) -> Result<(), EvalError> + fn visit_mut_post_nolimit(&mut self, f: &mut F) -> Result<(), Error> where - F: FnMut(&mut Self) -> Result<(), EvalError>, + F: FnMut(&mut Self) -> Result<(), Error>, { self.visit_mut_children(|e: &mut Self| e.visit_mut_post_nolimit(f))?; f(self) } - fn visit_mut_children(&mut self, mut f: F) -> Result<(), EvalError> + fn visit_mut_children(&mut self, mut f: F) -> Result<(), Error> where - F: FnMut(&mut Self) -> Result<(), EvalError>, + F: FnMut(&mut Self) -> Result<(), Error>, { match self { ScalarExpr::Column(_) @@ -302,7 +329,7 @@ impl ScalarExpr { /// /// false for lower bound, true for upper bound /// TODO(discord9): allow simple transform like `now() + a < b` to `now() < b - a` - pub fn extract_bound(&self) -> Result<(Option, Option), EvalError> { + pub fn extract_bound(&self) -> Result<(Option, Option), Error> { let unsupported_err = |msg: &str| { UnsupportedTemporalFilterSnafu { reason: msg.to_string(), @@ -437,11 +464,11 @@ mod test { let mut expr = ScalarExpr::Column(4); let permutation = vec![1, 2, 3]; let res = expr.permute(&permutation); - assert!(matches!(res, Err(EvalError::InvalidArgument { .. }))); + assert!(matches!(res, Err(Error::InvalidQuery { .. }))); let mut expr = ScalarExpr::Column(0); let permute_map = BTreeMap::from([(1, 2), (3, 4)]); let res = expr.permute_map(&permute_map); - assert!(matches!(res, Err(EvalError::InvalidArgument { .. }))); + assert!(matches!(res, Err(Error::InvalidQuery { .. }))); } } diff --git a/src/flow/src/expr/signature.rs b/src/flow/src/expr/signature.rs new file mode 100644 index 000000000000..a7615502a520 --- /dev/null +++ b/src/flow/src/expr/signature.rs @@ -0,0 +1,67 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Function signature, useful for type checking and function resolution. + +use datatypes::data_type::ConcreteDataType; +use serde::{Deserialize, Serialize}; +use smallvec::SmallVec; + +/// Function signature +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] +pub struct Signature { + /// the input types, usually not great than two input arg + pub input: SmallVec<[ConcreteDataType; 2]>, + /// Output type + pub output: ConcreteDataType, + /// Generic function + pub generic_fn: GenericFn, +} + +/// Generic function category +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Deserialize, Serialize, Hash)] +pub enum GenericFn { + // aggregate func + Max, + Min, + Sum, + Count, + Any, + All, + // unary func + Not, + IsNull, + IsTrue, + IsFalse, + StepTimestamp, + Cast, + // binary func + Eq, + NotEq, + Lt, + Lte, + Gt, + Gte, + Add, + Sub, + Mul, + Div, + Mod, + // varadic func + And, + Or, + // unmaterized func + Now, + CurrentSchema, +} diff --git a/src/flow/src/lib.rs b/src/flow/src/lib.rs index 1e232b68e15c..53c995462fb4 100644 --- a/src/flow/src/lib.rs +++ b/src/flow/src/lib.rs @@ -12,8 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! This crate manage dataflow in Greptime, including adapter, expr, plan, repr and utils. +//! It can transform substrait plan into it's own plan and execute it. +//! It also contains definition of expression, adapter and plan, and internal state management. + #![allow(dead_code)] #![allow(unused_imports)] +#![warn(missing_docs)] +#[warn(clippy::missing_docs_in_private_items)] // allow unused for now because it should be use later mod adapter; mod compute; diff --git a/src/flow/src/plan.rs b/src/flow/src/plan.rs index 837d7f207a11..b4ed48a1c71b 100644 --- a/src/flow/src/plan.rs +++ b/src/flow/src/plan.rs @@ -27,14 +27,17 @@ use crate::expr::{ use crate::plan::join::JoinPlan; use crate::repr::{DiffRow, RelationType}; +/// A plan for a dataflow component. But with type to indicate the output type of the relation. #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] pub struct TypedPlan { /// output type of the relation pub typ: RelationType, + /// The untyped plan. pub plan: Plan, } /// TODO(discord9): support `TableFunc`(by define FlatMap that map 1 to n) +/// Plan describe how to transform data in dataflow #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] pub enum Plan { /// A constant collection of rows. diff --git a/src/flow/src/plan/reduce.rs b/src/flow/src/plan/reduce.rs index 52dd3a509d50..09dc44b37f10 100644 --- a/src/flow/src/plan/reduce.rs +++ b/src/flow/src/plan/reduce.rs @@ -16,9 +16,12 @@ use serde::{Deserialize, Serialize}; use crate::expr::{AggregateExpr, Id, LocalId, MapFilterProject, SafeMfpPlan, ScalarExpr}; +/// Describe how to extract key-value pair from a `Row` #[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Deserialize, Serialize)] pub struct KeyValPlan { + /// Extract key from row pub key_plan: SafeMfpPlan, + /// Extract value from row pub val_plan: SafeMfpPlan, } diff --git a/src/flow/src/repr.rs b/src/flow/src/repr.rs index 91ad6c1c38b6..5239869c39bf 100644 --- a/src/flow/src/repr.rs +++ b/src/flow/src/repr.rs @@ -48,6 +48,7 @@ pub type Duration = i64; /// Default type for a repr of changes to a collection. pub type DiffRow = (Row, Timestamp, Diff); +/// Row with key-value pair, timestamp and diff pub type KeyValDiffRow = ((Row, Row), Timestamp, Diff); /// Convert a value that is or can be converted to Datetime to internal timestamp @@ -93,22 +94,31 @@ pub fn value_to_internal_ts(value: Value) -> Result { /// i.e. more compact like raw u8 of \[tag0, value0, tag1, value1, ...\] #[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Default, Serialize, Deserialize)] pub struct Row { + /// The inner vector of values pub inner: Vec, } impl Row { + /// Create an empty row pub fn empty() -> Self { Self { inner: vec![] } } + + /// Create a row from a vector of values pub fn new(row: Vec) -> Self { Self { inner: row } } + + /// Get the value at the given index pub fn get(&self, idx: usize) -> Option<&Value> { self.inner.get(idx) } + + /// Clear the row pub fn clear(&mut self) { self.inner.clear(); } + /// clear and return the inner vector /// /// useful if you want to reuse the vector as a buffer @@ -116,6 +126,7 @@ impl Row { self.inner.clear(); &mut self.inner } + /// pack a iterator of values into a row pub fn pack(iter: I) -> Row where @@ -125,22 +136,31 @@ impl Row { inner: iter.into_iter().collect(), } } + /// unpack a row into a vector of values pub fn unpack(self) -> Vec { self.inner } + + /// extend the row with values from an iterator pub fn extend(&mut self, iter: I) where I: IntoIterator, { self.inner.extend(iter); } + + /// Creates a consuming iterator, that is, one that moves each value out of the `Row` (from start to end). The `Row` cannot be used after calling this pub fn into_iter(self) -> impl Iterator { self.inner.into_iter() } + + /// Returns an iterator over the slice. pub fn iter(&self) -> impl Iterator { self.inner.iter() } + + /// eturns the number of elements in the row, also known as its 'length'. pub fn len(&self) -> usize { self.inner.len() } diff --git a/src/flow/src/repr/relation.rs b/src/flow/src/repr/relation.rs index 2f98d8f1db2a..c1f7bfdc8853 100644 --- a/src/flow/src/repr/relation.rs +++ b/src/flow/src/repr/relation.rs @@ -21,6 +21,7 @@ use crate::adapter::error::{InvalidQuerySnafu, Result}; /// a set of column indices that are "keys" for the collection. #[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize, Hash)] pub struct Key { + /// indicate whose column form key pub column_indices: Vec, } @@ -122,6 +123,7 @@ impl RelationType { self } + /// Adds new keys for the relation. Also sorts the key indices. pub fn with_keys(mut self, keys: Vec>) -> Self { for key in keys { self = self.with_key(key) diff --git a/src/flow/src/utils.rs b/src/flow/src/utils.rs index c2dc15dad4f1..8623f3fada78 100644 --- a/src/flow/src/utils.rs +++ b/src/flow/src/utils.rs @@ -25,7 +25,10 @@ use crate::expr::error::InternalSnafu; use crate::expr::{EvalError, ScalarExpr}; use crate::repr::{value_to_internal_ts, Diff, DiffRow, Duration, KeyValDiffRow, Row, Timestamp}; +/// A batch of updates, arranged by key pub type Batch = BTreeMap>; + +/// A spine of batches, arranged by timestamp pub type Spine = BTreeMap; /// Determine when should a key expire according to it's event timestamp in key, @@ -136,6 +139,7 @@ pub struct Arrangement { } impl Arrangement { + /// create a new empty arrangement pub fn new() -> Self { Self { spine: Default::default(), @@ -453,19 +457,26 @@ pub struct ArrangeHandler { inner: Arc>, } impl ArrangeHandler { + /// create a new handler from arrangement pub fn from(arr: Arrangement) -> Self { Self { inner: Arc::new(RwLock::new(arr)), } } + + /// write lock the arrangement pub fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, Arrangement> { self.inner.blocking_write() } + + /// read lock the arrangement pub fn read(&self) -> tokio::sync::RwLockReadGuard<'_, Arrangement> { self.inner.blocking_read() } /// clone the handler, but only keep the future updates + /// + /// it's a cheap operation, since it's `Arc-ed` and only clone the `Arc` pub fn clone_future_only(&self) -> Option { if self.read().is_written { return None; @@ -478,6 +489,8 @@ impl ArrangeHandler { /// clone the handler, but keep all updates /// prevent illegal clone after the arrange have been written, /// because that will cause loss of data before clone + /// + /// it's a cheap operation, since it's `Arc-ed` and only clone the `Arc` pub fn clone_full_arrange(&self) -> Option { if self.read().is_written { return None;