Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Enable Comet shuffle in Spark SQL tests #210

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 188 additions & 14 deletions dev/diffs/3.4.2.diff
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,20 @@ index 56e9520fdab..917932336df 100644
spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
spark.range(100).write.saveAsTable(s"$dbName.$table2Name")

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 9ddb4abe98b..1bebe99f1cc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -3311,7 +3311,8 @@ class DataFrameSuite extends QueryTest
assert(df2.isLocal)
}

- test("SPARK-35886: PromotePrecision should be subexpr replaced") {
+ test("SPARK-35886: PromotePrecision should be subexpr replaced",
+ IgnoreComet("TODO: fix Comet for this test")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be forgotten easily. Can we create ticket for that?

withTable("tbl") {
sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
index f33432ddb6f..fe9f74ff8f1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala
Expand Down Expand Up @@ -365,7 +379,7 @@ index 00000000000..4b31bea33de
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 5125708be32..e274a497996 100644
index 5125708be32..a1f1ae90796 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
Expand All @@ -376,18 +390,30 @@ index 5125708be32..e274a497996 100644
import org.apache.spark.sql.execution.{BinaryExecNode, FilterExec, ProjectExec, SortExec, SparkPlan, WholeStageCodegenExec}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.exchange.{ShuffleExchangeExec, ShuffleExchangeLike}
@@ -1371,7 +1372,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
@@ -1369,9 +1370,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
}
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// No extra sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 5)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 5)
}

// Test output ordering is not preserved
@@ -1382,7 +1383,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
@@ -1380,9 +1384,12 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan
val selectExpr = "/*+ BROADCAST(left_t) */ k1 as k0"
val plan = sql(getJoinQuery(selectExpr, joinType)).queryExecution.executedPlan
assert(collect(plan) { case _: BroadcastNestedLoopJoinExec => true }.size === 1)
assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
- assert(collect(plan) { case _: SortMergeJoinExec => true }.size === 3)
+ assert(collect(plan) {
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
+ }.size === 3)
// Have sort on left side before last sort merge join
- assert(collect(plan) { case _: SortExec => true }.size === 6)
+ assert(collect(plan) { case _: SortExec | _: CometSortExec => true }.size === 6)
Expand All @@ -408,18 +434,19 @@ index b5b34922694..a72403780c4 100644
protected val baseResourcePath = {
// use the same way as `SQLQueryTestSuite` to get the resource path
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 3cfda19134a..afcfba37c6f 100644
index 3cfda19134a..278bb1060c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer
@@ -21,6 +21,8 @@ import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Join, LogicalPlan, Project, Sort, Union}
+import org.apache.spark.sql.comet.CometScanExec
+import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecution}
import org.apache.spark.sql.execution.datasources.FileScanRDD
@@ -1543,6 +1544,12 @@ class SubquerySuite extends QueryTest
@@ -1543,6 +1545,12 @@ class SubquerySuite extends QueryTest
fs.inputRDDs().forall(
_.asInstanceOf[FileScanRDD].filePartitions.forall(
_.files.forall(_.urlEncodedPath.contains("p=0"))))
Expand All @@ -432,6 +459,14 @@ index 3cfda19134a..afcfba37c6f 100644
case _ => false
})
}
@@ -2109,6 +2117,7 @@ class SubquerySuite extends QueryTest
df.collect()
val exchanges = collect(df.queryExecution.executedPlan) {
case s: ShuffleExchangeExec => s
+ case s: CometShuffleExchangeExec => s
}
assert(exchanges.size === 1)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
index cfc8b2cc845..c6fcfd7bd08 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/FileDataSourceV2FallBackSuite.scala
Expand Down Expand Up @@ -541,6 +576,28 @@ index ac710c32296..37746bd470d 100644
val projection = Seq.tabulate(columnNum)(i => s"c$i + c$i as newC$i")
val df = spark.read.parquet(path).selectExpr(projection: _*)

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
index 593bd7bb4ba..be1b82d0030 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala
@@ -29,6 +29,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListe
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy}
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
+import org.apache.spark.sql.comet._
import org.apache.spark.sql.execution.{CollectLimitExec, LocalTableScanExec, PartialReducerPartitionSpec, QueryExecution, ReusedSubqueryExec, ShuffledRowRDD, SortExec, SparkPlan, SparkPlanInfo, UnionExec}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.command.DataWritingCommandExec
@@ -116,6 +117,9 @@ class AdaptiveQueryExecSuite
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
collect(plan) {
case j: SortMergeJoinExec => j
+ case j: CometSortMergeJoinExec =>
+ assert(j.originalPlan.isInstanceOf[SortMergeJoinExec])
+ j.originalPlan.asInstanceOf[SortMergeJoinExec]
}
}

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index bd9c79e5b96..ab7584e768e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
Expand Down Expand Up @@ -935,7 +992,7 @@ index d083cac48ff..3c11bcde807 100644
import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 266bb343526..f393606997c 100644
index 266bb343526..b33bb677f0d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec
Expand Down Expand Up @@ -1051,7 +1108,16 @@ index 266bb343526..f393606997c 100644
checkAnswer(aggDF, df1.groupBy("j").agg(max("k")))
}
}
@@ -1031,10 +1060,16 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
@@ -1026,15 +1055,24 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti
expectedNumShuffles: Int,
expectedCoalescedNumBuckets: Option[Int]): Unit = {
val plan = sql(query).queryExecution.executedPlan
- val shuffles = plan.collect { case s: ShuffleExchangeExec => s }
+ val shuffles = plan.collect {
+ case s: ShuffleExchangeExec => s
+ case s: CometShuffleExchangeExec => s
+ }
assert(shuffles.length == expectedNumShuffles)

val scans = plan.collect {
case f: FileSourceScanExec if f.optionalNumCoalescedBuckets.isDefined => f
Expand Down Expand Up @@ -1139,6 +1205,94 @@ index 75f440caefc..36b1146bc3a 100644
}.headOption.getOrElse {
fail(s"No FileScan in query\n${df.queryExecution}")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
index b597a244710..b2e8be41065 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateDistributionSuite.scala
@@ -21,6 +21,7 @@ import java.io.File

import org.apache.commons.io.FileUtils

+import org.apache.spark.sql.IgnoreComet
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, MemoryStream}
import org.apache.spark.sql.internal.SQLConf
@@ -91,7 +92,7 @@ class FlatMapGroupsWithStateDistributionSuite extends StreamTest
}

