From 85a0950d9bdbc27cb81c8af2c5f7c9646bfa1a96 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 8 Sep 2021 10:34:23 +0800 Subject: [PATCH 1/3] Add leafNodeDefaultParallelism support Signed-off-by: Bobby Wang --- .../com/nvidia/spark/rapids/SparkShims.scala | 2 + .../spark/rapids/basicPhysicalOperators.scala | 3 +- .../spark/rapids/shims/v2/Spark30XShims.scala | 5 +++ .../spark/rapids/shims/v2/Spark32XShims.scala | 9 +++- .../apache/spark/sql/Spark32XShimsUtils.scala | 25 +++++++++++ .../apache/spark/sql/GpuSparkPlanSuite.scala | 45 +++++++++++++++++++ 6 files changed, 86 insertions(+), 3 deletions(-) create mode 100644 sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala create mode 100644 tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala index 8c1b4135785..adcf217eefb 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/SparkShims.scala @@ -252,6 +252,8 @@ trait SparkShims { def isCustomReaderExec(x: SparkPlan): Boolean def aqeShuffleReaderExec: ExecRule[_ <: SparkPlan] + + def leafNodeDefaultParallelism(ss: SparkSession): Int } abstract class SparkCommonShims extends SparkShims { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 60efc47f9e1..096ef41ac4d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -378,7 +378,8 @@ case class GpuRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range val start: Long = range.start val end: Long = range.end val step: Long = range.step - val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numSlices: Int = range.numSlices.getOrElse(ShimLoader.getSparkShims + .leafNodeDefaultParallelism(ShimLoader.getSparkShims.sessionFromPlan(this))) val numElements: BigInt = range.numElements val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step) diff --git a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala b/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala index 24bb47e86f2..84e8ceffcf2 100644 --- a/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala +++ b/sql-plugin/src/main/spark30+all/scala/com/nvidia/spark/rapids/shims/v2/Spark30XShims.scala @@ -96,4 +96,9 @@ trait Spark30XShims extends SparkShims { ExecChecks((TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_64 + TypeSig.ARRAY + TypeSig.STRUCT + TypeSig.MAP).nested(), TypeSig.all), (exec, conf, p, r) => new GpuCustomShuffleReaderMeta(exec, conf, p, r)) + + override def leafNodeDefaultParallelism(ss: SparkSession): Int = { + ss.sparkContext.defaultParallelism + } + } diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 659600d1810..3924083acc2 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -19,17 +19,17 @@ package com.nvidia.spark.rapids.shims.v2 import com.nvidia.spark.rapids._ import com.nvidia.spark.rapids.GpuOverrides.exec import com.nvidia.spark.rapids.shims._ - import org.apache.hadoop.fs.FileStatus import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.Spark32XShimsUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.util.DateFormatter import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters @@ -131,6 +131,11 @@ trait Spark32XShims extends SparkShims { // we will need to change the API to pass these values in. enableAddPartitions = true, enableDropPartitions = false) + + override def leafNodeDefaultParallelism(ss: SparkSession): Int = { + Spark32XShimsUtils.leafNodeDefaultParallelism(ss) + } + } // TODO dedupe utils inside shims diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala new file mode 100644 index 00000000000..ae1fda0727d --- /dev/null +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.SparkSession + +object Spark32XShimsUtils { + def leafNodeDefaultParallelism(ss: SparkSession): Int = { + ss.leafNodeDefaultParallelism + } +} \ No newline at end of file diff --git a/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala b/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala new file mode 100644 index 00000000000..c9b73555d47 --- /dev/null +++ b/tests/src/test/scala/org/apache/spark/sql/GpuSparkPlanSuite.scala @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.{Locale, TimeZone} + +import com.nvidia.spark.rapids.{ShimLoader, SparkSessionHolder} +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.plans.logical.Range + +class GpuSparkPlanSuite extends FunSuite { + + test("leafNodeDefaultParallelism for GpuRangeExec") { + + val conf = new SparkConf() + .set("spark.sql.leafNodeDefaultParallelism", "7") + .set("spark.rapids.sql.enabled", "true") + + SparkSessionHolder.withSparkSession(conf, spark => { + val defaultSlice = ShimLoader.getSparkShims.leafNodeDefaultParallelism(spark) + val ds = new Dataset(spark, Range(0, 20, 1, None), Encoders.LONG) + val partitions = ds.rdd.getNumPartitions + assert(partitions == defaultSlice) + }) + + } + +} + From f652195c7df16a7864a7591249e35a93f81c13ff Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 8 Sep 2021 13:08:54 +0800 Subject: [PATCH 2/3] minor change --- .../com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala | 2 +- .../scala/org/apache/spark/sql/Spark32XShimsUtils.scala | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala index 3924083acc2..bf97aad0764 100644 --- a/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala +++ b/sql-plugin/src/main/spark320/scala/com/nvidia/spark/rapids/shims/v2/Spark32XShims.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.util.DateFormatter import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.ParquetFilters diff --git a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala index ae1fda0727d..f6d85d56060 100644 --- a/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala +++ b/sql-plugin/src/main/spark320/scala/org/apache/spark/sql/Spark32XShimsUtils.scala @@ -16,10 +16,11 @@ package org.apache.spark.sql -import org.apache.spark.sql.SparkSession - object Spark32XShimsUtils { + def leafNodeDefaultParallelism(ss: SparkSession): Int = { ss.leafNodeDefaultParallelism } -} \ No newline at end of file + +} + From 76d4b1597f8e493fb694dca63f4f949a2b4050d4 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Wed, 8 Sep 2021 16:30:43 +0800 Subject: [PATCH 3/3] resolve comments --- .../scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala index 096ef41ac4d..3633436747d 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/basicPhysicalOperators.scala @@ -379,7 +379,7 @@ case class GpuRangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range val end: Long = range.end val step: Long = range.step val numSlices: Int = range.numSlices.getOrElse(ShimLoader.getSparkShims - .leafNodeDefaultParallelism(ShimLoader.getSparkShims.sessionFromPlan(this))) + .leafNodeDefaultParallelism(sparkSession)) val numElements: BigInt = range.numElements val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step)