diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md index aae451add9e75..eecb63d3be656 100644 --- a/datafusion-examples/README.md +++ b/datafusion-examples/README.md @@ -62,6 +62,7 @@ cargo run --example csv_sql - [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined Scalar Function (UDF) - [`advanced_udf.rs`](examples/advanced_udf.rs): Define and invoke a more complicated User Defined Scalar Function (UDF) - [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF) +- [`advanced_udaf.rs`](examples/advanced_udaf.rs): Define and invoke a more complicated User Defined Aggregate Function (UDAF) - [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF) - [`advanced_udwf.rs`](examples/advanced_udwf.rs): Define and invoke a more complicated User Defined Window Function (UDWF) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs new file mode 100644 index 0000000000000..393c27678f8c0 --- /dev/null +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -0,0 +1,227 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, Float32Array}, + record_batch::RecordBatch, +}; +use datafusion::error::Result; +use datafusion::prelude::*; +use datafusion_common::{cast::as_float64_array, ScalarValue}; +use datafusion_expr::{Accumulator, AggregateUDF, AggregateUDFImpl, Signature}; + +/// This example shows how to use the full AggregateUDFImpl API to implement a user +/// defined aggregate function. As in the `simple_udaf.rs` example, this struct implements +/// a function `accumulator` that returns the `Accumulator` instance. +/// +/// To do so, we must implement the `AggregateUDFImpl` trait. +struct GeoMeanUdf { + signature: Signature, +} + +impl GeoMeanUdf { + /// Create a new instance of the GeoMeanUdf struct + fn new() -> Self { + Self { + signature: Signature::exact( + // this function will always take one arguments of type f64 + vec![DataType::Float64], + // this function is deterministic and will always return the same + // result for the same input + Volatility::Immutable, + ), + } + } +} + +impl AggregateUDFImpl for GeoMeanUdf { + /// We implement as_any so that we can downcast the AggregateUDFImpl trait object + fn as_any(&self) -> &dyn Any { + self + } + + /// Return the name of this function + fn name(&self) -> &str { + "geo_mean" + } + + /// Return the "signature" of this function -- namely that types of arguments it will take + fn signature(&self) -> &Signature { + &self.signature + } + + /// What is the type of value that will be returned by this function. + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + /// This is the accumulator factory; DataFusion uses it to create new accumulators. + fn accumulator(&self, _arg: &DataType) -> Result> { + Ok(Box::new(GeometricMean::new())) + } + + /// This is the description of the state. accumulator's state() must match the types here. + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(vec![DataType::Float64, DataType::UInt32]) + } +} + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + let arr = &values[0]; + (0..arr.len()).try_for_each(|index| { + let v = ScalarValue::try_from_array(arr, index)?; + + if let ScalarValue::Float64(Some(value)) = v { + self.prod *= value; + self.n += 1; + } else { + unreachable!("") + } + Ok(()) + }) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + let arr = &states[0]; + (0..arr.len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>()?; + if let (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) = + (&v[0], &v[1]) + { + self.prod *= prod; + self.n += n; + } else { + unreachable!("") + } + Ok(()) + }) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::arrow::datatypes::{Field, Schema}; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + // create the AggregateUDF + let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); + ctx.register_udaf(geometric_mean.clone()); + + let sql_df = ctx.sql("SELECT geo_mean(a) FROM t").await?; + sql_df.show().await?; + + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t").await?; + + // perform the aggregation + let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + + // execute the query + let results = df.collect().await?; + + // downcast the array to the expected type + let result = as_float64_array(results[0].column(0))?; + + // verify that the calculation is correct + assert!((result.value(0) - 8.0).abs() < f64::EPSILON); + println!("The geometric mean of [2,4,8,64] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index fb0ecd02c6b09..e3406e8bd3a0c 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -36,8 +36,7 @@ use datafusion::{ assert_batches_eq, error::Result, logical_expr::{ - AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, TypeSignature, Volatility, + AccumulatorFactoryFunction, AggregateUDF, Signature, TypeSignature, Volatility, }, physical_plan::Accumulator, prelude::SessionContext, @@ -46,7 +45,7 @@ use datafusion::{ use datafusion_common::{ assert_contains, cast::as_primitive_array, exec_err, DataFusionError, }; -use datafusion_expr::create_udaf; +use datafusion_expr::{create_udaf, SimpleAggregateUDF}; use datafusion_physical_expr::expressions::AvgAccumulator; /// Test to show the contents of the setup @@ -408,26 +407,27 @@ impl TimeSum { fn register(ctx: &mut SessionContext, test_state: Arc, name: &str) { let timestamp_type = DataType::Timestamp(TimeUnit::Nanosecond, None); + let input_type = vec![timestamp_type.clone()]; // Returns the same type as its input - let return_type = Arc::new(timestamp_type.clone()); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&return_type))); + let return_type = timestamp_type.clone(); - let state_type = Arc::new(vec![timestamp_type.clone()]); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::clone(&state_type))); + let state_type = vec![timestamp_type.clone()]; let volatility = Volatility::Immutable; - let signature = Signature::exact(vec![timestamp_type], volatility); - let captured_state = Arc::clone(&test_state); let accumulator: AccumulatorFactoryFunction = Arc::new(move |_| Ok(Box::new(Self::new(Arc::clone(&captured_state))))); - let time_sum = - AggregateUDF::new(name, &signature, &return_type, &accumulator, &state_type); + let time_sum = AggregateUDF::from(SimpleAggregateUDF::new( + name, + input_type, + return_type, + volatility, + accumulator, + state_type, + )); // register the selector as "time_sum" ctx.register_udaf(time_sum) @@ -510,11 +510,8 @@ impl FirstSelector { } fn register(ctx: &mut SessionContext) { - let return_type = Arc::new(Self::output_datatype()); - let state_type = Arc::new(Self::state_datatypes()); - - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + let return_type = Self::output_datatype(); + let state_type = Self::state_datatypes(); // Possible input signatures let signatures = vec![TypeSignature::Exact(Self::input_datatypes())]; @@ -526,13 +523,13 @@ impl FirstSelector { let name = "first"; - let first = AggregateUDF::new( - name, - &Signature::one_of(signatures, volatility), - &return_type, - &accumulator, - &state_type, - ); + let first = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( + name.to_string(), + Signature::one_of(signatures, volatility), + return_type, + accumulator, + state_type, + )); // register the selector as "first" ctx.register_udaf(first) diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 7b3f65248586f..8d78726f127bf 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -26,9 +26,9 @@ use crate::{ aggregate_function, built_in_function, conditional_expressions::CaseBuilder, logical_plan::Subquery, AccumulatorFactoryFunction, AggregateUDF, BuiltinScalarFunction, Expr, LogicalPlan, Operator, ReturnTypeFunction, - ScalarFunctionImplementation, ScalarUDF, Signature, StateTypeFunction, Volatility, + ScalarFunctionImplementation, ScalarUDF, Signature, Volatility, }; -use crate::{ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; +use crate::{AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowUDF, WindowUDFImpl}; use arrow::datatypes::DataType; use datafusion_common::{Column, Result}; use std::any::Any; @@ -1037,15 +1037,92 @@ pub fn create_udaf( accumulator: AccumulatorFactoryFunction, state_type: Arc>, ) -> AggregateUDF { - let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); - let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); - AggregateUDF::new( + let return_type = Arc::try_unwrap(return_type).unwrap_or_else(|t| t.as_ref().clone()); + let state_type = Arc::try_unwrap(state_type).unwrap_or_else(|t| t.as_ref().clone()); + AggregateUDF::from(SimpleAggregateUDF::new( name, - &Signature::exact(input_type, volatility), - &return_type, - &accumulator, - &state_type, - ) + input_type, + return_type, + volatility, + accumulator, + state_type, + )) +} + +/// Implements [`AggregateUDFImpl`] for functions that have a single signature and +/// return type. +pub struct SimpleAggregateUDF { + name: String, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, +} + +impl SimpleAggregateUDF { + /// Create a new `AggregateUDFImpl` from a name, input types, return type, state type and + /// implementation. Implementing [`AggregateUDFImpl`] allows more flexibility + pub fn new( + name: impl Into, + input_type: Vec, + return_type: DataType, + volatility: Volatility, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + let signature = Signature::exact(input_type, volatility); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } + + pub fn new_with_signature( + name: impl Into, + signature: Signature, + return_type: DataType, + accumulator: AccumulatorFactoryFunction, + state_type: Vec, + ) -> Self { + let name = name.into(); + Self { + name, + signature, + return_type, + accumulator, + state_type, + } + } +} + +impl AggregateUDFImpl for SimpleAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, arg: &DataType) -> Result> { + (self.accumulator)(arg) + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + Ok(self.state_type.clone()) + } } /// Creates a new UDWF with a specific signature, state type and return type. diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 077681d217257..0d431f10c4324 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD, }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::AggregateUDF; +pub use udaf::{AggregateUDF, AggregateUDFImpl}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index cfbca4ab1337a..5dab4a474b301 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -23,6 +23,7 @@ use crate::{ }; use arrow::datatypes::DataType; use datafusion_common::Result; +use std::any::Any; use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; @@ -86,6 +87,10 @@ impl std::hash::Hash for AggregateUDF { impl AggregateUDF { /// Create a new AggregateUDF + /// + /// See [`AggregateUDFImpl`] for a more convenient way to create a + /// `AggregateUDF` using trait objects + #[deprecated(since = "34.0.0", note = "please implement AggregateUDFImpl instead")] pub fn new( name: &str, signature: &Signature, @@ -102,6 +107,39 @@ impl AggregateUDF { } } + /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object + /// + /// Note this is the same as using the `From` impl (`AggregateUDF::from`) + pub fn new_from_impl(fun: F) -> AggregateUDF + where + F: AggregateUDFImpl + Send + Sync + 'static, + { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let accumulator: AccumulatorFactoryFunction = + Arc::new(move |arg| captured_self.accumulator(arg)); + + let captured_self = arc_fun.clone(); + let state_type: StateTypeFunction = Arc::new(move |return_type| { + let state_type = captured_self.state_type(return_type)?; + Ok(Arc::new(state_type)) + }); + + Self { + name: arc_fun.name().to_string(), + signature: arc_fun.signature().clone(), + return_type: return_type.clone(), + accumulator, + state_type, + } + } + /// creates an [`Expr`] that calls the aggregate function. /// /// This utility allows using the UDAF without requiring access to @@ -147,3 +185,89 @@ impl AggregateUDF { Ok(Arc::try_unwrap(res).unwrap_or_else(|res| res.as_ref().clone())) } } + +impl From for AggregateUDF +where + F: AggregateUDFImpl + Send + Sync + 'static, +{ + fn from(fun: F) -> Self { + Self::new_from_impl(fun) + } +} + +/// Trait for implementing [`AggregateUDF`]. +/// +/// This trait exposes the full API for implementing user defined aggregate functions and +/// can be used to implement any function. +/// +/// See [`advanced_udaf.rs`] for a full example with complete implementation and +/// [`AggregateUDF`] for other available options. +/// +/// +/// [`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs +/// # Basic Example +/// ``` +/// # use std::any::Any; +/// # use arrow::datatypes::DataType; +/// # use datafusion_common::{DataFusionError, plan_err, Result}; +/// # use datafusion_expr::{col, ColumnarValue, Signature, Volatility}; +/// # use datafusion_expr::{AggregateUDFImpl, AggregateUDF, Accumulator}; +/// struct GeoMeanUdf { +/// signature: Signature +/// }; +/// +/// impl GeoMeanUdf { +/// fn new() -> Self { +/// Self { +/// signature: Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable) +/// } +/// } +/// } +/// +/// /// Implement the AggregateUDFImpl trait for GeoMeanUdf +/// impl AggregateUDFImpl for GeoMeanUdf { +/// fn as_any(&self) -> &dyn Any { self } +/// fn name(&self) -> &str { "geo_mean" } +/// fn signature(&self) -> &Signature { &self.signature } +/// fn return_type(&self, args: &[DataType]) -> Result { +/// if !matches!(args.get(0), Some(&DataType::Float64)) { +/// return plan_err!("add_one only accepts Float64 arguments"); +/// } +/// Ok(DataType::Float64) +/// } +/// // This is the accumulator factory; DataFusion uses it to create new accumulators. +/// fn accumulator(&self, _arg: &DataType) -> Result> { unimplemented!() } +/// fn state_type(&self, _return_type: &DataType) -> Result> { +/// Ok(vec![DataType::Float64, DataType::UInt32]) +/// } +/// } +/// +/// // Create a new AggregateUDF from the implementation +/// let geometric_mean = AggregateUDF::from(GeoMeanUdf::new()); +/// +/// // Call the function `geo_mean(col)` +/// let expr = geometric_mean.call(vec![col("a")]); +/// ``` +pub trait AggregateUDFImpl { + /// Returns this object as an [`Any`] trait object + fn as_any(&self) -> &dyn Any; + + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns the function's [`Signature`] for information about what input + /// types are accepted and the function's Volatility. + fn signature(&self) -> &Signature; + + /// What [`DataType`] will be returned by this function, given the types of + /// the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; + + /// This is the accumulator factory [`AccumulatorFactoryFunction`]; + /// DataFusion uses it to create new accumulators. + fn accumulator(&self, arg: &DataType) -> Result>; + + /// This is the description of the state. + /// accumulator's state() must match the types here. + fn state_type(&self, return_type: &DataType) -> Result>; +} diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4d54dad996703..c025ae7aa8faf 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -751,13 +751,13 @@ mod test { use datafusion_expr::{ cast, col, concat, concat_ws, create_udaf, is_true, AccumulatorFactoryFunction, AggregateFunction, AggregateUDF, BinaryExpr, BuiltinScalarFunction, Case, - ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, StateTypeFunction, - Subquery, + ColumnarValue, ExprSchemable, Filter, Operator, ScalarUDFImpl, + SimpleAggregateUDF, Subquery, }; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarUDF, Signature, Volatility, + Expr, LogicalPlan, ScalarUDF, Signature, Volatility, }; use datafusion_physical_expr::expressions::AvgAccumulator; @@ -902,19 +902,17 @@ mod test { #[test] fn agg_udaf_invalid_input() -> Result<()> { let empty = empty(); - let return_type: ReturnTypeFunction = - Arc::new(move |_| Ok(Arc::new(DataType::Float64))); - let state_type: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); + let return_type = DataType::Float64; + let state_type = vec![DataType::UInt64, DataType::Float64]; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::::default())); - let my_avg = AggregateUDF::new( - "MY_AVG", - &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), - &return_type, - &accumulator, - &state_type, - ); + let my_avg = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( + "MY_AVG".to_string(), + Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let udaf = Expr::AggregateFunction(expr::AggregateFunction::new_udf( Arc::new(my_avg), vec![lit("10")], diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61ad..0b5a201c7dff9 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -780,8 +780,7 @@ mod test { avg, col, lit, logical_plan::builder::LogicalPlanBuilder, sum, }; use datafusion_expr::{ - grouping_set, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, - Signature, StateTypeFunction, Volatility, + grouping_set, AggregateUDF, AggregateUDFImpl, Signature, Volatility, }; use crate::optimizer::OptimizerContext; @@ -899,23 +898,56 @@ mod test { #[test] fn aggregate() -> Result<()> { + struct InnerAggregateUDF { + signature: Signature, + } + + impl InnerAggregateUDF { + fn new() -> Self { + Self { + signature: Signature::exact( + vec![DataType::UInt32], + Volatility::Stable, + ), + } + } + } + + impl AggregateUDFImpl for InnerAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "my_agg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + assert_eq!(arg_types, &[DataType::UInt32]); + Ok(DataType::UInt32) + } + + fn accumulator( + &self, + _arg: &DataType, + ) -> Result> { + unimplemented!() + } + + fn state_type(&self, _return_type: &DataType) -> Result> { + unimplemented!() + } + } + let table_scan = test_table_scan()?; - let return_type: ReturnTypeFunction = Arc::new(|inputs| { - assert_eq!(inputs, &[DataType::UInt32]); - Ok(Arc::new(DataType::UInt32)) - }); - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| unimplemented!()); - let state_type: StateTypeFunction = Arc::new(|_| unimplemented!()); let udf_agg = |inner: Expr| { Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - Arc::new(AggregateUDF::new( - "my_agg", - &Signature::exact(vec![DataType::UInt32], Volatility::Stable), - &return_type, - &accumulator, - &state_type, - )), + Arc::new(AggregateUDF::from(InnerAggregateUDF::new())), vec![inner], false, None, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 27ac5d122f83f..6502f34f67f42 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -73,8 +73,8 @@ use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{FileTypeWriterOptions, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ReturnTypeFunction, Signature, - StateTypeFunction, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, }; use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; @@ -374,18 +374,24 @@ fn roundtrip_aggregate_udaf() -> Result<()> { } } - let rt_func: ReturnTypeFunction = Arc::new(move |_| Ok(Arc::new(DataType::Int64))); + let return_type = DataType::Int64; let accumulator: AccumulatorFactoryFunction = Arc::new(|_| Ok(Box::new(Example))); - let st_func: StateTypeFunction = - Arc::new(move |_| Ok(Arc::new(vec![DataType::Int64]))); - - let udaf = AggregateUDF::new( - "example", - &Signature::exact(vec![DataType::Int64], Volatility::Immutable), - &rt_func, - &accumulator, - &st_func, - ); + let state_type = vec![DataType::Int64]; + + // let udaf = AggregateUDF::new( + // "example", + // &Signature::exact(vec![DataType::Int64], Volatility::Immutable), + // &rt_func, + // &accumulator, + // &st_func, + // ); + let udaf = AggregateUDF::from(SimpleAggregateUDF::new_with_signature( + "example".to_string(), + Signature::exact(vec![DataType::Int64], Volatility::Immutable), + return_type, + accumulator, + state_type, + )); let ctx = SessionContext::new(); ctx.register_udaf(udaf.clone()); diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index 1f687f978f30e..baa70f8fe5562 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -398,7 +398,8 @@ impl Accumulator for GeometricMean { ### registering an Aggregate UDF -To register a Aggreate UDF, you need to wrap the function implementation in a `AggregateUDF` struct and then register it with the `SessionContext`. DataFusion provides the `create_udaf` helper functions to make this easier. +To register a Aggreate UDF, you need to wrap the function implementation in a [`AggregateUDF`] struct and then register it with the `SessionContext`. DataFusion provides the [`create_udaf`] helper functions to make this easier. +There is a lower level API with more functionality but is more complex, that is documented in [`advanced_udaf.rs`]. ```rust use datafusion::logical_expr::{Volatility, create_udaf}; @@ -420,6 +421,9 @@ let geometric_mean = create_udaf( Arc::new(vec![DataType::Float64, DataType::UInt32]), ); ``` +[`aggregateudf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/struct.AggregateUDF.html +[`create_udaf`]: https://docs.rs/datafusion/latest/datafusion/logical_expr/fn.create_udaf.html +[`advanced_udaf.rs`]: https://github.com/apache/arrow-datafusion/blob/main/datafusion-examples/examples/advanced_udaf.rs The `create_udaf` has six arguments to check: