Skip to content

Commit

Permalink
Transition to v2 shims [Databricks] (#4857)
Browse files Browse the repository at this point in the history
* 301 compiling

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Merged the shims

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* db build errors

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* renamed folder

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* modify binary-dedupe.sh to reflect the new package and remove call to ShimLoader

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* some more changes to return the static classname for shuffle managers

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* init shim when getting shufflemanager

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* getVersion changes

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* add hypot back

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed buildShim

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* clean up

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed package v2

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* reference the correct package

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed duplicate versions of RapidsShuffleManager

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* addressed review comments

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fix db build

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Revert "fix db build"

This reverts commit cf5bd29.

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Revert "addressed review comments"

This reverts commit a2a6b7e.

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Revert "removed duplicate versions of RapidsShuffleManager"

This reverts commit 980eba3.

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed the non-existent folder

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed unused import

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* reverted shuffle manager and internal manager change

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* revert spark2diffs changes

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Fix 301db build

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* removed reference of ShimLoader.getSparkShims from doc

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* Revert 312db build fix

Will be fixed as a part of a separate patch

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* merge conflicts

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fix db build

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fix 301db

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fixed 304

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fixed 330 build errors

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

* fixed imports

Signed-off-by: Raza Jafri <rjafri@nvidia.com>

Co-authored-by: Raza Jafri <rjafri@nvidia.com>
  • Loading branch information
razajafri and razajafri authored Mar 8, 2022
1 parent a0aeaba commit 9f727bd
Show file tree
Hide file tree
Showing 278 changed files with 698 additions and 789 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@ import scala.reflect.api
import scala.reflect.runtime.universe._

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.SparkShimImpl

import org.apache.spark.internal.Logging

Expand Down Expand Up @@ -70,7 +71,7 @@ object ApiValidation extends Logging {
var printNewline = false

val sparkToShimMap = Map("3.0.1" -> "spark301", "3.1.1" -> "spark311")
val sparkVersion = ShimLoader.getSparkShims.getSparkShimVersion.toString
val sparkVersion = SparkShimImpl.getSparkShimVersion.toString
val shimVersion = sparkToShimMap(sparkVersion)

gpuKeys.foreach { e =>
Expand Down
6 changes: 3 additions & 3 deletions docs/dev/shims.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ In the following we provide recipes for typical scenarios addressed by the Shim
It's among the easiest issues to resolve. We define a method in SparkShims
trait covering a superset of parameters from all versions and call it
```
ShimLoader.getSparkShims.methodWithDiscrepancies(p_1, ..., p_n)
SparkShimImpl.methodWithDiscrepancies(p_1, ..., p_n)
```
instead of referencing it directly. Shim implementations are in charge of dispatching it further
instead of referencing it directly. Shim implementations (SparkShimImpl) are in charge of dispatching it further
to correct version-dependent methods. Moreover, unlike in the below sections
conflicts between versions are easily avoided by using different package or class names
for conflicting Shim implementations.
Expand All @@ -40,7 +40,7 @@ Upstream base classes we derive from might be incompatible in the sense that one
requires us to implement/override the method `M` whereas the other prohibits it by marking
the base implementation `final`, E.g. `org.apache.spark.sql.catalyst.trees.TreeNode` changes
between Spark 3.1.x and Spark 3.2.x. So instead of deriving from such classes directly we
inject an intermediate trait e.g. `com.nvidia.spark.rapids.shims.v2.ShimExpression` that
inject an intermediate trait e.g. `com.nvidia.spark.rapids.shims.ShimExpression` that
has a varying source code depending on the Spark version we compile against to overcome this
issue as you can see e.g., comparing TreeNode:
1. [ShimExpression For 3.0.x and 3.1.x](https://github.com/NVIDIA/spark-rapids/blob/main/sql-plugin/src/main/post320-treenode/scala/com/nvidia/spark/rapids/shims/v2/TreeNode.scala#L23)
Expand Down
19 changes: 9 additions & 10 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/301/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -164,6 +165,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/302/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -222,6 +224,7 @@
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/303/scala</source>
<source>${project.basedir}/src/main/301until304/scala</source>
<source>${project.basedir}/src/main/301until310-all/scala</source>
<source>${project.basedir}/src/main/301until310-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
Expand Down Expand Up @@ -327,7 +330,7 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/311/scala</source>
<source>${project.basedir}/src/main/311-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-noncdh/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
Expand All @@ -337,7 +340,6 @@
<source>${project.basedir}/src/main/311until320-all/scala</source>
<source>${project.basedir}/src/main/311until320-noncdh/scala</source>
<source>${project.basedir}/src/main/311until320-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/pre320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -464,7 +466,6 @@
<source>${project.basedir}/src/main/311+-all/scala</source>
<source>${project.basedir}/src/main/311until320-noncdh/scala</source>
<source>${project.basedir}/src/main/31xdb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -509,7 +510,7 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/312/scala</source>
<source>${project.basedir}/src/main/312-nondb/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-noncdh/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
Expand All @@ -519,7 +520,6 @@
<source>${project.basedir}/src/main/311until320-all/scala</source>
<source>${project.basedir}/src/main/311until320-noncdh/scala</source>
<source>${project.basedir}/src/main/311until320-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/pre320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -577,7 +577,6 @@
<source>${project.basedir}/src/main/311until320-all/scala</source>
<source>${project.basedir}/src/main/311until320-noncdh/scala</source>
<source>${project.basedir}/src/main/311until320-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/pre320-treenode/scala</source>
</sources>
</configuration>
Expand Down Expand Up @@ -629,7 +628,6 @@
<source>${project.basedir}/src/main/301until330-all/scala</source>
<source>${project.basedir}/src/main/311+-all/scala</source>
<source>${project.basedir}/src/main/311+-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/320/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
Expand Down Expand Up @@ -693,7 +691,6 @@
<source>${project.basedir}/src/main/301until330-all/scala</source>
<source>${project.basedir}/src/main/311+-all/scala</source>
<source>${project.basedir}/src/main/311+-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
Expand Down Expand Up @@ -757,9 +754,9 @@
<source>${project.basedir}/src/main/301until330-all/scala</source>
<source>${project.basedir}/src/main/311+-all/scala</source>
<source>${project.basedir}/src/main/311+-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/322+/scala</source>
<source>${project.basedir}/src/main/320until330-all/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
Expand Down Expand Up @@ -822,6 +819,7 @@
<source>${project.basedir}/src/main/311+-nondb/scala</source>
<source>${project.basedir}/src/main/320+/scala</source>
<source>${project.basedir}/src/main/321+/scala</source>
<source>${project.basedir}/src/main/322+/scala</source>
<source>${project.basedir}/src/main/330+/scala</source>
<source>${project.basedir}/src/main/post320-treenode/scala</source>
</sources>
Expand Down Expand Up @@ -879,6 +877,8 @@
<configuration>
<sources>
<source>${project.basedir}/src/main/301+-nondb/scala</source>
<source>${project.basedir}/src/main/311-nondb/scala</source>
<source>${project.basedir}/src/main/311cdh/scala</source>
<source>${project.basedir}/src/main/301until320-all/scala</source>
<source>${project.basedir}/src/main/301until320-nondb/scala</source>
<source>${project.basedir}/src/main/301until330-all/scala</source>
Expand All @@ -887,7 +887,6 @@
<source>${project.basedir}/src/main/311cdh/scala</source>
<source>${project.basedir}/src/main/311until320-all/scala</source>
<source>${project.basedir}/src/main/311until320-nondb/scala</source>
<source>${project.basedir}/src/main/311until330-all/scala</source>
<source>${project.basedir}/src/main/pre320-treenode/scala</source>
</sources>
</configuration>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ final class CastExprMeta[INPUT <: Cast](
// NOOP for anything prior to 3.2.0
case (_: StringType, dt:DecimalType) =>
// Spark 2.x: removed check for
// !ShimLoader.getSparkShims.isCastingStringToNegDecimalScaleSupported
// !SparkShimImpl.isCastingStringToNegDecimalScaleSupported
// this dealt with handling a bug fix that is only in newer versions of Spark
// (https://issues.apache.org/jira/browse/SPARK-37451)
// Since we don't know what version of Spark 3 they will be using
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import scala.reflect.ClassTag
import scala.util.control.NonFatal

import com.nvidia.spark.rapids.RapidsConf.{SUPPRESS_PLANNING_FAILURE, TEST_CONF}
import com.nvidia.spark.rapids.shims.v2._
import com.nvidia.spark.rapids.shims._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand Down Expand Up @@ -1397,7 +1397,7 @@ object GpuOverrides extends Logging {
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[ToUnixTimestamp](a, conf, p, r) {
override def shouldFallbackOnAnsiTimestamp: Boolean = false
// ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp
// SparkShimImpl.shouldFallbackOnAnsiTimestamp
}),
expr[UnixTimestamp](
"Returns the UNIX timestamp of current or specified time",
Expand All @@ -1410,7 +1410,7 @@ object GpuOverrides extends Logging {
TypeSig.STRING)),
(a, conf, p, r) => new UnixTimeExprMeta[UnixTimestamp](a, conf, p, r) {
override def shouldFallbackOnAnsiTimestamp: Boolean = false
// ShimLoader.getSparkShims.shouldFallbackOnAnsiTimestamp
// SparkShimImpl.shouldFallbackOnAnsiTimestamp

}),
expr[Hour](
Expand Down Expand Up @@ -2865,8 +2865,8 @@ object GpuOverrides extends Logging {
TypeSig.ARRAY + TypeSig.DECIMAL_128).nested(), TypeSig.all),
(sample, conf, p, r) => new GpuSampleExecMeta(sample, conf, p, r) {}
),
// ShimLoader.getSparkShims.aqeShuffleReaderExec,
// ShimLoader.getSparkShims.neverReplaceShowCurrentNamespaceCommand,
// SparkShimImpl.aqeShuffleReaderExec,
// SparkShimImpl.neverReplaceShowCurrentNamespaceCommand,
neverReplaceExec[ExecutedCommandExec]("Table metadata operation")
).collect { case r if r != null => (r.getClassFor.asSubclass(classOf[SparkPlan]), r) }.toMap

Expand Down Expand Up @@ -2955,7 +2955,7 @@ object GpuOverrides extends Logging {
// case c2r: ColumnarToRowExec => prepareExplainOnly(c2r.child)
case re: ReusedExchangeExec => prepareExplainOnly(re.child)
// case aqe: AdaptiveSparkPlanExec =>
// prepareExplainOnly(ShimLoader.getSparkShims.getAdaptiveInputPlan(aqe))
// prepareExplainOnly(SparkShimImpl.getAdaptiveInputPlan(aqe))
case sub: SubqueryExec => prepareExplainOnly(sub.child)
}
planAfter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ object GpuParquetFileFormat {
// they set when they get to 3.x. The default in 3.x is EXCEPTION which would be good
// for us.
/*
ShimLoader.getSparkShims.int96ParquetRebaseWrite(sqlConf) match {
SparkShimImpl.int96ParquetRebaseWrite(sqlConf) match {
case "EXCEPTION" =>
case "CORRECTED" =>
case "LEGACY" =>
Expand All @@ -90,7 +90,7 @@ object GpuParquetFileFormat {
meta.willNotWorkOnGpu(s"$other is not a supported rebase mode for int96")
}
ShimLoader.getSparkShims.parquetRebaseWrite(sqlConf) match {
SparkShimImpl.parquetRebaseWrite(sqlConf) match {
case "EXCEPTION" => //Good
case "CORRECTED" => //Good
case "LEGACY" =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,31 +97,31 @@ object GpuParquetScanBase {

// Spark 2.x doesn't support the rebase mode
/*
sqlConf.get(ShimLoader.getSparkShims.int96ParquetRebaseReadKey) match {
sqlConf.get(SparkShimImpl.int96ParquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is EXCEPTION")
s"${SparkShimImpl.int96ParquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.int96ParquetRebaseReadKey} is LEGACY")
s"${SparkShimImpl.int96ParquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
}
sqlConf.get(ShimLoader.getSparkShims.parquetRebaseReadKey) match {
sqlConf.get(SparkShimImpl.parquetRebaseReadKey) match {
case "EXCEPTION" => if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is EXCEPTION")
s"${SparkShimImpl.parquetRebaseReadKey} is EXCEPTION")
}
case "CORRECTED" => // Good
case "LEGACY" => // really is EXCEPTION for us...
if (schemaMightNeedNestedRebase) {
meta.willNotWorkOnGpu("Nested timestamp and date values are not supported when " +
s"${ShimLoader.getSparkShims.parquetRebaseReadKey} is LEGACY")
s"${SparkShimImpl.parquetRebaseReadKey} is LEGACY")
}
case other =>
meta.willNotWorkOnGpu(s"$other is not a supported read rebase mode")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.shims.v2.GpuCSVScan
import com.nvidia.spark.rapids.shims.GpuCSVScan

import org.apache.spark.sql.execution.FileSourceScanExec

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.shims.v2._
import com.nvidia.spark.rapids.shims._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.nvidia.spark.rapids
import java.io.{File, FileOutputStream}
import java.time.ZoneId

import com.nvidia.spark.rapids.shims.v2.TypeSigUtil
import com.nvidia.spark.rapids.shims.TypeSigUtil

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnaryExpression, WindowSpecDefinition}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._
import com.nvidia.spark.rapids.shims.v2._
import com.nvidia.spark.rapids.shims._

import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.rapids.execution.{GpuHashJoin, JoinTypeChecks}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
* limitations under the License.
*/

package com.nvidia.spark.rapids.shims.v2
package com.nvidia.spark.rapids.shims

import java.nio.charset.StandardCharsets

Expand Down
Loading

0 comments on commit 9f727bd

Please sign in to comment.