diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c0a0134fedfb1..0f24101447364 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -2262,12 +2262,7 @@ mod tests { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index 86ff6ebac2282..5aa3ab3bc0a81 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -237,12 +237,7 @@ async fn simple_udaf() -> Result<()> { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index e9d155d5d2f1f..747ce4ade86c2 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -893,12 +893,7 @@ mod test { vec![DataType::Float64], Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }), + Arc::new(|_| Ok(Box::::default())), Arc::new(vec![DataType::UInt64, DataType::Float64]), ); let udaf = Expr::AggregateUDF(expr::AggregateUDF::new( @@ -919,12 +914,8 @@ mod test { Arc::new(move |_| Ok(Arc::new(DataType::Float64))); let state_type: StateTypeFunction = Arc::new(move |_| Ok(Arc::new(vec![DataType::UInt64, DataType::Float64]))); - let accumulator: AccumulatorFactoryFunction = Arc::new(|_| { - Ok(Box::new(AvgAccumulator::try_new( - &DataType::Float64, - &DataType::Float64, - )?)) - }); + let accumulator: AccumulatorFactoryFunction = + Arc::new(|_| Ok(Box::::default())); let my_avg = AggregateUDF::new( "MY_AVG", &Signature::uniform(1, vec![DataType::Float64], Volatility::Immutable), diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index 932081c4eb0f4..efc7c3d4134fc 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -21,17 +21,13 @@ use arrow::array::{AsArray, PrimitiveBuilder}; use log::debug; use std::any::Any; -use std::convert::TryFrom; use std::sync::Arc; use crate::aggregate::groups_accumulator::accumulate::NullState; -use crate::aggregate::sum; -use crate::aggregate::sum::sum_batch; -use crate::aggregate::utils::calculate_result_decimal_for_avg; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; use crate::{AggregateExpr, GroupsAccumulator, PhysicalExpr}; -use arrow::compute; +use arrow::compute::{cast, sum}; use arrow::datatypes::{DataType, Decimal128Type, Float64Type, UInt64Type}; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -40,9 +36,7 @@ use arrow::{ use arrow_array::{ Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray, }; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::Accumulator; use super::groups_accumulator::EmitTo; @@ -106,11 +100,29 @@ impl AggregateExpr for Avg { } fn create_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - // avg is f64 or decimal - &self.sum_data_type, - &self.rt_data_type, - )?)) + use DataType::*; + // instantiate specialized accumulator based for the type + match (&self.sum_data_type, &self.rt_data_type) { + (Float64, Float64) => { + Ok(Box::new(AvgAccumulator::new(self.pre_cast_to_sum_type))) + } + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalAvgAccumulator { + sum: None, + count: 0, + sum_scale: *sum_scale, + sum_precision: *sum_precision, + target_precision: *target_precision, + target_scale: *target_scale, + })), + _ => not_impl_err!( + "AvgGroupsAccumulator for ({} --> {})", + self.sum_data_type, + self.rt_data_type + ), + } } fn state_fields(&self) -> Result> { @@ -141,10 +153,7 @@ impl AggregateExpr for Avg { } fn create_sliding_accumulator(&self) -> Result> { - Ok(Box::new(AvgAccumulator::try_new( - &self.sum_data_type, - &self.rt_data_type, - )?)) + self.create_accumulator() } fn groups_accumulator_supported(&self) -> bool { @@ -208,97 +217,164 @@ impl PartialEq for Avg { } /// An accumulator to compute the average -#[derive(Debug)] +#[derive(Debug, Default)] pub struct AvgAccumulator { - // sum is used for null - sum: ScalarValue, - sum_data_type: DataType, - return_data_type: DataType, + sum: Option, count: u64, + cast_input: bool, } impl AvgAccumulator { - /// Creates a new `AvgAccumulator` - pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> Result { - Ok(Self { - sum: ScalarValue::try_from(datatype)?, - sum_data_type: datatype.clone(), - return_data_type: return_data_type.clone(), - count: 0, + /// Create a new [`AvgAccumulator`] + /// + /// If `cast_input` is `true` this will automatically cast input to `f64` + pub fn new(cast_input: bool) -> Self { + Self { + cast_input, + ..Default::default() + } + } + + fn cast_input(&self, input: &ArrayRef) -> Result { + Ok(match self.cast_input { + true => cast(input, &DataType::Float64)?, + false => input.clone(), }) } } impl Accumulator for AvgAccumulator { fn state(&self) -> Result> { - Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Float64(self.sum), + ]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = self.cast_input(&values[0])?; + let values = values.as_primitive::(); self.count += (values.len() - values.null_count()) as u64; - self.sum = self - .sum - .add(&sum::sum_batch(values, &self.sum_data_type)?)?; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(0.); + *v += x; + } Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; + let values = self.cast_input(&values[0])?; + let values = values.as_primitive::(); self.count -= (values.len() - values.null_count()) as u64; - let delta = sum_batch(values, &self.sum.get_datatype())?; - self.sum = self.sum.sub(&delta)?; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); + } Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = downcast_value!(states[0], UInt64Array); // counts are summed - self.count += compute::sum(counts).unwrap_or(0); + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); // sums are summed - self.sum = self - .sum - .add(&sum::sum_batch(&states[1], &self.sum_data_type)?)?; + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(0.); + *v = *v + x; + } Ok(()) } fn evaluate(&self) -> Result { - match self.sum { - ScalarValue::Float64(e) => { - Ok(ScalarValue::Float64(e.map(|f| f / self.count as f64))) - } - ScalarValue::Decimal128(value, _, scale) => { - match value { - None => match &self.return_data_type { - DataType::Decimal128(p, s) => { - Ok(ScalarValue::Decimal128(None, *p, *s)) - } - other => internal_err!( - "Error returned data type in AvgAccumulator {other:?}" - ), - }, - Some(value) => { - // now the sum_type and return type is not the same, need to convert the sum type to return type - calculate_result_decimal_for_avg( - value, - self.count as i128, - scale, - &self.return_data_type, - ) - } - } - } - _ => internal_err!("Sum should be f64 or decimal128 on average"), + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + fn supports_retract_batch(&self) -> bool { + true + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the average for decimals +#[derive(Debug)] +struct DecimalAvgAccumulator { + sum: Option, + count: u64, + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, +} + +impl Accumulator for DecimalAvgAccumulator { + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + + self.count += (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + let v = self.sum.get_or_insert(0); + *v += x; + } + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count -= (values.len() - values.null_count()) as u64; + if let Some(x) = sum(values) { + self.sum = Some(self.sum.unwrap() - x); } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[0].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert(0); + *v = *v + x; + } + Ok(()) + } + + fn evaluate(&self) -> Result { + let v = self + .sum + .map(|v| { + Decimal128Averager::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )? + .avg(v, self.count as _) + }) + .transpose()?; + + Ok(ScalarValue::Decimal128( + v, + self.target_precision, + self.target_scale, + )) } fn supports_retract_batch(&self) -> bool { true } fn size(&self) -> usize { - std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size() + std::mem::size_of_val(self) } } @@ -493,70 +569,77 @@ mod tests { use super::*; use crate::expressions::col; use crate::expressions::tests::aggregate; - use crate::generic_test_op; use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::*}; - use datafusion_common::Result; + use datafusion_expr::aggregate_function::sum_type_of_avg; + use datafusion_expr::type_coercion::aggregates::avg_return_type; + + fn test_with_pre_cast(array: ArrayRef, expected: ScalarValue) { + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), true)]); + let dt = array.data_type().clone(); + let sum_type = sum_type_of_avg(&[dt.clone()]).unwrap(); + let rt = avg_return_type(&dt).unwrap(); + let cast = sum_type != dt; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![array]).unwrap(); + + let agg = Arc::new(Avg::new_with_pre_cast( + col("a", &schema).unwrap(), + "bla".to_string(), + sum_type, + rt, + cast, + )); + let actual = aggregate(&batch, agg).unwrap(); + assert_eq!(expected, actual); + } #[test] - fn avg_decimal() -> Result<()> { + fn avg_decimal() { // test agg let array: ArrayRef = Arc::new( (1..7) .map(Some) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(Some(35000), 14, 4) - ) + test_with_pre_cast(array, ScalarValue::Decimal128(Some(35000), 14, 4)); } #[test] - fn avg_decimal_with_nulls() -> Result<()> { + fn avg_decimal_with_nulls() { let array: ArrayRef = Arc::new( (1..6) .map(|i| if i == 2 { None } else { Some(i) }) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(Some(32500), 14, 4) - ) + test_with_pre_cast(array, ScalarValue::Decimal128(Some(32500), 14, 4)); } #[test] - fn avg_decimal_all_nulls() -> Result<()> { + fn avg_decimal_all_nulls() { // test agg let array: ArrayRef = Arc::new( std::iter::repeat::>(None) .take(6) .collect::() - .with_precision_and_scale(10, 0)?, + .with_precision_and_scale(10, 0) + .unwrap(), ); - generic_test_op!( - array, - DataType::Decimal128(10, 0), - Avg, - ScalarValue::Decimal128(None, 14, 4) - ) + test_with_pre_cast(array, ScalarValue::Decimal128(None, 14, 4)); } #[test] - fn avg_i32() -> Result<()> { + fn avg_i32() { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3_f64)) + test_with_pre_cast(a, ScalarValue::from(3_f64)); } #[test] - fn avg_i32_with_nulls() -> Result<()> { + fn avg_i32_with_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ Some(1), None, @@ -564,33 +647,33 @@ mod tests { Some(4), Some(5), ])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::from(3.25f64)) + test_with_pre_cast(a, ScalarValue::from(3.25_f64)); } #[test] - fn avg_i32_all_nulls() -> Result<()> { + fn avg_i32_all_nulls() { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - generic_test_op!(a, DataType::Int32, Avg, ScalarValue::Float64(None)) + test_with_pre_cast(a, ScalarValue::Float64(None)); } #[test] - fn avg_u32() -> Result<()> { + fn avg_u32() { let a: ArrayRef = Arc::new(UInt32Array::from(vec![1_u32, 2_u32, 3_u32, 4_u32, 5_u32])); - generic_test_op!(a, DataType::UInt32, Avg, ScalarValue::from(3.0f64)) + test_with_pre_cast(a, ScalarValue::from(3.0f64)); } #[test] - fn avg_f32() -> Result<()> { + fn avg_f32() { let a: ArrayRef = Arc::new(Float32Array::from(vec![1_f32, 2_f32, 3_f32, 4_f32, 5_f32])); - generic_test_op!(a, DataType::Float32, Avg, ScalarValue::from(3_f64)) + test_with_pre_cast(a, ScalarValue::from(3.0f64)); } #[test] - fn avg_f64() -> Result<()> { + fn avg_f64() { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64, 4_f64, 5_f64])); - generic_test_op!(a, DataType::Float64, Avg, ScalarValue::from(3_f64)) + test_with_pre_cast(a, ScalarValue::from(3.0f64)); } } diff --git a/datafusion/physical-expr/src/aggregate/utils.rs b/datafusion/physical-expr/src/aggregate/utils.rs index e86eb1dc1fc51..463d8fec189c1 100644 --- a/datafusion/physical-expr/src/aggregate/utils.rs +++ b/datafusion/physical-expr/src/aggregate/utils.rs @@ -23,8 +23,7 @@ use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PREC use arrow_array::cast::AsArray; use arrow_array::types::Decimal128Type; use arrow_schema::{DataType, Field}; -use datafusion_common::internal_err; -use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use std::any::Any; use std::sync::Arc; @@ -118,34 +117,6 @@ impl Decimal128Averager { } } -/// Returns `sum`/`count` for decimal values, detecting and reporting overflow. -/// -/// * sum: stored as Decimal128 with `sum_scale` scale -/// * count: stored as a i128 (*NOT* a Decimal128 value) -/// * sum_scale: the scale of `sum` -/// * target_type: the output decimal type -pub fn calculate_result_decimal_for_avg( - sum: i128, - count: i128, - sum_scale: i8, - target_type: &DataType, -) -> Result { - match target_type { - DataType::Decimal128(target_precision, target_scale) => { - let new_value = - Decimal128Averager::try_new(sum_scale, *target_precision, *target_scale)? - .avg(sum, count)?; - - Ok(ScalarValue::Decimal128( - Some(new_value), - *target_precision, - *target_scale, - )) - } - other => internal_err!("Invalid target type in AvgAccumulator {other:?}"), - } -} - /// Adjust array type metadata if needed /// /// Since `Decimal128Arrays` created from `Vec` have