Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ScalarUDF: Remove supports_zero_argument and avoid creating null array for empty args #10193

Merged
4 changes: 0 additions & 4 deletions datafusion/core/src/physical_optimizer/projection_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1402,7 +1402,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1471,7 +1470,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 3))),
Expand Down Expand Up @@ -1543,7 +1541,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d", 2))),
Expand Down Expand Up @@ -1612,7 +1609,6 @@ mod tests {
],
DataType::Int32,
None,
false,
)),
Arc::new(CaseExpr::try_new(
Some(Arc::new(Column::new("d_new", 3))),
Expand Down
202 changes: 66 additions & 136 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
// under the License.

use arrow::compute::kernels::numeric::add;
use arrow_array::{
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array,
};
use arrow_schema::DataType::Float64;
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState};
use datafusion::prelude::*;
Expand All @@ -36,9 +33,7 @@ use datafusion_expr::{
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable,
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use rand::{thread_rng, Rng};
use std::any::Any;
use std::iter;
use std::sync::Arc;

/// test that casting happens on udfs.
Expand Down Expand Up @@ -168,6 +163,48 @@ async fn scalar_udf() -> Result<()> {
Ok(())
}

struct Simple0ArgsScalarUDF {
name: String,
signature: Signature,
return_type: DataType,
}

impl std::fmt::Debug for Simple0ArgsScalarUDF {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ScalarUDF")
.field("name", &self.name)
.field("signature", &self.signature)
.field("fun", &"<FUNC>")
.finish()
}
}

impl ScalarUDFImpl for Simple0ArgsScalarUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
&self.name
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(self.return_type.clone())
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
not_impl_err!("{} function does not accept arguments", self.name())
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}
}

#[tokio::test]
async fn scalar_udf_zero_params() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
Expand All @@ -179,20 +216,14 @@ async fn scalar_udf_zero_params() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;
// create function just returns 100 regardless of inp
let myfunc = Arc::new(|_args: &[ColumnarValue]| {
Ok(ColumnarValue::Array(
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef,
))
});

ctx.register_udf(create_udf(
"get_100",
vec![],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
));
let get_100_udf = Simple0ArgsScalarUDF {
name: "get_100".to_string(),
signature: Signature::exact(vec![], Volatility::Immutable),
return_type: DataType::Int32,
};

ctx.register_udf(ScalarUDF::from(get_100_udf));

let result = plan_and_collect(&ctx, "select get_100() a from t").await?;
let expected = [
Expand Down Expand Up @@ -403,123 +434,6 @@ async fn test_user_defined_functions_with_alias() -> Result<()> {
Ok(())
}

#[derive(Debug)]
pub struct RandomUDF {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have RandomFunc, they are the same so remove it.

signature: Signature,
}

impl RandomUDF {
pub fn new() -> Self {
Self {
signature: Signature::any(0, Volatility::Volatile),
}
}
}

impl ScalarUDFImpl for RandomUDF {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"random_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
let len: usize = match &args[0] {
// This udf is always invoked with zero argument so its argument
// is a null array indicating the batch size.
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(),
_ => {
return Err(datafusion::error::DataFusionError::Internal(
"Invalid argument type".to_string(),
))
}
};
let mut rng = thread_rng();
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len);
let array = Float64Array::from_iter_values(values);
Ok(ColumnarValue::Array(Arc::new(array)))
}
}

/// Ensure that a user defined function with zero argument will be invoked
/// with a null array indicating the batch size.
#[tokio::test]
async fn test_user_defined_functions_zero_argument() -> Result<()> {
let ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![Field::new(
"index",
DataType::UInt8,
false,
)]));

let batch = RecordBatch::try_new(
schema,
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))],
)?;

ctx.register_batch("data_table", batch)?;

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
ctx.register_udf(random_normal_udf);

let result = plan_and_collect(
&ctx,
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table",
)
.await?;

assert_eq!(result.len(), 1);
let batch = &result[0];
let random_udf = batch
.column(0)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
let native_random = batch
.column(1)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();

assert_eq!(random_udf.len(), native_random.len());

let mut previous = -1.0;
for i in 0..random_udf.len() {
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0);
assert!(random_udf.value(i) != previous);
previous = random_udf.value(i);
}

Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is just moved , the test remains

