From 51ff8a2a1163e66c62a1534a06c52db2a04e30a6 Mon Sep 17 00:00:00 2001 From: Emil Ejbyfeldt Date: Thu, 13 Jun 2024 14:11:34 +0200 Subject: [PATCH 1/3] Move Regr_* functions to use UDAF Closes #10883 and is part of #8708 --- datafusion/expr/src/aggregate_function.rs | 56 +------- .../expr/src/type_coercion/aggregates.rs | 21 --- datafusion/functions-aggregate/src/lib.rs | 19 +++ datafusion/functions-aggregate/src/macros.rs | 14 +- .../src}/regr.rs | 127 +++++++++++------- .../physical-expr/src/aggregate/build_in.rs | 78 ----------- datafusion/physical-expr/src/aggregate/mod.rs | 1 - .../physical-expr/src/expressions/mod.rs | 1 - datafusion/proto/proto/datafusion.proto | 18 +-- datafusion/proto/src/generated/pbjson.rs | 18 --- datafusion/proto/src/generated/prost.rs | 36 ++--- .../proto/src/logical_plan/from_proto.rs | 9 -- datafusion/proto/src/logical_plan/to_proto.rs | 24 ---- .../proto/src/physical_plan/to_proto.rs | 14 +- datafusion/sqllogictest/test_files/errors.slt | 4 +- 15 files changed, 134 insertions(+), 306 deletions(-) rename datafusion/{physical-expr/src/aggregate => functions-aggregate/src}/regr.rs (84%) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index e3d2e6555d5c..b476ca5e0c43 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -47,24 +47,6 @@ pub enum AggregateFunction { NthValue, /// Correlation Correlation, - /// Slope from linear regression - RegrSlope, - /// Intercept from linear regression - RegrIntercept, - /// Number of input rows in which both expressions are not null - RegrCount, - /// R-squared value from linear regression - RegrR2, - /// Average of the independent variable - RegrAvgx, - /// Average of the dependent variable - RegrAvgy, - /// Sum of squares of the independent variable - RegrSXX, - /// Sum of squares of the dependent variable - RegrSYY, - /// Sum of products of pairs of numbers - RegrSXY, /// Approximate continuous percentile function ApproxPercentileCont, /// Approximate continuous percentile function with weight @@ -96,15 +78,6 @@ impl AggregateFunction { ArrayAgg => "ARRAY_AGG", NthValue => "NTH_VALUE", Correlation => "CORR", - RegrSlope => "REGR_SLOPE", - RegrIntercept => "REGR_INTERCEPT", - RegrCount => "REGR_COUNT", - RegrR2 => "REGR_R2", - RegrAvgx => "REGR_AVGX", - RegrAvgy => "REGR_AVGY", - RegrSXX => "REGR_SXX", - RegrSYY => "REGR_SYY", - RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Grouping => "GROUPING", @@ -144,15 +117,6 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "regr_slope" => AggregateFunction::RegrSlope, - "regr_intercept" => AggregateFunction::RegrIntercept, - "regr_count" => AggregateFunction::RegrCount, - "regr_r2" => AggregateFunction::RegrR2, - "regr_avgx" => AggregateFunction::RegrAvgx, - "regr_avgy" => AggregateFunction::RegrAvgy, - "regr_sxx" => AggregateFunction::RegrSXX, - "regr_syy" => AggregateFunction::RegrSYY, - "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, "approx_percentile_cont_with_weight" => { @@ -205,15 +169,6 @@ impl AggregateFunction { AggregateFunction::Correlation => { correlation_return_type(&coerced_data_types[0]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => Ok(DataType::Float64), AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new( "item", @@ -278,16 +233,7 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Correlation - | AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { + AggregateFunction::Correlation => { Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::ApproxPercentileCont => { diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index ab7deaff9885..d845f50b6fac 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -159,27 +159,6 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY => { - let valid_types = [NUMERICS.to_vec(), vec![Null]].concat(); - let input_types_valid = // number of input already checked before - valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]); - if !input_types_valid { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(vec![Float64, Float64]) - } AggregateFunction::ApproxPercentileCont => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return plan_err!( diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 56fc1305bb59..fabe15e416f4 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -61,6 +61,7 @@ pub mod covariance; pub mod first_last; pub mod hyperloglog; pub mod median; +pub mod regr; pub mod stddev; pub mod sum; pub mod variance; @@ -85,6 +86,15 @@ pub mod expr_fn { pub use super::first_last::first_value; pub use super::first_last::last_value; pub use super::median::median; + pub use super::regr::regr_avgx; + pub use super::regr::regr_avgy; + pub use super::regr::regr_count; + pub use super::regr::regr_intercept; + pub use super::regr::regr_r2; + pub use super::regr::regr_slope; + pub use super::regr::regr_sxx; + pub use super::regr::regr_sxy; + pub use super::regr::regr_syy; pub use super::stddev::stddev; pub use super::stddev::stddev_pop; pub use super::sum::sum; @@ -102,6 +112,15 @@ pub fn all_default_aggregate_functions() -> Vec> { covariance::covar_pop_udaf(), median::median_udaf(), count::count_udaf(), + regr::regr_slope_udaf(), + regr::regr_intercept_udaf(), + regr::regr_count_udaf(), + regr::regr_r2_udaf(), + regr::regr_avgx_udaf(), + regr::regr_avgy_udaf(), + regr::regr_sxx_udaf(), + regr::regr_syy_udaf(), + regr::regr_sxy_udaf(), variance::var_samp_udaf(), variance::var_pop_udaf(), stddev::stddev_udaf(), diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 75bb9dc54719..cae72cf35223 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -32,8 +32,8 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_expr_and_func { - ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { +macro_rules! make_udaf_expr { + ($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { // "fluent expr_fn" style function #[doc = $DOC] pub fn $EXPR_FN( @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func { None, )) } + }; +} +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN); create_func!($UDAF, $AGGREGATE_UDF_FN); }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func { macro_rules! create_func { ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default()); + }; + ($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => { paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` @@ -86,7 +94,7 @@ macro_rules! create_func { pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc { [< STATIC_ $UDAF >] .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default())) + std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE)) }) .clone() } diff --git a/datafusion/physical-expr/src/aggregate/regr.rs b/datafusion/functions-aggregate/src/regr.rs similarity index 84% rename from datafusion/physical-expr/src/aggregate/regr.rs rename to datafusion/functions-aggregate/src/regr.rs index 36e7b7c9b3e4..8d04ae87157d 100644 --- a/datafusion/physical-expr/src/aggregate/regr.rs +++ b/datafusion/functions-aggregate/src/regr.rs @@ -18,9 +18,8 @@ //! Defines physical expressions that can evaluated at runtime during query execution use std::any::Any; -use std::sync::Arc; +use std::fmt::Debug; -use crate::{AggregateExpr, PhysicalExpr}; use arrow::array::Float64Array; use arrow::{ array::{ArrayRef, UInt64Array}, @@ -28,13 +27,56 @@ use arrow::{ datatypes::DataType, datatypes::Field, }; -use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue}; +use datafusion_common::{downcast_value, plan_err, unwrap_or_internal_err, ScalarValue}; use datafusion_common::{DataFusionError, Result}; -use datafusion_expr::Accumulator; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; + +macro_rules! make_regr_udaf_expr_and_func { + ($EXPR_FN:ident, $AGGREGATE_UDF_FN:ident, $REGR_TYPE:expr) => { + make_udaf_expr!($EXPR_FN, expr_y expr_x, concat!("Compute a linear regression of type [", stringify!($REGR_TYPE), "]"), $AGGREGATE_UDF_FN); + create_func!($EXPR_FN, $AGGREGATE_UDF_FN, Regr::new($REGR_TYPE, stringify!($EXPR_FN))); + } +} + +make_regr_udaf_expr_and_func!(regr_slope, regr_slope_udaf, RegrType::Slope); +make_regr_udaf_expr_and_func!(regr_intercept, regr_intercept_udaf, RegrType::Intercept); +make_regr_udaf_expr_and_func!(regr_count, regr_count_udaf, RegrType::Count); +make_regr_udaf_expr_and_func!(regr_r2, regr_r2_udaf, RegrType::R2); +make_regr_udaf_expr_and_func!(regr_avgx, regr_avgx_udaf, RegrType::AvgX); +make_regr_udaf_expr_and_func!(regr_avgy, regr_avgy_udaf, RegrType::AvgY); +make_regr_udaf_expr_and_func!(regr_sxx, regr_sxx_udaf, RegrType::SXX); +make_regr_udaf_expr_and_func!(regr_syy, regr_syy_udaf, RegrType::SYY); +make_regr_udaf_expr_and_func!(regr_sxy, regr_sxy_udaf, RegrType::SXY); + +pub struct Regr { + signature: Signature, + regr_type: RegrType, + func_name: &'static str, +} -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; +impl Debug for Regr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("regr") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} +impl Regr { + pub fn new(regr_type: RegrType, func_name: &'static str) -> Self { + Self { + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + regr_type, + func_name, + } + } +} + +/* #[derive(Debug)] pub struct Regr { name: String, @@ -48,6 +90,7 @@ impl Regr { self.regr_type.clone() } } +*/ #[derive(Debug, Clone)] #[allow(clippy::upper_case_acronyms)] @@ -92,86 +135,75 @@ pub enum RegrType { SXY, } -impl Regr { - pub fn new( - expr_y: Arc, - expr_x: Arc, - name: impl Into, - regr_type: RegrType, - return_type: DataType, - ) -> Self { - // the result of regr_slope only support FLOAT64 data type. - assert!(matches!(return_type, DataType::Float64)); - Self { - name: name.into(), - regr_type, - expr_y, - expr_x, - } - } -} - -impl AggregateExpr for Regr { +impl AggregateUDFImpl for Regr { fn as_any(&self) -> &dyn Any { self } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) + fn name(&self) -> &str { + self.func_name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) } - fn create_accumulator(&self) -> Result> { + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn create_sliding_accumulator(&self) -> Result> { + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { Ok(Box::new(RegrAccumulator::try_new(&self.regr_type)?)) } - fn state_fields(&self) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { Ok(vec![ Field::new( - format_state_name(&self.name, "count"), + format_state_name(args.name, "count"), DataType::UInt64, true, ), Field::new( - format_state_name(&self.name, "mean_x"), + format_state_name(args.name, "mean_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "mean_y"), + format_state_name(args.name, "mean_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_x"), + format_state_name(args.name, "m2_x"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "m2_y"), + format_state_name(args.name, "m2_y"), DataType::Float64, true, ), Field::new( - format_state_name(&self.name, "algo_const"), + format_state_name(args.name, "algo_const"), DataType::Float64, true, ), ]) } - - fn expressions(&self) -> Vec> { - vec![self.expr_y.clone(), self.expr_x.clone()] - } - - fn name(&self) -> &str { - &self.name - } } +/* impl PartialEq for Regr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) @@ -184,6 +216,7 @@ impl PartialEq for Regr { .unwrap_or(false) } } +*/ /// `RegrAccumulator` is used to compute linear regression aggregate functions /// by maintaining statistics needed to compute them in an online fashion. @@ -305,6 +338,10 @@ impl Accumulator for RegrAccumulator { Ok(()) } + fn supports_retract_batch(&self) -> bool { + true + } + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values_y = &cast(&values[0], &DataType::Float64)?; let values_x = &cast(&values[1], &DataType::Float64)?; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index aee7bca3b88f..23a2216c0463 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -34,7 +34,6 @@ use datafusion_common::{exec_err, internal_err, not_impl_err, Result}; use datafusion_expr::AggregateFunction; use crate::aggregate::average::Avg; -use crate::aggregate::regr::RegrType; use crate::expressions::{self, Literal}; use crate::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; /// Create a physical aggregation expression. @@ -158,83 +157,6 @@ pub fn create_aggregate_expr( (AggregateFunction::Correlation, true) => { return not_impl_err!("CORR(DISTINCT) aggregations are not available"); } - (AggregateFunction::RegrSlope, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Slope, - data_type, - )), - (AggregateFunction::RegrIntercept, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Intercept, - data_type, - )), - (AggregateFunction::RegrCount, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::Count, - data_type, - )), - (AggregateFunction::RegrR2, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::R2, - data_type, - )), - (AggregateFunction::RegrAvgx, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgX, - data_type, - )), - (AggregateFunction::RegrAvgy, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::AvgY, - data_type, - )), - (AggregateFunction::RegrSXX, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXX, - data_type, - )), - (AggregateFunction::RegrSYY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SYY, - data_type, - )), - (AggregateFunction::RegrSXY, false) => Arc::new(expressions::Regr::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - RegrType::SXY, - data_type, - )), - ( - AggregateFunction::RegrSlope - | AggregateFunction::RegrIntercept - | AggregateFunction::RegrCount - | AggregateFunction::RegrR2 - | AggregateFunction::RegrAvgx - | AggregateFunction::RegrAvgy - | AggregateFunction::RegrSXX - | AggregateFunction::RegrSYY - | AggregateFunction::RegrSXY, - true, - ) => { - return not_impl_err!("{}(DISTINCT) aggregations are not available", fun); - } (AggregateFunction::ApproxPercentileCont, false) => { if input_phy_exprs.len() == 2 { Arc::new(expressions::ApproxPercentileCont::new( diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 01105c8559c9..9079a81e6241 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -33,7 +33,6 @@ pub(crate) mod string_agg; #[macro_use] pub(crate) mod min_max; pub(crate) mod groups_accumulator; -pub(crate) mod regr; pub(crate) mod stats; pub(crate) mod stddev; pub(crate) mod variance; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 123ada6d7c86..beba25740501 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -50,7 +50,6 @@ pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::min_max::{Max, MaxAccumulator, Min, MinAccumulator}; pub use crate::aggregate::nth_value::NthValueAgg; -pub use crate::aggregate::regr::{Regr, RegrType}; pub use crate::aggregate::stats::StatsType; pub use crate::aggregate::string_agg::StringAgg; pub use crate::window::cume_dist::{cume_dist, CumeDist}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2bb3ec793d7f..822eada7675f 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -496,15 +496,15 @@ enum AggregateFunction { BIT_XOR = 21; BOOL_AND = 22; BOOL_OR = 23; - REGR_SLOPE = 26; - REGR_INTERCEPT = 27; - REGR_COUNT = 28; - REGR_R2 = 29; - REGR_AVGX = 30; - REGR_AVGY = 31; - REGR_SXX = 32; - REGR_SYY = 33; - REGR_SXY = 34; + // REGR_SLOPE = 26; + // REGR_INTERCEPT = 27; + // REGR_COUNT = 28; + // REGR_R2 = 29; + // REGR_AVGX = 30; + // REGR_AVGY = 31; + // REGR_SXX = 32; + // REGR_SYY = 33; + // REGR_SXY = 34; STRING_AGG = 35; NTH_VALUE_AGG = 36; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 59b7861a6ef1..69014267bec6 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -546,15 +546,6 @@ impl serde::Serialize for AggregateFunction { Self::BitXor => "BIT_XOR", Self::BoolAnd => "BOOL_AND", Self::BoolOr => "BOOL_OR", - Self::RegrSlope => "REGR_SLOPE", - Self::RegrIntercept => "REGR_INTERCEPT", - Self::RegrCount => "REGR_COUNT", - Self::RegrR2 => "REGR_R2", - Self::RegrAvgx => "REGR_AVGX", - Self::RegrAvgy => "REGR_AVGY", - Self::RegrSxx => "REGR_SXX", - Self::RegrSyy => "REGR_SYY", - Self::RegrSxy => "REGR_SXY", Self::StringAgg => "STRING_AGG", Self::NthValueAgg => "NTH_VALUE_AGG", }; @@ -647,15 +638,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR" => Ok(AggregateFunction::BitXor), "BOOL_AND" => Ok(AggregateFunction::BoolAnd), "BOOL_OR" => Ok(AggregateFunction::BoolOr), - "REGR_SLOPE" => Ok(AggregateFunction::RegrSlope), - "REGR_INTERCEPT" => Ok(AggregateFunction::RegrIntercept), - "REGR_COUNT" => Ok(AggregateFunction::RegrCount), - "REGR_R2" => Ok(AggregateFunction::RegrR2), - "REGR_AVGX" => Ok(AggregateFunction::RegrAvgx), - "REGR_AVGY" => Ok(AggregateFunction::RegrAvgy), - "REGR_SXX" => Ok(AggregateFunction::RegrSxx), - "REGR_SYY" => Ok(AggregateFunction::RegrSyy), - "REGR_SXY" => Ok(AggregateFunction::RegrSxy), "STRING_AGG" => Ok(AggregateFunction::StringAgg), "NTH_VALUE_AGG" => Ok(AggregateFunction::NthValueAgg), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0861c287fcfa..048362cf2db4 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1950,15 +1950,15 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - RegrSlope = 26, - RegrIntercept = 27, - RegrCount = 28, - RegrR2 = 29, - RegrAvgx = 30, - RegrAvgy = 31, - RegrSxx = 32, - RegrSyy = 33, - RegrSxy = 34, + // RegrSlope = 26, + // RegrIntercept = 27, + // RegrCount = 28, + // RegrR2 = 29, + // RegrAvgx = 30, + // RegrAvgy = 31, + // RegrSxx = 32, + // RegrSyy = 33, + // RegrSxy = 34, StringAgg = 35, NthValueAgg = 36, } @@ -1985,15 +1985,6 @@ impl AggregateFunction { AggregateFunction::BitXor => "BIT_XOR", AggregateFunction::BoolAnd => "BOOL_AND", AggregateFunction::BoolOr => "BOOL_OR", - AggregateFunction::RegrSlope => "REGR_SLOPE", - AggregateFunction::RegrIntercept => "REGR_INTERCEPT", - AggregateFunction::RegrCount => "REGR_COUNT", - AggregateFunction::RegrR2 => "REGR_R2", - AggregateFunction::RegrAvgx => "REGR_AVGX", - AggregateFunction::RegrAvgy => "REGR_AVGY", - AggregateFunction::RegrSxx => "REGR_SXX", - AggregateFunction::RegrSyy => "REGR_SYY", - AggregateFunction::RegrSxy => "REGR_SXY", AggregateFunction::StringAgg => "STRING_AGG", AggregateFunction::NthValueAgg => "NTH_VALUE_AGG", } @@ -2017,15 +2008,6 @@ impl AggregateFunction { "BIT_XOR" => Some(Self::BitXor), "BOOL_AND" => Some(Self::BoolAnd), "BOOL_OR" => Some(Self::BoolOr), - "REGR_SLOPE" => Some(Self::RegrSlope), - "REGR_INTERCEPT" => Some(Self::RegrIntercept), - "REGR_COUNT" => Some(Self::RegrCount), - "REGR_R2" => Some(Self::RegrR2), - "REGR_AVGX" => Some(Self::RegrAvgx), - "REGR_AVGY" => Some(Self::RegrAvgy), - "REGR_SXX" => Some(Self::RegrSxx), - "REGR_SYY" => Some(Self::RegrSyy), - "REGR_SXY" => Some(Self::RegrSxy), "STRING_AGG" => Some(Self::StringAgg), "NTH_VALUE_AGG" => Some(Self::NthValueAgg), _ => None, diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 2ad40d883fe6..4bceab44cc70 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -148,15 +148,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::Count => Self::Count, protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Correlation => Self::Correlation, - protobuf::AggregateFunction::RegrSlope => Self::RegrSlope, - protobuf::AggregateFunction::RegrIntercept => Self::RegrIntercept, - protobuf::AggregateFunction::RegrCount => Self::RegrCount, - protobuf::AggregateFunction::RegrR2 => Self::RegrR2, - protobuf::AggregateFunction::RegrAvgx => Self::RegrAvgx, - protobuf::AggregateFunction::RegrAvgy => Self::RegrAvgy, - protobuf::AggregateFunction::RegrSxx => Self::RegrSXX, - protobuf::AggregateFunction::RegrSyy => Self::RegrSYY, - protobuf::AggregateFunction::RegrSxy => Self::RegrSXY, protobuf::AggregateFunction::ApproxPercentileCont => { Self::ApproxPercentileCont } diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6a275ed7a1b8..e189bde3187d 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -119,15 +119,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::Count => Self::Count, AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Correlation => Self::Correlation, - AggregateFunction::RegrSlope => Self::RegrSlope, - AggregateFunction::RegrIntercept => Self::RegrIntercept, - AggregateFunction::RegrCount => Self::RegrCount, - AggregateFunction::RegrR2 => Self::RegrR2, - AggregateFunction::RegrAvgx => Self::RegrAvgx, - AggregateFunction::RegrAvgy => Self::RegrAvgy, - AggregateFunction::RegrSXX => Self::RegrSxx, - AggregateFunction::RegrSYY => Self::RegrSyy, - AggregateFunction::RegrSXY => Self::RegrSxy, AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont, AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight @@ -410,21 +401,6 @@ pub fn serialize_expr( AggregateFunction::Correlation => { protobuf::AggregateFunction::Correlation } - AggregateFunction::RegrSlope => { - protobuf::AggregateFunction::RegrSlope - } - AggregateFunction::RegrIntercept => { - protobuf::AggregateFunction::RegrIntercept - } - AggregateFunction::RegrR2 => protobuf::AggregateFunction::RegrR2, - AggregateFunction::RegrAvgx => protobuf::AggregateFunction::RegrAvgx, - AggregateFunction::RegrAvgy => protobuf::AggregateFunction::RegrAvgy, - AggregateFunction::RegrCount => { - protobuf::AggregateFunction::RegrCount - } - AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, - AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, - AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e25447b023d8..68b7a1bf1283 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,7 +27,7 @@ use datafusion::physical_plan::expressions::{ BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, Regr, RegrType, RowNumber, StringAgg, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; @@ -270,18 +270,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Avg } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::Correlation - } else if let Some(regr_expr) = aggr_expr.downcast_ref::() { - match regr_expr.get_regr_type() { - RegrType::Slope => protobuf::AggregateFunction::RegrSlope, - RegrType::Intercept => protobuf::AggregateFunction::RegrIntercept, - RegrType::Count => protobuf::AggregateFunction::RegrCount, - RegrType::R2 => protobuf::AggregateFunction::RegrR2, - RegrType::AvgX => protobuf::AggregateFunction::RegrAvgx, - RegrType::AvgY => protobuf::AggregateFunction::RegrAvgy, - RegrType::SXX => protobuf::AggregateFunction::RegrSxx, - RegrType::SYY => protobuf::AggregateFunction::RegrSyy, - RegrType::SXY => protobuf::AggregateFunction::RegrSxy, - } } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::ApproxPercentileCont } else if aggr_expr diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index c7b9808c249d..e44f8243f1a7 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -112,11 +112,11 @@ statement error DataFusion error: Error during planning: No function matches the select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int64, Utf8\] to the signature Uniform\(2, \[Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64\]\) failed\. and No function matches the given name and argument types 'regr_slope\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tregr_slope\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: No function matches the given name and argument types 'REGR_SLOPE\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tREGR_SLOPE\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float32, Utf8\] to the signature Uniform\(2, \[Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64\]\) failed\. and No function matches the given name and argument types 'regr_slope\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tregr_slope\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) select c9, regr_slope(c11, '2') over () as min1 From 549c383e8eb41c9915af68025100549fc121a7aa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jun 2024 09:38:12 -0400 Subject: [PATCH 2/3] Format and regen --- datafusion/proto/src/generated/pbjson.rs | 9 --------- datafusion/proto/src/generated/prost.rs | 18 +++++++++--------- datafusion/proto/src/physical_plan/to_proto.rs | 4 ++-- 3 files changed, 11 insertions(+), 20 deletions(-) diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 69014267bec6..351f6aaa2bf1 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -573,15 +573,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "BIT_XOR", "BOOL_AND", "BOOL_OR", - "REGR_SLOPE", - "REGR_INTERCEPT", - "REGR_COUNT", - "REGR_R2", - "REGR_AVGX", - "REGR_AVGY", - "REGR_SXX", - "REGR_SYY", - "REGR_SXY", "STRING_AGG", "NTH_VALUE_AGG", ]; diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 048362cf2db4..9ec8251f1d4c 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1950,15 +1950,15 @@ pub enum AggregateFunction { BitXor = 21, BoolAnd = 22, BoolOr = 23, - // RegrSlope = 26, - // RegrIntercept = 27, - // RegrCount = 28, - // RegrR2 = 29, - // RegrAvgx = 30, - // RegrAvgy = 31, - // RegrSxx = 32, - // RegrSyy = 33, - // RegrSxy = 34, + /// REGR_SLOPE = 26; + /// REGR_INTERCEPT = 27; + /// REGR_COUNT = 28; + /// REGR_R2 = 29; + /// REGR_AVGX = 30; + /// REGR_AVGY = 31; + /// REGR_SXX = 32; + /// REGR_SYY = 33; + /// REGR_SXY = 34; StringAgg = 35, NthValueAgg = 36, } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 68b7a1bf1283..ef462ac94b9a 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -27,8 +27,8 @@ use datafusion::physical_plan::expressions::{ BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, CumeDist, DistinctArrayAgg, DistinctBitXor, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, - OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, - TryCastExpr, WindowShift, + OrderSensitiveArrayAgg, Rank, RankType, RowNumber, StringAgg, TryCastExpr, + WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; From c984b05ccd214435513ef4ba756a4f08b5f1b41e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 13 Jun 2024 10:56:40 -0400 Subject: [PATCH 3/3] tweak error check --- datafusion/sqllogictest/test_files/errors.slt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/sqllogictest/test_files/errors.slt b/datafusion/sqllogictest/test_files/errors.slt index e44f8243f1a7..d51c69496d46 100644 --- a/datafusion/sqllogictest/test_files/errors.slt +++ b/datafusion/sqllogictest/test_files/errors.slt @@ -112,11 +112,11 @@ statement error DataFusion error: Error during planning: No function matches the select avg(c1, c12) from aggregate_test_100; # AggregateFunction with wrong argument type -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Int64, Utf8\] to the signature Uniform\(2, \[Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64\]\) failed\. and No function matches the given name and argument types 'regr_slope\(Int64, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tregr_slope\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select regr_slope(1, '2'); # WindowFunction using AggregateFunction wrong signature -statement error DataFusion error: Error during planning: Error during planning: Coercion from \[Float32, Utf8\] to the signature Uniform\(2, \[Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64\]\) failed\. and No function matches the given name and argument types 'regr_slope\(Float32, Utf8\)'\. You might need to add explicit type casts\.\n\tCandidate functions:\n\tregr_slope\(Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64, Int8/Int16/Int32/Int64/UInt8/UInt16/UInt32/UInt64/Float32/Float64\) +statement error Coercion select c9, regr_slope(c11, '2') over () as min1