diff --git a/e2e_test/udf/python.slt b/e2e_test/udf/python.slt index aa9ed2f5fa23..71654ac8ad64 100644 --- a/e2e_test/udf/python.slt +++ b/e2e_test/udf/python.slt @@ -24,14 +24,18 @@ create function gcd(int, int, int) returns int language python as gcd3 using lin statement error exists create function gcd(int, int) returns int language python as gcd using link 'http://localhost:8815'; +# Create a function that returns multiple columns. +statement ok +create function extract_tcp_info(bytea) returns struct +language python as extract_tcp_info using link 'http://localhost:8815'; + # Create a table function. statement ok create function series(int) returns table (x int) language python as series using link 'http://localhost:8815'; # Create a table function that returns multiple columns. statement ok -create function extract_tcp_info(bytea) returns table (src_ip varchar, dst_ip varchar, src_port smallint, dst_port smallint) -language python as extract_tcp_info using link 'http://localhost:8815'; +create function series2(int) returns table (x int, y varchar) language python as series2 using link 'http://localhost:8815'; query I select int_42(); @@ -57,6 +61,20 @@ select series(5); 3 4 +query IT +select * from series2(3); +---- +0 #0 +1 #1 +2 #2 + +query T +select series2(3); +---- +(0,#0) +(1,#1) +(2,#2) + # TODO: support argument implicit cast for UDF # e.g. extract_tcp_info(E'\\x45'); @@ -65,11 +83,6 @@ select extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d97 ---- (192.168.0.14,192.168.0.1,861,8374) -query TTII -select * from extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: bytea); ----- -192.168.0.14 192.168.0.1 861 8374 - query TTII select (extract_tcp_info(E'\\x45000034a8a8400040065b8ac0a8000ec0a80001035d20b6d971b900000000080020200493310000020405b4' :: BYTEA)).*; ---- diff --git a/e2e_test/udf/test.py b/e2e_test/udf/test.py index d6e72e4e9c99..3d9cba52d36b 100644 --- a/e2e_test/udf/test.py +++ b/e2e_test/udf/test.py @@ -24,19 +24,25 @@ def gcd3(x: int, y: int, z: int) -> int: return gcd(gcd(x, y), z) +@udf(input_types=['BINARY'], result_type='STRUCT') +def extract_tcp_info(tcp_packet: bytes): + src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20]) + src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24]) + src_addr = socket.inet_ntoa(src_addr) + dst_addr = socket.inet_ntoa(dst_addr) + return src_addr, dst_addr, src_port, dst_port + + @udtf(input_types='INT', result_types='INT') def series(n: int) -> Iterator[int]: for i in range(n): yield i -@udtf(input_types=['BINARY'], result_types=['VARCHAR', 'VARCHAR', 'SMALLINT', 'SMALLINT']) -def extract_tcp_info(tcp_packet: bytes) -> Iterator: - src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20]) - src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24]) - src_addr = socket.inet_ntoa(src_addr) - dst_addr = socket.inet_ntoa(dst_addr) - yield src_addr, dst_addr, src_port, dst_port +@udtf(input_types='INT', result_types=['INT', 'VARCHAR']) +def series2(n: int) -> Iterator[tuple[int, str]]: + for i in range(n): + yield i, f'#{i}' if __name__ == '__main__': @@ -45,5 +51,6 @@ def extract_tcp_info(tcp_packet: bytes) -> Iterator: server.add_function(gcd) server.add_function(gcd3) server.add_function(series) + server.add_function(series2) server.add_function(extract_tcp_info) server.serve() diff --git a/src/udf/python/example.py b/src/udf/python/example.py index 07884db09054..9b57315b6880 100644 --- a/src/udf/python/example.py +++ b/src/udf/python/example.py @@ -34,13 +34,13 @@ def series2(n: int) -> Iterator[tuple[int, str]]: yield i, str(i) -@udtf(input_types=['BINARY'], result_types=['VARCHAR', 'VARCHAR', 'SMALLINT', 'SMALLINT']) -def extract_tcp_info(tcp_packet: bytes) -> Iterator: +@udf(input_types=['BINARY'], result_type='STRUCT') +def extract_tcp_info(tcp_packet: bytes): src_addr, dst_addr = struct.unpack('!4s4s', tcp_packet[12:20]) src_port, dst_port = struct.unpack('!HH', tcp_packet[20:24]) src_addr = socket.inet_ntoa(src_addr) dst_addr = socket.inet_ntoa(dst_addr) - yield src_addr, dst_addr, src_port, dst_port + return src_addr, dst_addr, src_port, dst_port if __name__ == '__main__': diff --git a/src/udf/python/risingwave/udf.py b/src/udf/python/risingwave/udf.py index cb73b5365228..c2d617d1068c 100644 --- a/src/udf/python/risingwave/udf.py +++ b/src/udf/python/risingwave/udf.py @@ -266,5 +266,15 @@ def _string_to_data_type(type_str: str): return pa.string() case 'BINARY' | 'VARBINARY': return pa.binary() - case _: - raise ValueError(f'Unsupported type: {type_str}') + + # extract 'STRUCT' + if type_str.startswith('STRUCT'): + type_str = type_str[6:].strip('<>') + fields = [] + for field in type_str.split(','): + field = field.strip() + name, type_str = field.split(' ') + fields.append(pa.field(name, _string_to_data_type(type_str))) + return pa.struct(fields) + + raise ValueError(f'Unsupported type: {type_str}')