Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement per-shim parallel world jar classloader #3381

Merged
merged 21 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions sql-plugin/src/main/scala/com/nvidia/spark/SQLPlugin.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,7 +16,7 @@

package com.nvidia.spark

import com.nvidia.spark.rapids.{RapidsDriverPlugin, RapidsExecutorPlugin}
import com.nvidia.spark.rapids.ShimLoader

import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, SparkPlugin}
import org.apache.spark.internal.Logging
Expand All @@ -26,6 +26,9 @@ import org.apache.spark.internal.Logging
* To enable this plugin, set the config "spark.plugins" to `com.nvidia.spark.SQLPlugin`
*/
class SQLPlugin extends SparkPlugin with Logging {
override def driverPlugin(): DriverPlugin = new RapidsDriverPlugin
override def executorPlugin(): ExecutorPlugin = new RapidsExecutorPlugin
override def driverPlugin(): DriverPlugin =
ShimLoader.newInstanceOf("com.nvidia.spark.rapids.RapidsDriverPlugin")
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved

override def executorPlugin(): ExecutorPlugin =
ShimLoader.newInstanceOf("com.nvidia.spark.rapids.RapidsExecutorPlugin")
}
33 changes: 13 additions & 20 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@

package com.nvidia.spark.rapids

import java.util
import java.util.Properties
import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference}

import scala.collection.JavaConverters._
import scala.util.Try

import com.nvidia.spark.rapids.python.PythonWorkerSemaphore
import java.util
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin, PluginContext}
import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.sql.{DataFrame, SparkSessionExtensions}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
Expand All @@ -51,24 +51,6 @@ case class ColumnarOverrideRules() extends ColumnarRule with Logging {
override def postColumnarTransitions: Rule[SparkPlan] = overrideTransitions
}

/**
* Extension point to enable GPU SQL processing.
*/
class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
override def apply(extensions: SparkSessionExtensions): Unit = {
val pluginProps = RapidsPluginUtils.loadProps(RapidsPluginUtils.PLUGIN_PROPS_FILENAME)
logInfo(s"RAPIDS Accelerator build: $pluginProps")
val cudfProps = RapidsPluginUtils.loadProps(RapidsPluginUtils.CUDF_PROPS_FILENAME)
logInfo(s"cudf build: $cudfProps")
val pluginVersion = pluginProps.getProperty("version", "UNKNOWN")
val cudfVersion = cudfProps.getProperty("version", "UNKNOWN")
logWarning(s"RAPIDS Accelerator $pluginVersion using cudf $cudfVersion." +
s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false")
extensions.injectColumnar(_ => ColumnarOverrideRules())
extensions.injectQueryStagePrepRule(_ => GpuQueryStagePrepOverrides())
}
}

object RapidsPluginUtils extends Logging {
val CUDF_PROPS_FILENAME = "cudf-java-version-info.properties"
val PLUGIN_PROPS_FILENAME = "rapids4spark-version-info.properties"
Expand All @@ -82,6 +64,17 @@ object RapidsPluginUtils extends Logging {
private val KRYO_REGISTRATOR_KEY = "spark.kryo.registrator"
private val KRYO_REGISTRATOR_NAME = classOf[GpuKryoRegistrator].getName

{
val pluginProps = loadProps(RapidsPluginUtils.PLUGIN_PROPS_FILENAME)
logInfo(s"RAPIDS Accelerator build: $pluginProps")
val cudfProps = loadProps(RapidsPluginUtils.CUDF_PROPS_FILENAME)
logInfo(s"cudf build: $cudfProps")
val pluginVersion = pluginProps.getProperty("version", "UNKNOWN")
val cudfVersion = cudfProps.getProperty("version", "UNKNOWN")
logWarning(s"RAPIDS Accelerator $pluginVersion using cudf $cudfVersion." +
s" To disable GPU support set `${RapidsConf.SQL_ENABLED}` to false")
}

def fixupConfigs(conf: SparkConf): Unit = {
// First add in the SQL executor plugin because that is what we need at a minimum
if (conf.contains(SQL_PLUGIN_CONF_KEY)) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.nvidia.spark.rapids

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}

/**
* Extension point to enable GPU SQL processing.
*/
class SQLExecPlugin extends (SparkSessionExtensions => Unit) with Logging {
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectColumnar(columnarOverrides)
extensions.injectQueryStagePrepRule(queryStagePrepOverrides)
}

private def columnarOverrides(sparkSession: SparkSession): ColumnarRule = {
ShimLoader.newInstanceOf("com.nvidia.spark.rapids.ColumnarOverrideRules")
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}

private def queryStagePrepOverrides(sparkSession: SparkSession): Rule[SparkPlan] = {
ShimLoader.newInstanceOf("com.nvidia.spark.rapids.GpuQueryStagePrepOverrides")
}
}
203 changes: 175 additions & 28 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,53 +16,168 @@

package com.nvidia.spark.rapids

import java.util.ServiceLoader
import java.net.URL

import scala.collection.JavaConverters._

import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION}
import org.apache.spark.{SPARK_BUILD_USER, SPARK_VERSION, SparkConf}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.rapids.VisibleShuffleManager
import org.apache.spark.util.{MutableURLClassLoader, ParentClassLoader}

