Skip to content

Commit

Permalink
add code docs
Browse files Browse the repository at this point in the history
  • Loading branch information
wangrunji0408 committed Jun 9, 2023
1 parent 5389f16 commit c39adfd
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,40 @@
import java.util.Iterator;
import java.util.function.Function;

/**
* Base class for a batch-processing user-defined function.
*/
abstract class UserDefinedFunctionBatch {
protected Schema inputSchema;
protected Schema outputSchema;
protected BufferAllocator allocator;

public Schema getInputSchema() {
/**
* Get the input schema of the function.
*/
Schema getInputSchema() {
return inputSchema;
}

public Schema getOutputSchema() {
/**
* Get the output schema of the function.
*/
Schema getOutputSchema() {
return outputSchema;
}

/**
* Evaluate the function by processing a batch of input data.
*
* @param batch the input data batch to process
* @return an iterator over the output data batches
*/
abstract Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch);
}

/**
* Batch-processing wrapper over a user-defined scalar function.
*/
class ScalarFunctionBatch extends UserDefinedFunctionBatch {
ScalarFunction function;
Method method;
Expand Down Expand Up @@ -66,6 +84,9 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {

}

/**
* Batch-processing wrapper over a user-defined table function.
*/
class TableFunctionBatch extends UserDefinedFunctionBatch {
TableFunction<?> function;
Method method;
Expand Down Expand Up @@ -131,8 +152,14 @@ Iterator<VectorSchemaRoot> evalBatch(VectorSchemaRoot batch) {
}
}

/**
* Utility class for reflection.
*/
class Reflection {
static Method getEvalMethod(Object obj) {
/**
* Get the method named <code>eval</code>.
*/
static Method getEvalMethod(UserDefinedFunction obj) {
var methods = new ArrayList<Method>();
for (Method method : obj.getClass().getDeclaredMethods()) {
if (method.getName().equals("eval")) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,46 @@
package com.risingwave.functions;

/**
* Base class for a user-defined scalar function. A user-defined scalar function
* maps zero, one, or multiple scalar values to a new scalar value.
*
* <p>
* The behavior of a {@link ScalarFunction} can be defined by implementing a
* custom evaluation method. An evaluation method must be declared publicly and
* named <code>eval</code>. Multiple overloaded methods named <code>eval</code>
* are not supported yet.
*
* <p>
* By default, input and output data types are automatically extracted using
* reflection.
*
* <p>
* The following examples show how to specify a scalar function:
*
* <pre>
* {@code
* // a function that accepts two INT arguments and computes a sum
* class SumFunction extends ScalarFunction {
* public Integer eval(Integer a, Integer b) {
* return a + b;
* }
* }
*
* // a function that returns a struct type
* class StructFunction extends ScalarFunction {
* public static class KeyValue {
* public String key;
* public int value;
* }
*
* public KeyValue eval(int a) {
* KeyValue kv = new KeyValue();
* kv.key = a.toString();
* kv.value = a;
* return kv;
* }
* }
* }</pre>
*/
public abstract class ScalarFunction extends UserDefinedFunction {
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,59 @@
import java.util.ArrayList;
import java.util.List;

/**
* Base class for a user-defined table function. A user-defined table function
* maps zero, one, or multiple scalar values to zero, one, or multiple rows (or
* structured types). If an output record consists of only one field, the
* structured record can be omitted, and a scalar value can be emitted that will
* be implicitly wrapped into a row by the runtime.
*
* <p>
* The behavior of a {@link TableFunction} can be defined by implementing a
* custom evaluation method. An evaluation method must be declared publicly, not
* static, and named <code>eval</code>. Multiple overloaded methods named
* <code>eval</code> are not supported yet.
*
* <p>
* By default, input and output data types are automatically extracted using
* reflection. This includes the generic argument {@code T} of the class for
* determining an output data type. Input arguments are derived from one or more
* {@code eval()} methods.
*
* <p>
* The following examples show how to specify a table function:
*
* <pre>
* {@code
* // a function that accepts an INT arguments and emits the range from 0 to the
* // given number.
* class Series extends TableFunction<Integer> {
* public void eval(int x) {
* for (int i = 0; i < n; i++) {
* collect(i);
* }
* }
* }
*
* // a function that accepts an String arguments and emits the words of the
* // given string.
* class Split extends TableFunction<Split.Row> {
* public static class Row {
* public String word;
* public int length;
* }
*
* public void eval(String str) {
* for (var s : str.split(" ")) {
* Row row = new Row();
* row.word = s;
* row.length = s.length();
* collect(row);
* }
* }
* }
* }</pre>
*/
public abstract class TableFunction<T> extends UserDefinedFunction {

// Collector used to emit rows.
Expand Down
42 changes: 35 additions & 7 deletions src/udf/java/src/main/java/com/risingwave/functions/TypeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import java.util.stream.Collectors;

class TypeUtils {
public static Field stringToField(String typeStr) {
/**
* Convert a string to an Arrow type.
*/
static Field stringToField(String typeStr) {
typeStr = typeStr.toUpperCase();
if (typeStr.equals("BOOLEAN") || typeStr.equals("BOOL")) {
return Field.nullable("", new ArrowType.Bool());
Expand Down Expand Up @@ -61,7 +64,14 @@ public static Field stringToField(String typeStr) {
}
}

public static Field classToField(Class<?> param, String name) {
/**
* Convert a Java class to an Arrow type.
*
* @param param The Java class.
* @param name The name of the field.
* @return The Arrow type.
*/
static Field classToField(Class<?> param, String name) {
if (param == Boolean.class || param == boolean.class) {
return Field.nullable(name, new ArrowType.Bool());
} else if (param == Short.class || param == short.class) {
Expand Down Expand Up @@ -95,20 +105,29 @@ public static Field classToField(Class<?> param, String name) {
}
}

public static Schema methodToInputSchema(Method method) {
/**
* Get the input schema from a Java method.
*/
static Schema methodToInputSchema(Method method) {
var fields = new ArrayList<Field>();
for (var param : method.getParameters()) {
fields.add(classToField(param.getType(), param.getName()));
}
return new Schema(fields);
}

public static Schema methodToOutputSchema(Method method) {
/**
* Get the output schema of a scalar function from a Java method.
*/
static Schema methodToOutputSchema(Method method) {
var type = method.getReturnType();
return new Schema(Arrays.asList(classToField(type, "")));
}

public static Schema tableFunctionToOutputSchema(Class<?> type) {
/**
* Get the output schema of a table function from a Java class.
*/
static Schema tableFunctionToOutputSchema(Class<?> type) {
var parameterizedType = (ParameterizedType) type.getGenericSuperclass();
var typeArguments = parameterizedType.getActualTypeArguments();
type = (Class<?>) typeArguments[0];
Expand All @@ -117,12 +136,18 @@ public static Schema tableFunctionToOutputSchema(Class<?> type) {
return new Schema(Arrays.asList(row_index, classToField(type, "")));
}

public static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) {
/**
* Create an Arrow vector from an array of values.
*/
static FieldVector createVector(Field field, BufferAllocator allocator, Object[] values) {
var vector = field.createVector(allocator);
fillVector(vector, values);
return vector;
}

/**
* Fill an Arrow vector with an array of values.
*/
static void fillVector(FieldVector fieldVector, Object[] values) {
if (fieldVector instanceof SmallIntVector) {
var vector = (SmallIntVector) fieldVector;
Expand Down Expand Up @@ -218,7 +243,10 @@ static void fillVector(FieldVector fieldVector, Object[] values) {
fieldVector.setValueCount(values.length);
}

// Returns a function that converts the object to the correct type.
/**
* Return a function that converts the object get from input array to the
* correct type.
*/
static Function<Object, Object> processFunc(Field field) {
if (field.getType() instanceof ArrowType.Utf8) {
// object is org.apache.arrow.vector.util.Text
Expand Down
23 changes: 23 additions & 0 deletions src/udf/java/src/main/java/com/risingwave/functions/UdfServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A server that exposes user-defined functions over Apache Arrow Flight.
*/
public class UdfServer implements AutoCloseable {

private FlightServer server;
Expand All @@ -32,24 +35,44 @@ public UdfServer(String host, int port) {
this.producer).build();
}

/**
* Add a user-defined function to the server.
*
* @param name the name of the function
* @param udf the function to add
* @throws IllegalArgumentException if a function with the same name already
* exists
*/
public void addFunction(String name, UserDefinedFunction udf) throws IllegalArgumentException {
logger.info("added function: " + name);
this.producer.addFunction(name, udf);
}

/**
* Start the server.
*/
public void start() throws IOException {
this.server.start();
logger.info("listening on " + this.server.getLocation().toSocketAddress());
}

/**
* Get the port the server is listening on.
*/
public int getPort() {
return this.server.getPort();
}

/**
* Wait for the server to terminate.
*/
public void awaitTermination() throws InterruptedException {
this.server.awaitTermination();
}

/**
* Close the server.
*/
public void close() throws InterruptedException {
this.server.close();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
package com.risingwave.functions;

/**
* Base class for all user-defined functions.
*
* @see ScalarFunction
* @see TableFunction
*/
public abstract class UserDefinedFunction {
}

0 comments on commit c39adfd

Please sign in to comment.