Skip to content

Commit

Permalink
Updating documentation per PR review
Browse files Browse the repository at this point in the history
  • Loading branch information
timsaucer committed Sep 29, 2024
1 parent 3dfa330 commit 42092f7
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 70 deletions.
17 changes: 7 additions & 10 deletions docs/source/user-guide/common-operations/udf-and-udfa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
.. specific language governing permissions and limitations
.. under the License.
User Defined Functions
User-Defined Functions
======================

DataFusion provides powerful expressions and functions, reducing the need for custom Python
Expand All @@ -24,15 +24,15 @@ functions. However you can still incorporate your own functions, i.e. User-Defin
Scalar Functions
----------------

When writing a user defined function that can operate on a row by row basis, these are called Scalar
When writing a user-defined function that can operate on a row by row basis, these are called Scalar
Functions. You can define your own scalar function by calling
:py:func:`~datafusion.udf.ScalarUDF.udf` .

The basic definition of a scalar UDF is a python function that takes one or more
`pyarrow <https://arrow.apache.org/docs/python/index.html>`_ arrays and returns a single array as
output. DataFusion scalar UDFs operate on an entire batch of record at a time, though the evaluation
of those records should be on a row by row basis. In the following example, we compute if the input
array contains null values.
output. DataFusion scalar UDFs operate on an entire batch of records at a time, though the
evaluation of those records should be on a row by row basis. In the following example, we compute
if the input array contains null values.

.. ipython:: python
Expand Down Expand Up @@ -76,10 +76,7 @@ converting to Python objects to do the evaluation.
from datafusion import udf, col
def is_null(array: pyarrow.Array) -> pyarrow.Array:
results = []
for value in array:
results.append(value.as_py() == None)
return pyarrow.array(results)
return pyarrow.array([value.as_py() is None for value in array])
is_null_arr = udf(is_null, [pyarrow.int64()], pyarrow.bool_(), 'stable')
Expand Down Expand Up @@ -169,7 +166,7 @@ There are three methods of evaluation of UDWFs.

Which methods you implement are based upon which of these options are set.

.. list-table:: Title
.. list-table::
:header-rows: 1

* - ``uses_window_frame``
Expand Down
2 changes: 1 addition & 1 deletion examples/python-udwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def supports_bounded_execution(self) -> bool:
return True

def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
# Ovrerride the default range of current row since uses_window_frame is False
# Override the default range of current row since uses_window_frame is False
# So for the purpose of this test we just smooth from the previous row to
# current.
if idx == 0:
Expand Down
2 changes: 1 addition & 1 deletion python/datafusion/tests/test_udwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def supports_bounded_execution(self) -> bool:
return True

def get_range(self, idx: int, num_rows: int) -> tuple[int, int]:
# Ovrerride the default range of current row since uses_window_frame is False
# Override the default range of current row since uses_window_frame is False
# So for the purpose of this test we just smooth from the previous row to
# current.
if idx == 0:
Expand Down
124 changes: 66 additions & 58 deletions python/datafusion/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

"""Provides the user defined functions for evaluation of dataframes."""
"""Provides the user-defined functions for evaluation of dataframes."""

from __future__ import annotations

Expand Down Expand Up @@ -76,7 +76,7 @@ def __str__(self):


class ScalarUDF:
"""Class for performing scalar user defined functions (UDF).
"""Class for performing scalar user-defined functions (UDF).
Scalar UDFs operate on a row by row basis. See also :py:class:`AggregateUDF` for
operating on a group of rows.
Expand All @@ -90,7 +90,7 @@ def __init__(
return_type: _R,
volatility: Volatility | str,
) -> None:
"""Instantiate a scalar user defined function (UDF).
"""Instantiate a scalar user-defined function (UDF).
See helper method :py:func:`udf` for argument details.
"""
Expand All @@ -115,7 +115,7 @@ def udf(
volatility: Volatility | str,
name: str | None = None,
) -> ScalarUDF:
"""Create a new User Defined Function.
"""Create a new User-Defined Function.
Args:
func: A callable python function.
Expand All @@ -127,7 +127,7 @@ def udf(
name: A descriptive name for the function.
Returns:
A user defined aggregate function, which can be used in either data
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
"""
if not callable(func):
Expand Down Expand Up @@ -172,7 +172,7 @@ def evaluate(self) -> pyarrow.Scalar:


class AggregateUDF:
"""Class for performing scalar user defined functions (UDF).
"""Class for performing scalar user-defined functions (UDF).
Aggregate UDFs operate on a group of rows and return a single value. See
also :py:class:`ScalarUDF` for operating on a row by row basis.
Expand All @@ -187,7 +187,7 @@ def __init__(
state_type: list[pyarrow.DataType],
volatility: Volatility | str,
) -> None:
"""Instantiate a user defined aggregate function (UDAF).
"""Instantiate a user-defined aggregate function (UDAF).
See :py:func:`udaf` for a convenience function and argument
descriptions.
Expand All @@ -214,7 +214,7 @@ def udaf(
volatility: Volatility | str,
name: str | None = None,
) -> AggregateUDF:
"""Create a new User Defined Aggregate Function.
"""Create a new User-Defined Aggregate Function.
The accumulator function must be callable and implement :py:class:`Accumulator`.
Expand All @@ -227,7 +227,7 @@ def udaf(
name: A descriptive name for the function.
Returns:
A user defined aggregate function, which can be used in either data
A user-defined aggregate function, which can be used in either data
aggregation or window function calls.
"""
if not issubclass(accum, Accumulator):
Expand All @@ -249,16 +249,21 @@ def udaf(


class WindowEvaluator(metaclass=ABCMeta):
"""Evaluator class for user defined window functions (UDWF).
"""Evaluator class for user-defined window functions (UDWF).
It is up to the user to decide which evaluate function is appropriate.
|``uses_window_frame``|``supports_bounded_execution``|``include_rank``|function_to_implement|
|---|---|----|----|
|False (default) |False (default) |False (default) | ``evaluate_all`` |
|False |True |False | ``evaluate`` |
|False |True/False |True | ``evaluate_all_with_rank`` |
|True |True/False |True/False | ``evaluate`` |
+------------------------+--------------------------------+------------------+---------------------------+
| ``uses_window_frame`` | ``supports_bounded_execution`` | ``include_rank`` | function_to_implement |
+========================+================================+==================+===========================+
| False (default) | False (default) | False (default) | ``evaluate_all`` |
+------------------------+--------------------------------+------------------+---------------------------+
| False | True | False | ``evaluate`` |
+------------------------+--------------------------------+------------------+---------------------------+
| False | True/False | True | ``evaluate_all_with_rank``|
+------------------------+--------------------------------+------------------+---------------------------+
| True | True/False | True/False | ``evaluate`` |
+------------------------+--------------------------------+------------------+---------------------------+
""" # noqa: W505

def memoize(self) -> None:
Expand Down Expand Up @@ -299,41 +304,43 @@ def is_causal(self) -> bool:
def evaluate_all(self, values: list[pyarrow.Array], num_rows: int) -> pyarrow.Array:
"""Evaluate a window function on an entire input partition.
This function is called once per input *partition* for window
functions that *do not use* values from the window frame,
such as `ROW_NUMBER`, `RANK`, `DENSE_RANK`, `PERCENT_RANK`,
`CUME_DIST`, `LEAD`, `LAG`).
This function is called once per input *partition* for window functions that
*do not use* values from the window frame, such as
:py:func:`~datafusion.functions.row_number`, :py:func:`~datafusion.functions.rank`,
:py:func:`~datafusion.functions.dense_rank`, :py:func:`~datafusion.functions.percent_rank`,
:py:func:`~datafusion.functions.cume_dist`, :py:func:`~datafusion.functions.lead`,
and :py:func:`~datafusion.functions.lag`.
It produces the result of all rows in a single pass. It
expects to receive the entire partition as the `value` and
expects to receive the entire partition as the ``value`` and
must produce an output column with one output row for every
input row.
`num_rows` is required to correctly compute the output in case
`values.len() == 0`
``num_rows`` is required to correctly compute the output in case
``len(values) == 0``
Implementing this function is an optimization: certain window
Implementing this function is an optimization. Certain window
functions are not affected by the window frame definition or
the query doesn't have a frame, and `evaluate` skips the
the query doesn't have a frame, and ``evaluate`` skips the
(costly) window frame boundary calculation and the overhead of
calling `evaluate` for each output row.
calling ``evaluate`` for each output row.
For example, the `LAG` built in window function does not use
the values of its window frame (it can be computed in one shot
on the entire partition with `Self::evaluate_all` regardless of the
window defined in the `OVER` clause)
on the entire partition with ``Self::evaluate_all`` regardless of the
window defined in the ``OVER`` clause)
```sql
lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```
.. code-block:: text
However, `avg()` computes the average in the window and thus
does use its window frame
lag(x, 1) OVER (ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```sql
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
```
"""
However, ``avg()`` computes the average in the window and thus
does use its window frame.
.. code-block:: text
avg(x) OVER (PARTITION BY y ORDER BY z ROWS BETWEEN 2 PRECEDING AND 3 FOLLOWING)
""" # noqa: W505
pass

def evaluate(
Expand Down Expand Up @@ -361,27 +368,28 @@ def evaluate_all_with_rank(
"""Called for window functions that only need the rank of a row.
Evaluate the partition evaluator against the partition using
the row ranks. For example, `RANK(col)` produces
```text
col | rank
--- + ----
A | 1
A | 1
C | 3
D | 4
D | 5
```
the row ranks. For example, ``rank(col("a"))`` produces
.. code-block:: text
a | rank
- + ----
A | 1
A | 1
C | 3
D | 4
D | 4
For this case, `num_rows` would be `5` and the
`ranks_in_partition` would be called with
```text
[
(0,1),
(2,2),
(3,4),
]
.. code-block:: text
[
(0,1),
(2,2),
(3,4),
]
The user must implement this method if ``include_rank`` returns True.
"""
Expand All @@ -405,7 +413,7 @@ def include_rank(self) -> bool:


class WindowUDF:
"""Class for performing window user defined functions (UDF).
"""Class for performing window user-defined functions (UDF).
Window UDFs operate on a partition of rows. See
also :py:class:`ScalarUDF` for operating on a row by row basis.
Expand All @@ -419,7 +427,7 @@ def __init__(
return_type: pyarrow.DataType,
volatility: Volatility | str,
) -> None:
"""Instantiate a user defined window function (UDWF).
"""Instantiate a user-defined window function (UDWF).
See :py:func:`udwf` for a convenience function and argument
descriptions.
Expand All @@ -445,7 +453,7 @@ def udwf(
volatility: Volatility | str,
name: str | None = None,
) -> WindowUDF:
"""Create a new User Defined Window Function.
"""Create a new User-Defined Window Function.
Args:
func: The python function.
Expand All @@ -455,7 +463,7 @@ def udwf(
name: A descriptive name for the function.
Returns:
A user defined window function.
A user-defined window function.
"""
if not isinstance(func, WindowEvaluator):
raise TypeError(
Expand Down

0 comments on commit 42092f7

Please sign in to comment.