test("SPARK-38204: flatMapGroupsWithState should require StatefulOpClusteredDistribution " +
- "from children - without initial state") {
+ "from children - without initial state", IgnoreComet("TODO: fix Comet for this test")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this.

// function will return -1 on timeout and returns count of the state otherwise
val stateFunc =
(key: (String, String), values: Iterator[(String, String, Long)],
@@ -243,7 +244,8 @@ class FlatMapGroupsWithStateDistributionSuite extends StreamTest
}

test("SPARK-38204: flatMapGroupsWithState should require ClusteredDistribution " +
- "from children if the query starts from checkpoint in 3.2.x - without initial state") {
+ "from children if the query starts from checkpoint in 3.2.x - without initial state",
+ IgnoreComet("TODO: fix Comet for this test")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

// function will return -1 on timeout and returns count of the state otherwise
val stateFunc =
(key: (String, String), values: Iterator[(String, String, Long)],
@@ -335,7 +337,8 @@ class FlatMapGroupsWithStateDistributionSuite extends StreamTest
}

test("SPARK-38204: flatMapGroupsWithState should require ClusteredDistribution " +
- "from children if the query starts from checkpoint in prior to 3.2") {
+ "from children if the query starts from checkpoint in prior to 3.2",
+ IgnoreComet("TODO: fix Comet for this test")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

// function will return -1 on timeout and returns count of the state otherwise
val stateFunc =
(key: (String, String), values: Iterator[(String, String, Long)],
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index 6aa7d0945c7..38523536154 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException

import org.apache.spark.SparkException
import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction
-import org.apache.spark.sql.{DataFrame, Encoder}
+import org.apache.spark.sql.{DataFrame, Encoder, IgnoreCometSuite}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsWithState
@@ -46,8 +46,9 @@ case class RunningCount(count: Long)

case class Result(key: Long, count: Int)

+// TODO: fix Comet to enable this suite
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, all this suite fail?

@SlowSQLTest
-class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest {
+class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with IgnoreCometSuite {

import testImplicits._

diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
index 2a2a83d35e1..e3b7b290b3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateWithInitialStateSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.streaming

import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Dataset, KeyValueGroupedDataset}
+import org.apache.spark.sql.{AnalysisException, Dataset, IgnoreComet, KeyValueGroupedDataset}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Update
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
@@ -253,7 +253,8 @@ class FlatMapGroupsWithStateWithInitialStateSuite extends StateStoreMetricsTest
assert(e.message.contains(expectedError))
}

- test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState") {
+ test("flatMapGroupsWithState - initial state - initial state has flatMapGroupsWithState",
+ IgnoreComet("TODO: fix Comet for this test")) {
val initialStateDS = Seq(("keyInStateAndData", new RunningCount(1))).toDS()
val initialState: KeyValueGroupedDataset[String, RunningCount] =
initialStateDS.groupByKey(_._1).mapValues(_._2)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
index abe606ad9c1..2d930b64cca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamTableAPISuite.scala
Expand Down Expand Up @@ -1221,10 +1375,10 @@ index dd55fcfe42c..cc18147d17a 100644

spark.internalCreateDataFrame(withoutFilters.execute(), schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
index ed2e309fa07..3767d4e7ca4 100644
index ed2e309fa07..4cfe0093da7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala
@@ -74,6 +74,18 @@ trait SharedSparkSessionBase
@@ -74,6 +74,21 @@ trait SharedSparkSessionBase
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
Expand All @@ -1238,6 +1392,9 @@ index ed2e309fa07..3767d4e7ca4 100644
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.comet.exec.all.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about comet columnar shuffle? I think we are prioritizing that as the first class shuffle support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is still pending on the work to make it as default. Once it becomes default, all the tests will automatically switch to use it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense.

+ }
+ }
conf.set(
Expand Down Expand Up @@ -1265,11 +1422,25 @@ index 52abd248f3a..7a199931a08 100644
case h: HiveTableScanExec => h.partitionPruningPred.collect {
case d: DynamicPruningExpression => d.child
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 1966e1e64fd..cde97a0aafe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -656,7 +656,8 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
Row(3, 4, 4, 3, null) :: Nil)
}

- test("single distinct multiple columns set") {
+ test("single distinct multiple columns set",
+ IgnoreComet("TODO: fix Comet for this test")) {
checkAnswer(
spark.sql(
"""
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 07361cfdce9..545b3184c23 100644
index 07361cfdce9..c5d94c92e32 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -55,25 +55,43 @@ object TestHive
@@ -55,25 +55,46 @@ object TestHive
new SparkContext(
System.getProperty("spark.sql.test.master", "local[1]"),
"TestSQLContext",
Expand Down Expand Up @@ -1322,6 +1493,9 @@ index 07361cfdce9..545b3184c23 100644
+ conf
+ .set("spark.comet.exec.enabled", "true")
+ .set("spark.comet.exec.all.enabled", "true")
+ .set("spark.shuffle.manager",
+ "org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager")
+ .set("spark.comet.exec.shuffle.enabled", "true")
+ }
+ }

Expand Down
Loading