diff --git a/src/udf/java/src/main/java/com/risingwave/functions/FunctionWrapper.java b/src/udf/java/src/main/java/com/risingwave/functions/FunctionWrapper.java index 687411ec6b0f0..5577076829955 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/FunctionWrapper.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/FunctionWrapper.java @@ -11,22 +11,40 @@ import java.util.Iterator; import java.util.function.Function; +/** + * Base class for a batch-processing user-defined function. + */ abstract class UserDefinedFunctionBatch { protected Schema inputSchema; protected Schema outputSchema; protected BufferAllocator allocator; - public Schema getInputSchema() { + /** + * Get the input schema of the function. + */ + Schema getInputSchema() { return inputSchema; } - public Schema getOutputSchema() { + /** + * Get the output schema of the function. + */ + Schema getOutputSchema() { return outputSchema; } + /** + * Evaluate the function by processing a batch of input data. + * + * @param batch the input data batch to process + * @return an iterator over the output data batches + */ abstract Iterator evalBatch(VectorSchemaRoot batch); } +/** + * Batch-processing wrapper over a user-defined scalar function. + */ class ScalarFunctionBatch extends UserDefinedFunctionBatch { ScalarFunction function; Method method; @@ -66,6 +84,9 @@ Iterator evalBatch(VectorSchemaRoot batch) { } +/** + * Batch-processing wrapper over a user-defined table function. + */ class TableFunctionBatch extends UserDefinedFunctionBatch { TableFunction function; Method method; @@ -131,8 +152,14 @@ Iterator evalBatch(VectorSchemaRoot batch) { } } +/** + * Utility class for reflection. + */ class Reflection { - static Method getEvalMethod(Object obj) { + /** + * Get the method named eval. + */ + static Method getEvalMethod(UserDefinedFunction obj) { var methods = new ArrayList(); for (Method method : obj.getClass().getDeclaredMethods()) { if (method.getName().equals("eval")) { diff --git a/src/udf/java/src/main/java/com/risingwave/functions/ScalarFunction.java b/src/udf/java/src/main/java/com/risingwave/functions/ScalarFunction.java index c06bd8743bd4a..69b1f24859f8f 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/ScalarFunction.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/ScalarFunction.java @@ -1,4 +1,46 @@ package com.risingwave.functions; +/** + * Base class for a user-defined scalar function. A user-defined scalar function + * maps zero, one, or multiple scalar values to a new scalar value. + * + *

+ * The behavior of a {@link ScalarFunction} can be defined by implementing a + * custom evaluation method. An evaluation method must be declared publicly and + * named eval. Multiple overloaded methods named eval + * are not supported yet. + * + *

+ * By default, input and output data types are automatically extracted using + * reflection. + * + *

+ * The following examples show how to specify a scalar function: + * + *

+ * {@code
+ * // a function that accepts two INT arguments and computes a sum
+ * class SumFunction extends ScalarFunction {
+ *     public Integer eval(Integer a, Integer b) {
+ *         return a + b;
+ *     }
+ * }
+ * 
+ * // a function that returns a struct type
+ * class StructFunction extends ScalarFunction {
+ *     public static class KeyValue {
+ *         public String key;
+ *         public int value;
+ *     }
+ * 
+ *     public KeyValue eval(int a) {
+ *         KeyValue kv = new KeyValue();
+ *         kv.key = a.toString();
+ *         kv.value = a;
+ *         return kv;
+ *     }
+ * }
+ * }
+ */ public abstract class ScalarFunction extends UserDefinedFunction { } diff --git a/src/udf/java/src/main/java/com/risingwave/functions/TableFunction.java b/src/udf/java/src/main/java/com/risingwave/functions/TableFunction.java index 8abd714d215ca..a399201d0c261 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/TableFunction.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/TableFunction.java @@ -3,24 +3,77 @@ import java.util.ArrayList; import java.util.List; +/** + * Base class for a user-defined table function. A user-defined table function + * maps zero, one, or multiple scalar values to zero, one, or multiple rows (or + * structured types). If an output record consists of only one field, the + * structured record can be omitted, and a scalar value can be emitted that will + * be implicitly wrapped into a row by the runtime. + * + *

+ * The behavior of a {@link TableFunction} can be defined by implementing a + * custom evaluation method. An evaluation method must be declared publicly, not + * static, and named eval. Multiple overloaded methods named + * eval are not supported yet. + * + *

+ * By default, input and output data types are automatically extracted using + * reflection. This includes the generic argument {@code T} of the class for + * determining an output data type. Input arguments are derived from one or more + * {@code eval()} methods. + * + *

+ * The following examples show how to specify a table function: + * + *

+ * {@code
+ * // a function that accepts an INT arguments and emits the range from 0 to the
+ * // given number.
+ * class Series extends TableFunction {
+ *     public void eval(int x) {
+ *         for (int i = 0; i < n; i++) {
+ *             collect(i);
+ *         }
+ *     }
+ * }
+ * 
+ * // a function that accepts an String arguments and emits the words of the
+ * // given string.
+ * class Split extends TableFunction {
+ *     public static class Row {
+ *         public String word;
+ *         public int length;
+ *     }
+ * 
+ *     public void eval(String str) {
+ *         for (var s : str.split(" ")) {
+ *             Row row = new Row();
+ *             row.word = s;
+ *             row.length = s.length();
+ *             collect(row);
+ *         }
+ *     }
+ * }
+ * }
+ */ public abstract class TableFunction extends UserDefinedFunction { - // Collector used to emit rows. + /** Collector used to emit rows. */ private transient List rows = new ArrayList<>(); - // Takes all emitted rows. - public final Object[] take() { + /** Takes all emitted rows. */ + final Object[] take() { var result = this.rows.toArray(); this.rows.clear(); return result; } - // Returns the number of emitted rows. - public final int size() { + /** Returns the number of emitted rows. */ + final int size() { return this.rows.size(); } - // Emits an (implicit or explicit) output row. + /** Emits an output row. */ protected final void collect(T row) { this.rows.add(row); } diff --git a/src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java b/src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java index e4c7cfbe242ed..40e5ec1e3fc85 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java @@ -16,7 +16,10 @@ import java.util.stream.Collectors; class TypeUtils { - public static Field stringToField(String typeStr) { + /** + * Convert a string to an Arrow type. + */ + static Field stringToField(String typeStr) { typeStr = typeStr.toUpperCase(); if (typeStr.equals("BOOLEAN") || typeStr.equals("BOOL")) { return Field.nullable("", new ArrowType.Bool()); @@ -61,7 +64,14 @@ public static Field stringToField(String typeStr) { } } - public static Field classToField(Class param, String name) { + /** + * Convert a Java class to an Arrow type. + * + * @param param The Java class. + * @param name The name of the field. + * @return The Arrow type. + */ + static Field classToField(Class param, String name) { if (param == Boolean.class || param == boolean.class) { return Field.nullable(name, new ArrowType.Bool()); } else if (param == Short.class || param == short.class) { @@ -95,7 +105,10 @@ public static Field classToField(Class param, String name) { } } - public static Schema methodToInputSchema(Method method) { + /** + * Get the input schema from a Java method. + */ + static Schema methodToInputSchema(Method method) { var fields = new ArrayList(); for (var param : method.getParameters()) { fields.add(classToField(param.getType(), param.getName())); @@ -103,12 +116,18 @@ public static Schema methodToInputSchema(Method method) { return new Schema(fields); } - public static Schema methodToOutputSchema(Method method) { + /** + * Get the output schema of a scalar function from a Java method. + */ + static Schema methodToOutputSchema(Method method) { var type = method.getReturnType(); return new Schema(Arrays.asList(classToField(type, ""))); } - public static Schema tableFunctionToOutputSchema(Class type) { + /** + * Get the output schema of a table function from a Java class. + */ + static Schema tableFunctionToOutputSchema(Class type) { var parameterizedType = (ParameterizedType) type.getGenericSuperclass(); var typeArguments = parameterizedType.getActualTypeArguments(); type = (Class) typeArguments[0]; @@ -117,12 +136,18 @@ public static Schema tableFunctionToOutputSchema(Class type) { return new Schema(Arrays.asList(row_index, classToField(type, ""))); } - public static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) { + /** + * Create an Arrow vector from an array of values. + */ + static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) { var vector = field.createVector(allocator); fillVector(vector, values); return vector; } + /** + * Fill an Arrow vector with an array of values. + */ static void fillVector(FieldVector fieldVector, Object[] values) { if (fieldVector instanceof SmallIntVector) { var vector = (SmallIntVector) fieldVector; @@ -218,7 +243,10 @@ static void fillVector(FieldVector fieldVector, Object[] values) { fieldVector.setValueCount(values.length); } - // Returns a function that converts the object to the correct type. + /** + * Return a function that converts the object get from input array to the + * correct type. + */ static Function processFunc(Field field) { if (field.getType() instanceof ArrowType.Utf8) { // object is org.apache.arrow.vector.util.Text diff --git a/src/udf/java/src/main/java/com/risingwave/functions/UdfServer.java b/src/udf/java/src/main/java/com/risingwave/functions/UdfServer.java index 061333479c1e7..0ad1024c1b3fd 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/UdfServer.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/UdfServer.java @@ -16,6 +16,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +/** + * A server that exposes user-defined functions over Apache Arrow Flight. + */ public class UdfServer implements AutoCloseable { private FlightServer server; @@ -32,24 +35,44 @@ public UdfServer(String host, int port) { this.producer).build(); } + /** + * Add a user-defined function to the server. + * + * @param name the name of the function + * @param udf the function to add + * @throws IllegalArgumentException if a function with the same name already + * exists + */ public void addFunction(String name, UserDefinedFunction udf) throws IllegalArgumentException { logger.info("added function: " + name); this.producer.addFunction(name, udf); } + /** + * Start the server. + */ public void start() throws IOException { this.server.start(); logger.info("listening on " + this.server.getLocation().toSocketAddress()); } + /** + * Get the port the server is listening on. + */ public int getPort() { return this.server.getPort(); } + /** + * Wait for the server to terminate. + */ public void awaitTermination() throws InterruptedException { this.server.awaitTermination(); } + /** + * Close the server. + */ public void close() throws InterruptedException { this.server.close(); } diff --git a/src/udf/java/src/main/java/com/risingwave/functions/UserDefinedFunction.java b/src/udf/java/src/main/java/com/risingwave/functions/UserDefinedFunction.java index dab835d681274..241b6b459935a 100644 --- a/src/udf/java/src/main/java/com/risingwave/functions/UserDefinedFunction.java +++ b/src/udf/java/src/main/java/com/risingwave/functions/UserDefinedFunction.java @@ -1,4 +1,10 @@ package com.risingwave.functions; +/** + * Base class for all user-defined functions. + * + * @see ScalarFunction + * @see TableFunction + */ public abstract class UserDefinedFunction { }