From a0fccbf886346fde5dfbda136149ec98bbd6e952 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 6 May 2024 17:46:06 +0800 Subject: [PATCH] Move `Covariance` (Sample) `covar` / `covar_samp` to be a User Defined Aggregate Function (#10372) * introduce CovarianceSample Signed-off-by: jayzhan211 * rewrite macro Signed-off-by: jayzhan211 * rm old statstype Signed-off-by: jayzhan211 * register Signed-off-by: jayzhan211 * state field Signed-off-by: jayzhan211 * rm builtin Signed-off-by: jayzhan211 * addres comments Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/src/physical_planner.rs | 1 + datafusion/expr/src/aggregate_function.rs | 11 +- .../expr/src/type_coercion/aggregates.rs | 2 +- .../functions-aggregate/src/covariance.rs | 318 ++++++++++++++++++ .../functions-aggregate/src/first_last.rs | 4 +- datafusion/functions-aggregate/src/lib.rs | 7 +- datafusion/functions-aggregate/src/macros.rs | 68 ++-- .../physical-expr-common/src/aggregate/mod.rs | 1 + .../src/aggregate/stats.rs | 26 ++ .../physical-expr/src/aggregate/build_in.rs | 154 +-------- .../physical-expr/src/aggregate/covariance.rs | 174 ---------- .../physical-expr/src/aggregate/stats.rs | 9 +- .../physical-expr/src/expressions/mod.rs | 2 +- datafusion/proto/proto/datafusion.proto | 2 +- datafusion/proto/src/generated/pbjson.rs | 3 - datafusion/proto/src/generated/prost.rs | 4 +- .../proto/src/logical_plan/from_proto.rs | 1 - datafusion/proto/src/logical_plan/to_proto.rs | 4 - .../proto/src/physical_plan/to_proto.rs | 14 +- .../tests/cases/roundtrip_logical_plan.rs | 2 + .../sqllogictest/test_files/functions.slt | 2 +- 21 files changed, 418 insertions(+), 391 deletions(-) create mode 100644 datafusion/functions-aggregate/src/covariance.rs create mode 100644 datafusion/physical-expr-common/src/aggregate/stats.rs diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 391ded84eab9..dfcda553af7d 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -1901,6 +1901,7 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( let ignore_nulls = null_treatment .unwrap_or(sqlparser::ast::NullTreatment::RespectNulls) == NullTreatment::IgnoreNulls; + let (agg_expr, filter, order_by) = match func_def { AggregateFunctionDefinition::BuiltIn(fun) => { let physical_sort_exprs = match order_by { diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 3dc9c3a01c15..af8a682eff58 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -63,8 +63,6 @@ pub enum AggregateFunction { Stddev, /// Standard Deviation (Population) StddevPop, - /// Covariance (Sample) - Covariance, /// Covariance (Population) CovariancePop, /// Correlation @@ -128,7 +126,6 @@ impl AggregateFunction { VariancePop => "VAR_POP", Stddev => "STDDEV", StddevPop => "STDDEV_POP", - Covariance => "COVAR", CovariancePop => "COVAR_POP", Correlation => "CORR", RegrSlope => "REGR_SLOPE", @@ -184,9 +181,7 @@ impl FromStr for AggregateFunction { "string_agg" => AggregateFunction::StringAgg, // statistical "corr" => AggregateFunction::Correlation, - "covar" => AggregateFunction::Covariance, "covar_pop" => AggregateFunction::CovariancePop, - "covar_samp" => AggregateFunction::Covariance, "stddev" => AggregateFunction::Stddev, "stddev_pop" => AggregateFunction::StddevPop, "stddev_samp" => AggregateFunction::Stddev, @@ -260,9 +255,6 @@ impl AggregateFunction { AggregateFunction::VariancePop => { variance_return_type(&coerced_data_types[0]) } - AggregateFunction::Covariance => { - covariance_return_type(&coerced_data_types[0]) - } AggregateFunction::CovariancePop => { covariance_return_type(&coerced_data_types[0]) } @@ -357,8 +349,7 @@ impl AggregateFunction { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), - AggregateFunction::Covariance - | AggregateFunction::CovariancePop + AggregateFunction::CovariancePop | AggregateFunction::Correlation | AggregateFunction::RegrSlope | AggregateFunction::RegrIntercept diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 5ffdc8f94753..39726d7d0e62 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -183,7 +183,7 @@ pub fn coerce_types( } Ok(vec![Float64, Float64]) } - AggregateFunction::Covariance | AggregateFunction::CovariancePop => { + AggregateFunction::CovariancePop => { if !is_covariance_support_arg_type(&input_types[0]) { return plan_err!( "The function {:?} does not support inputs of type {:?}.", diff --git a/datafusion/functions-aggregate/src/covariance.rs b/datafusion/functions-aggregate/src/covariance.rs new file mode 100644 index 000000000000..130b193996b6 --- /dev/null +++ b/datafusion/functions-aggregate/src/covariance.rs @@ -0,0 +1,318 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`CovarianceSample`]: covariance sample aggregations. + +use std::fmt::Debug; + +use arrow::{ + array::{ArrayRef, Float64Array, UInt64Array}, + compute::kernels::cast, + datatypes::{DataType, Field}, +}; + +use datafusion_common::{ + downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result, + ScalarValue, +}; +use datafusion_expr::{ + function::AccumulatorArgs, type_coercion::aggregates::NUMERICS, + utils::format_state_name, Accumulator, AggregateUDFImpl, Signature, Volatility, +}; +use datafusion_physical_expr_common::aggregate::stats::StatsType; + +make_udaf_expr_and_func!( + CovarianceSample, + covar_samp, + y x, + "Computes the sample covariance.", + covar_samp_udaf +); + +pub struct CovarianceSample { + signature: Signature, + aliases: Vec, +} + +impl Debug for CovarianceSample { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("CovarianceSample") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for CovarianceSample { + fn default() -> Self { + Self::new() + } +} + +impl CovarianceSample { + pub fn new() -> Self { + Self { + aliases: vec![String::from("covar")], + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for CovarianceSample { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "covar_samp" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("Covariance requires numeric input types"); + } + + Ok(DataType::Float64) + } + + fn state_fields( + &self, + name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean1"), DataType::Float64, true), + Field::new(format_state_name(name, "mean2"), DataType::Float64, true), + Field::new( + format_state_name(name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } +} + +/// An accumulator to compute covariance +/// The algorithm used is an online implementation and numerically stable. It is derived from the following paper +/// for calculating variance: +/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products". +/// Technometrics. 4 (3): 419–420. doi:10.2307/1266577. JSTOR 1266577. +/// +/// The algorithm has been analyzed here: +/// Ling, Robert F. (1974). "Comparison of Several Algorithms for Computing Sample Means and Variances". +/// Journal of the American Statistical Association. 69 (348): 859–866. doi:10.2307/2286154. JSTOR 2286154. +/// +/// Though it is not covered in the original paper but is based on the same idea, as a result the algorithm is online, +/// parallelizable and numerically stable. + +#[derive(Debug)] +pub struct CovarianceAccumulator { + algo_const: f64, + mean1: f64, + mean2: f64, + count: u64, + stats_type: StatsType, +} + +impl CovarianceAccumulator { + /// Creates a new `CovarianceAccumulator` + pub fn try_new(s_type: StatsType) -> Result { + Ok(Self { + algo_const: 0_f64, + mean1: 0_f64, + mean2: 0_f64, + count: 0_u64, + stats_type: s_type, + }) + } + + pub fn get_count(&self) -> u64 { + self.count + } + + pub fn get_mean1(&self) -> f64 { + self.mean1 + } + + pub fn get_mean2(&self) -> f64 { + self.mean2 + } + + pub fn get_algo_const(&self) -> f64 { + self.algo_const + } +} + +impl Accumulator for CovarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean1), + ScalarValue::from(self.mean2), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + let new_count = self.count + 1; + let delta1 = value1 - self.mean1; + let new_mean1 = delta1 / new_count as f64 + self.mean1; + let delta2 = value2 - self.mean2; + let new_mean2 = delta2 / new_count as f64 + self.mean2; + let new_c = delta1 * (value2 - new_mean2) + self.algo_const; + + self.count += 1; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + + let new_count = self.count - 1; + let delta1 = self.mean1 - value1; + let new_mean1 = delta1 / new_count as f64 + self.mean1; + let delta2 = self.mean2 - value2; + let new_mean2 = delta2 / new_count as f64 + self.mean2; + let new_c = self.algo_const - delta1 * (new_mean2 - value2); + + self.count -= 1; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], UInt64Array); + let means1 = downcast_value!(states[1], Float64Array); + let means2 = downcast_value!(states[2], Float64Array); + let cs = downcast_value!(states[3], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_u64 { + continue; + } + let new_count = self.count + c; + let new_mean1 = self.mean1 * self.count as f64 / new_count as f64 + + means1.value(i) * c as f64 / new_count as f64; + let new_mean2 = self.mean2 * self.count as f64 / new_count as f64 + + means2.value(i) * c as f64 / new_count as f64; + let delta1 = self.mean1 - means1.value(i); + let delta2 = self.mean2 - means2.value(i); + let new_c = self.algo_const + + cs.value(i) + + delta1 * delta2 * self.count as f64 * c as f64 / new_count as f64; + + self.count = new_count; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0 { + self.count - 1 + } else { + self.count + } + } + }; + + if count == 0 { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(self.algo_const / count as f64))) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/functions-aggregate/src/first_last.rs b/datafusion/functions-aggregate/src/first_last.rs index 8dc4cee87a3b..e3b685e90376 100644 --- a/datafusion/functions-aggregate/src/first_last.rs +++ b/datafusion/functions-aggregate/src/first_last.rs @@ -39,12 +39,12 @@ use datafusion_physical_expr_common::expressions; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; use datafusion_physical_expr_common::utils::reverse_order_bys; -use sqlparser::ast::NullTreatment; + use std::any::Any; use std::fmt::Debug; use std::sync::Arc; -make_udaf_function!( +make_udaf_expr_and_func!( FirstValue, first_value, "Returns the first value in a group of values.", diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index 8016b76889f7..d4e4d3a5f328 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -55,6 +55,7 @@ #[macro_use] pub mod macros; +pub mod covariance; pub mod first_last; use datafusion_common::Result; @@ -65,12 +66,16 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::covariance::covar_samp; pub use super::first_last::first_value; } /// Registers all enabled packages with a [`FunctionRegistry`] pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { - let functions: Vec> = vec![first_last::first_value_udaf()]; + let functions: Vec> = vec![ + first_last::first_value_udaf(), + covariance::covar_samp_udaf(), + ]; functions.into_iter().try_for_each(|udf| { let existing_udaf = registry.register_udaf(udf)?; diff --git a/datafusion/functions-aggregate/src/macros.rs b/datafusion/functions-aggregate/src/macros.rs index 04f9fecb8b19..27fc623a182b 100644 --- a/datafusion/functions-aggregate/src/macros.rs +++ b/datafusion/functions-aggregate/src/macros.rs @@ -15,33 +15,59 @@ // specific language governing permissions and limitations // under the License. -macro_rules! make_udaf_function { +macro_rules! make_udaf_expr_and_func { + ($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN( + $($arg: datafusion_expr::Expr,)* + distinct: bool, + filter: Option>, + order_by: Option>, + null_treatment: Option + ) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + vec![$($arg),*], + distinct, + filter, + order_by, + null_treatment, + )) + } + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; ($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => { - paste::paste! { - // "fluent expr_fn" style function - #[doc = $DOC] - pub fn $EXPR_FN( - args: Vec, - distinct: bool, - filter: Option>, - order_by: Option>, - null_treatment: Option - ) -> Expr { - Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( - $AGGREGATE_UDF_FN(), - args, - distinct, - filter, - order_by, - null_treatment, - )) - } + // "fluent expr_fn" style function + #[doc = $DOC] + pub fn $EXPR_FN( + args: Vec, + distinct: bool, + filter: Option>, + order_by: Option>, + null_treatment: Option + ) -> datafusion_expr::Expr { + datafusion_expr::Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + $AGGREGATE_UDF_FN(), + args, + distinct, + filter, + order_by, + null_treatment, + )) + } + create_func!($UDAF, $AGGREGATE_UDF_FN); + }; +} +macro_rules! create_func { + ($UDAF:ty, $AGGREGATE_UDF_FN:ident) => { + paste::paste! { /// Singleton instance of [$UDAF], ensures the UDAF is only created once /// named STATIC_$(UDAF). For example `STATIC_FirstValue` #[allow(non_upper_case_globals)] static [< STATIC_ $UDAF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); + std::sync::OnceLock::new(); /// AggregateFunction that returns a [AggregateUDF] for [$UDAF] /// diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 448af634176a..d2e3414fbfce 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +pub mod stats; pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; diff --git a/datafusion/physical-expr-common/src/aggregate/stats.rs b/datafusion/physical-expr-common/src/aggregate/stats.rs new file mode 100644 index 000000000000..6a11ebe36c5f --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/stats.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// TODO: Move this to functions-aggregate module +/// Enum used for differentiating population and sample for statistical functions +#[derive(Debug, Clone, Copy)] +pub enum StatsType { + /// Population + Population, + /// Sample + Sample, +} diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 57ed35b0b761..36af875473be 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -181,15 +181,6 @@ pub fn create_aggregate_expr( (AggregateFunction::VariancePop, true) => { return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); } - (AggregateFunction::Covariance, false) => Arc::new(expressions::Covariance::new( - input_phy_exprs[0].clone(), - input_phy_exprs[1].clone(), - name, - data_type, - )), - (AggregateFunction::Covariance, true) => { - return not_impl_err!("COVAR(DISTINCT) aggregations are not available"); - } (AggregateFunction::CovariancePop, false) => { Arc::new(expressions::CovariancePop::new( input_phy_exprs[0].clone(), @@ -428,8 +419,8 @@ mod tests { use crate::expressions::{ try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Correlation, Count, Covariance, - DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, + BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, + Max, Min, Stddev, Sum, Variance, }; use super::*; @@ -950,147 +941,6 @@ mod tests { Ok(()) } - #[test] - fn test_covar_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Covariance]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_covar_pop_expr() -> Result<()> { - let funcs = vec![AggregateFunction::CovariancePop]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - - #[test] - fn test_corr_expr() -> Result<()> { - let funcs = vec![AggregateFunction::Correlation]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = Schema::new(vec![ - Field::new("c1", data_type.clone(), true), - Field::new("c2", data_type.clone(), true), - ]); - let input_phy_exprs: Vec> = vec![ - Arc::new( - expressions::Column::new_with_schema("c1", &input_schema) - .unwrap(), - ), - Arc::new( - expressions::Column::new_with_schema("c2", &input_schema) - .unwrap(), - ), - ]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..2], - &input_schema, - "c1", - )?; - if fun == AggregateFunction::Covariance { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", DataType::Float64, true), - result_agg_phy_exprs.field().unwrap() - ) - } - } - } - Ok(()) - } - #[test] fn test_median_expr() -> Result<()> { let funcs = vec![AggregateFunction::ApproxMedian]; diff --git a/datafusion/physical-expr/src/aggregate/covariance.rs b/datafusion/physical-expr/src/aggregate/covariance.rs index ba9bdbc8aee3..272f1d8be2b5 100644 --- a/datafusion/physical-expr/src/aggregate/covariance.rs +++ b/datafusion/physical-expr/src/aggregate/covariance.rs @@ -36,14 +36,6 @@ use crate::aggregate::stats::StatsType; use crate::aggregate::utils::down_cast_any_ref; use crate::expressions::format_state_name; -/// COVAR and COVAR_SAMP aggregate expression -#[derive(Debug)] -pub struct Covariance { - name: String, - expr1: Arc, - expr2: Arc, -} - /// COVAR_POP aggregate expression #[derive(Debug)] pub struct CovariancePop { @@ -52,83 +44,6 @@ pub struct CovariancePop { expr2: Arc, } -impl Covariance { - /// Create a new COVAR aggregate function - pub fn new( - expr1: Arc, - expr2: Arc, - name: impl Into, - data_type: DataType, - ) -> Self { - // the result of covariance just support FLOAT64 data type. - assert!(matches!(data_type, DataType::Float64)); - Self { - name: name.into(), - expr1, - expr2, - } - } -} - -impl AggregateExpr for Covariance { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Float64, true)) - } - - fn create_accumulator(&self) -> Result> { - Ok(Box::new(CovarianceAccumulator::try_new(StatsType::Sample)?)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![ - Field::new( - format_state_name(&self.name, "count"), - DataType::UInt64, - true, - ), - Field::new( - format_state_name(&self.name, "mean1"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "mean2"), - DataType::Float64, - true, - ), - Field::new( - format_state_name(&self.name, "algo_const"), - DataType::Float64, - true, - ), - ]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr1.clone(), self.expr2.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for Covariance { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name && self.expr1.eq(&x.expr1) && self.expr2.eq(&x.expr2) - }) - .unwrap_or(false) - } -} - impl CovariancePop { /// Create a new COVAR_POP aggregate function pub fn new( @@ -429,36 +344,6 @@ mod tests { ) } - #[test] - fn covariance_f64_2() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_f64_4() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![4.1_f64, 5_f64, 6_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::from(0.9033333333333335_f64) - ) - } - #[test] fn covariance_f64_5() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1.1_f64, 2_f64, 3_f64])); @@ -580,50 +465,6 @@ mod tests { ) } - #[test] - fn covariance_i32_with_nulls_3() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(1), - None, - Some(2), - None, - Some(3), - None, - ])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(4), - Some(9), - Some(5), - Some(8), - Some(6), - None, - ])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Covariance, - ScalarValue::from(1_f64) - ) - } - - #[test] - fn covariance_i32_all_nulls() -> Result<()> { - let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - let b: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); - - generic_test_op2!( - a, - b, - DataType::Int32, - DataType::Int32, - Covariance, - ScalarValue::Float64(None) - ) - } - #[test] fn covariance_pop_i32_all_nulls() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![None, None])); @@ -639,21 +480,6 @@ mod tests { ) } - #[test] - fn covariance_1_input() -> Result<()> { - let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); - let b: ArrayRef = Arc::new(Float64Array::from(vec![2_f64])); - - generic_test_op2!( - a, - b, - DataType::Float64, - DataType::Float64, - Covariance, - ScalarValue::Float64(None) - ) - } - #[test] fn covariance_pop_1_input() -> Result<()> { let a: ArrayRef = Arc::new(Float64Array::from(vec![1_f64])); diff --git a/datafusion/physical-expr/src/aggregate/stats.rs b/datafusion/physical-expr/src/aggregate/stats.rs index 98baaccffe81..d9338f5a962f 100644 --- a/datafusion/physical-expr/src/aggregate/stats.rs +++ b/datafusion/physical-expr/src/aggregate/stats.rs @@ -15,11 +15,4 @@ // specific language governing permissions and limitations // under the License. -/// Enum used for differentiating population and sample for statistical functions -#[derive(Debug, Clone, Copy)] -pub enum StatsType { - /// Population - Population, - /// Sample - Sample, -} +pub use datafusion_physical_expr_common::aggregate::stats::StatsType; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 55ebd9ed8c44..0cd2ac2c9e42 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -52,7 +52,7 @@ pub use crate::aggregate::build_in::create_aggregate_expr; pub use crate::aggregate::correlation::Correlation; pub use crate::aggregate::count::Count; pub use crate::aggregate::count_distinct::DistinctCount; -pub use crate::aggregate::covariance::{Covariance, CovariancePop}; +pub use crate::aggregate::covariance::CovariancePop; pub use crate::aggregate::grouping::Grouping; pub use crate::aggregate::median::Median; pub use crate::aggregate::min_max::{Max, Min}; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 9e4ea8e712ed..c057ab8acda7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -548,7 +548,7 @@ enum AggregateFunction { ARRAY_AGG = 6; VARIANCE = 7; VARIANCE_POP = 8; - COVARIANCE = 9; + // COVARIANCE = 9; COVARIANCE_POP = 10; STDDEV = 11; STDDEV_POP = 12; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index b5779d25c6e2..994703c5fcfb 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -430,7 +430,6 @@ impl serde::Serialize for AggregateFunction { Self::ArrayAgg => "ARRAY_AGG", Self::Variance => "VARIANCE", Self::VariancePop => "VARIANCE_POP", - Self::Covariance => "COVARIANCE", Self::CovariancePop => "COVARIANCE_POP", Self::Stddev => "STDDEV", Self::StddevPop => "STDDEV_POP", @@ -478,7 +477,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG", "VARIANCE", "VARIANCE_POP", - "COVARIANCE", "COVARIANCE_POP", "STDDEV", "STDDEV_POP", @@ -555,7 +553,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "VARIANCE" => Ok(AggregateFunction::Variance), "VARIANCE_POP" => Ok(AggregateFunction::VariancePop), - "COVARIANCE" => Ok(AggregateFunction::Covariance), "COVARIANCE_POP" => Ok(AggregateFunction::CovariancePop), "STDDEV" => Ok(AggregateFunction::Stddev), "STDDEV_POP" => Ok(AggregateFunction::StddevPop), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index c822ac13013c..fc23a9ea05f7 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2834,7 +2834,7 @@ pub enum AggregateFunction { ArrayAgg = 6, Variance = 7, VariancePop = 8, - Covariance = 9, + /// COVARIANCE = 9; CovariancePop = 10, Stddev = 11, StddevPop = 12, @@ -2881,7 +2881,6 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Variance => "VARIANCE", AggregateFunction::VariancePop => "VARIANCE_POP", - AggregateFunction::Covariance => "COVARIANCE", AggregateFunction::CovariancePop => "COVARIANCE_POP", AggregateFunction::Stddev => "STDDEV", AggregateFunction::StddevPop => "STDDEV_POP", @@ -2925,7 +2924,6 @@ impl AggregateFunction { "ARRAY_AGG" => Some(Self::ArrayAgg), "VARIANCE" => Some(Self::Variance), "VARIANCE_POP" => Some(Self::VariancePop), - "COVARIANCE" => Some(Self::Covariance), "COVARIANCE_POP" => Some(Self::CovariancePop), "STDDEV" => Some(Self::Stddev), "STDDEV_POP" => Some(Self::StddevPop), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 83b232da9d21..35d4c6409bc1 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -428,7 +428,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ArrayAgg => Self::ArrayAgg, protobuf::AggregateFunction::Variance => Self::Variance, protobuf::AggregateFunction::VariancePop => Self::VariancePop, - protobuf::AggregateFunction::Covariance => Self::Covariance, protobuf::AggregateFunction::CovariancePop => Self::CovariancePop, protobuf::AggregateFunction::Stddev => Self::Stddev, protobuf::AggregateFunction::StddevPop => Self::StddevPop, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index b2236847ace8..dcec2a3b8595 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -369,7 +369,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ArrayAgg => Self::ArrayAgg, AggregateFunction::Variance => Self::Variance, AggregateFunction::VariancePop => Self::VariancePop, - AggregateFunction::Covariance => Self::Covariance, AggregateFunction::CovariancePop => Self::CovariancePop, AggregateFunction::Stddev => Self::Stddev, AggregateFunction::StddevPop => Self::StddevPop, @@ -674,9 +673,6 @@ pub fn serialize_expr( AggregateFunction::VariancePop => { protobuf::AggregateFunction::VariancePop } - AggregateFunction::Covariance => { - protobuf::AggregateFunction::Covariance - } AggregateFunction::CovariancePop => { protobuf::AggregateFunction::CovariancePop } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index c7df6ebf5828..a0a0ee72054b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -25,12 +25,12 @@ use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, - DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, - Variance, VariancePop, WindowShift, + CastExpr, Column, Correlation, Count, CovariancePop, CumeDist, DistinctArrayAgg, + DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, InListExpr, + IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, NegativeExpr, + NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, + RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, Variance, + VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -292,8 +292,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { protobuf::AggregateFunction::Variance } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::VariancePop - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::Covariance } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::CovariancePop } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 65985f86801e..3800b672b5e2 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -30,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::functions_aggregate::covariance::covar_samp; use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; use datafusion::test_util::{TestTableFactory, TestTableProvider}; @@ -614,6 +615,7 @@ async fn roundtrip_expr_api() -> Result<()> { ), array_replace_all(make_array(vec![lit(1), lit(2), lit(3)]), lit(2), lit(4)), first_value(vec![lit(1)], false, None, None, None), + covar_samp(lit(1.5), lit(2.2), false, None, None, None), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index bc8f6a268703..d03b33d0c8e5 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -495,7 +495,7 @@ statement error Did you mean 'STDDEV'? SELECT STDEV(v1) from test; # Aggregate function -statement error Did you mean 'COVAR'? +statement error DataFusion error: Error during planning: Invalid function 'covaria'.\nDid you mean 'covar'? SELECT COVARIA(1,1); # Window function