Skip to content

Commit

Permalink
Move more expression code to datafusion-expr (#2669)
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove authored Jun 1, 2022
1 parent 0806d85 commit 340bfb1
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 138 deletions.
136 changes: 1 addition & 135 deletions datafusion/core/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,139 +18,5 @@
//! This module provides an `Expr` enum for representing expressions
//! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct.

pub use super::Operator;
use crate::error::Result;
use arrow::datatypes::DataType;
pub use datafusion_common::{Column, ExprSchema};
pub use datafusion_expr::expr_fn::*;
use datafusion_expr::AccumulatorFunctionImplementation;
use datafusion_expr::BuiltinScalarFunction;
pub use datafusion_expr::Expr;
use datafusion_expr::StateTypeFunction;
pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
use datafusion_expr::{AggregateUDF, ScalarUDF};
use datafusion_expr::{
ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
};
use std::sync::Arc;

/// Combines an array of filter expressions into a single filter expression
/// consisting of the input filter expressions joined with logical AND.
/// Returns None if the filters array is empty.
pub fn combine_filters(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
let combined_filter = filters
.iter()
.skip(1)
.fold(filters[0].clone(), |acc, filter| and(acc, filter.clone()));
Some(combined_filter)
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
/// * the UDF has a fixed return type
/// * the UDF has a fixed signature (e.g. [f64, f64])
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}

/// Creates a new UDAF with a specific signature, state type and return type.
/// The signature and state type must match the `Accumulator's implementation`.
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
AggregateUDF::new(
name,
&Signature::exact(vec![input_type], volatility),
&return_type,
&accumulator,
&state_type,
)
}

/// Calls a named built in function
/// ```
/// use datafusion::logical_plan::*;
///
/// // create the expression sin(x) < 0.2
/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2));
/// ```
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
Err(e) => Err(e),
}
}

#[cfg(test)]
mod tests {
use super::super::{col, lit};
use super::*;
use datafusion_expr::expr_fn::binary_expr;

#[test]
fn digest_function_definitions() {
if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
unreachable!();
}
}

#[test]
fn combine_zero_filters() {
let result = combine_filters(&[]);
assert_eq!(result, None);
}

#[test]
fn combine_one_filter() {
let filter = binary_expr(col("c1"), Operator::Lt, lit(1));
let result = combine_filters(&[filter.clone()]);
assert_eq!(result, Some(filter));
}

#[test]
fn combine_multiple_filters() {
let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1));
let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2));
let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3));
let result =
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
}
pub use datafusion_expr::{expr_fn::*, lit, lit_timestamp_nano, Expr, Literal, Operator};
125 changes: 122 additions & 3 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@

//! Functions for creating logical expressions

use crate::conditional_expressions::CaseBuilder;
use crate::logical_plan::Subquery;
use crate::{aggregate_function, built_in_function, lit, Expr, LogicalPlan, Operator};
use crate::{
aggregate_function, built_in_function, conditional_expressions::CaseBuilder, lit,
logical_plan::Subquery, AccumulatorFunctionImplementation, AggregateUDF,
BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::sync::Arc;

/// Create a column expression based on a qualified or unqualified column name
Expand Down Expand Up @@ -359,6 +364,86 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder {
CaseBuilder::new(None, vec![when], vec![then], None)
}

/// Combines an array of filter expressions into a single filter expression
/// consisting of the input filter expressions joined with logical AND.
/// Returns None if the filters array is empty.
pub fn combine_filters(filters: &[Expr]) -> Option<Expr> {
if filters.is_empty() {
return None;
}
let combined_filter = filters
.iter()
.skip(1)
.fold(filters[0].clone(), |acc, filter| and(acc, filter.clone()));
Some(combined_filter)
}

/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
match expr {
Expr::Alias(sub_expr, _) => unalias(*sub_expr),
_ => expr,
}
}

/// Creates a new UDF with a specific signature and specific return type.
/// This is a helper function to create a new UDF.
/// The function `create_udf` returns a subset of all possible `ScalarFunction`:
/// * the UDF has a fixed return type
/// * the UDF has a fixed signature (e.g. [f64, f64])
pub fn create_udf(
name: &str,
input_types: Vec<DataType>,
return_type: Arc<DataType>,
volatility: Volatility,
fun: ScalarFunctionImplementation,
) -> ScalarUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
ScalarUDF::new(
name,
&Signature::exact(input_types, volatility),
&return_type,
&fun,
)
}

/// Creates a new UDAF with a specific signature, state type and return type.
/// The signature and state type must match the `Accumulator's implementation`.
#[allow(clippy::rc_buffer)]
pub fn create_udaf(
name: &str,
input_type: DataType,
return_type: Arc<DataType>,
volatility: Volatility,
accumulator: AccumulatorFunctionImplementation,
state_type: Arc<Vec<DataType>>,
) -> AggregateUDF {
let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone()));
let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone()));
AggregateUDF::new(
name,
&Signature::exact(vec![input_type], volatility),
&return_type,
&accumulator,
&state_type,
)
}

/// Calls a named built in function
/// ```
/// use datafusion_expr::{col, lit, call_fn};
///
/// // create the expression sin(x) < 0.2
/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2));
/// ```
pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
match name.as_ref().parse::<BuiltinScalarFunction>() {
Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
Err(e) => Err(e),
}
}

#[cfg(test)]
mod test {
use super::*;
Expand Down Expand Up @@ -501,4 +586,38 @@ mod test {
test_scalar_expr!(DatePart, date_part, part, date);
test_scalar_expr!(DateTrunc, date_trunc, part, date);
}

#[test]
fn digest_function_definitions() {
if let Expr::ScalarFunction { fun, args } = digest(col("tableA.a"), lit("md5")) {
let name = BuiltinScalarFunction::Digest;
assert_eq!(name, fun);
assert_eq!(2, args.len());
} else {
unreachable!();
}
}

#[test]
fn combine_zero_filters() {
let result = combine_filters(&[]);
assert_eq!(result, None);
}

#[test]
fn combine_one_filter() {
let filter = binary_expr(col("c1"), Operator::Lt, lit(1));
let result = combine_filters(&[filter.clone()]);
assert_eq!(result, Some(filter));
}

#[test]
fn combine_multiple_filters() {
let filter1 = binary_expr(col("c1"), Operator::Lt, lit(1));
let filter2 = binary_expr(col("c2"), Operator::Lt, lit(2));
let filter3 = binary_expr(col("c3"), Operator::Lt, lit(3));
let result =
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
}

0 comments on commit 340bfb1

Please sign in to comment.