Skip to content

Commit

Permalink
improve UDF docs
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Mar 16, 2023
1 parent 25a4809 commit de71c9e
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 26 deletions.
20 changes: 0 additions & 20 deletions src/udf/README.md

This file was deleted.

75 changes: 75 additions & 0 deletions src/udf/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# RisingWave Python API

This library provides a Python API for creating user-defined functions (UDF) in RisingWave.

Currently, RisingWave supports user-defined functions implemented as external functions.
Users need to define functions using the API provided by this library, and then start a Python process as a UDF server.
RisingWave calls the function remotely by accessing the UDF server at a given address.

## Installation

```sh
pip install risingwave
```

## Usage

Define functions in a Python file:

```python
# udf.py
from risingwave.udf import udf, udtf, UdfServer

# Define a scalar function
@udf(input_types=['INT', 'INT'], result_type='INT')
def gcd(x, y):
while y != 0:
(x, y) = (y, x % y)
return x

# Define a table function
@udtf(input_types='INT', result_types='INT')
def series(n):
for i in range(n):
yield i

# Start a UDF server
if __name__ == '__main__':
server = UdfServer(location="0.0.0.0:8815")
server.add_function(gcd)
server.add_function(series)
server.serve()
```

Start the UDF server:

```sh
python3 udf.py
```

To create functions in RisingWave, use the following syntax:

```sql
create function <name> ( <arg_type>[, ...] )
[ returns <ret_type> | returns table ( <column_name> <column_type> [, ...] ) ]
language python as <name_defined_in_server>
using link '<udf_server_address>';
```

- The `language` parameter must be set to `python`.
- The `as` parameter specifies the function name defined in the UDF server.
- The `link` parameter specifies the address of the UDF server.

For example:

```sql
create function gcd(int, int) returns int
language python as gcd using link 'http://localhost:8815';

create function series(int) returns table (x int)
language python as series using link 'http://localhost:8815';

select gcd(25, 15);

select * from series(10);
```
2 changes: 1 addition & 1 deletion src/udf/python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def series2(n: int) -> Iterator[tuple[int, str]]:


if __name__ == '__main__':
server = UdfServer()
server = UdfServer(location="0.0.0.0:8815")
server.add_function(random_int)
server.add_function(gcd)
server.add_function(gcd3)
Expand Down
46 changes: 41 additions & 5 deletions src/udf/python/risingwave/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,21 @@ def udf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataType
result_type: Union[str, pa.DataType],
name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]:
"""
Annotation for creating a user-defined function.
Annotation for creating a user-defined scalar function.
Parameters:
- input_types: A list of strings or Arrow data types that specifies the input data types.
- result_type: A string or an Arrow data type that specifies the return value type.
- name: An optional string specifying the function name. If not provided, the original name will be used.
Example:
```
@udf(input_types=['INT', 'INT'], result_type='INT')
def gcd(x, y):
while y != 0:
(x, y) = (y, x % y)
return x
```
"""

return lambda f: UserDefinedScalarFunctionWrapper(f, input_types, result_type, name)
Expand All @@ -130,20 +144,42 @@ def udtf(input_types: Union[List[Union[str, pa.DataType]], Union[str, pa.DataTyp
name: Optional[str] = None,) -> Union[Callable, UserDefinedFunction]:
"""
Annotation for creating a user-defined table function.
Parameters:
- input_types: A list of strings or Arrow data types that specifies the input data types.
- result_types A list of strings or Arrow data types that specifies the return value types.
- name: An optional string specifying the function name. If not provided, the original name will be used.
Example:
```
@udtf(input_types='INT', result_types='INT')
def series(n):
for i in range(n):
yield i
```
"""

return lambda f: UserDefinedTableFunctionWrapper(f, input_types, result_types, name)


class UdfServer(pa.flight.FlightServerBase):
"""
UDF server based on Apache Arrow Flight protocol.
Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight
A server that provides user-defined functions to clients.
Example:
```
server = UdfServer(location="0.0.0.0:8815")
server.add_function(my_udf)
server.serve()
```
"""
# UDF server based on Apache Arrow Flight protocol.
# Reference: https://arrow.apache.org/cookbook/py/flight.html#simple-parquet-storage-service-with-arrow-flight

_functions: Dict[str, UserDefinedFunction]

def __init__(self, location="grpc://0.0.0.0:8815", **kwargs):
super(UdfServer, self).__init__(location, **kwargs)
def __init__(self, location="0.0.0.0:8815", **kwargs):
super(UdfServer, self).__init__('grpc://' + location, **kwargs)
self._functions = {}

def get_flight_info(self, context, descriptor):
Expand Down
15 changes: 15 additions & 0 deletions src/udf/python/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from setuptools import find_packages, setup

setup(
name="risingwave",
version="0.0.1",
author="RisingWave Labs",
description="RisingWave Python API",
url="https://github.com/risingwavelabs/risingwave",
packages=find_packages(),
classifiers=[
"Programming Language :: Python",
"License :: OSI Approved :: Apache Software License"
],
python_requires=">=3.10",
)

0 comments on commit de71c9e

Please sign in to comment.