Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Spark's Utils.getContextOrSparkClassLoader to load Shims [databricks] #5646

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,6 @@ You can build against different versions of the CUDA Toolkit by using qone of th
We support JDK8 as our main JDK version. However, it's possible to build and run with more modern
JDK versions as well. To this end set `JAVA_HOME` in the environment to your JDK root directory.

With JDK9+, you need to disable the default classloader manipulation option and set
spark.rapids.force.caller.classloader=false in your Spark application configuration. There are, however,
known issues with it, e.g. see #5513.

At the time of this writing, the most robust way to run the RAPIDS Accelerator is from a jar dedicated to
a single Spark version. To this end please use a single shim and specify `-DallowConventionalDistJar=true`

Expand Down
1 change: 1 addition & 0 deletions dist/unshimmed-common-from-spark311.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ com/nvidia/spark/rapids/SparkShimVersion*
com/nvidia/spark/rapids/SparkShims*
com/nvidia/spark/udf/Plugin*
org/apache/spark/sql/rapids/ProxyRapidsShuffleInternalManagerBase*
org/apache/spark/sql/rapids/execution/Unshimmed*
org/apache/spark/sql/rapids/RapidsShuffleManagerLike*
rapids/*.py
rapids4spark-version-info.properties
3 changes: 1 addition & 2 deletions jenkins/databricks/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ ICEBERG_SPARK_VER=$(echo $BASE_SPARK_VER | cut -d. -f1,2)
# Classloader config is here to work around classloader issues with
# --packages in distributed setups, should be fixed by
# https://github.com/NVIDIA/spark-rapids/pull/5646
ICEBERG_CONFS="--conf spark.rapids.force.caller.classloader=false \
--packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \
ICEBERG_CONFS="--packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \
--conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \
--conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \
--conf spark.sql.catalog.spark_catalog.type=hadoop \
Expand Down
8 changes: 0 additions & 8 deletions jenkins/spark-premerge-build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,6 @@ mvn_verify() {

# Triggering here until we change the jenkins file
rapids_shuffle_smoke_test

# non-caller classloader smoke test in pseudo-distributed
# standalone cluster
echo "Running test_cartesian_join_special_case_count with spark.rapids.force.caller.classloader=false"
PYSP_TEST_spark_rapids_force_caller_classloader=false \
NUM_LOCAL_EXECS=1 \
TEST_PARALLEL=0 \
./integration_tests/run_pyspark_from_build.sh -k 'test_cartesian_join_special_case_count[100]'
}

rapids_shuffle_smoke_test() {
Expand Down
1 change: 0 additions & 1 deletion jenkins/spark-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,6 @@ run_iceberg_tests() {
# --packages in distributed setups, should be fixed by
# https://github.com/NVIDIA/spark-rapids/pull/5646
SPARK_SUBMIT_FLAGS="$BASE_SPARK_SUBMIT_ARGS $SEQ_CONF \
--conf spark.rapids.force.caller.classloader=false \
--packages org.apache.iceberg:iceberg-spark-runtime-${ICEBERG_SPARK_VER}_2.12:${ICEBERG_VERSION} \
--conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \
--conf spark.sql.catalog.spark_catalog=org.apache.iceberg.spark.SparkSessionCatalog \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1479,14 +1479,6 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val FORCE_SHIMCALLER_CLASSLOADER = conf("spark.rapids.force.caller.classloader")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does anything else in spark2 use this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is merely due to spark2 cloning RapidsConf source. We verified there is no actual use of spark.rapids.force.caller.classloader in the spark 2 explain support.

.doc("Option to statically add shim's parallel world classloader URLs to " +
"the classloader of the ShimLoader class, typically Bootstrap classloader. This option" +
" uses reflection with setAccessible true on a classloader that is not created by Spark.")
.internal()
.booleanConf
.createWithDefault(value = true)

val SPARK_GPU_RESOURCE_NAME = conf("spark.rapids.gpu.resourceName")
.doc("The name of the Spark resource that represents a GPU that you want the plugin to use " +
"if using custom resources with Spark.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3953,7 +3953,7 @@ object GpuOverrides extends Logging {
.withPsNote(TypeEnum.STRUCT, "Round-robin partitioning is not supported for nested " +
s"structs if ${SQLConf.SORT_BEFORE_REPARTITION.key} is true")
.withPsNote(
Seq(TypeEnum.ARRAY, TypeEnum.MAP),
Seq(TypeEnum.ARRAY, TypeEnum.MAP),
"Round-robin partitioning is not supported if " +
s"${SQLConf.SORT_BEFORE_REPARTITION.key} is true"),
TypeSig.all),
Expand Down Expand Up @@ -4017,7 +4017,7 @@ object GpuOverrides extends Logging {
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 +
TypeSig.MAP + TypeSig.ARRAY + TypeSig.STRUCT)
.nested()
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
"not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
Expand All @@ -4033,7 +4033,7 @@ object GpuOverrides extends Logging {
.nested()
.withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " +
"converted between CPU and GPU")
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
"not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
Expand All @@ -4051,7 +4051,7 @@ object GpuOverrides extends Logging {
.nested()
.withPsNote(TypeEnum.BINARY, "only allowed when aggregate buffers can be " +
"converted between CPU and GPU")
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
.withPsNote(Seq(TypeEnum.ARRAY, TypeEnum.MAP),
"not allowed for grouping expressions")
.withPsNote(TypeEnum.STRUCT,
"not allowed for grouping expressions if containing Array or Map as child"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1497,14 +1497,6 @@ object RapidsConf {
.booleanConf
.createWithDefault(true)

val FORCE_SHIMCALLER_CLASSLOADER = conf("spark.rapids.force.caller.classloader")
.doc("Option to statically add shim's parallel world classloader URLs to " +
"the classloader of the ShimLoader class, typically Bootstrap classloader. This option" +
" uses reflection with setAccessible true on a classloader that is not created by Spark.")
.internal()
.booleanConf
.createWithDefault(value = true)

val SPARK_GPU_RESOURCE_NAME = conf("spark.rapids.gpu.resourceName")
.doc("The name of the Spark resource that represents a GPU that you want the plugin to use " +
"if using custom resources with Spark.")
Expand Down
66 changes: 5 additions & 61 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/ShimLoader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import com.nvidia.spark.rapids.iceberg.IcebergProvider
import org.apache.commons.lang3.reflect.MethodUtils
import scala.annotation.tailrec
import scala.collection.JavaConverters._
import scala.util.{Failure, Success, Try}
import scala.util.Try

import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL, SPARK_REVISION, SPARK_VERSION, SparkConf, SparkEnv}
import org.apache.spark.api.plugin.{DriverPlugin, ExecutorPlugin}
Expand All @@ -32,6 +32,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan}
import org.apache.spark.sql.rapids.execution.UnshimmedTrampolineUtil
import org.apache.spark.util.MutableURLClassLoader

/*
Expand Down Expand Up @@ -129,52 +130,6 @@ object ShimLoader extends Logging {
s"org.apache.spark.sql.rapids.shims.$shimId.RapidsShuffleInternalManager"
}

private def serializerClassloader(): Option[ClassLoader] = {
// Hypothesis: serializer is the most universal way to intercept classloaders

// https://github.com/apache/spark/blob/master/core/src/main/scala/
// org/apache/spark/serializer/JavaSerializer.scala#L147

// https://github.com/apache/spark/blob/master/core/src/main/scala/
// org/apache/spark/serializer/KryoSerializer.scala#L134

Option(SparkEnv.get)
.flatMap {
case env if !env.conf.get("spark.rapids.force.caller.classloader",
true.toString).toBoolean => Option(env.serializer)
case _ if (conventionalSingleShimJarDetected) => None
case _ =>
logInfo("Forcing shim caller classloader update (default behavior). " +
"If it causes issues with userClassPathFirst, set " +
"spark.rapids.force.caller.classloader to false!")
None
}
.flatMap { serializer =>
logInfo("Looking for a mutable classloader (defaultClassLoader) in SparkEnv.serializer " +
serializer)
// scalac generates accessor methods
val serdeClassLoader = MethodUtils
.invokeMethod(serializer, true, "defaultClassLoader")
.asInstanceOf[Option[ClassLoader]]
.getOrElse {
val threadContextClassLoader = Thread.currentThread().getContextClassLoader
logInfo(s"No defaultClassLoader found in $serializer, falling back " +
s"on Thread context classloader: " + threadContextClassLoader)
threadContextClassLoader
}

logInfo("Extracted Spark classloader from SparkEnv.serializer " + serdeClassLoader)
findURLClassLoader(serdeClassLoader)
}.orElse {
val shimLoaderCallerCl = getClass.getClassLoader
if (!conventionalSingleShimJarDetected) {
logInfo("Falling back on ShimLoader caller's classloader " + shimLoaderCallerCl)
}
Option(shimLoaderCallerCl)
}
}


@tailrec
private def findURLClassLoader(classLoader: ClassLoader): Option[ClassLoader] = {
// walk up the classloader hierarchy until we hit a classloader we can mutate
Expand Down Expand Up @@ -214,23 +169,12 @@ object ShimLoader extends Logging {
}

private def updateSparkClassLoader(): Unit = {
// TODO propose a proper addClassPathURL API to Spark similar to addJar but
// accepting non-file-based URI
serializerClassloader().foreach { urlAddable =>
findURLClassLoader(UnshimmedTrampolineUtil.sparkClassLoader).foreach { urlAddable =>
urlsForSparkClassLoader.foreach { url =>
if (!conventionalSingleShimJarDetected) {
logInfo(s"Updating spark classloader $urlAddable with the URLs: " +
urlsForSparkClassLoader.mkString(", "))
Try(MethodUtils.invokeMethod(urlAddable, true, "addURL", url))
.recoverWith {
case nsm: NoSuchMethodException =>
logWarning("JDK8+ detected, consider setting " +
"spark.rapids.force.caller.classloader to false as a workaround")
logDebug(s"JDK8+ detected by catching ${nsm}", nsm)
Success(Unit)
case t => Failure(t)
}.get

MethodUtils.invokeMethod(urlAddable, true, "addURL", url)
logInfo(s"Spark classLoader $urlAddable updated successfully")
urlAddable match {
case urlCl: java.net.URLClassLoader =>
Expand All @@ -240,7 +184,7 @@ object ShimLoader extends Logging {
s"classloader $urlCl although addURL succeeded, maybe pushed up to the " +
s"parent classloader ${urlCl.getParent}")
}
case _ => ()
case _ => Unit
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.rapids.execution

object UnshimmedTrampolineUtil {
def sparkClassLoader: ClassLoader = {
org.apache.spark.util.Utils.getContextOrSparkClassLoader
}
}