Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ApproxPercentileCont signature #8825

Merged
merged 2 commits into from
Jan 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,23 @@ impl AggregateFunction {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
let mut variants =
Vec::with_capacity(NUMERICS.len() * (INTEGERS.len() + 1));
// Accept any numeric value paired with a float64 percentile
let with_tdigest_size = NUMERICS.iter().map(|t| {
TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()])
});
Signature::one_of(
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.chain(with_tdigest_size)
.collect(),
Volatility::Immutable,
)
for num in NUMERICS {
variants
.push(TypeSignature::Exact(vec![num.clone(), DataType::Float64]));
// Additionally accept an integer number of centroids for T-Digest
for int in INTEGERS {
variants.push(TypeSignature::Exact(vec![
num.clone(),
DataType::Float64,
int.clone(),
]))
}
}

Signature::one_of(variants, Volatility::Immutable)
}
AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
// Accept any numeric value paired with a float64 percentile
Expand Down
22 changes: 6 additions & 16 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ pub fn coerce_types(
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
let valid_types = [NUMERICS.to_vec(), vec![DataType::Null]].concat();
let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't do this, but it is unfortunate there are special coercion rules for AggregateFunctions that are not handled by the more general purpose Function cocercion rules.

let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
if !input_types_valid {
Expand All @@ -243,15 +243,15 @@ pub fn coerce_types(
input_types[0]
);
}
if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) {
if input_types.len() == 3 && !input_types[2].is_integer() {
return plan_err!(
"The percentile sample points count for {:?} must be integer, not {:?}.",
agg_fun, input_types[2]
);
}
let mut result = input_types.to_vec();
if can_coerce_from(&DataType::Float64, &input_types[1]) {
result[1] = DataType::Float64;
if can_coerce_from(&Float64, &input_types[1]) {
result[1] = Float64;
} else {
return plan_err!(
"Could not coerce the percent argument for {:?} to Float64. Was {:?}.",
Expand All @@ -275,7 +275,7 @@ pub fn coerce_types(
input_types[1]
);
}
if !matches!(input_types[2], DataType::Float64) {
if !matches!(input_types[2], Float64) {
return plan_err!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun,
Expand Down Expand Up @@ -560,17 +560,7 @@ pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
}

pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
)
arg_type.is_integer()
}

/// Return `true` if `arg_type` is of a [`DataType`] that the
Expand Down
26 changes: 25 additions & 1 deletion datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ mod test {
}

#[test]
fn agg_function_invalid_input() -> Result<()> {
fn agg_function_invalid_input_avg() -> Result<()> {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::Avg;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
Expand All @@ -984,6 +984,30 @@ mod test {
Ok(())
}

#[test]
fn agg_function_invalid_input_percentile() {
let empty = empty();
let fun: AggregateFunction = AggregateFunction::ApproxPercentileCont;
let agg_expr = Expr::AggregateFunction(expr::AggregateFunction::new(
fun,
vec![lit(0.95), lit(42.0), lit(100.0)],
false,
None,
None,
));

let err = Projection::try_new(vec![agg_expr], empty)
.err()
.unwrap()
.strip_backtrace();

let prefix = "Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT(Float64, Float64, Float64)'. You might need to add explicit type casts.\n\tCandidate functions:";
assert!(!err
.strip_prefix(prefix)
.unwrap()
.contains("APPROX_PERCENTILE_CONT(Float64, Float64, Float64)"));
}

#[test]
fn binary_op_date32_op_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
Expand Down
3 changes: 3 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100
statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Int16, Float64, Float64\)'\. You might need to add explicit type casts\.
SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100

statement error DataFusion error: Error during planning: No function matches the given name and argument types 'APPROX_PERCENTILE_CONT\(Float64, Float64, Float64\)'\. You might need to add explicit type casts\.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't this test pass on main too? Maybe the problem is that the sqllogictest framework doesn't display the hints 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the issue. Perhaps I should write a unit test instead? 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could that would be great (otherwise I worry we we may break this again during a refactor or something and not realize it)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a unit test 👍

SELECT approx_percentile_cont(c12, 0.95, 111.1) FROM aggregate_test_100

# array agg can use order by
query ?
SELECT array_agg(c13 ORDER BY c13)
Expand Down