Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-enable spark.rapids.shims-provider-override #3497

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,6 @@ class RapidsDriverPlugin extends DriverPlugin with Logging {
val sparkConf = pluginContext.conf
RapidsPluginUtils.fixupConfigs(sparkConf)
val conf = new RapidsConf(sparkConf)
if (conf.shimsProviderOverride.isDefined) { // TODO test it, probably not working yet
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

if (GpuShuffleEnv.isRapidsShuffleAvailable &&
conf.shuffleTransportEarlyStart) {
Expand All @@ -174,9 +171,6 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
extraConf: java.util.Map[String, String]): Unit = {
try {
val conf = new RapidsConf(extraConf.asScala.toMap)
if (conf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(conf.shimsProviderOverride.get)
}

// Compare if the cudf version mentioned in the classpath is equal to the version which
// plugin expects. If there is a version mismatch, throw error. This check can be disabled
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1163,7 +1163,14 @@ object RapidsConf {
"If you are using a custom Spark version such as Spark 3.0.1.0 then this can be used to " +
"specify the shims provider that matches the base Spark version of Spark 3.0.1, i.e.: " +
"com.nvidia.spark.rapids.shims.spark301.SparkShimServiceProvider. If you modified Spark " +
"then there is no guarantee the RAPIDS Accelerator will function properly.")
"then there is no guarantee the RAPIDS Accelerator will function properly." +
"When tested in a combined jar with other Shims, it's expected that the provided " +
"implementation follows the same convention as existing Spark shims. If its class" +
" name has the form com.nvidia.spark.rapids.shims.<shimId>.YourSparkShimServiceProvider. " +
"The last package name component, i.e., shimId, can be used in the combined jar as the root" +
" directory /shimId for any incompatible classes. When tested in isolation, no special " +
"jar root is required"
)
.stringConf
.createOptional

Expand Down
24 changes: 18 additions & 6 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@ import java.net.URL

import scala.collection.JavaConverters._

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf, SparkEnv}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
import org.apache.spark.sql.rapids.VisibleShuffleManager
import org.apache.spark.util.{MutableURLClassLoader, ParentClassLoader}

/*
Expand Down Expand Up @@ -173,10 +172,22 @@ object ShimLoader extends Logging {
// ShimServiceProvider API definition is not shared via parent and we run
// into ClassCastExceptions. If we find a way to solve this then we can revert to ServiceLoader

// IMPORTANT don't use RapidsConf as it transitively references classes that must remain
// in parallel worlds
val shimServiceProviderOverrideClassName = Option(SparkEnv.get) // Spark-less RapidsConf.help
.flatMap(_.conf.getOption("spark.rapids.shims-provider-override"))
shimServiceProviderOverrideClassName.foreach { shimProviderClass =>
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
}

val serviceProviderListPath = SERVICE_LOADER_PREFIX + classOf[SparkShimServiceProvider].getName
val serviceProviderList = thisClassLoader.getResources(serviceProviderListPath)
.asScala.map(scala.io.Source.fromURL)
.flatMap(_.getLines())
val serviceProviderList = shimServiceProviderOverrideClassName
.map(clsName => Seq(clsName)).getOrElse {
thisClassLoader.getResources(serviceProviderListPath)
.asScala.map(scala.io.Source.fromURL)
.flatMap(_.getLines())
}

assert(serviceProviderList.nonEmpty, "Classpath should contain the resource for " +
serviceProviderListPath)
Expand All @@ -197,7 +208,8 @@ object ShimLoader extends Logging {
None
}
}.find { case (shimServiceProvider, _) =>
shimServiceProvider.matchesVersion(sparkVersion)
shimServiceProviderOverrideClassName.nonEmpty ||
shimServiceProvider.matchesVersion(sparkVersion)
}.map { case (inst, url) =>
shimURL = url
// this class will be loaded again by the real executor classloader
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,6 @@ abstract class RapidsShuffleInternalManagerBase(conf: SparkConf, val isDriver: B

private val rapidsConf = new RapidsConf(conf)

// set the shim override if specified since the shuffle manager loads early
if (rapidsConf.shimsProviderOverride.isDefined) {
ShimLoader.setSparkShimProviderClass(rapidsConf.shimsProviderOverride.get)
}

protected val wrapped = new SortShuffleManager(conf)

private[this] val transportEnabledMessage = if (!rapidsConf.shuffleTransportEnabled) {
Expand Down