/*
Plugin jar uses non-standard class file layout. It consists of three types of areas
1. a few publicly documented classes in the conventional layout at the top
2. a large fraction of classes whose bytecode is identical under all supported Spark versions
in spark3xx-common
3. a smaller fraction of classes that differ under one of the supported Spark versions

com/nvidia/spark/SQLPlugin.class

spark3xx-common/com/nvidia/spark/rapids/CastExprMeta.class

spark301/org/apache/spark/sql/rapids/GpuUnaryMinus.class
spark311/org/apache/spark/sql/rapids/GpuUnaryMinus.class
spark320/org/apache/spark/sql/rapids/GpuUnaryMinus.class

Each shim can see a consistent parallel world without conflicts by referencing
only one conflicting directory.

E.g., Spark 3.2.0 Shim will use

jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark320/

Spark 3.1.1 will use

jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark3xx-common/
jar:file:/home/spark/rapids-4-spark_2.12-21.10.jar!/spark311/

Using these Jar URL's allows referencing different bytecode produced from identical sources
by incompatible Scala / Spark dependencies.
*/
object ShimLoader extends Logging {
private var shimProviderClass: String = null
private var sparkShims: SparkShims = null
logDebug(s"ShimLoader object instance: ${this} loaded by ${getClass.getClassLoader}")
private val shimRootURL = {
val thisClassFile = getClass.getName.replace(".", "/") + ".class"
val url = getClass.getClassLoader.getResource(thisClassFile)
val urlStr = url.toString
val rootUrlStr = urlStr.substring(0, urlStr.length - thisClassFile.length)
new URL(rootUrlStr)
}

private val shimCommonURL = new URL(s"${shimRootURL.toString}spark3xx-common/")

@volatile private var shimProviderClass: String = _
@volatile private var sparkShims: SparkShims = _
@volatile private var shimURL: URL = _
@volatile private var pluginClassLoader: ClassLoader = _

// REPL-only logic
@volatile private var tmpClassLoader: MutableURLClassLoader = _

def shimId: String = shimIdFromPackageName(shimProviderClass)

def getRapidsShuffleManagerClass: String = {
findShimProvider()
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
s"com.nvidia.spark.rapids.$shimId.RapidsShuffleManager"
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}

def getRapidsShuffleInternal: String = {
findShimProvider()
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
s"org.apache.spark.sql.rapids.shims.$shimId.RapidsShuffleInternalManager"
}

private def updateSparkClassLoader(): Unit = {
val contextClassLoader = Thread.currentThread().getContextClassLoader
Option(contextClassLoader).collect {
case mutable: MutableURLClassLoader => mutable
case replCL if replCL.getClass.getName == "org.apache.spark.repl.ExecutorClassLoader" =>
val parentLoaderField = replCL.getClass.getDeclaredMethod("parentLoader")
val parentLoader = parentLoaderField.invoke(replCL).asInstanceOf[ParentClassLoader]
parentLoader.getParent.asInstanceOf[MutableURLClassLoader]
}.foreach { mutable =>
// MutableURLClassloader dedupes for us
pluginClassLoader = contextClassLoader
mutable.addURL(shimURL)
mutable.addURL(shimCommonURL)
}
}

private def detectShimProvider(): SparkShimServiceProvider = {
private def getShimClassLoader(): ClassLoader = {
if (shimURL == null) {
findShimProvider()
}
if (pluginClassLoader == null) {
updateSparkClassLoader()
}
if (pluginClassLoader == null) {
if (tmpClassLoader == null) {
tmpClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL),
getClass.getClassLoader)
}
tmpClassLoader
} else {
pluginClassLoader
}
}

