diff --git a/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java new file mode 100644 index 000000000..a492ce887 --- /dev/null +++ b/common/src/main/java/org/apache/comet/CometArrowStreamWriter.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.compression.NoCompressionCodec; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; + +/** + * A custom `ArrowStreamWriter` that allows writing batches from different root to the same stream. + * Arrow `ArrowStreamWriter` cannot change the root after initialization. + */ +public class CometArrowStreamWriter extends ArrowStreamWriter { + public CometArrowStreamWriter( + VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out) { + super(root, provider, out); + } + + public void writeMoreBatch(VectorSchemaRoot root) throws IOException { + VectorUnloader unloader = + new VectorUnloader( + root, /*includeNullCount*/ true, NoCompressionCodec.INSTANCE, /*alignBuffers*/ true); + + try (ArrowRecordBatch batch = unloader.getRecordBatch()) { + writeRecordBatch(batch); + } + } +} diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 1682295f7..cc726e3e8 100644 --- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -20,6 +20,7 @@ package org.apache.comet.vector import java.io.OutputStream +import java.nio.channels.Channels import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,10 +29,11 @@ import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, CDataDictiona import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.dictionary.DictionaryProvider -import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.SparkException import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.comet.CometArrowStreamWriter + class NativeUtil { private val allocator = new RootAllocator(Long.MaxValue) private val dictionaryProvider: CDataDictionaryProvider = new CDataDictionaryProvider @@ -46,29 +48,27 @@ class NativeUtil { * the output stream */ def serializeBatches(batches: Iterator[ColumnarBatch], out: OutputStream): Long = { - var schemaRoot: Option[VectorSchemaRoot] = None - var writer: Option[ArrowStreamWriter] = None + var writer: Option[CometArrowStreamWriter] = None var rowCount = 0 batches.foreach { batch => val (fieldVectors, batchProviderOpt) = getBatchFieldVectors(batch) - val root = schemaRoot.getOrElse(new VectorSchemaRoot(fieldVectors.asJava)) + val root = new VectorSchemaRoot(fieldVectors.asJava) val provider = batchProviderOpt.getOrElse(dictionaryProvider) if (writer.isEmpty) { - writer = Some(new ArrowStreamWriter(root, provider, out)) + writer = Some(new CometArrowStreamWriter(root, provider, Channels.newChannel(out))) writer.get.start() + writer.get.writeBatch() + } else { + writer.get.writeMoreBatch(root) } - writer.get.writeBatch() root.clear() - schemaRoot = Some(root) - rowCount += batch.numRows() } writer.map(_.end()) - schemaRoot.map(_.close()) rowCount } diff --git a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala index e8dba93e7..304c3ce77 100644 --- a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala +++ b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala @@ -23,7 +23,7 @@ import java.nio.channels.ReadableByteChannel import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.vector.StreamReader +import org.apache.comet.vector._ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[ColumnarBatch] { @@ -36,6 +36,13 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna return true } + // Release the previous batch. + // If it is not released, when closing the reader, arrow library will complain about + // memory leak. + if (currentBatch != null) { + currentBatch.close() + } + batch = nextBatch() if (batch.isEmpty) { return false @@ -50,13 +57,6 @@ class ArrowReaderIterator(channel: ReadableByteChannel) extends Iterator[Columna val nextBatch = batch.get - // Release the previous batch. - // If it is not released, when closing the reader, arrow library will complain about - // memory leak. - if (currentBatch != null) { - currentBatch.close() - } - currentBatch = nextBatch batch = None currentBatch diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index 1380ee91a..fcbf42f5b 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -31,14 +31,14 @@ import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometNativeShuffle} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -356,6 +356,27 @@ class CometSparkSessionExtensions op } + case op: BroadcastHashJoinExec + if isCometOperatorEnabled(conf, "broadcast_hash_join") && + op.children.forall(isCometNative(_)) => + val newOp = transform1(op) + newOp match { + case Some(nativeOp) => + CometBroadcastHashJoinExec( + nativeOp, + op, + op.leftKeys, + op.rightKeys, + op.joinType, + op.condition, + op.buildSide, + op.left, + op.right, + SerializedPlan(None)) + case None => + op + } + case op: SortMergeJoinExec if isCometOperatorEnabled(conf, "sort_merge_join") && op.children.forall(isCometNative(_)) => @@ -411,6 +432,16 @@ class CometSparkSessionExtensions u } + // For AQE broadcast stage on a Comet broadcast exchange + case s @ BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => + val newOp = transform1(s) + newOp match { + case Some(nativeOp) => + CometSinkPlaceHolder(nativeOp, s, s) + case None => + s + } + case b: BroadcastExchangeExec if isCometNative(b.child) && isCometOperatorEnabled(conf, "broadcastExchangeExec") && isCometBroadCastEnabled(conf) => diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index bf2510b25..b98c4388e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -29,14 +29,14 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, NormalizeNaNAndZero} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, SinglePartition} import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils -import org.apache.spark.sql.comet.{CometSinkPlaceHolder, DecimalPrecision} +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometSinkPlaceHolder, DecimalPrecision} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1915,7 +1915,16 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { } } - case join: ShuffledHashJoinExec if isCometOperatorEnabled(op.conf, "hash_join") => + case join: HashJoin => + // `HashJoin` has only two implementations in Spark, but we check the type of the join to + // make sure we are handling the correct join type. + if (!(isCometOperatorEnabled(op.conf, "hash_join") && + join.isInstanceOf[ShuffledHashJoinExec]) && + !(isCometOperatorEnabled(op.conf, "broadcast_hash_join") && + join.isInstanceOf[BroadcastHashJoinExec])) { + return None + } + if (join.buildSide == BuildRight) { // DataFusion HashJoin assumes build side is always left. // TODO: support BuildRight @@ -2063,6 +2072,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde { case ShuffleQueryStageExec(_, _: CometShuffleExchangeExec, _) => true case ShuffleQueryStageExec(_, ReusedExchangeExec(_, _: CometShuffleExchangeExec), _) => true case _: TakeOrderedAndProjectExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true case _: BroadcastExchangeExec => true case _ => false } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index f115b2ad9..24f9f3279 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -258,8 +258,7 @@ class CometBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometBatchPartition] - - partition.value.value.flatMap(CometExec.decodeBatches(_)).toIterator + partition.value.value.toIterator.flatMap(CometExec.decodeBatches) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index fb300a303..84734a175 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.comet.execution.shuffle.{ArrowReaderIterator, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{BinaryExecNode, ColumnarToRowExec, ExecSubqueryExpression, ExplainUtils, LeafExecNode, ScalarSubquery, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -269,7 +269,8 @@ abstract class CometNativeExec extends CometExec { plan match { case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: CometUnionExec | - _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec => + _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: ReusedExchangeExec | + _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec => func(plan) case _: CometPlan => // Other Comet operators, continue to traverse the tree. @@ -622,7 +623,44 @@ case class CometHashJoinExec( } override def hashCode(): Int = - Objects.hashCode(leftKeys, rightKeys, condition, left, right) + Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right) +} + +case class CometBroadcastHashJoinExec( + override val nativeOp: Operator, + override val originalPlan: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + condition: Option[Expression], + buildSide: BuildSide, + override val left: SparkPlan, + override val right: SparkPlan, + override val serializedPlanOpt: SerializedPlan) + extends CometBinaryExec { + override def withNewChildrenInternal(newLeft: SparkPlan, newRight: SparkPlan): SparkPlan = + this.copy(left = newLeft, right = newRight) + + override def stringArgs: Iterator[Any] = + Iterator(leftKeys, rightKeys, joinType, condition, left, right) + + override def equals(obj: Any): Boolean = { + obj match { + case other: CometBroadcastHashJoinExec => + this.leftKeys == other.leftKeys && + this.rightKeys == other.rightKeys && + this.condition == other.condition && + this.buildSide == other.buildSide && + this.left == other.left && + this.right == other.right && + this.serializedPlanOpt == other.serializedPlanOpt + case _ => + false + } + } + + override def hashCode(): Int = + Objects.hashCode(leftKeys, rightKeys, condition, buildSide, left, right) } case class CometSortMergeJoinExec( diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index a64ec8749..6f479e3bb 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -23,9 +23,11 @@ import org.scalactic.source.Position import org.scalatest.Tag import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus class CometJoinSuite extends CometTestBase { @@ -38,6 +40,68 @@ class CometJoinSuite extends CometTestBase { } } + test("Broadcast HashJoin without join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql("SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b ON tbl_a._2 = tbl_b._1") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + + test("Broadcast HashJoin with join filter") { + assume(isSpark34Plus, "ChunkedByteBuffer is not serializable before Spark 3.4+") + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> "100", + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + "spark.comet.exec.broadcast.enabled" -> "true", + "spark.sql.join.forceApplyShuffledHashJoin" -> "true", + SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withParquetTable((0 until 1000).map(i => (i, i % 5)), "tbl_a") { + withParquetTable((0 until 1000).map(i => (i % 10, i + 2)), "tbl_b") { + // Inner join: build left + val df1 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df1, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // Right join: build left + val df2 = + sql( + "SELECT /*+ BROADCAST(tbl_a) */ * FROM tbl_a RIGHT JOIN tbl_b " + + "ON tbl_a._2 = tbl_b._1 AND tbl_a._1 > tbl_b._2") + checkSparkAnswerAndOperator( + df2, + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } + test("HashJoin without join filter") { withSQLConf( SQLConf.PREFER_SORTMERGEJOIN.key -> "false",