Skip to content

Commit

Permalink
Add customizable equality and hash functions to UDFs
Browse files Browse the repository at this point in the history
  • Loading branch information
joroKr21 committed Jul 10, 2024
1 parent 146b679 commit 31a170b
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 33 deletions.
56 changes: 43 additions & 13 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@

//! [`AggregateUDF`]: User Defined Aggregate Functions

use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;
use std::vec;

use arrow::datatypes::{DataType, Field};
use sqlparser::ast::NullTreatment;

use datafusion_common::{exec_err, not_impl_err, plan_err, Result};

use crate::expr::AggregateFunction;
use crate::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
Expand All @@ -26,13 +37,6 @@ use crate::utils::format_state_name;
use crate::utils::AggregateOrderSensitivity;
use crate::{Accumulator, Expr};
use crate::{AccumulatorFactoryFunction, ReturnTypeFunction, Signature};
use arrow::datatypes::{DataType, Field};
use datafusion_common::{exec_err, not_impl_err, plan_err, Result};
use sqlparser::ast::NullTreatment;
use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
use std::vec;

/// Logical representation of a user-defined [aggregate function] (UDAF).
///
Expand Down Expand Up @@ -72,20 +76,19 @@ pub struct AggregateUDF {

impl PartialEq for AggregateUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref()) || other.inner.equals(self.inner.as_ref())
}
}

impl Eq for AggregateUDF {}

impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for AggregateUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

impl std::fmt::Display for AggregateUDF {
impl fmt::Display for AggregateUDF {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.name())
}
Expand Down Expand Up @@ -507,6 +510,21 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}

/// Dynamic equality. Allows customizing the equality of aggregate UDFs.
/// By default, compares the UDF name and signature.
fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Dynamic hashing. Allows customizing the hash code of aggregate UDFs.
/// By default, hashes the UDF name and signature.
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

pub enum ReversedUDAF {
Expand Down Expand Up @@ -562,6 +580,18 @@ impl AggregateUDFImpl for AliasedAggregateUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedAggregateUDFImpl>() {
self.inner.equals(other.inner.as_ref())
} else {
self.inner.equals(other)
}
}

fn hash_value(&self) -> u64 {
self.inner.hash_value()
}
}

/// Implementation of [`AggregateUDFImpl`] that wraps the function style pointers
Expand Down
45 changes: 36 additions & 9 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

use std::any::Any;
use std::fmt::{self, Debug, Formatter};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Arc;

use arrow::datatypes::DataType;

use datafusion_common::{not_impl_err, ExprSchema, Result};

use crate::expr::create_name;
use crate::interval_arithmetic::Interval;
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
Expand All @@ -29,9 +34,6 @@ use crate::{
ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature,
};

use arrow::datatypes::DataType;
use datafusion_common::{not_impl_err, ExprSchema, Result};

/// Logical representation of a Scalar User Defined Function.
///
/// A scalar function produces a single row output for each row of input. This
Expand Down Expand Up @@ -59,16 +61,15 @@ pub struct ScalarUDF {

impl PartialEq for ScalarUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref()) || other.inner.equals(self.inner.as_ref())
}
}

impl Eq for ScalarUDF {}

impl std::hash::Hash for ScalarUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for ScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

Expand Down Expand Up @@ -540,6 +541,21 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("Function {} does not implement coerce_types", self.name())
}

/// Dynamic equality. Allows customizing the equality of scalar UDFs.
/// By default, compares the UDF name and signature.
fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Dynamic hashing. Allows customizing the hash code of scalar UDFs.
/// By default, hashes the UDF name and signature.
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

/// ScalarUDF that adds an alias to the underlying function. It is better to
Expand All @@ -557,7 +573,6 @@ impl AliasedScalarUDFImpl {
) -> Self {
let mut aliases = inner.aliases().to_vec();
aliases.extend(new_aliases.into_iter().map(|s| s.to_string()));

Self { inner, aliases }
}
}
Expand All @@ -575,6 +590,18 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl {
self.inner.signature()
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedScalarUDFImpl>() {
self.inner.equals(other.inner.as_ref())
} else {
self.inner.equals(other)
}
}

fn hash_value(&self) -> u64 {
self.inner.hash_value()
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
self.inner.return_type(arg_types)
}
Expand Down
52 changes: 41 additions & 11 deletions datafusion/expr/src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@

//! [`WindowUDF`]: User Defined Window Functions

use crate::{
function::WindowFunctionSimplification, Expr, PartitionEvaluator,
PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};
use arrow::datatypes::DataType;
use datafusion_common::Result;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::{
any::Any,
fmt::{self, Debug, Display, Formatter},
sync::Arc,
};

use arrow::datatypes::DataType;

use datafusion_common::Result;

use crate::{
function::WindowFunctionSimplification, Expr, PartitionEvaluator,
PartitionEvaluatorFactory, ReturnTypeFunction, Signature, WindowFrame,
};

/// Logical representation of a user-defined window function (UDWF)
/// A UDWF is different from a UDF in that it is stateful across batches.
///
Expand Down Expand Up @@ -62,16 +66,15 @@ impl Display for WindowUDF {

impl PartialEq for WindowUDF {
fn eq(&self, other: &Self) -> bool {
self.name() == other.name() && self.signature() == other.signature()
self.inner.equals(other.inner.as_ref()) || other.inner.equals(self.inner.as_ref())
}
}

impl Eq for WindowUDF {}

impl std::hash::Hash for WindowUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name().hash(state);
self.signature().hash(state);
impl Hash for WindowUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.hash_value().hash(state)
}
}

Expand Down Expand Up @@ -296,6 +299,21 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
fn simplify(&self) -> Option<WindowFunctionSimplification> {
None
}

/// Dynamic equality. Allows customizing the equality of window UDFs.
/// By default, compares the UDF name and signature.
fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
self.name() == other.name() && self.signature() == other.signature()
}

/// Dynamic hashing. Allows customizing the hash code of window UDFs.
/// By default, hashes the UDF name and signature.
fn hash_value(&self) -> u64 {
let hasher = &mut DefaultHasher::new();
self.name().hash(hasher);
self.signature().hash(hasher);
hasher.finish()
}
}

/// WindowUDF that adds an alias to the underlying function. It is better to
Expand Down Expand Up @@ -342,6 +360,18 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
fn aliases(&self) -> &[String] {
&self.aliases
}

fn equals(&self, other: &dyn WindowUDFImpl) -> bool {
if let Some(other) = other.as_any().downcast_ref::<AliasedWindowUDFImpl>() {
self.inner.equals(other.inner.as_ref())
} else {
self.inner.equals(other)
}
}

fn hash_value(&self) -> u64 {
self.inner.hash_value()
}
}

/// Implementation of [`WindowUDFImpl`] that wraps the function style pointers
Expand Down

0 comments on commit 31a170b

Please sign in to comment.