Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize out bounds checking for joins when the gather map has only valid entries #3799

Merged
merged 1 commit into from
Oct 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ package com.nvidia.spark.rapids

import scala.collection.mutable

import ai.rapids.cudf.{GatherMap, NvtxColor}
import ai.rapids.cudf.{GatherMap, NvtxColor, OutOfBoundsPolicy}

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
Expand Down Expand Up @@ -280,7 +281,8 @@ abstract class SplittableJoinIterator(
protected def makeGatherer(
maps: Array[GatherMap],
leftData: LazySpillableColumnarBatch,
rightData: LazySpillableColumnarBatch): Option[JoinGatherer] = {
rightData: LazySpillableColumnarBatch,
joinType: JoinType): Option[JoinGatherer] = {
assert(maps.length > 0 && maps.length <= 2)
try {
val leftMap = maps.head
Expand All @@ -298,11 +300,37 @@ abstract class SplittableJoinIterator(
val lazyLeftMap = LazySpillableGatherMap(leftMap, spillCallback, "left_map")
val gatherer = rightMap match {
case None =>
// When there isn't a `rightMap` we are in either LeftSemi or LeftAnti joins.
// In these cases, the map and the table are both the left side, and everything in the map
// is a match on the left table, so we don't want to check for bounds.
rightData.close()
JoinGatherer(lazyLeftMap, leftData)
JoinGatherer(lazyLeftMap, leftData, OutOfBoundsPolicy.DONT_CHECK)
case Some(right) =>
// Inner joins -- manifest the intersection of both left and right sides. The gather maps
// contain the number of rows that must be manifested, and every index
// must be within bounds, so we can skip the bounds checking.
//
// Left outer -- Left outer manifests all rows for the left table. The left gather map
// must contain valid indices, so we skip the check for the left side. The right side
// has to be checked, since we need to produce nulls (for the right) for those
// rows on the left side that don't have a match on the right.
//
// Right outer -- Is the opposite from left outer (skip right bounds check, keep left)
//
// Full outer -- Can produce nulls for any left or right rows that don't have a match
// in the opposite table. So we must check both gather maps.
//
val leftOutOfBoundsPolicy = joinType match {
case _: InnerLike | LeftOuter => OutOfBoundsPolicy.DONT_CHECK
case _ => OutOfBoundsPolicy.NULLIFY
}
val rightOutOfBoundsPolicy = joinType match {
case _: InnerLike | RightOuter => OutOfBoundsPolicy.DONT_CHECK
case _ => OutOfBoundsPolicy.NULLIFY
}
val lazyRightMap = LazySpillableGatherMap(right, spillCallback, "right_map")
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData)
JoinGatherer(lazyLeftMap, leftData, lazyRightMap, rightData,
leftOutOfBoundsPolicy, rightOutOfBoundsPolicy)
}
if (gatherer.isDone) {
// Nothing matched...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, Scalar, Table}
import ai.rapids.cudf.{ColumnVector, ColumnView, DeviceMemoryBuffer, DType, GatherMap, NvtxColor, NvtxRange, OrderByArg, OutOfBoundsPolicy, Scalar, Table}

import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DateType, DecimalType, IntegerType, LongType, MapType, NullType, NumericType, StringType, StructType, TimestampType}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -136,15 +136,18 @@ trait JoinGatherer extends LazySpillable with Arm {

object JoinGatherer extends Arm {
def apply(gatherMap: LazySpillableGatherMap,
inputData: LazySpillableColumnarBatch): JoinGatherer =
new JoinGathererImpl(gatherMap, inputData)
inputData: LazySpillableColumnarBatch,
outOfBoundsPolicy: OutOfBoundsPolicy): JoinGatherer =
new JoinGathererImpl(gatherMap, inputData, outOfBoundsPolicy)

def apply(leftMap: LazySpillableGatherMap,
leftData: LazySpillableColumnarBatch,
rightMap: LazySpillableGatherMap,
rightData: LazySpillableColumnarBatch): JoinGatherer = {
val left = JoinGatherer(leftMap, leftData)
val right = JoinGatherer(rightMap, rightData)
rightData: LazySpillableColumnarBatch,
outOfBoundsPolicyLeft: OutOfBoundsPolicy,
outOfBoundsPolicyRight: OutOfBoundsPolicy): JoinGatherer = {
val left = JoinGatherer(leftMap, leftData, outOfBoundsPolicyLeft)
val right = JoinGatherer(rightMap, rightData, outOfBoundsPolicyRight)
MultiJoinGather(left, right)
}

Expand Down Expand Up @@ -499,7 +502,8 @@ object JoinGathererImpl {
*/
class JoinGathererImpl(
private val gatherMap: LazySpillableGatherMap,
private val data: LazySpillableColumnarBatch) extends JoinGatherer {
private val data: LazySpillableColumnarBatch,
boundsCheckPolicy: OutOfBoundsPolicy) extends JoinGatherer {

assert(data.numCols > 0, "data with no columns should have been filtered out already")

Expand Down Expand Up @@ -536,7 +540,7 @@ class JoinGathererImpl(
val ret = withResource(gatherMap.toColumnView(start, n)) { gatherView =>
val batch = data.getBatch
val gatheredTable = withResource(GpuColumnVector.from(batch)) { table =>
table.gather(gatherView)
table.gather(gatherView, boundsCheckPolicy)
}
withResource(gatheredTable) { gt =>
GpuColumnVector.from(gt, GpuColumnVector.extractTypes(batch))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package org.apache.spark.sql.rapids.execution

import ai.rapids.cudf.{ast, GatherMap, NvtxColor, Table}
import ai.rapids.cudf.{ast, GatherMap, NvtxColor, OutOfBoundsPolicy, Table}
import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2.ShimBinaryExecNode

Expand Down Expand Up @@ -170,16 +170,21 @@ class CrossJoinIterator(
val leftMap = LazySpillableGatherMap.leftCross(leftBatch.numRows, rightBatch.numRows)
val rightMap = LazySpillableGatherMap.rightCross(leftBatch.numRows, rightBatch.numRows)

// Cross joins do not need to worry about bounds checking because the gather maps
// are generated using mod and div based on the number of rows on the left and
// right, so we specify here `DONT_CHECK` for all.
val joinGatherer = (leftBatch.numCols, rightBatch.numCols) match {
case (_, 0) =>
rightBatch.close()
rightMap.close()
JoinGatherer(leftMap, leftBatch)
JoinGatherer(leftMap, leftBatch, OutOfBoundsPolicy.DONT_CHECK)
case (0, _) =>
leftBatch.close()
leftMap.close()
JoinGatherer(rightMap, rightBatch)
case (_, _) => JoinGatherer(leftMap, leftBatch, rightMap, rightBatch)
JoinGatherer(rightMap, rightBatch, OutOfBoundsPolicy.DONT_CHECK)
case (_, _) =>
JoinGatherer(leftMap, leftBatch, rightMap, rightBatch,
OutOfBoundsPolicy.DONT_CHECK, OutOfBoundsPolicy.DONT_CHECK)
}
if (joinGatherer.isDone) {
joinGatherer.close()
Expand Down Expand Up @@ -252,7 +257,7 @@ class ConditionalNestedLoopJoinIterator(
case GpuBuildRight => (streamTable, streamBatch, builtTable, builtSpillOnly)
}
val maps = computeGatherMaps(leftTable, rightTable, numJoinRows)
makeGatherer(maps, leftBatch, rightBatch)
makeGatherer(maps, leftBatch, rightBatch, joinType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ class HashJoinIterator(
throw new NotImplementedError(s"Joint Type ${joinType.getClass} is not currently" +
s" supported")
}
makeGatherer(maps, leftData, rightData)
makeGatherer(maps, leftData, rightData, joinType)
}
}

Expand Down