Skip to content

Commit

Permalink
support returning struct type for scalar functions
Browse files Browse the repository at this point in the history
Signed-off-by: Runji Wang <wangrunji0408@163.com>
  • Loading branch information
wangrunji0408 committed Mar 21, 2023
1 parent b374e19 commit c26fcf2
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 19 deletions.
27 changes: 20 additions & 7 deletions e2e_test/udf/python.slt
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,18 @@ create function gcd(int, int, int) returns int language python as gcd3 using lin
statement error exists
create function gcd(int, int) returns int language python as gcd using link 'http://localhost:8815';

# Create a function that returns multiple columns.
statement ok
create function extract_tcp_info(bytea) returns struct<src_ip varchar, dst_ip varchar, src_port smallint, dst_port smallint>
language python as extract_tcp_info using link 'http://localhost:8815';

# Create a table function.
statement ok
create function series(int) returns table (x int) language python as series using link 'http://localhost:8815';

# Create a table function that returns multiple columns.
statement ok
create function extract_tcp_info(bytea) returns table (src_ip varchar, dst_ip varchar, src_port smallint, dst_port smallint)
language python as extract_tcp_info using link 'http://localhost:8815';
create function series2(int) returns table (x int, y varchar) language python as series2 using link 'http://localhost:8815';

query I
select int_42();
Expand All @@ -57,6 +61,20 @@ select series(5);
3
4

query IT
select * from series2(3);
----
0 #0
1 #1
2 #2

query T
select series2(3);
----
(0,#0)
(1,#1)
(2,#2)

# TODO: support argument implicit cast for UDF
# e.g. extract_tcp_info(E'\\x45');

Expand All @@ -65,11 +83,6 @@ select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d97
----
(192.168.0.14,192.168.0.1,861,8374)

query TTII
select * from extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea);
----
192.168.0.14 192.168.0.1 861 8374

query TTII
select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: BYTEA)).*;
----
Expand Down
21 changes: 14 additions & 7 deletions e2e_test/udf/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,25 @@ def gcd3(x: int, y: int, z: int) -> int:
return gcd(gcd(x, y), z)


@udf(input_types=['BINARY'], result_type='STRUCT<src_ip VARCHAR, dst_ip VARCHAR, src_port SMALLINT, dst_port SMALLINT>')
def extract_tcp_info(tcp_packet: bytes):
src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20])
src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24])
src_addr = socket.inet_ntoa(src_addr)
dst_addr = socket.inet_ntoa(dst_addr)
return src_addr, dst_addr, src_port, dst_port


@udtf(input_types='INT', result_types='INT')
def series(n: int) -> Iterator[int]:
for i in range(n):
yield i


@udtf(input_types=['BINARY'], result_types=['VARCHAR', 'VARCHAR', 'SMALLINT', 'SMALLINT'])
def extract_tcp_info(tcp_packet: bytes) -> Iterator:
src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20])
src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24])
src_addr = socket.inet_ntoa(src_addr)
dst_addr = socket.inet_ntoa(dst_addr)
yield src_addr, dst_addr, src_port, dst_port
@udtf(input_types='INT', result_types=['INT', 'VARCHAR'])
def series2(n: int) -> Iterator[tuple[int, str]]:
for i in range(n):
yield i, f'#{i}'


if __name__ == '__main__':
Expand All @@ -45,5 +51,6 @@ def extract_tcp_info(tcp_packet: bytes) -> Iterator:
server.add_function(gcd)
server.add_function(gcd3)
server.add_function(series)
server.add_function(series2)
server.add_function(extract_tcp_info)
server.serve()
6 changes: 3 additions & 3 deletions src/udf/python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ def series2(n: int) -> Iterator[tuple[int, str]]:
yield i, str(i)


@udtf(input_types=['BINARY'], result_types=['VARCHAR', 'VARCHAR', 'SMALLINT', 'SMALLINT'])
def extract_tcp_info(tcp_packet: bytes) -> Iterator:
@udf(input_types=['BINARY'], result_type='STRUCT<src_ip VARCHAR, dst_ip VARCHAR, src_port SMALLINT, dst_port SMALLINT>')
def extract_tcp_info(tcp_packet: bytes):
src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20])
src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24])
src_addr = socket.inet_ntoa(src_addr)
dst_addr = socket.inet_ntoa(dst_addr)
yield src_addr, dst_addr, src_port, dst_port
return src_addr, dst_addr, src_port, dst_port


if __name__ == '__main__':
Expand Down
14 changes: 12 additions & 2 deletions src/udf/python/risingwave/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,5 +266,15 @@ def _string_to_data_type(type_str: str):
return pa.string()
case 'BINARY' | 'VARBINARY':
return pa.binary()
case _:
raise ValueError(f'Unsupported type: {type_str}')

# extract 'STRUCT<a INT, b VARCHAR, ...>'
if type_str.startswith('STRUCT'):
type_str = type_str[6:].strip('<>')
fields = []
for field in type_str.split(','):
field = field.strip()
name, type_str = field.split(' ')
fields.append(pa.field(name, _string_to_data_type(type_str)))
return pa.struct(fields)

raise ValueError(f'Unsupported type: {type_str}')

0 comments on commit c26fcf2

Please sign in to comment.