diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index cea72c3cb5e6..9db7635d99a0 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -386,18 +386,23 @@ impl AggregateFunction { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::ApproxPercentileCont => { + let mut variants = + Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1)); // Accept any numeric value paired with a float64 percentile - let with_tdigest_size = NUMERICS.iter().map(|t| { - TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()]) - }); - Signature::one_of( - NUMERICS - .iter() - .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) - .chain(with_tdigest_size) - .collect(), - Volatility::Immutable, - ) + for num in NUMERICS { + variants + .push(TypeSignature::Exact(vec![num.clone(), DataType::Float64])); + // Additionally accept an integer number of centroids for T-Digest + for int in INTEGERS { + variants.push(TypeSignature::Exact(vec![ + num.clone(), + DataType::Float64, + int.clone(), + ])) + } + } + + Signature::one_of(variants, Volatility::Immutable) } AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of( // Accept any numeric value paired with a float64 percentile diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 7128b575978a..56bb5c9b69c4 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -223,7 +223,7 @@ pub fn coerce_types( | AggregateFunction::RegrSXX | AggregateFunction::RegrSYY | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat(); + let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); let input_types_valid = // number of input already checked before valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); if !input_types_valid { @@ -243,15 +243,15 @@ pub fn coerce_types( input_types[0] ); } - if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) { + if input_types.len() == 3 && !input_types[2].is_integer() { return plan_err!( "The percentile sample points count for {:?} must be integer, not {:?}.", agg_fun, input_types[2] ); } let mut result = input_types.to_vec(); - if can_coerce_from(&DataType::Float64, &input_types[1]) { - result[1] = DataType::Float64; + if can_coerce_from(&Float64, &input_types[1]) { + result[1] = Float64; } else { return plan_err!( "Could not coerce the percent argument for {:?} to Float64. Was {:?}.", @@ -275,7 +275,7 @@ pub fn coerce_types( input_types[1] ); } - if !matches!(input_types[2], DataType::Float64) { + if !matches!(input_types[2], Float64) { return plan_err!( "The percentile argument for {:?} must be Float64, not {:?}.", agg_fun, @@ -560,17 +560,7 @@ pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { } pub fn is_integer_arg_type(arg_type: &DataType) -> bool { - matches!( - arg_type, - DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - ) + arg_type.is_integer() } /// Return `true` if `arg_type` is of a [`DataType`] that the diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 3821279fed0f..8c4e907e6734 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -963,7 +963,7 @@ mod test { } #[test] - fn agg_function_invalid_input() -> Result<()> { + fn agg_function_invalid_input_avg() -> Result<()> { let empty = empty(); let fun: AggregateFunction = AggregateFunction::Avg; let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( @@ -984,6 +984,30 @@ mod test { Ok(()) } + #[test] + fn agg_function_invalid_input_percentile() { + let empty = empty(); + let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont; + let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new( + fun, + vec![lit(0.95), lit(42.0), lit(100.0)], + false, + None, + None, + )); + + let err = Projection::try_new(vec![agg_expr], empty) + .err() + .unwrap() + .strip_backtrace(); + + let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:"; + assert!(!err + .strip_prefix(prefix) + .unwrap() + .contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)")); + } + #[test] fn binary_op_date32_op_interval() -> Result<()> { //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index aa512f6e2600..50cdebd054a7 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -95,6 +95,9 @@ SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100 statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\. SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100 +statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\. +SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100 + # array agg can use order by query ? SELECT array_agg(c13 ORDER BY c13)