Skip to content

Commit

Permalink
feat: Support Broadcast HashJoin (apache#211)
Browse files Browse the repository at this point in the history
* feat: Support HashJoin

* Add comment

* Clean up test

* Fix join filter

* Fix clippy

* Use consistent function with sort merge join

* Add note about left semi and left anti joins

* feat: Support BroadcastHashJoin

* Move tests

* Remove unused import

* Add function to parse join parameters

* Remove duplicate code

* For review
  • Loading branch information
viirya authored and wangyum committed Mar 28, 2024
1 parent df6571b commit 1540b56
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 28 deletions.
51 changes: 51 additions & 0 deletions common/src/main/java/org/apache/comet/CometArrowStreamWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet;

import java.io.IOException;
import java.nio.channels.WritableByteChannel;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;

/**
* A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream.
* Arrow `ArrowStreamWriter` cannot change the root after initialization.
*/
public class CometArrowStreamWriter extends ArrowStreamWriter {
public CometArrowStreamWriter(
VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) {
super(root, provider, out);
}

public void writeMoreBatch(VectorSchemaRoot root) throws IOException {
VectorUnloader unloader =
new VectorUnloader(
root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true);

try (ArrowRecordBatch batch = unloader.getRecordBatch()) {
writeRecordBatch(batch);
}
}
}
18 changes: 9 additions & 9 deletions common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.comet.vector

import java.io.OutputStream
import java.nio.channels.Channels

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona
import org.apache.arrow.memory.RootAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.dictionary.DictionaryProvider
import org.apache.arrow.vector.ipc.ArrowStreamWriter
import org.apache.spark.SparkException
import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.CometArrowStreamWriter

class NativeUtil {
private val allocator = new RootAllocator(Long.MaxValue)
private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider
Expand All @@ -46,29 +48,27 @@ class NativeUtil {
* the output stream
*/
def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = {
var schemaRoot: Option[VectorSchemaRoot] = None
var writer: Option[ArrowStreamWriter] = None
var writer: Option[CometArrowStreamWriter] = None
var rowCount = 0

batches.foreach { batch =>
val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch)
val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava))
val root = new VectorSchemaRoot(fieldVectors.asJava)
val provider = batchProviderOpt.getOrElse(dictionaryProvider)

if (writer.isEmpty) {
writer = Some(new ArrowStreamWriter(root, provider, out))
writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out)))
writer.get.start()
writer.get.writeBatch()
} else {
writer.get.writeMoreBatch(root)
}
writer.get.writeBatch()

root.clear()
schemaRoot = Some(root)

rowCount += batch.numRows()
}

writer.map(_.end())
schemaRoot.map(_.close())

rowCount
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel

import org.apache.spark.sql.vectorized.ColumnarBatch

import org.apache.comet.vector.StreamReader
import org.apache.comet.vector._

class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] {

Expand All @@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna
return true
}

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

batch = nextBatch()
if (batch.isEmpty) {
return false
Expand All @@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna

val nextBatch = batch.get

// Release the previous batch.
// If it is not released, when closing the reader, arrow library will complain about
// memory leak.
if (currentBatch != null) {
currentBatch.close()
}

currentBatch = nextBatch
batch = None
currentBatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._
import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -356,6 +356,27 @@ class CometSparkSessionExtensions
op
}

case op: BroadcastHashJoinExec
if isCometOperatorEnabled(conf, "broadcast_hash_join") &&
op.children.forall(isCometNative(_)) =>
val newOp = transform1(op)
newOp match {
case Some(nativeOp) =>
CometBroadcastHashJoinExec(
nativeOp,
op,
op.leftKeys,
op.rightKeys,
op.joinType,
op.condition,
op.buildSide,
op.left,
op.right,
SerializedPlan(None))
case None =>
op
}

case op: SortMergeJoinExec
if isCometOperatorEnabled(conf, "sort_merge_join") &&
op.children.forall(isCometNative(_)) =>
Expand Down Expand Up @@ -411,6 +432,16 @@ class CometSparkSessionExtensions
u
}

// For AQE broadcast stage on a Comet broadcast exchange
case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) =>
val newOp = transform1(s)
newOp match {
case Some(nativeOp) =>
CometSinkPlaceHolder(nativeOp, s, s)
case None =>
s
}

case b: BroadcastExchangeExec
if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") &&
isCometBroadCastEnabled(conf) =>
Expand Down
18 changes: 14 additions & 4 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -1915,7 +1915,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
}
}

case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") =>
case join: HashJoin =>
// `HashJoin` has only two implementations in Spark, but we check the type of the join to
// make sure we are handling the correct join type.
if (!(isCometOperatorEnabled(op.conf, "hash_join") &&
join.isInstanceOf[ShuffledHashJoinExec]) &&
!(isCometOperatorEnabled(op.conf, "broadcast_hash_join") &&
join.isInstanceOf[BroadcastHashJoinExec])) {
return None
}

if (join.buildSide == BuildRight) {
// DataFusion HashJoin assumes build side is always left.
// TODO: support BuildRight
Expand Down Expand Up @@ -2063,6 +2072,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde {
case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true
case _: TakeOrderedAndProjectExec => true
case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true
case _: BroadcastExchangeExec => true
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ class CometBatchRDD(

override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
val partition = split.asInstanceOf[CometBatchPartition]

partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator
partition.value.value.toIterator.flatMap(CometExec.decodeBatches)
}
}

Expand Down
44 changes: 41 additions & 3 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec}
import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec {
plan match {
case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec |
_: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec |
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec =>
_: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec |
_: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
func(plan)
case _: CometPlan =>
// Other Comet operators, continue to traverse the tree.
Expand Down Expand Up @@ -622,7 +623,44 @@ case class CometHashJoinExec(
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, left, right)
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
}

case class CometBroadcastHashJoinExec(
override val nativeOp: Operator,
override val originalPlan: SparkPlan,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
joinType: JoinType,
condition: Option[Expression],
buildSide: BuildSide,
override val left: SparkPlan,
override val right: SparkPlan,
override val serializedPlanOpt: SerializedPlan)
extends CometBinaryExec {
override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan =
this.copy(left = newLeft, right = newRight)

override def stringArgs: Iterator[Any] =
Iterator(leftKeys, rightKeys, joinType, condition, left, right)

override def equals(obj: Any): Boolean = {
obj match {
case other: CometBroadcastHashJoinExec =>
this.leftKeys == other.leftKeys &&
this.rightKeys == other.rightKeys &&
this.condition == other.condition &&
this.buildSide == other.buildSide &&
this.left == other.left &&
this.right == other.right &&
this.serializedPlanOpt == other.serializedPlanOpt
case _ =>
false
}
}

override def hashCode(): Int =
Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right)
}

case class CometSortMergeJoinExec(
Expand Down
Loading

0 comments on commit 1540b56

Please sign in to comment.