let random_normal_udf = ScalarUDF::from(RandomUDF::new());
let ctx = SessionContext::new();

ctx.register_udf(random_normal_udf.clone());

assert!(ctx.udfs().contains("random_udf"));

ctx.deregister_udf("random_udf");

assert!(!ctx.udfs().contains("random_udf"));

Ok(())
}

#[derive(Debug)]
struct CastToI64UDF {
signature: Signature,
Expand Down Expand Up @@ -615,6 +529,22 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn deregister_udf() -> Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

let cast2i64 = ScalarUDF::from(CastToI64UDF::new());
let ctx = SessionContext::new();

ctx.register_udf(cast2i64.clone());

assert!(ctx.udfs().contains("cast_to_i64"));

ctx.deregister_udf("cast_to_i64");

assert!(!ctx.udfs().contains("cast_to_i64"));

Ok(())
}

#[derive(Debug)]
struct TakeUDF {
signature: Signature,
Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub fn data_types(
);
}
}

let valid_types = get_valid_types(&signature.type_signature, current_types)?;

if valid_types
Expand Down
29 changes: 23 additions & 6 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
ScalarFunctionImplementation, Signature,
};
use arrow::datatypes::DataType;
use datafusion_common::{ExprSchema, Result};
use datafusion_common::{not_impl_err, ExprSchema, Result};
use std::any::Any;
use std::fmt;
use std::fmt::Debug;
Expand Down Expand Up @@ -180,6 +180,13 @@ impl ScalarUDF {
self.inner.invoke(args)
}

/// Invoke the function without `args` but number of rows, returning the appropriate result.
///
/// See [`ScalarUDFImpl::invoke_no_args`] for more details.
pub fn invoke_no_args(&self, number_rows: usize) -> Result<ColumnarValue> {
self.inner.invoke_no_args(number_rows)
}

/// Returns a `ScalarFunctionImplementation` that can invoke the function
/// during execution
pub fn fun(&self) -> ScalarFunctionImplementation {
Expand Down Expand Up @@ -322,10 +329,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// The function will be invoked passed with the slice of [`ColumnarValue`]
/// (either scalar or array).
///
/// # Zero Argument Functions
/// If the function has zero parameters (e.g. `now()`) it will be passed a
/// single element slice which is a a null array to indicate the batch's row
/// count (so the function can know the resulting array size).
/// If the function does not take any arguments, please use [invoke_no_args]
/// instead and return [not_impl_err] for this function.
///
///
/// # Performance
///
Expand All @@ -335,7 +341,18 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
///
/// [`ColumnarValue::values_to_arrays`] can be used to convert the arguments
/// to arrays, which will likely be simpler code, but be slower.
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue>;
///
/// [invoke_no_args]: ScalarUDFImpl::invoke_no_args
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue>;

/// Invoke the function without `args`, instead the number of rows are provided,
/// returning the appropriate result.
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
not_impl_err!(
"Function {} does not implement invoke_no_args but called",
self.name()
)
}

/// Returns any aliases (alternate names) for this function.
///
Expand Down
4 changes: 4 additions & 0 deletions datafusion/functions-array/src/make_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ impl ScalarUDFImpl for MakeArray {
make_scalar_function(make_array_inner)(args)
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
make_scalar_function(make_array_inner)(&[])
}

fn aliases(&self) -> &[String] {
&self.aliases
}
Expand Down
18 changes: 9 additions & 9 deletions datafusion/functions/src/math/pi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
// under the License.

use std::any::Any;
use std::sync::Arc;

use arrow::array::Float64Array;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Float64;

use datafusion_common::{exec_err, Result};
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility};
use datafusion_expr::{ScalarUDFImpl, Signature};

Expand Down Expand Up @@ -62,12 +60,14 @@ impl ScalarUDFImpl for PiFunc {
Ok(Float64)
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
if !matches!(&args[0], ColumnarValue::Array(_)) {
return exec_err!("Expect pi function to take no param");
}
let array = Float64Array::from_value(std::f64::consts::PI, 1);
Ok(ColumnarValue::Array(Arc::new(array)))
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think signature check is enough, so just ignore args

not_impl_err!("{} function does not accept arguments", self.name())
}

fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some(
std::f64::consts::PI,
))))
}

fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> {
Expand Down
Loading