Skip to content

Commit

Permalink
Support null safe host column vector (#4052)
Browse files Browse the repository at this point in the history
* Support null safe host column vector

Signed-off-by: Firestarman <firestarmanllc@gmail.com>
  • Loading branch information
firestarman authored Nov 16, 2021
1 parent 1452166 commit 701e6e2
Show file tree
Hide file tree
Showing 9 changed files with 290 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ case class GpuRowBasedScalaUDF(
}
}
}

override val checkNull: Boolean = false
}

object GpuScalaUDFMeta {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, ScalaUDF}
import org.apache.spark.sql.rapids.{GpuRowBasedScalaUDFBase, ScalaUDFMetaBase}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}

/** Run a row-based UDF in a GPU operation */
case class GpuRowBasedScalaUDF(
Expand Down Expand Up @@ -74,6 +74,19 @@ case class GpuRowBasedScalaUDF(
(CatalystTypeConverters.createToScalaConverter(dataType), false)
}
}

/**
* Need nulls check when there are array types with nulls in the input.
* This is for `https://github.com/NVIDIA/spark-rapids/issues/3942`.
*/
override val checkNull: Boolean = children.exists(child => hasArrayWithNulls(child.dataType))

private def hasArrayWithNulls(dt: DataType): Boolean = dt match {
case ArrayType(et, hasNull) => hasNull || hasArrayWithNulls(et)
case MapType(kt, vt, _) => hasArrayWithNulls(kt) || hasArrayWithNulls(vt)
case StructType(fields) => fields.exists(f => hasArrayWithNulls(f.dataType))
case _ => false
}
}

