Skip to content

Commit

Permalink
Add Java bindings for string literal support in AST (#13072)
Browse files Browse the repository at this point in the history
Depends on #13061
Add Java bindings for string scalar support in AST
Add unit test for string comparison - column vs column, column vs literal.

Authors:
  - Karthikeyan (https://github.com/karthikeyann)

Approvers:
  - Jason Lowe (https://github.com/jlowe)
  - MithunR (https://github.com/mythrocks)

URL: #13072
  • Loading branch information
karthikeyann authored Apr 25, 2023
1 parent 7bbd5ee commit 6ebdce9
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 10 deletions.
14 changes: 13 additions & 1 deletion java/src/main/java/ai/rapids/cudf/ast/Literal.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -200,6 +200,18 @@ public static Literal ofDurationFromLong(DType type, Long value) {
return ofDurationFromLong(type, value.longValue());
}

/** Construct a string literal with the specified value or null. */
public static Literal ofString(String value) {
if (value == null) {
return ofNull(DType.STRING);
}
byte[] stringBytes = value.getBytes();
byte[] serializedValue = new byte[stringBytes.length + Integer.BYTES];
ByteBuffer.wrap(serializedValue).order(ByteOrder.nativeOrder()).putInt(stringBytes.length);
System.arraycopy(stringBytes, 0, serializedValue, Integer.BYTES, stringBytes.length);
return new Literal(DType.STRING, serializedValue);
}

Literal(DType type, byte[] serializedValue) {
this.type = type;
this.serializedValue = serializedValue;
Expand Down
45 changes: 37 additions & 8 deletions java/src/main/native/src/CompiledExpression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,6 +15,7 @@
*/

#include <cstdint>
#include <memory>
#include <stdexcept>
#include <vector>

Expand Down Expand Up @@ -56,12 +57,20 @@ class jni_serialized_ast {

/** Read a multi-byte value from the serialized AST data buffer */
template <typename T> T read() {
check_for_eof(sizeof(T));
// use memcpy since data may be misaligned
T result;
memcpy(reinterpret_cast<jbyte *>(&result), data_ptr, sizeof(T));
data_ptr += sizeof(T);
return result;
if constexpr (std::is_same_v<T, std::string>) {
auto const size = read<cudf::size_type>();
check_for_eof(size);
auto const result = std::string(reinterpret_cast<char const *>(data_ptr), size);
data_ptr += size;
return result;
} else {
check_for_eof(sizeof(T));
// use memcpy since data may be misaligned
T result;
memcpy(reinterpret_cast<jbyte *>(&result), data_ptr, sizeof(T));
data_ptr += sizeof(T);
return result;
}
}

/** Decode a libcudf data type from the serialized AST data buffer */
Expand Down Expand Up @@ -254,9 +263,29 @@ struct make_literal {
std::move(scalar_ptr));
}

/** Construct an AST literal from a string value */
template <typename T, std::enable_if_t<std::is_same_v<T, cudf::string_view>> * = nullptr>
cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid,
cudf::jni::ast::compiled_expr &compiled_expr,
jni_serialized_ast &jni_ast) {
std::unique_ptr<cudf::scalar> scalar_ptr = [&]() {
if (is_valid) {
std::string val = jni_ast.read<std::string>();
return std::make_unique<cudf::string_scalar>(val, is_valid);
} else {
return std::make_unique<cudf::string_scalar>(rmm::device_buffer{}, is_valid);
}
}();

auto &str_scalar = static_cast<cudf::string_scalar &>(*scalar_ptr);
return compiled_expr.add_literal(std::make_unique<cudf::ast::literal>(str_scalar),
std::move(scalar_ptr));
}

/** Default functor implementation to catch type dispatch errors */
template <typename T, std::enable_if_t<!cudf::is_numeric<T>() && !cudf::is_timestamp<T>() &&
!cudf::is_duration<T>()> * = nullptr>
!cudf::is_duration<T>() &&
!std::is_same_v<T, cudf::string_view>> * = nullptr>
cudf::ast::literal &operator()(cudf::data_type dtype, bool is_valid,
cudf::jni::ast::compiled_expr &compiled_expr,
jni_serialized_ast &jni_ast) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -476,6 +476,69 @@ void testBinaryComparisonOperationTransform(BinaryOperator op, Integer[] in1, In
}
}

private static Stream<Arguments> createStringLiteralComparisonParams() {
String[] in1 = new String[] {"a", "bb", null, "ccc", "dddd"};
String in2 = "ccc";
return Stream.of(
// nulls compare as equal by default
Arguments.of(BinaryOperator.NULL_EQUAL, in1, in2, Arrays.asList(false, false, false, true, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, (a) -> !a.equals(in2))),
Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) < 0)),
Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) > 0)),
Arguments.of(BinaryOperator.LESS_EQUAL, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) <= 0)),
Arguments.of(BinaryOperator.GREATER_EQUAL, in1, in2, mapArray(in1, (a) -> a.compareTo(in2) >= 0)),
// null literal
Arguments.of(BinaryOperator.NULL_EQUAL, in1, null, Arrays.asList(false, false, true, false, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, null, Arrays.asList(null, null, null, null, null)),
Arguments.of(BinaryOperator.LESS, in1, null, Arrays.asList(null, null, null, null, null)));
}

