Skip to content

Commit

Permalink
Avoid replacing a hash join if build side is unsupported by the join …
Browse files Browse the repository at this point in the history
…type (#4463)

Signed-off-by: Jason Lowe <jlowe@nvidia.com>
  • Loading branch information
jlowe authored Jan 6, 2022
1 parent c4194bb commit f8fa0b5
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ class GpuBroadcastHashJoinMeta(
override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys,
join.condition)
val Seq(leftChild, rightChild) = childPlans
val buildSideMeta = buildSide match {
case GpuBuildLeft => leftChild
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,16 @@ class GpuShuffledHashJoinMeta(
join.rightKeys.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
val condition: Option[BaseExprMeta[_]] =
join.condition.map(GpuOverrides.wrapExpr(_, conf, Some(this)))
val buildSide: GpuBuildSide = GpuJoinUtils.getGpuBuildSide(join.buildSide)

override val childExprs: Seq[BaseExprMeta[_]] = leftKeys ++ rightKeys ++ condition

override val namedChildExprs: Map[String, Seq[BaseExprMeta[_]]] =
JoinTypeChecks.equiJoinMeta(leftKeys, rightKeys, condition)

override def tagPlanForGpu(): Unit = {
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys,
join.condition)
}

override def convertToGpu(): GpuExec = {
Expand All @@ -56,7 +58,7 @@ class GpuShuffledHashJoinMeta(
leftKeys.map(_.convertToGpu()),
rightKeys.map(_.convertToGpu()),
join.joinType,
GpuJoinUtils.getGpuBuildSide(join.buildSide),
buildSide,
None,
left,
right,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class GpuSortMergeJoinMeta(

override def tagPlanForGpu(): Unit = {
// Use conditions from Hash Join
GpuHashJoin.tagJoin(this, join.joinType, join.leftKeys, join.rightKeys, join.condition)
GpuHashJoin.tagJoin(this, join.joinType, buildSide, join.leftKeys, join.rightKeys,
join.condition)

if (!conf.enableReplaceSortMergeJoin) {
willNotWorkOnGpu(s"Not replacing sort merge join with hash join, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ object GpuHashJoin extends Arm {
def tagJoin(
meta: RapidsMeta[_, _, _],
joinType: JoinType,
buildSide: GpuBuildSide,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression]): Unit = {
Expand All @@ -122,6 +123,14 @@ object GpuHashJoin extends Arm {
case _ =>
meta.willNotWorkOnGpu(s"$joinType currently is not supported")
}

buildSide match {
case GpuBuildLeft if !canBuildLeft(joinType) =>
meta.willNotWorkOnGpu(s"$joinType does not support left-side build")
case GpuBuildRight if !canBuildRight(joinType) =>
meta.willNotWorkOnGpu(s"$joinType does not support right-side build")
case _ =>
}
}

/** Determine if this type of join supports using the right side of the join as the build side. */
Expand Down

0 comments on commit f8fa0b5

Please sign in to comment.