Skip to content

Commit

Permalink
Re-enable spark.rapids.shims-provider-override (#3497)
Browse files Browse the repository at this point in the history
- The class name provided can be an arbitrary implementation of
  SparkShimServiceProvider
- If the Shim is tested with other shims together and requires its own Parallel
  World, then it should follow the same conventions as existing Shims:
  the last component in its package name should be the root directory in
  the combined rapids-4-spark jar. If tested in isolation from other
  Shims, the standard jar layout can be used

Signed-off-by: Gera Shegalov <gera@apache.org>
  • Loading branch information
gerashegalov authored Sep 16, 2021
1 parent e332f1c commit d6ab561
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
val sparkConf = pluginContext.conf
RapidsPluginUtils.fixupConfigs(sparkConf)
val conf = new RapidsConf(sparkConf)
if (conf.shimsProviderOverride.isDefined) { // TODO test it, probably not working yet
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

if (GpuShuffleEnv.isRapidsShuffleAvailable &&
conf.shuffleTransportEarlyStart) {
Expand All @@ -174,9 +171,6 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
extraConf: java.util.Map[String, String]): Unit = {
try {
val conf = new RapidsConf(extraConf.asScala.toMap)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

// Compare if the cudf version mentioned in the classpath is equal to the version which
// plugin expects. If there is a version mismatch, throw error. This check can be disabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,14 @@ object RapidsConf {
"If you are using a custom Spark version such as Spark 3.0.1.0 then this can be used to " +
"specify the shims provider that matches the base Spark version of Spark 3.0.1, i.e.: " +
"com.nvidia.spark.rapids.shims.spark301.SparkShimServiceProvider. If you modified Spark " +
"then there is no guarantee the RAPIDS Accelerator will function properly.")
"then there is no guarantee the RAPIDS Accelerator will function properly." +
"When tested in a combined jar with other Shims, it's expected that the provided " +
"implementation follows the same convention as existing Spark shims. If its class" +
" name has the form com.nvidia.spark.rapids.shims.<shimId>.YourSparkShimServiceProvider. " +
"The last package name component, i.e., shimId, can be used in the combined jar as the root" +
" directory /shimId for any incompatible classes. When tested in isolation, no special " +
"jar root is required"
)
.stringConf
.createOptional

Expand Down
23 changes: 18 additions & 5 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import java.net.URL

import scala.collection.JavaConverters._

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf, SparkEnv}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -195,10 +195,22 @@ object ShimLoader extends Logging {
// ShimServiceProvider API definition is not shared via parent and we run
// into ClassCastExceptions. If we find a way to solve this then we can revert to ServiceLoader

// IMPORTANT don't use RapidsConf as it transitively references classes that must remain
// in parallel worlds
val shimServiceProviderOverrideClassName = Option(SparkEnv.get) // Spark-less RapidsConf.help
.flatMap(_.conf.getOption("spark.rapids.shims-provider-override"))
shimServiceProviderOverrideClassName.foreach { shimProviderClass =>
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
}

val serviceProviderListPath = SERVICE_LOADER_PREFIX + classOf[SparkShimServiceProvider].getName
val serviceProviderList = thisClassLoader.getResources(serviceProviderListPath)
.asScala.map(scala.io.Source.fromURL)
.flatMap(_.getLines())
val serviceProviderList = shimServiceProviderOverrideClassName
.map(clsName => Seq(clsName)).getOrElse {
thisClassLoader.getResources(serviceProviderListPath)
.asScala.map(scala.io.Source.fromURL)
.flatMap(_.getLines())
}

assert(serviceProviderList.nonEmpty, "Classpath should contain the resource for " +
serviceProviderListPath)
Expand All @@ -219,7 +231,8 @@ object ShimLoader extends Logging {
None
}
}.find { case (shimServiceProvider, _) =>
shimServiceProvider.matchesVersion(sparkVersion)
shimServiceProviderOverrideClassName.nonEmpty ||
shimServiceProvider.matchesVersion(sparkVersion)
}.map { case (inst, url) =>
shimURL = url
// this class will be loaded again by the real executor classloader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B

private val rapidsConf = new RapidsConf(conf)

// set the shim override if specified since the shuffle manager loads early
if (rapidsConf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(rapidsConf.shimsProviderOverride.get)
}

protected val wrapped = new SortShuffleManager(conf)

private[this] val transportEnabledMessage = if (!rapidsConf.shuffleTransportEnabled) {
Expand Down

0 comments on commit d6ab561

Please sign in to comment.