-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ScalarUDF: Remove supports_zero_argument
and avoid creating null array for empty args
#10193
Changes from all commits
eabfe68
7b529d9
03ec8b5
36e685e
7b04c0b
7c10382
864d197
5b51fb7
88d2a33
7c81776
bd4c65b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,7 @@ | |
// under the License. | ||
|
||
use arrow::compute::kernels::numeric::add; | ||
use arrow_array::{ | ||
Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, UInt8Array, | ||
}; | ||
use arrow_schema::DataType::Float64; | ||
use arrow_array::{ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch}; | ||
use arrow_schema::{DataType, Field, Schema}; | ||
use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; | ||
use datafusion::prelude::*; | ||
|
@@ -36,9 +33,7 @@ use datafusion_expr::{ | |
create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, ExprSchemable, | ||
LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, | ||
}; | ||
use rand::{thread_rng, Rng}; | ||
use std::any::Any; | ||
use std::iter; | ||
use std::sync::Arc; | ||
|
||
/// test that casting happens on udfs. | ||
|
@@ -168,6 +163,48 @@ async fn scalar_udf() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
struct Simple0ArgsScalarUDF { | ||
name: String, | ||
signature: Signature, | ||
return_type: DataType, | ||
} | ||
|
||
impl std::fmt::Debug for Simple0ArgsScalarUDF { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
f.debug_struct("ScalarUDF") | ||
.field("name", &self.name) | ||
.field("signature", &self.signature) | ||
.field("fun", &"<FUNC>") | ||
.finish() | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for Simple0ArgsScalarUDF { | ||
fn as_any(&self) -> &dyn Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
&self.name | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(self.return_type.clone()) | ||
} | ||
|
||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
not_impl_err!("{} function does not accept arguments", self.name()) | ||
} | ||
|
||
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> { | ||
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100)))) | ||
} | ||
} | ||
|
||
#[tokio::test] | ||
async fn scalar_udf_zero_params() -> Result<()> { | ||
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); | ||
|
@@ -179,20 +216,14 @@ async fn scalar_udf_zero_params() -> Result<()> { | |
let ctx = SessionContext::new(); | ||
|
||
ctx.register_batch("t", batch)?; | ||
// create function just returns 100 regardless of inp | ||
let myfunc = Arc::new(|_args: &[ColumnarValue]| { | ||
Ok(ColumnarValue::Array( | ||
Arc::new((0..1).map(|_| 100).collect::<Int32Array>()) as ArrayRef, | ||
)) | ||
}); | ||
|
||
ctx.register_udf(create_udf( | ||
"get_100", | ||
vec![], | ||
Arc::new(DataType::Int32), | ||
Volatility::Immutable, | ||
myfunc, | ||
)); | ||
let get_100_udf = Simple0ArgsScalarUDF { | ||
name: "get_100".to_string(), | ||
signature: Signature::exact(vec![], Volatility::Immutable), | ||
return_type: DataType::Int32, | ||
}; | ||
|
||
ctx.register_udf(ScalarUDF::from(get_100_udf)); | ||
|
||
let result = plan_and_collect(&ctx, "select get_100() a from t").await?; | ||
let expected = [ | ||
|
@@ -403,123 +434,6 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
pub struct RandomUDF { | ||
signature: Signature, | ||
} | ||
|
||
impl RandomUDF { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::any(0, Volatility::Volatile), | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for RandomUDF { | ||
fn as_any(&self) -> &dyn std::any::Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"random_udf" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(Float64) | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
let len: usize = match &args[0] { | ||
// This udf is always invoked with zero argument so its argument | ||
// is a null array indicating the batch size. | ||
ColumnarValue::Array(array) if array.data_type().is_null() => array.len(), | ||
_ => { | ||
return Err(datafusion::error::DataFusionError::Internal( | ||
"Invalid argument type".to_string(), | ||
)) | ||
} | ||
}; | ||
let mut rng = thread_rng(); | ||
let values = iter::repeat_with(|| rng.gen_range(0.1..1.0)).take(len); | ||
let array = Float64Array::from_iter_values(values); | ||
Ok(ColumnarValue::Array(Arc::new(array))) | ||
} | ||
} | ||
|
||
/// Ensure that a user defined function with zero argument will be invoked | ||
/// with a null array indicating the batch size. | ||
#[tokio::test] | ||
async fn test_user_defined_functions_zero_argument() -> Result<()> { | ||
let ctx = SessionContext::new(); | ||
|
||
let schema = Arc::new(Schema::new(vec![Field::new( | ||
"index", | ||
DataType::UInt8, | ||
false, | ||
)])); | ||
|
||
let batch = RecordBatch::try_new( | ||
schema, | ||
vec![Arc::new(UInt8Array::from_iter_values([1, 2, 3]))], | ||
)?; | ||
|
||
ctx.register_batch("data_table", batch)?; | ||
|
||
let random_normal_udf = ScalarUDF::from(RandomUDF::new()); | ||
ctx.register_udf(random_normal_udf); | ||
|
||
let result = plan_and_collect( | ||
&ctx, | ||
"SELECT random_udf() AS random_udf, random() AS native_random FROM data_table", | ||
) | ||
.await?; | ||
|
||
assert_eq!(result.len(), 1); | ||
let batch = &result[0]; | ||
let random_udf = batch | ||
.column(0) | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.unwrap(); | ||
let native_random = batch | ||
.column(1) | ||
.as_any() | ||
.downcast_ref::<Float64Array>() | ||
.unwrap(); | ||
|
||
assert_eq!(random_udf.len(), native_random.len()); | ||
|
||
let mut previous = -1.0; | ||
for i in 0..random_udf.len() { | ||
assert!(random_udf.value(i) >= 0.0 && random_udf.value(i) < 1.0); | ||
assert!(random_udf.value(i) != previous); | ||
previous = random_udf.value(i); | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn deregister_udf() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is just moved , the test remains |
||
let random_normal_udf = ScalarUDF::from(RandomUDF::new()); | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_udf(random_normal_udf.clone()); | ||
|
||
assert!(ctx.udfs().contains("random_udf")); | ||
|
||
ctx.deregister_udf("random_udf"); | ||
|
||
assert!(!ctx.udfs().contains("random_udf")); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct CastToI64UDF { | ||
signature: Signature, | ||
|
@@ -615,6 +529,22 @@ async fn test_user_defined_functions_cast_to_i64() -> Result<()> { | |
Ok(()) | ||
} | ||
|
||
#[tokio::test] | ||
async fn deregister_udf() -> Result<()> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
let cast2i64 = ScalarUDF::from(CastToI64UDF::new()); | ||
let ctx = SessionContext::new(); | ||
|
||
ctx.register_udf(cast2i64.clone()); | ||
|
||
assert!(ctx.udfs().contains("cast_to_i64")); | ||
|
||
ctx.deregister_udf("cast_to_i64"); | ||
|
||
assert!(!ctx.udfs().contains("cast_to_i64")); | ||
|
||
Ok(()) | ||
} | ||
|
||
#[derive(Debug)] | ||
struct TakeUDF { | ||
signature: Signature, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,11 @@ | |
// under the License. | ||
|
||
use std::any::Any; | ||
use std::sync::Arc; | ||
|
||
use arrow::array::Float64Array; | ||
use arrow::datatypes::DataType; | ||
use arrow::datatypes::DataType::Float64; | ||
|
||
use datafusion_common::{exec_err, Result}; | ||
use datafusion_common::{not_impl_err, Result, ScalarValue}; | ||
use datafusion_expr::{ColumnarValue, FuncMonotonicity, Volatility}; | ||
use datafusion_expr::{ScalarUDFImpl, Signature}; | ||
|
||
|
@@ -62,12 +60,14 @@ impl ScalarUDFImpl for PiFunc { | |
Ok(Float64) | ||
} | ||
|
||
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
if !matches!(&args[0], ColumnarValue::Array(_)) { | ||
return exec_err!("Expect pi function to take no param"); | ||
} | ||
let array = Float64Array::from_value(std::f64::consts::PI, 1); | ||
Ok(ColumnarValue::Array(Arc::new(array))) | ||
fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think signature check is enough, so just ignore args |
||
not_impl_err!("{} function does not accept arguments", self.name()) | ||
} | ||
|
||
fn invoke_no_args(&self, _number_rows: usize) -> Result<ColumnarValue> { | ||
Ok(ColumnarValue::Scalar(ScalarValue::Float64(Some( | ||
std::f64::consts::PI, | ||
)))) | ||
} | ||
|
||
fn monotonicity(&self) -> Result<Option<FuncMonotonicity>> { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have
RandomFunc
, they are the same so remove it.