diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 4e35c0a690590..e7e58bf84362e 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -20,7 +20,6 @@ use arrow::datatypes::{ DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, }; -use std::ops::Deref; use datafusion_common::{internal_err, plan_err, Result}; @@ -142,22 +141,6 @@ pub fn check_arg_count( Ok(()) } -pub fn get_min_max_result_type(input_types: &[DataType]) -> Result> { - // make sure that the input types only has one element. - assert_eq!(input_types.len(), 1); - // min and max support the dictionary data type - // unpack the dictionary to get the value - match &input_types[0] { - DataType::Dictionary(_, dict_value_type) => { - // TODO add checker, if the value type is complex data type - Ok(vec![dict_value_type.deref().clone()]) - } - // TODO add checker for datatype which min and max supported - // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function - _ => Ok(input_types.to_vec()), - } -} - /// function return type of a sum pub fn sum_return_type(arg_type: &DataType) -> Result { match arg_type { @@ -327,14 +310,6 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result Result<()> { - let data_type = - DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); - let result = get_min_max_result_type(&[data_type])?; - assert_eq!(result, vec![DataType::Int32]); - Ok(()) - } #[test] fn test_variance_return_data_type() -> Result<()> { diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 0fac608b44de0..b54cd181a0cbf 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -112,9 +112,7 @@ pub mod expr_fn { pub use super::grouping::grouping; pub use super::median::median; pub use super::min_max::max; - pub use super::min_max::max_distinct; pub use super::min_max::min; - pub use super::min_max::min_distinct; pub use super::regr::regr_avgx; pub use super::regr::regr_avgy; pub use super::regr::regr_count; diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 6048458ba06ea..4d743983411dc 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -2,7 +2,6 @@ // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // @@ -34,13 +33,14 @@ // under the License. use arrow::array::{ - ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, Decimal128Array, - Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, IntervalYearMonthArray, - LargeBinaryArray, LargeStringArray, StringArray, Time32MillisecondArray, - Time32SecondArray, Time64MicrosecondArray, Time64NanosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + ArrayRef, BinaryArray, BinaryViewArray, BooleanArray, Date32Array, Date64Array, + Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array, Int32Array, + Int64Array, Int8Array, IntervalDayTimeArray, IntervalMonthDayNanoArray, + IntervalYearMonthArray, LargeBinaryArray, LargeStringArray, StringArray, + StringViewArray, Time32MillisecondArray, Time32SecondArray, Time64MicrosecondArray, + Time64NanosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow::compute; use arrow::datatypes::{ @@ -60,84 +60,25 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; +use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; -use datafusion_expr::{type_coercion, Expr, GroupsAccumulator}; - -pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; - -pub static SIGNED_INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, -]; - -pub static UNSIGNED_INTEGERS: &[DataType] = &[ - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - -pub static INTEGERS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, -]; - -pub static NUMERICS: &[DataType] = &[ - DataType::Int8, - DataType::Int16, - DataType::Int32, - DataType::Int64, - DataType::UInt8, - DataType::UInt16, - DataType::UInt32, - DataType::UInt64, - DataType::Float32, - DataType::Float64, -]; - -pub static TIMESTAMPS: &[DataType] = &[ - DataType::Timestamp(TimeUnit::Second, None), - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Timestamp(TimeUnit::Microsecond, None), - DataType::Timestamp(TimeUnit::Nanosecond, None), -]; - -pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; - -pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; - -pub static TIMES: &[DataType] = &[ - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Nanosecond), -]; - -pub static TIMES_INTERVALS: &[DataType] = &[ - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::MonthDayNano), -]; - -// Min/max aggregation can take Dictionary encode input but always produces unpacked -// (aka non Dictionary) output. We need to adjust the output data type to reflect this. -// The reason min/max aggregate produces unpacked output because there is only one -// min/max value per group; there is no needs to keep them Dictionary encode -fn min_max_aggregate_data_type(input_type: &DataType) -> &DataType { - if let DataType::Dictionary(_, value_type) = input_type { - value_type - } else { - input_type +use std::ops::Deref; + +fn get_min_max_result_type(input_types: &[DataType]) -> Result> { + // make sure that the input types only has one element. + assert_eq!(input_types.len(), 1); + // min and max support the dictionary data type + // unpack the dictionary to get the value + match &input_types[0] { + DataType::Dictionary(_, dict_value_type) => { + // TODO add checker, if the value type is complex data type + Ok(vec![dict_value_type.deref().clone()]) + } + // TODO add checker for datatype which min and max supported + // For example, the `Struct` and `Map` type are not supported in the MIN and MAX function + _ => Ok(input_types.to_vec()), } } @@ -213,19 +154,11 @@ impl AggregateUDFImpl for Max { } fn return_type(&self, arg_types: &[DataType]) -> Result { - type_coercion::aggregates::get_min_max_result_type(arg_types)? - .into_iter() - .next() - .ok_or_else(|| { - DataFusionError::Internal( - "Expected at one input type for MAX aggregate function".to_string(), - ) - }) + Ok(arg_types[0].to_owned()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let data_type = &min_max_aggregate_data_type(acc_args.data_type); - Ok(Box::new(MaxAccumulator::try_new(data_type)?)) + Ok(Box::new(MaxAccumulator::try_new(acc_args.data_type)?)) } fn aliases(&self) -> &[String] { @@ -234,9 +167,8 @@ impl AggregateUDFImpl for Max { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; - let data_type = min_max_aggregate_data_type(args.data_type); matches!( - data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -262,7 +194,7 @@ impl AggregateUDFImpl for Max { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = min_max_aggregate_data_type(args.data_type); + let data_type = args.data_type; match data_type { Int8 => instantiate_max_accumulator!(data_type, i8, Int8Type), Int16 => instantiate_max_accumulator!(data_type, i16, Int16Type), @@ -323,8 +255,7 @@ impl AggregateUDFImpl for Max { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = min_max_aggregate_data_type(args.data_type); - Ok(Box::new(SlidingMaxAccumulator::try_new(data_type)?)) + Ok(Box::new(SlidingMaxAccumulator::try_new(args.data_type)?)) } fn is_descending(&self) -> Option { @@ -335,7 +266,7 @@ impl AggregateUDFImpl for Max { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - type_coercion::aggregates::get_min_max_result_type(arg_types) + get_min_max_result_type(arg_types) } fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical @@ -508,6 +439,14 @@ fn min_batch(values: &ArrayRef) -> Result { DataType::LargeUtf8 => { typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, min_string) } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + min_string_view + ) + } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, min_boolean) } @@ -522,6 +461,14 @@ fn min_batch(values: &ArrayRef) -> Result { min_binary ) } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + min_binary_view + ) + } _ => min_max_batch!(values, min), }) } @@ -535,12 +482,28 @@ fn max_batch(values: &ArrayRef) -> Result { DataType::LargeUtf8 => { typed_min_max_batch_string!(values, LargeStringArray, LargeUtf8, max_string) } + DataType::Utf8View => { + typed_min_max_batch_string!( + values, + StringViewArray, + Utf8View, + max_string_view + ) + } DataType::Boolean => { typed_min_max_batch!(values, BooleanArray, Boolean, max_boolean) } DataType::Binary => { typed_min_max_batch_binary!(&values, BinaryArray, Binary, max_binary) } + DataType::BinaryView => { + typed_min_max_batch_binary!( + &values, + BinaryViewArray, + BinaryView, + max_binary_view + ) + } DataType::LargeBinary => { typed_min_max_batch_binary!( &values, @@ -683,12 +646,18 @@ macro_rules! min_max { (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + typed_min_max_string!(lhs, rhs, Utf8View, $OP) + } (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { typed_min_max_string!(lhs, rhs, Binary, $OP) } (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { typed_min_max_string!(lhs, rhs, LargeBinary, $OP) } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + typed_min_max_string!(lhs, rhs, BinaryView, $OP) + } (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) } @@ -835,10 +804,6 @@ impl MaxAccumulator { } impl Accumulator for MaxAccumulator { - fn state(&mut self) -> Result> { - Ok(vec![self.evaluate()?]) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &max_batch(values)?; @@ -852,6 +817,9 @@ impl Accumulator for MaxAccumulator { self.update_batch(states) } + fn state(&mut self) -> Result> { + Ok(vec![self.evaluate()?]) + } fn evaluate(&mut self) -> Result { Ok(self.max.clone()) } @@ -955,19 +923,11 @@ impl AggregateUDFImpl for Min { } fn return_type(&self, arg_types: &[DataType]) -> Result { - type_coercion::aggregates::get_min_max_result_type(arg_types)? - .into_iter() - .next() - .ok_or_else(|| { - DataFusionError::Internal( - "Expected at one input type for MIN aggregate function".to_string(), - ) - }) + Ok(arg_types[0].to_owned()) } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { - let data_type = min_max_aggregate_data_type(acc_args.data_type); - Ok(Box::new(MinAccumulator::try_new(data_type)?)) + Ok(Box::new(MinAccumulator::try_new(acc_args.data_type)?)) } fn aliases(&self) -> &[String] { @@ -976,9 +936,8 @@ impl AggregateUDFImpl for Min { fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { use DataType::*; - let data_type = min_max_aggregate_data_type(args.data_type); matches!( - data_type, + args.data_type, Int8 | Int16 | Int32 | Int64 @@ -1004,7 +963,7 @@ impl AggregateUDFImpl for Min { ) -> Result> { use DataType::*; use TimeUnit::*; - let data_type = min_max_aggregate_data_type(args.data_type); + let data_type = args.data_type; match data_type { Int8 => instantiate_min_accumulator!(data_type, i8, Int8Type), Int16 => instantiate_min_accumulator!(data_type, i16, Int16Type), @@ -1065,8 +1024,7 @@ impl AggregateUDFImpl for Min { &self, args: AccumulatorArgs, ) -> Result> { - let data_type = min_max_aggregate_data_type(args.data_type); - Ok(Box::new(SlidingMinAccumulator::try_new(data_type)?)) + Ok(Box::new(SlidingMinAccumulator::try_new(args.data_type)?)) } fn is_descending(&self) -> Option { @@ -1078,7 +1036,7 @@ impl AggregateUDFImpl for Min { } fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - type_coercion::aggregates::get_min_max_result_type(arg_types) + get_min_max_result_type(arg_types) } fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { @@ -1455,28 +1413,6 @@ make_udaf_expr_and_func!( min_udaf ); -pub fn max_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - max_udaf(), - vec![expr], - true, - None, - None, - None, - )) -} - -pub fn min_distinct(expr: Expr) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - min_udaf(), - vec![expr], - true, - None, - None, - None, - )) -} - #[cfg(test)] mod tests { use super::*; @@ -1703,4 +1639,13 @@ mod tests { } } } + + #[test] + fn test_get_min_max_return_type_coerce_dictionary() -> Result<()> { + let data_type = + DataType::Dictionary(Box::new(DataType::Utf8), Box::new(DataType::Int32)); + let result = get_min_max_result_type(&[data_type])?; + assert_eq!(result, vec![DataType::Int32]); + Ok(()) + } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 51a34b4861b04..67815d346f531 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -393,10 +393,8 @@ mod tests { use crate::test::*; use arrow::datatypes::DataType; - // TODO: stubs or real functions - use datafusion_expr::test::function_stub::sum; + use datafusion_expr::test::function_stub::{max, min, sum}; use datafusion_expr::{col, lit, out_ref_col, scalar_subquery, Between}; - use datafusion_functions_aggregate::expr_fn::{max, min}; /// Test multiple correlated subqueries #[test] diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index 26b8014500fc8..86a4bc9c3d26d 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -291,11 +291,33 @@ mod tests { use datafusion_expr::ExprFunctionExt; use datafusion_expr::{lit, logical_plan::builder::LogicalPlanBuilder}; use datafusion_functions_aggregate::count::count_udaf; - use datafusion_functions_aggregate::expr_fn::{ - count, count_distinct, max, max_distinct, min, sum, - }; + use datafusion_functions_aggregate::expr_fn::{count, count_distinct, max, min, sum}; + use datafusion_functions_aggregate::min_max::max_udaf; + use datafusion_functions_aggregate::min_max::min_udaf; use datafusion_functions_aggregate::sum::sum_udaf; + fn max_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + max_udaf(), + vec![expr], + true, + None, + None, + None, + )) + } + + fn min_distinct(expr: Expr) -> Expr { + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + min_udaf(), + vec![expr], + true, + None, + None, + None, + )) + } + fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq_display_indent( Arc::new(SingleDistinctToGroupBy::new()), diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a4334fce52cdc..24a9bf2ad2ab5 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -910,17 +910,10 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { )), input, )?); - let aggr_expr = create_aggregate_expr( - &max_udaf(), - &[udf_expr.clone()], - &[], - &[], - &[], - &schema, - "max", - false, - false, - )?; + let aggr_expr = AggregateExprBuilder::new(&max_udaf(), &[udf_expr.clone()]) + .schema(schema) + .name("max") + .build()?; let window = Arc::new(WindowAggExec::try_new( vec![Arc::new(PlainAggregateWindowExpr::new(