object GpuScalaUDFMeta {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,10 @@ public final RapidsHostColumnVector copyToHost() {
return new RapidsHostColumnVector(type, cudfCv.copyToHost());
}

public final RapidsNullSafeHostColumnVector copyToNullSafeHost() {
return new RapidsNullSafeHostColumnVector(copyToHost());
}

@Override
public final String toString() {
return getBase().toString();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids;

/**
* Wrapper of a RapidsHostColumnVector, which will check nulls in each "getXXX" call and
* return the default value of a type when trying to read a null.
* The performance may not be good enough, so use it only when there is no other way.
*/
public final class RapidsNullSafeHostColumnVector extends RapidsNullSafeHostColumnVectorCore {
private final RapidsHostColumnVector rapidsHcv;

public RapidsNullSafeHostColumnVector(RapidsHostColumnVector rapidsHcv) {
super(rapidsHcv);
this.rapidsHcv = rapidsHcv;
}

public final RapidsNullSafeHostColumnVector incRefCount() {
// Just pass through the reference counting
rapidsHcv.incRefCount();
return this;
}

public final ai.rapids.cudf.HostColumnVector getBase() {
return rapidsHcv.getBase();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids;

import ai.rapids.cudf.HostColumnVectorCore;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;

/**
* Wrapper of a RapidsHostColumnVectorCore, which will check nulls in each "getXXX" call and
* return the default value of a type when trying to read a null.
* The performance may not be good enough, so use it only when there is no other way.
*/
public class RapidsNullSafeHostColumnVectorCore extends ColumnVector {
private final RapidsHostColumnVectorCore rapidsHcvc;
private final RapidsNullSafeHostColumnVectorCore[] cachedChildren;

public RapidsNullSafeHostColumnVectorCore(RapidsHostColumnVectorCore hcvc) {
super(hcvc.dataType());
this.rapidsHcvc = hcvc;
if (type instanceof MapType) {
// Map is a special case where we cache 2 children because it really ends up being
// a list of structs in cuDF so the list only has one child, not the key/value of
// stored in the struct
cachedChildren = new RapidsNullSafeHostColumnVectorCore[2];
} else {
cachedChildren = new RapidsNullSafeHostColumnVectorCore[hcvc.getBase().getNumChildren()];
}
}

public ai.rapids.cudf.HostColumnVectorCore getBase() {
return rapidsHcvc.getBase();
}

@Override
public void close() {
for (int i = 0; i < cachedChildren.length; i++) {
RapidsNullSafeHostColumnVectorCore cv = cachedChildren[i];
if (cv != null) {
cv.close();
// avoid double closing this
cachedChildren[i] = null;
}
}
rapidsHcvc.close();
}

@Override
public boolean hasNull() {
return rapidsHcvc.hasNull();
}

@Override
public int numNulls() {
return rapidsHcvc.numNulls();
}

@Override
public boolean isNullAt(int rowId) {
return rapidsHcvc.isNullAt(rowId);
}

@Override
public boolean getBoolean(int rowId) {
return isNullAt(rowId) ? false : rapidsHcvc.getBoolean(rowId);
}

@Override
public byte getByte(int rowId) {
return isNullAt(rowId) ? 0 : rapidsHcvc.getByte(rowId);
}

@Override
public short getShort(int rowId) {
return isNullAt(rowId) ? 0 : rapidsHcvc.getShort(rowId);
}

@Override
public int getInt(int rowId) {
return isNullAt(rowId) ? 0 : rapidsHcvc.getInt(rowId);
}

@Override
public long getLong(int rowId) {
return isNullAt(rowId) ? 0L : rapidsHcvc.getLong(rowId);
}

@Override
public float getFloat(int rowId) {
return isNullAt(rowId) ? 0.0F : rapidsHcvc.getFloat(rowId);
}

@Override
public double getDouble(int rowId) {
return isNullAt(rowId) ? 0.0 : rapidsHcvc.getDouble(rowId);
}

@Override
public UTF8String getUTF8String(int rowId) {
return isNullAt(rowId) ? null : rapidsHcvc.getUTF8String(rowId);
}

@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
return isNullAt(rowId) ? null : rapidsHcvc.getDecimal(rowId, precision, scale);
}

/** We also need to wrap up the children for nested types: array, map, struct, etc ... */

@Override
public ColumnarArray getArray(int rowId) {
if (isNullAt(rowId)) return null;
// Not null
if (cachedChildren[0] == null) {
// Cache the child data
ArrayType at = (ArrayType) type;
HostColumnVectorCore data = getBase().getChildColumnView(0);
cachedChildren[0] = new RapidsNullSafeHostColumnVectorCore(
new RapidsHostColumnVectorCore(at.elementType(), data));
}
RapidsNullSafeHostColumnVectorCore data = cachedChildren[0];
int startOffset = (int) getBase().getStartListOffset(rowId);
int endOffset = (int) getBase().getEndListOffset(rowId);
return new ColumnarArray(data, startOffset, endOffset - startOffset);
}

@Override
public ColumnarMap getMap(int ordinal) {
if (isNullAt(ordinal)) return null;
// Not null
if (cachedChildren[0] == null) {
// Cache the key/value, map is stored as list of struct (two children)
MapType mt = (MapType) type;
HostColumnVectorCore structHcv = getBase().getChildColumnView(0);
// keys and values
HostColumnVectorCore keyHcvCore = structHcv.getChildColumnView(0);
HostColumnVectorCore valueHcvCore = structHcv.getChildColumnView(1);

cachedChildren[0] = new RapidsNullSafeHostColumnVectorCore(
new RapidsHostColumnVectorCore(mt.keyType(), keyHcvCore));
cachedChildren[1] = new RapidsNullSafeHostColumnVectorCore(
new RapidsHostColumnVectorCore(mt.valueType(), valueHcvCore));
}
RapidsNullSafeHostColumnVectorCore keys = cachedChildren[0];
RapidsNullSafeHostColumnVectorCore values = cachedChildren[1];

int startOffset = (int) getBase().getStartListOffset(ordinal);
int endOffset = (int) getBase().getEndListOffset(ordinal);
return new ColumnarMap(keys, values, startOffset,endOffset - startOffset);
}

@Override
public ColumnVector getChild(int ordinal) {
if (cachedChildren[ordinal] == null) {
StructType st = (StructType) type;
StructField[] fields = st.fields();
for (int i = 0; i < fields.length; i++) {
HostColumnVectorCore tmp = getBase().getChildColumnView(i);
cachedChildren[i] = new RapidsNullSafeHostColumnVectorCore(
new RapidsHostColumnVectorCore(fields[i].dataType(), tmp));
}
}
return cachedChildren[ordinal];
}

@Override
public byte[] getBinary(int rowId) {
if(isNullAt(rowId)) return null;
// Not null
if (cachedChildren[0] == null) {
// cache the child data
HostColumnVectorCore data = getBase().getChildColumnView(0);
cachedChildren[0] = new RapidsNullSafeHostColumnVectorCore(
new RapidsHostColumnVectorCore(DataTypes.ByteType, data));
}
RapidsNullSafeHostColumnVectorCore data = cachedChildren[0];
int startOffset = (int) getBase().getStartListOffset(rowId);
int endOffset = (int) getBase().getEndListOffset(rowId);
return new ColumnarArray(data, startOffset, endOffset - startOffset).toByteArray();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,18 @@ class ColumnarToRowIterator(batches: Iterator[ColumnarBatch],
numInputBatches: GpuMetric,
numOutputRows: GpuMetric,
opTime: GpuMetric,
streamTime: GpuMetric) extends Iterator[InternalRow] with Arm {
streamTime: GpuMetric,
nullSafe: Boolean = false) extends Iterator[InternalRow] with Arm {
// GPU batches read in must be closed by the receiver (us)
@transient private var cb: ColumnarBatch = null
private var it: java.util.Iterator[InternalRow] = null

private[this] lazy val toHost = if (nullSafe) {
(gpuCV: GpuColumnVector) => gpuCV.copyToNullSafeHost()
} else{
(gpuCV: GpuColumnVector) => gpuCV.copyToHost()
}

Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => closeCurrentBatch()))

private def closeCurrentBatch(): Unit = {
Expand All @@ -201,7 +208,7 @@ class ColumnarToRowIterator(batches: Iterator[ColumnarBatch],
devCb.foreach { devCb =>
withResource(new NvtxWithMetrics("ColumnarToRow: batch", NvtxColor.RED, opTime)) { _ =>
try {
cb = new ColumnarBatch(GpuColumnVector.extractColumns(devCb).map(_.copyToHost()),
cb = new ColumnarBatch(GpuColumnVector.extractColumns(devCb).map(toHost),
devCb.numRows())
it = cb.rowIterator()
// In order to match the numOutputRows metric in the generated code we update
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{HostColumnVector, NvtxColor, NvtxRange}
import ai.rapids.cudf.{HostColumnVector, HostColumnVectorCore, NvtxColor, NvtxRange}
import com.nvidia.spark.RapidsUDF
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.shims.v2.ShimExpression
Expand Down Expand Up @@ -75,6 +75,10 @@ object GpuUserDefinedFunction {
// UDFs can support all types except UDT which does not have a clear columnar representation.
val udfTypeSig: TypeSig = (TypeSig.commonCudfTypes + TypeSig.DECIMAL_64 + TypeSig.NULL +
TypeSig.BINARY + TypeSig.CALENDAR + TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested()

/** (This will be initialized once per process) */
lazy val hostColumnAssertionEnabled: Boolean =
classOf[HostColumnVectorCore].desiredAssertionStatus()
}

/**
Expand All @@ -89,6 +93,9 @@ trait GpuRowBasedUserDefinedFunction extends GpuExpression
/** True if the UDF is deterministic */
val udfDeterministic: Boolean

/** True if the UDF needs null check when converting input columns to rows */
val checkNull: Boolean

/** The row based function of the UDF. */
protected def evaluateRow(childrenRow: InternalRow): Any

Expand All @@ -110,12 +117,19 @@ trait GpuRowBasedUserDefinedFunction extends GpuExpression
val retType = GpuColumnVector.convertFrom(dataType, nullable)
val retRow = new GenericInternalRow(size = 1)
closeOnExcept(new HostColumnVector.ColumnBuilder(retType, batch.numRows)) { builder =>
/**
* This `nullSafe` is for https://github.com/NVIDIA/spark-rapids/issues/3942.
* And more details can be found from
* https://github.com/NVIDIA/spark-rapids/pull/3997#issuecomment-957650846
*/
val nullSafe = checkNull && GpuUserDefinedFunction.hostColumnAssertionEnabled
new ColumnarToRowIterator(
Iterator.single(new ColumnarBatch(argCols.toArray, batch.numRows())),
NoopMetric,
NoopMetric,
NoopMetric,
NoopMetric).foreach { row =>
NoopMetric,
nullSafe).foreach { row =>
retRow.update(0, evaluateRow(row))
retConverter.append(retRow, 0, builder)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,21 +534,6 @@ abstract class ScalaUDFMetaBase(
willNotWorkOnGpu(s"neither $udfName implemented by $udfClass provides " +
s"a GPU implementation, nor the conf `${RapidsConf.ENABLE_CPU_BASED_UDF.key}` " +
s"is enabled")
} else if (opRapidsFunc.isEmpty && conf.isCpuBasedUDFEnabled
&& VersionUtils.isSpark311OrLater) {
// Fall back to CPU if the children contain array type with nulls,
// because of the issue as below.
// https://github.com/NVIDIA/spark-rapids/issues/3942
val hasArrayWithNulls = expr.children.exists { e =>
e.dataType match {
case ArrayType(_, containsNull) => containsNull
case _ => false
}
}
if (hasArrayWithNulls) {
willNotWorkOnGpu(s"support for array with nulls in an UDF input is disabled " +
s"temporarily for Spark 3.1.1+. UDF will run into an error for this case.")
}
}
}

Expand Down
Loading

0 comments on commit 701e6e2

Please sign in to comment.