diff --git a/datafusion/core/src/logical_plan/expr.rs b/datafusion/core/src/logical_plan/expr.rs index f8d7f46c6729..fecb7ce46f7f 100644 --- a/datafusion/core/src/logical_plan/expr.rs +++ b/datafusion/core/src/logical_plan/expr.rs @@ -18,139 +18,5 @@ //! This module provides an `Expr` enum for representing expressions //! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct. -pub use super::Operator; -use crate::error::Result; -use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; -pub use datafusion_expr::expr_fn::*; -use datafusion_expr::AccumulatorFunctionImplementation; -use datafusion_expr::BuiltinScalarFunction; -pub use datafusion_expr::Expr; -use datafusion_expr::StateTypeFunction; -pub use datafusion_expr::{lit, lit_timestamp_nano, Literal}; -use datafusion_expr::{AggregateUDF, ScalarUDF}; -use datafusion_expr::{ - ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, -}; -use std::sync::Arc; - -/// Combines an array of filter expressions into a single filter expression -/// consisting of the input filter expressions joined with logical AND. -/// Returns None if the filters array is empty. -pub fn combine_filters(filters: &[Expr]) -> Option { - if filters.is_empty() { - return None; - } - let combined_filter = filters - .iter() - .skip(1) - .fold(filters[0].clone(), |acc, filter| and(acc, filter.clone())); - Some(combined_filter) -} - -/// Recursively un-alias an expressions -#[inline] -pub fn unalias(expr: Expr) -> Expr { - match expr { - Expr::Alias(sub_expr, _) => unalias(*sub_expr), - _ => expr, - } -} - -/// Creates a new UDF with a specific signature and specific return type. -/// This is a helper function to create a new UDF. -/// The function `create_udf` returns a subset of all possible `ScalarFunction`: -/// * the UDF has a fixed return type -/// * the UDF has a fixed signature (e.g. [f64, f64]) -pub fn create_udf( - name: &str, - input_types: Vec, - return_type: Arc, - volatility: Volatility, - fun: ScalarFunctionImplementation, -) -> ScalarUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - ScalarUDF::new( - name, - &Signature::exact(input_types, volatility), - &return_type, - &fun, - ) -} - -/// Creates a new UDAF with a specific signature, state type and return type. -/// The signature and state type must match the `Accumulator's implementation`. -#[allow(clippy::rc_buffer)] -pub fn create_udaf( - name: &str, - input_type: DataType, - return_type: Arc, - volatility: Volatility, - accumulator: AccumulatorFunctionImplementation, - state_type: Arc>, -) -> AggregateUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); - AggregateUDF::new( - name, - &Signature::exact(vec![input_type], volatility), - &return_type, - &accumulator, - &state_type, - ) -} - -/// Calls a named built in function -/// ``` -/// use datafusion::logical_plan::*; -/// -/// // create the expression sin(x) < 0.2 -/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2)); -/// ``` -pub fn call_fn(name: impl AsRef, args: Vec) -> Result { - match name.as_ref().parse::() { - Ok(fun) => Ok(Expr::ScalarFunction { fun, args }), - Err(e) => Err(e), - } -} - -#[cfg(test)] -mod tests { - use super::super::{col, lit}; - use super::*; - use datafusion_expr::expr_fn::binary_expr; - - #[test] - fn digest_function_definitions() { - if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) { - let name = BuiltinScalarFunction::Digest; - assert_eq!(name, fun); - assert_eq!(2, args.len()); - } else { - unreachable!(); - } - } - - #[test] - fn combine_zero_filters() { - let result = combine_filters(&[]); - assert_eq!(result, None); - } - - #[test] - fn combine_one_filter() { - let filter = binary_expr(col("c1"), Operator::Lt, lit(1)); - let result = combine_filters(&[filter.clone()]); - assert_eq!(result, Some(filter)); - } - - #[test] - fn combine_multiple_filters() { - let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1)); - let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2)); - let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3)); - let result = - combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]); - assert_eq!(result, Some(and(and(filter1, filter2), filter3))); - } -} +pub use datafusion_expr::{expr_fn::*, lit, lit_timestamp_nano, Expr, Literal, Operator}; diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 9ed8c536bc4a..14eeb2c82551 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -17,9 +17,14 @@ //! Functions for creating logical expressions -use crate::conditional_expressions::CaseBuilder; -use crate::logical_plan::Subquery; -use crate::{aggregate_function, built_in_function, lit, Expr, LogicalPlan, Operator}; +use crate::{ + aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit, + logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF, + BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, + ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, +}; +use arrow::datatypes::DataType; +use datafusion_common::Result; use std::sync::Arc; /// Create a column expression based on a qualified or unqualified column name @@ -359,6 +364,86 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { CaseBuilder::new(None, vec![when], vec![then], None) } +/// Combines an array of filter expressions into a single filter expression +/// consisting of the input filter expressions joined with logical AND. +/// Returns None if the filters array is empty. +pub fn combine_filters(filters: &[Expr]) -> Option { + if filters.is_empty() { + return None; + } + let combined_filter = filters + .iter() + .skip(1) + .fold(filters[0].clone(), |acc, filter| and(acc, filter.clone())); + Some(combined_filter) +} + +/// Recursively un-alias an expressions +#[inline] +pub fn unalias(expr: Expr) -> Expr { + match expr { + Expr::Alias(sub_expr, _) => unalias(*sub_expr), + _ => expr, + } +} + +/// Creates a new UDF with a specific signature and specific return type. +/// This is a helper function to create a new UDF. +/// The function `create_udf` returns a subset of all possible `ScalarFunction`: +/// * the UDF has a fixed return type +/// * the UDF has a fixed signature (e.g. [f64, f64]) +pub fn create_udf( + name: &str, + input_types: Vec, + return_type: Arc, + volatility: Volatility, + fun: ScalarFunctionImplementation, +) -> ScalarUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + ScalarUDF::new( + name, + &Signature::exact(input_types, volatility), + &return_type, + &fun, + ) +} + +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Accumulator's implementation`. +#[allow(clippy::rc_buffer)] +pub fn create_udaf( + name: &str, + input_type: DataType, + return_type: Arc, + volatility: Volatility, + accumulator: AccumulatorFunctionImplementation, + state_type: Arc>, +) -> AggregateUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + AggregateUDF::new( + name, + &Signature::exact(vec![input_type], volatility), + &return_type, + &accumulator, + &state_type, + ) +} + +/// Calls a named built in function +/// ``` +/// use datafusion_expr::{col, lit, call_fn}; +/// +/// // create the expression sin(x) < 0.2 +/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2)); +/// ``` +pub fn call_fn(name: impl AsRef, args: Vec) -> Result { + match name.as_ref().parse::() { + Ok(fun) => Ok(Expr::ScalarFunction { fun, args }), + Err(e) => Err(e), + } +} + #[cfg(test)] mod test { use super::*; @@ -501,4 +586,38 @@ mod test { test_scalar_expr!(DatePart, date_part, part, date); test_scalar_expr!(DateTrunc, date_trunc, part, date); } + + #[test] + fn digest_function_definitions() { + if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) { + let name = BuiltinScalarFunction::Digest; + assert_eq!(name, fun); + assert_eq!(2, args.len()); + } else { + unreachable!(); + } + } + + #[test] + fn combine_zero_filters() { + let result = combine_filters(&[]); + assert_eq!(result, None); + } + + #[test] + fn combine_one_filter() { + let filter = binary_expr(col("c1"), Operator::Lt, lit(1)); + let result = combine_filters(&[filter.clone()]); + assert_eq!(result, Some(filter)); + } + + #[test] + fn combine_multiple_filters() { + let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1)); + let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2)); + let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3)); + let result = + combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]); + assert_eq!(result, Some(and(and(filter1, filter2), filter3))); + } }