Skip to content

Commit

Permalink
[OPPRO-31] Use Arrow C ABI to interchange data between C++ and Java (o…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored May 17, 2022
1 parent e789913 commit 1bfbfce
Show file tree
Hide file tree
Showing 31 changed files with 436 additions and 267 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
package io.glutenproject.execution;

import io.glutenproject.vectorized.CHColumnVector;
import io.glutenproject.vectorized.GeneralInIterator;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.util.Iterator;

public class ColumnarNativeIterator extends AbstractColumnarNativeIterator {
public class ColumnarNativeIterator extends GeneralInIterator implements Iterator<byte[]> {

public ColumnarNativeIterator(Iterator<ColumnarBatch> delegated) {
super(delegated);
}

@Override
public byte[] next() {
ColumnarBatch dep_cb = nextBatch;
if (dep_cb.numRows() > 0) {
CHColumnVector col = (CHColumnVector) dep_cb.column(0);
ColumnarBatch nextBatch = nextColumnarBatch();
if (nextBatch.numRows() > 0) {
CHColumnVector col = (CHColumnVector) nextBatch.column(0);
return longtoBytes(col.getBlockAddress());
} else {
throw new IllegalStateException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,25 @@

import java.io.IOException;

public class BatchIterator extends AbstractBatchIterator {
private native boolean nativeHasNext(long nativeHandler);
private native byte[] nativeNext(long nativeHandler);
private native long nativeCHNext(long nativeHandler);
private native void nativeClose(long nativeHandler);
private native MetricsObject nativeFetchMetrics(long nativeHandler);

public BatchIterator() throws IOException {}
public class BatchIterator extends GeneralOutIterator {
private native boolean nativeHasNext(long nativeHandle);
private native byte[] nativeNext(long nativeHandle);
private native long nativeCHNext(long nativeHandle);
private native void nativeClose(long nativeHandle);
private native MetricsObject nativeFetchMetrics(long nativeHandle);

public BatchIterator(long instance_id) throws IOException {
super(instance_id);
}

@Override
public boolean hasNextInternal() throws IOException {
return nativeHasNext(nativeHandler);
return nativeHasNext(handle);
}

@Override
public ColumnarBatch nextInternal() throws IOException {
long block = nativeCHNext(nativeHandler);
long block = nativeCHNext(handle);
CHNativeBlock nativeBlock = new CHNativeBlock(block);
int cols = nativeBlock.numColumns();
ColumnVector[] columnVectors = new ColumnVector[cols];
Expand All @@ -55,14 +53,11 @@ public ColumnarBatch nextInternal() throws IOException {

@Override
public MetricsObject getMetricsInternal() throws IOException, ClassNotFoundException {
return nativeFetchMetrics(nativeHandler);
return nativeFetchMetrics(handle);
}

@Override
public void closeInternal() {
if (!closed) {
nativeClose(nativeHandler);
closed = true;
}
nativeClose(handle);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import io.glutenproject.execution._
import io.glutenproject.substrait.plan.PlanNode
import io.glutenproject.substrait.rel.{ExtensionTableBuilder, LocalFilesBuilder}
import io.glutenproject.vectorized.{ExpressionEvaluatorJniWrapper, _}
import java.util
import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext}

import org.apache.spark.sql.catalyst.expressions.Attribute
Expand Down Expand Up @@ -157,11 +156,11 @@ class CHIteratorApi extends IIteratorApi {
outputAttributes: Seq[Attribute],
context: TaskContext,
jarList: Seq[String]): Iterator[ColumnarBatch] = {
var resIter : AbstractBatchIterator = null
var resIter : GeneralOutIterator = null
if (loadNative) {
// TODO: 'jarList' is kept for codegen
val transKernel = new ExpressionEvaluator(jarList.asJava)
val inBatchIters = new java.util.ArrayList[AbstractColumnarNativeIterator]()
val inBatchIters = new java.util.ArrayList[GeneralInIterator]()
resIter = transKernel.createKernelWithBatchIterator(
inputPartition.substraitPlan, inBatchIters)
TaskContext.get().addTaskCompletionListener[Unit] { _ => resIter.close() }
Expand Down Expand Up @@ -206,14 +205,14 @@ class CHIteratorApi extends IIteratorApi {
pipelineTime: SQLMetric,
buildRelationBatchHolder: Seq[ColumnarBatch],
dependentKernels: Seq[ExpressionEvaluator],
dependentKernelIterators: Seq[AbstractBatchIterator]
dependentKernelIterators: Seq[GeneralOutIterator]
): Iterator[ColumnarBatch] = {
var build_elapse: Long = 0
var eval_elapse: Long = 0
GlutenConfig.getConf
val transKernel = new ExpressionEvaluator()
val columnarNativeIterator =
new util.ArrayList[AbstractColumnarNativeIterator](inputIterators.map { iter =>
new java.util.ArrayList[GeneralInIterator](inputIterators.map { iter =>
new ColumnarNativeIterator(iter.asJava)
}.asJava)
// we need to complete dependency RDD's firstly
Expand Down Expand Up @@ -271,9 +270,9 @@ class CHIteratorApi extends IIteratorApi {
* @return
*/
override def genBatchIterator(wsPlan: Array[Byte],
iterList: Seq[AbstractColumnarNativeIterator],
iterList: Seq[GeneralInIterator],
jniWrapper: ExpressionEvaluatorJniWrapper
): AbstractBatchIterator = {
): GeneralOutIterator = {
val batchIteratorInstance = jniWrapper.nativeCreateKernelWithIterator(
0L, wsPlan, iterList.toArray);
new BatchIterator(batchIteratorInstance)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package io.glutenproject.jni;

import io.glutenproject.execution.AbstractColumnarNativeIterator;
import io.glutenproject.vectorized.GeneralInIterator;
import io.glutenproject.row.RowIterator;
import io.glutenproject.row.SparkRowInfo;
import io.glutenproject.vectorized.ExpressionEvaluator;
Expand Down Expand Up @@ -37,7 +37,7 @@ public void testInitNative() throws Exception {

Plan substraitPlan = Plan.parseFrom(dis);

ArrayList<AbstractColumnarNativeIterator> inBatchIters = new ArrayList<AbstractColumnarNativeIterator>();
ArrayList<GeneralInIterator> inBatchIters = new ArrayList<GeneralInIterator>();
RowIterator resIter = transKernel.createKernelWithRowIterator(substraitPlan.toByteArray(),
inBatchIters);
while (resIter.hasNext()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ public ShuffleDecompressionJniWrapper() throws IOException {
*/
public native long make(byte[] schemaBuf) throws RuntimeException;

public native byte[] decompress(
public native boolean decompress(
long schemaHolderId,
String compressionCodec,
int numRows,
long[] bufAddrs,
long[] bufSizes,
long[] bufMask)
long[] bufMask,
long cSchema,
long cArray)
throws RuntimeException;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,26 @@
* limitations under the License.
*/

package io.glutenproject.execution;
package io.glutenproject.vectorized;

import io.glutenproject.expression.ArrowConverterUtils;
import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import io.glutenproject.utils.ArrowAbiUtil;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.util.Iterator;

public class ColumnarNativeIterator extends AbstractColumnarNativeIterator {
public class VeloxInIterator extends GeneralInIterator {

public ColumnarNativeIterator(Iterator<ColumnarBatch> delegated) {
public VeloxInIterator(Iterator<ColumnarBatch> delegated) {
super(delegated);
}

@Override
public byte[] next() {
ColumnarBatch dep_cb = nextBatch;
if (dep_cb.numRows() > 0) {
ArrowRecordBatch dep_rb = ArrowConverterUtils.createArrowRecordBatch(dep_cb);
return serialize(dep_rb);
} else {
throw new IllegalStateException();
}
}

private byte[] serialize(ArrowRecordBatch batch) {
return UnsafeRecordBatchSerializer.serializeUnsafe(batch);
public void next(long cSchemaAddress, long cArrayAddress) {
final ColumnarBatch batch = nextColumnarBatch();
final ArrowSchema cSchema = ArrowSchema.wrap(cSchemaAddress);
final ArrowArray cArray = ArrowArray.wrap(cArrayAddress);
ArrowAbiUtil.exportFromSparkColumnarBatch(SparkMemoryUtils.contextAllocator(), batch, cSchema, cArray);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,59 +17,51 @@

package io.glutenproject.vectorized;

import org.apache.arrow.dataset.jni.UnsafeRecordBatchSerializer;
import io.glutenproject.utils.ArrowAbiUtil;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;

import java.io.IOException;
import java.io.Serializable;

public class BatchIterator extends AbstractBatchIterator {
private native boolean nativeHasNext(long nativeHandler);
private native byte[] nativeNext(long nativeHandler);
private native long nativeCHNext(long nativeHandler);
private native void nativeClose(long nativeHandler);
private native MetricsObject nativeFetchMetrics(long nativeHandler);
public class VeloxOutIterator extends GeneralOutIterator {

public BatchIterator() throws IOException {}
private native boolean nativeHasNext(long nativeHandle);
private native boolean nativeNext(long nativeHandle, long cSchema, long cArray);
private native long nativeCHNext(long nativeHandle);
private native void nativeClose(long nativeHandle);
private native MetricsObject nativeFetchMetrics(long nativeHandle);

public BatchIterator(long instance_id) throws IOException {
public VeloxOutIterator(long instance_id) throws IOException {
super(instance_id);
}

@Override
public boolean hasNextInternal() throws IOException {
return nativeHasNext(nativeHandler);
return nativeHasNext(handle);
}

@Override
public ArrowRecordBatch nextInternal() throws IOException {
BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
if (nativeHandler == 0) {
return null;
public ColumnarBatch nextInternal() throws IOException {
final BufferAllocator allocator = SparkMemoryUtils.contextAllocator();
try (final ArrowArray cArray = ArrowArray.allocateNew(allocator);
final ArrowSchema cSchema = ArrowSchema.allocateNew(allocator)) {
if (!nativeNext(handle, cSchema.memoryAddress(), cArray.memoryAddress())) {
return null; // stream ended
}
return ArrowAbiUtil.importToSparkColumnarBatch(allocator, cSchema, cArray);
}
byte[] serializedRecordBatch = nativeNext(nativeHandler);
if (serializedRecordBatch == null) {
return null;
}
return UnsafeRecordBatchSerializer.deserializeUnsafe(allocator,
serializedRecordBatch);
}

@Override
public MetricsObject getMetricsInternal() throws IOException, ClassNotFoundException {
return nativeFetchMetrics(nativeHandler);
return nativeFetchMetrics(handle);
}

@Override
public void closeInternal() {
if (!closed) {
nativeClose(nativeHandler);
closed = true;
}
nativeClose(handle);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,13 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
context: TaskContext,
jarList: Seq[String]): Iterator[ColumnarBatch] = {
import org.apache.spark.sql.util.OASPackageBridge._
var inputSchema: Schema = null
var outputSchema: Schema = null
var resIter: AbstractBatchIterator = null
var inputSchema : Schema = null
var outputSchema : Schema = null
var resIter : GeneralOutIterator = null
if (loadNative) {
// TODO: 'jarList' is kept for codegen
val transKernel = new ExpressionEvaluator(jarList.asJava)
val inBatchIters = new java.util.ArrayList[AbstractColumnarNativeIterator]()
val inBatchIters = new java.util.ArrayList[GeneralInIterator]()
outputSchema = ArrowConverterUtils.toArrowSchema(outputAttributes)
resIter =
transKernel.createKernelWithBatchIterator(inputPartition.substraitPlan, inBatchIters)
Expand All @@ -230,17 +230,7 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
val rb = resIter.next().asInstanceOf[ArrowRecordBatch]
if (rb == null) {
val resultStructType = ArrowUtils.fromArrowSchema(outputSchema)
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
val outputNumRows = rb.getLength
val output = ArrowConverterUtils.fromArrowRecordBatch(outputSchema, rb)
ArrowConverterUtils.releaseArrowRecordBatch(rb)
val cb = new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]), outputNumRows)
val cb = resIter.next()
val bytes: Long = cb match {
case batch: ColumnarBatch =>
(0 until batch.numCols()).map { i =>
Expand Down Expand Up @@ -282,7 +272,7 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
pipelineTime: SQLMetric,
buildRelationBatchHolder: Seq[ColumnarBatch],
dependentKernels: Seq[ExpressionEvaluator],
dependentKernelIterators: Seq[AbstractBatchIterator]): Iterator[ColumnarBatch] = {
dependentKernelIterators: Seq[GeneralOutIterator]): Iterator[ColumnarBatch] = {

ExecutorManager.tryTaskSet(numaBindingInfo)

Expand All @@ -300,8 +290,8 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
val beforeBuild = System.nanoTime()
val transKernel = new ExpressionEvaluator(jarList.asJava)
val columnarNativeIterator =
new util.ArrayList[AbstractColumnarNativeIterator](inputIterators.map { iter =>
new ColumnarNativeIterator(iter.asJava)
new util.ArrayList[GeneralInIterator](inputIterators.map { iter =>
new VeloxInIterator(iter.asJava)
}.asJava)
val nativeResultIterator =
transKernel.createKernelWithBatchIterator(rootNode, columnarNativeIterator)
Expand All @@ -321,18 +311,8 @@ class VeloxIteratorApi extends IIteratorApi with Logging {

override def next(): ColumnarBatch = {
val beforeEval = System.nanoTime()
val recordBatch = nativeResultIterator.next.asInstanceOf[ArrowRecordBatch]
if (recordBatch == null) {
evalElapse += System.nanoTime() - beforeEval
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, schema)
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
val recordBatchSchema = ArrowConverterUtils.toArrowSchema(outputAttributes)
val columns = ArrowConverterUtils.fromArrowRecordBatch(recordBatchSchema, recordBatch)
ArrowConverterUtils.releaseArrowRecordBatch(recordBatch)
evalElapse += System.nanoTime() - beforeEval
new ColumnarBatch(columns.map(v => v.asInstanceOf[ColumnVector]), recordBatch.getLength)
val cb = nativeResultIterator.next
cb
}
}

Expand All @@ -351,8 +331,8 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
* @return
*/
override def genColumnarNativeIterator(
delegated: Iterator[ColumnarBatch]): ColumnarNativeIterator = {
new ColumnarNativeIterator(delegated.asJava)
delegated: Iterator[ColumnarBatch]): VeloxInIterator = {
new VeloxInIterator(delegated.asJava)
}

/**
Expand All @@ -362,13 +342,13 @@ class VeloxIteratorApi extends IIteratorApi with Logging {
*/
override def genBatchIterator(
wsPlan: Array[Byte],
iterList: Seq[AbstractColumnarNativeIterator],
jniWrapper: ExpressionEvaluatorJniWrapper): AbstractBatchIterator = {
iterList: Seq[GeneralInIterator],
jniWrapper: ExpressionEvaluatorJniWrapper): GeneralOutIterator = {
val memoryPool = SparkMemoryUtils.contextMemoryPool()
val poolId = memoryPool.getNativeInstanceId
val batchIteratorInstance =
jniWrapper.nativeCreateKernelWithIterator(poolId, wsPlan, iterList.toArray)
new BatchIterator(batchIteratorInstance)
new VeloxOutIterator(batchIteratorInstance)
}

/**
Expand Down
Loading

0 comments on commit 1bfbfce

Please sign in to comment.