Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fake memory leaks in some test cases [databricks] #5955

Merged
merged 5 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ class MortgageSparkSuite extends FunSuite {
/**
* This is intentionally a def rather than a val so that scalatest uses the correct value (from
* this class or the derived class) when registering tests.
*
* @note You are likely to see device/host leaks from this test when using the
* RAPIDS Shuffle Manager. The reason for that is a race between cuDF's MemoryCleaner
* and the SparkContext shutdown. Because of this, shuffle buffers cached may not get
* cleaned (on shuffle unregister) when the MemoryCleaner exits.
*/
def adaptiveQueryEnabled = false

Expand Down
26 changes: 25 additions & 1 deletion sql-plugin/src/main/scala/com/nvidia/spark/rapids/Plugin.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.mutable.{Map => MutableMap}
import scala.util.Try
import scala.util.matching.Regex

import ai.rapids.cudf.{CudaException, CudaFatalException, CudfException}
import ai.rapids.cudf.{CudaException, CudaFatalException, CudfException, MemoryCleaner}
import com.nvidia.spark.rapids.python.PythonWorkerSemaphore

import org.apache.spark.{ExceptionFailure, SparkConf, SparkContext, TaskFailedReason}
Expand All @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStag
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.internal.StaticSQLConf
import org.apache.spark.sql.rapids.GpuShuffleEnv
import org.apache.spark.sql.rapids.execution.TrampolineUtil
import org.apache.spark.sql.util.QueryExecutionListener

class PluginException(msg: String) extends RuntimeException(msg)
Expand Down Expand Up @@ -209,6 +210,9 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
pluginContext: PluginContext,
extraConf: java.util.Map[String, String]): Unit = {
try {
// if configured, re-register checking leaks hook.
reRegisterCheckLeakHook()

val conf = new RapidsConf(extraConf.asScala.toMap)

// Compare if the cudf version mentioned in the classpath is equal to the version which
Expand Down Expand Up @@ -259,6 +263,26 @@ class RapidsExecutorPlugin extends ExecutorPlugin with Logging {
}
}

/**
* Re-register leaks checking hook if configured.
*/
private def reRegisterCheckLeakHook(): Unit = {
// DEFAULT_SHUTDOWN_THREAD in MemoryCleaner is responsible to check the leaks at shutdown time,
// it expects all other hooks are done before the checking
// as other hooks will close some resources.

if (MemoryCleaner.configuredDefaultShutdownHook) {
// Shutdown hooks are executed concurrently in JVM, and there is no execution order guarantee.
// See the doc of `Runtime.addShutdownHook`.
// Here we should wait Spark hooks to be done, or a false leak will be detected.
// See issue: https://github.com/NVIDIA/spark-rapids/issues/5854
//
// Here use `Spark ShutdownHookManager` to manage hooks with priority.
// 20 priority is small enough, will run after Spark hooks.
TrampolineUtil.addShutdownHook(20, MemoryCleaner.removeDefaultShutdownHook())
}
}

private def checkCudfVersion(conf: RapidsConf): Unit = {
try {
val pluginProps = RapidsPluginUtils.loadProps(RapidsPluginUtils.PLUGIN_PROPS_FILENAME)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.rapids.shims.SparkUpgradeExceptionShims
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
import org.apache.spark.util.{ShutdownHookManager, Utils}

object TrampolineUtil {
def doExecuteBroadcast[T](child: SparkPlan): Broadcast[T] = child.doExecuteBroadcast()
Expand Down Expand Up @@ -152,4 +152,9 @@ object TrampolineUtil {

/** Remove the task context for the current thread */
def unsetTaskContext(): Unit = TaskContext.unset()

/** Add shutdown hook with priority */
def addShutdownHook(priority: Int, runnable: Runnable): AnyRef = {
ShutdownHookManager.addShutdownHook(priority)(() => runnable.run())
}
}