diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index ad7f728b..e89c5715 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -29,7 +29,6 @@ WindowFrame, column, literal, - udf, ) from datafusion.expr import Window @@ -236,21 +235,6 @@ def test_unnest_without_nulls(nested_df): assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9]) -def test_udf(df): - # is_null is a pa function over arrays - is_null = udf( - lambda x: x.is_null(), - [pa.int64()], - pa.bool_(), - volatility="immutable", - ) - - df = df.select(is_null(column("a"))) - result = df.collect()[0].column(0) - - assert result == pa.array([False, False, False]) - - def test_join(): ctx = SessionContext() diff --git a/python/datafusion/tests/test_udaf.py b/python/datafusion/tests/test_udaf.py index 6f2525b0..43460809 100644 --- a/python/datafusion/tests/test_udaf.py +++ b/python/datafusion/tests/test_udaf.py @@ -21,14 +21,14 @@ import pyarrow.compute as pc import pytest -from datafusion import Accumulator, column, udaf, udf +from datafusion import Accumulator, column, udaf class Summarize(Accumulator): """Interface of a user-defined accumulation.""" - def __init__(self): - self._sum = pa.scalar(0.0) + def __init__(self, initial_value: float = 0.0): + self._sum = pa.scalar(initial_value) def state(self) -> List[pa.Scalar]: return [self._sum] @@ -97,7 +97,7 @@ def test_errors(df): df.collect() -def test_aggregate(df): +def test_udaf_aggregate(df): summarize = udaf( Summarize, pa.float64(), @@ -106,13 +106,40 @@ def test_aggregate(df): volatility="immutable", ) - df = df.aggregate([], [summarize(column("a"))]) + df1 = df.aggregate([], [summarize(column("a"))]) # execute and collect the first (and only) batch - result = df.collect()[0] + result = df1.collect()[0] assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + df2 = df.aggregate([], [summarize(column("a"))]) + + # Run a second time to ensure the state is properly reset + result = df2.collect()[0] + + assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) + + +def test_udaf_aggregate_with_arguments(df): + bias = 10.0 + + summarize = udaf( + Summarize, + pa.float64(), + pa.float64(), + [pa.float64()], + volatility="immutable", + arguments=[bias], + ) + + df1 = df.aggregate([], [summarize(column("a"))]) + + # execute and collect the first (and only) batch + result = df1.collect()[0] + + assert result.column(0) == pa.array([bias + 1.0 + 2.0 + 3.0]) + def test_group_by(df): summarize = udaf( @@ -146,20 +173,3 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 - - -def test_register_udf(ctx, df) -> None: - is_null = udf( - lambda x: x.is_null(), - [pa.float64()], - pa.bool_(), - volatility="immutable", - name="is_null", - ) - - ctx.register_udf(is_null) - - df_result = ctx.sql("select is_null(a) from test_table") - result = df_result.collect()[0].column(0) - - assert result == pa.array([False, False, False]) diff --git a/python/datafusion/tests/test_udf.py b/python/datafusion/tests/test_udf.py new file mode 100644 index 00000000..568a66db --- /dev/null +++ b/python/datafusion/tests/test_udf.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from datafusion import udf, column +import pyarrow as pa +import pytest + + +@pytest.fixture +def df(ctx): + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], + ) + return ctx.create_dataframe([[batch]], name="test_table") + + +def test_udf(df): + # is_null is a pa function over arrays + is_null = udf( + lambda x: x.is_null(), + [pa.int64()], + pa.bool_(), + volatility="immutable", + ) + + df = df.select(is_null(column("a"))) + result = df.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +def test_register_udf(ctx, df) -> None: + is_null = udf( + lambda x: x.is_null(), + [pa.float64()], + pa.bool_(), + volatility="immutable", + name="is_null", + ) + + ctx.register_udf(is_null) + + df_result = ctx.sql("select is_null(a) from test_table") + result = df_result.collect()[0].column(0) + + assert result == pa.array([False, False, False]) + + +class OverThresholdUDF: + def __init__(self, threshold: int = 0) -> None: + self.threshold = threshold + + def __call__(self, values: pa.Array) -> pa.Array: + return pa.array(v.as_py() >= self.threshold for v in values) + + +def test_udf_with_parameters(df) -> None: + udf_no_param = udf( + OverThresholdUDF(), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df1 = df.select(udf_no_param(column("a"))) + result = df1.collect()[0].column(0) + + assert result == pa.array([True, True, True]) + + udf_with_param = udf( + OverThresholdUDF(2), + pa.int64(), + pa.bool_(), + volatility="immutable", + ) + + df2 = df.select(udf_with_param(column("a"))) + result = df2.collect()[0].column(0) + + assert result == pa.array([False, True, True]) diff --git a/python/datafusion/tests/test_udwf.py b/python/datafusion/tests/test_udwf.py index 67c0979f..67966eea 100644 --- a/python/datafusion/tests/test_udwf.py +++ b/python/datafusion/tests/test_udwf.py @@ -24,7 +24,7 @@ class ExponentialSmoothDefault(WindowEvaluator): - def __init__(self, alpha: float) -> None: + def __init__(self, alpha: float = 0.8) -> None: self.alpha = alpha def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: @@ -183,7 +183,7 @@ def df(): def test_udwf_errors(df): with pytest.raises(TypeError): udwf( - NotSubclassOfWindowEvaluator(), + NotSubclassOfWindowEvaluator, pa.float64(), pa.float64(), volatility="immutable", @@ -191,38 +191,50 @@ def test_udwf_errors(df): smooth_default = udwf( - ExponentialSmoothDefault(0.9), + ExponentialSmoothDefault, + pa.float64(), + pa.float64(), + volatility="immutable", + arguments=[0.9], +) + +smooth_no_arugments = udwf( + ExponentialSmoothDefault, pa.float64(), pa.float64(), volatility="immutable", ) smooth_bounded = udwf( - ExponentialSmoothBounded(0.9), + ExponentialSmoothBounded, pa.float64(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_rank = udwf( - ExponentialSmoothRank(0.9), + ExponentialSmoothRank, pa.utf8(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_frame = udwf( - ExponentialSmoothFrame(0.9), + ExponentialSmoothFrame, pa.float64(), pa.float64(), volatility="immutable", + arguments=[0.9], ) smooth_two_col = udwf( - SmoothTwoColumn(0.9), + SmoothTwoColumn, [pa.int64(), pa.int64()], pa.float64(), volatility="immutable", + arguments=[0.9], ) data_test_udwf_functions = [ @@ -231,6 +243,11 @@ def test_udwf_errors(df): smooth_default(column("a")), [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], ), + ( + "default_udwf_no_arguments", + smooth_no_arugments(column("a")), + [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], + ), ( "default_udwf_partitioned", smooth_default(column("a")).partition_by(column("c")).build(), diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index bb7a9086..88e9a193 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -21,9 +21,9 @@ import datafusion._internal as df_internal from datafusion.expr import Expr -from typing import Callable, TYPE_CHECKING, TypeVar +from typing import Callable, TYPE_CHECKING, TypeVar, Type from abc import ABCMeta, abstractmethod -from typing import List +from typing import List, Any, Optional from enum import Enum import pyarrow @@ -86,7 +86,7 @@ def __init__( self, name: str | None, func: Callable[..., _R], - input_types: list[pyarrow.DataType], + input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: _R, volatility: Volatility | str, ) -> None: @@ -94,6 +94,8 @@ def __init__( See helper method :py:func:`udf` for argument details. """ + if isinstance(input_types, pyarrow.DataType): + input_types = [input_types] self._udf = df_internal.ScalarUDF( name, func, input_types, return_type, str(volatility) ) @@ -104,8 +106,8 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udf.__call__(*args_raw)) @staticmethod def udf( @@ -133,7 +135,10 @@ def udf( if not callable(func): raise TypeError("`func` argument must be callable") if name is None: - name = func.__qualname__.lower() + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() return ScalarUDF( name=name, func=func, @@ -167,10 +172,6 @@ def evaluate(self) -> pyarrow.Scalar: pass -if TYPE_CHECKING: - _A = TypeVar("_A", bound=(Callable[..., _R], Accumulator)) - - class AggregateUDF: """Class for performing scalar user-defined functions (UDF). @@ -181,11 +182,12 @@ class AggregateUDF: def __init__( self, name: str | None, - accumulator: _A, + accumulator: Type[Accumulator], input_types: list[pyarrow.DataType], - return_type: _R, + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, + arguments: list[Any], ) -> None: """Instantiate a user-defined aggregate function (UDAF). @@ -193,7 +195,13 @@ def __init__( descriptions. """ self._udaf = df_internal.AggregateUDF( - name, accumulator, input_types, return_type, state_type, str(volatility) + name, + accumulator, + input_types, + return_type, + state_type, + str(volatility), + arguments, ) def __call__(self, *args: Expr) -> Expr: @@ -202,16 +210,17 @@ def __call__(self, *args: Expr) -> Expr: This function is not typically called by an end user. These calls will occur during the evaluation of the dataframe. """ - args = [arg.expr for arg in args] - return Expr(self._udaf.__call__(*args)) + args_raw = [arg.expr for arg in args] + return Expr(self._udaf.__call__(*args_raw)) @staticmethod def udaf( - accum: _A, - input_types: list[pyarrow.DataType], - return_type: _R, + accum: Type[Accumulator], + input_types: pyarrow.DataType | list[pyarrow.DataType], + return_type: pyarrow.DataType, state_type: list[pyarrow.DataType], volatility: Volatility | str, + arguments: Optional[list[Any]] = None, name: str | None = None, ) -> AggregateUDF: """Create a new User-Defined Aggregate Function. @@ -224,6 +233,7 @@ def udaf( return_type: The data type of the return value. state_type: The data types of the intermediate accumulation. volatility: See :py:class:`Volatility` for allowed values. + arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: @@ -236,8 +246,9 @@ def udaf( ) if name is None: name = accum.__qualname__.lower() - if isinstance(input_types, pyarrow.lib.DataType): + if isinstance(input_types, pyarrow.DataType): input_types = [input_types] + arguments = [] if arguments is None else arguments return AggregateUDF( name=name, accumulator=accum, @@ -245,6 +256,7 @@ def udaf( return_type=return_type, state_type=state_type, volatility=volatility, + arguments=arguments, ) @@ -422,10 +434,11 @@ class WindowUDF: def __init__( self, name: str | None, - func: WindowEvaluator, + func: Type[WindowEvaluator], input_types: list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, + arguments: list[Any], ) -> None: """Instantiate a user-defined window function (UDWF). @@ -433,7 +446,7 @@ def __init__( descriptions. """ self._udwf = df_internal.WindowUDF( - name, func, input_types, return_type, str(volatility) + name, func, input_types, return_type, str(volatility), arguments ) def __call__(self, *args: Expr) -> Expr: @@ -447,10 +460,11 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( - func: WindowEvaluator, + func: Type[WindowEvaluator], input_types: pyarrow.DataType | list[pyarrow.DataType], return_type: pyarrow.DataType, volatility: Volatility | str, + arguments: Optional[list[Any]] = None, name: str | None = None, ) -> WindowUDF: """Create a new User-Defined Window Function. @@ -460,12 +474,13 @@ def udwf( input_types: The data types of the arguments to ``func``. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. + arguments: A list of arguments to pass in to the __init__ method for accum. name: A descriptive name for the function. Returns: A user-defined window function. """ - if not isinstance(func, WindowEvaluator): + if not issubclass(func, WindowEvaluator): raise TypeError( "`func` must implement the abstract base class WindowEvaluator" ) @@ -473,10 +488,12 @@ def udwf( name = func.__class__.__qualname__.lower() if isinstance(input_types, pyarrow.DataType): input_types = [input_types] + arguments = [] if arguments is None else arguments return WindowUDF( name=name, func=func, input_types=input_types, return_type=return_type, volatility=volatility, + arguments=arguments, ) diff --git a/src/udaf.rs b/src/udaf.rs index a6aa59ac..b9db47a8 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -128,11 +128,15 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction { +pub fn to_rust_accumulator( + accum: PyObject, + arguments: Vec, +) -> AccumulatorFactoryFunction { Arc::new(move |_| -> Result> { let accum = Python::with_gil(|py| { + let py_args = PyTuple::new_bound(py, arguments.iter()); accum - .call0(py) + .call1(py, py_args) .map_err(|e| DataFusionError::Execution(format!("{e}"))) })?; Ok(Box::new(RustAccumulator::new(accum))) @@ -149,7 +153,7 @@ pub struct PyAggregateUDF { #[pymethods] impl PyAggregateUDF { #[new] - #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility))] + #[pyo3(signature=(name, accumulator, input_type, return_type, state_type, volatility, arguments))] fn new( name: &str, accumulator: PyObject, @@ -157,13 +161,14 @@ impl PyAggregateUDF { return_type: PyArrowType, state_type: PyArrowType>, volatility: &str, + arguments: Vec, ) -> PyResult { let function = create_udaf( name, input_type.0, Arc::new(return_type.0), parse_volatility(volatility)?, - to_rust_accumulator(accumulator), + to_rust_accumulator(accumulator, arguments), Arc::new(state_type.0), ); Ok(Self { function }) diff --git a/src/udwf.rs b/src/udwf.rs index 31cc5e60..68ef6620 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -195,9 +195,17 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator(evaluator: PyObject) -> PartitionEvaluatorFactory { +pub fn to_rust_partition_evaluator( + evaluator: PyObject, + arguments: Vec, +) -> PartitionEvaluatorFactory { Arc::new(move || -> Result> { - let evaluator = Python::with_gil(|py| evaluator.clone_ref(py)); + let evaluator = Python::with_gil(|py| { + let py_args = PyTuple::new_bound(py, arguments.iter()); + evaluator + .call1(py, py_args) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; Ok(Box::new(RustPartitionEvaluator::new(evaluator))) }) } @@ -212,13 +220,14 @@ pub struct PyWindowUDF { #[pymethods] impl PyWindowUDF { #[new] - #[pyo3(signature=(name, evaluator, input_types, return_type, volatility))] + #[pyo3(signature=(name, evaluator, input_types, return_type, volatility, arguments))] fn new( name: &str, evaluator: PyObject, input_types: Vec>, return_type: PyArrowType, volatility: &str, + arguments: Vec, ) -> PyResult { let return_type = return_type.0; let input_types = input_types.into_iter().map(|t| t.0).collect(); @@ -228,7 +237,7 @@ impl PyWindowUDF { input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator), + to_rust_partition_evaluator(evaluator, arguments), )); Ok(Self { function }) }