@ParameterizedTest
@MethodSource("createStringLiteralComparisonParams")
void testStringLiteralComparison(BinaryOperator op, String[] in1, String in2,
List<Boolean> expectedValues) {
Literal lit = Literal.ofString(in2);
BinaryOperation expr = new BinaryOperation(op,
new ColumnReference(0),
lit);
try (Table t = new Table.TestBuilder().column(in1).build();
CompiledExpression compiledExpr = expr.compile();
ColumnVector actual = compiledExpr.computeColumn(t);
ColumnVector expected = ColumnVector.fromBoxedBooleans(
expectedValues.toArray(new Boolean[0]))) {
assertColumnsAreEqual(expected, actual);
}
}

private static Stream<Arguments> createBinaryComparisonOperationStringParams() {
String[] in1 = new String[] {"a", "bb", null, "ccc", "dddd"};
String[] in2 = new String[] {"aa", "b", null, "ccc", "ddd"};
return Stream.of(
// nulls compare as equal by default
Arguments.of(BinaryOperator.NULL_EQUAL, in1, in2, Arrays.asList(false, false, true, true, false)),
Arguments.of(BinaryOperator.NOT_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> !a.equals(b))),
Arguments.of(BinaryOperator.LESS, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) < 0)),
Arguments.of(BinaryOperator.GREATER, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) > 0)),
Arguments.of(BinaryOperator.LESS_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) <= 0)),
Arguments.of(BinaryOperator.GREATER_EQUAL, in1, in2, mapArray(in1, in2, (a, b) -> a.compareTo(b) >= 0)));
}

@ParameterizedTest
@MethodSource("createBinaryComparisonOperationStringParams")
void testBinaryComparisonOperationStringTransform(BinaryOperator op, String[] in1, String[] in2,
List<Boolean> expectedValues) {
BinaryOperation expr = new BinaryOperation(op,
new ColumnReference(0),
new ColumnReference(1));
try (Table t = new Table.TestBuilder().column(in1).column(in2).build();
CompiledExpression compiledExpr = expr.compile();
ColumnVector actual = compiledExpr.computeColumn(t);
ColumnVector expected = ColumnVector.fromBoxedBooleans(
expectedValues.toArray(new Boolean[0]))) {
assertColumnsAreEqual(expected, actual);
}
}

private static Stream<Arguments> createBinaryBitwiseOperationParams() {
Integer[] in1 = new Integer[] { -5, 4, null, 2, -3 };
Integer[] in2 = new Integer[] { 123, -456, null, 0, -3 };
Expand Down

0 comments on commit 6ebdce9

Please sign in to comment.