private def detectShimProvider(): String = {
val sparkVersion = getSparkVersion
logInfo(s"Loading shim for Spark version: $sparkVersion")

// This is not ideal, but pass the version in here because otherwise loader that match the
// same version (3.0.1 Apache and 3.0.1 Databricks) would need to know how to differentiate.
val sparkShimLoaders = ServiceLoader.load(classOf[SparkShimServiceProvider])
.asScala.filter(_.matchesVersion(sparkVersion))
if (sparkShimLoaders.size > 1) {
throw new IllegalArgumentException(s"Multiple Spark Shim Loaders found: $sparkShimLoaders")
// Emulating service loader manually because we have a non-standard jar layout for classes
val serviceProviderListPath = "/META-INF/services/" + classOf[SparkShimServiceProvider].getName
val serviceProviderList = Option(getClass.getResource(serviceProviderListPath)).toSeq
.map(scala.io.Source.fromURL)
.flatMap(_.getLines())

val shimServiceProviderOpt = serviceProviderList.flatMap { shimServiceProviderStr =>
val mask = shimIdFromPackageName(shimServiceProviderStr)
try {
val shimURL = new java.net.URL(s"${shimRootURL.toString}$mask/")
val shimClassLoader = new MutableURLClassLoader(Array(shimURL, shimCommonURL),
getClass.getClassLoader)
// can't use ServiceLoader with parallel world layout
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
val shimClass = shimClassLoader.loadClass(shimServiceProviderStr)
Option(
(instantiateClass(shimClass).asInstanceOf[SparkShimServiceProvider], shimURL)
)
} catch {
case cnf: ClassNotFoundException =>
logWarning(cnf + ": Could not load the provider", cnf)
None
}
}.find { case (shimServiceProvider, _) =>
shimServiceProvider.matchesVersion(sparkVersion)
}.map { case (inst, url) =>
shimURL = url
// this class will be loaded again by the real executor classloader
inst.getClass.getName
}
logInfo(s"Found shims: $sparkShimLoaders")
val loader = sparkShimLoaders.headOption match {
case Some(loader) => loader
case None =>

shimServiceProviderOpt.getOrElse {
throw new IllegalArgumentException(s"Could not find Spark Shim Loader for $sparkVersion")
}
loader
}

private def findShimProvider(): SparkShimServiceProvider = {
private def shimIdFromPackageName(shimServiceProvider: SparkShimServiceProvider) = {
shimServiceProvider.getClass.getPackage.toString.split('.').last
}

private def shimIdFromPackageName(shimServiceProviderStr: String) = {
shimServiceProviderStr.split('.').takeRight(2).head
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
}

private def findShimProvider(): String = {
// TODO restore support for shim provider override
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
if (shimProviderClass == null) {
detectShimProvider()
} else {
logWarning(s"Overriding Spark shims provider to $shimProviderClass. " +
"This may be an untested configuration!")
val providerClass = Class.forName(shimProviderClass)
val constructor = providerClass.getConstructor()
constructor.newInstance().asInstanceOf[SparkShimServiceProvider]
shimProviderClass = detectShimProvider()
}
shimProviderClass
}

def getSparkShims: SparkShims = {
if (sparkShims == null) {
val provider = findShimProvider()
sparkShims = provider.buildShim
sparkShims = newInstanceOf[SparkShimServiceProvider](findShimProvider()).buildShim
}
sparkShims
}
Expand All @@ -76,7 +191,39 @@ object ShimLoader extends Logging {
}
}

// TODO broken right now, check if this can be supported with parallel worlds
// it implies the prerequisite of having such a class in the conventional root jar entry
// - or the necessity of an additional parameter for specifying the shim subdirectory
// - or enforcing the convention the class file parent directory is the shimid that is also
// a top entry e.g. /spark301/com/nvidia/test/shim/spark301/Spark301Shims.class
def setSparkShimProviderClass(classname: String): Unit = {
shimProviderClass = classname
}

def newInstanceOf[T](className: String): T = {
val loader = getShimClassLoader()
logDebug(s"Loading $className using $loader with the parent loader ${loader.getParent}")
instantiateClass(loader.loadClass(className)).asInstanceOf[T]
}

def newInternalShuffleManager(conf: SparkConf, isDriver: Boolean): VisibleShuffleManager = {
val shuffleClassLoader = getShimClassLoader()
val shuffleClassName =
s"org.apache.spark.sql.rapids.shims.${shimId}.RapidsShuffleInternalManager"
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
val shuffleClass = shuffleClassLoader.loadClass(shuffleClassName)
shuffleClass.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
.newInstance(conf, java.lang.Boolean.valueOf(isDriver))
.asInstanceOf[VisibleShuffleManager]
}

// avoid cached constructors
private def instantiateClass[T](cls: Class[T]): T = {
logDebug(s"Instantiate ${cls.getName} using classloader " + cls.getClassLoader)
cls.getClassLoader match {
case m: MutableURLClassLoader => logDebug("GERA_DEBUG urls " + m.getURLs.mkString("\n"))
gerashegalov marked this conversation as resolved.
Show resolved Hide resolved
case _ =>
}
val constructor = cls.getConstructor()
constructor.newInstance()
}
}
Loading