Skip to content

Commit

Permalink
[SPARK-49544][CONNECT] Replace coarse-locking in SparkConnectExecutio…
Browse files Browse the repository at this point in the history
…nManager with ConcurrentMap

### What changes were proposed in this pull request?

Replace the coarse-locking mechanism implemented in SparkConnectExecutionManager with ConcurrentMap in order to ameliorate lock contention.

### Why are the changes needed?

When there are too many threads, e.g., ~10K threads on a 4-core node, OS scheduling may cause priority inversion that leads to a serious performance problems, e.g., a 1000s delay when reattaching to an execute holder.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing test cases.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48034 from changgyoopark-db/SPARK-49544.

Authored-by: Changgyoo Park <changgyoo.park@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
changgyoopark-db authored and HyukjinKwon committed Sep 11, 2024
1 parent 3cb8d6e commit b466f32
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@

package org.apache.spark.sql.connect.service

import java.util.UUID

import scala.collection.mutable
import scala.jdk.CollectionConverters._

import org.apache.spark.{SparkEnv, SparkSQLException}
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Observation
Expand All @@ -35,30 +33,19 @@ import org.apache.spark.util.SystemClock
* Object used to hold the Spark Connect execution state.
*/
private[connect] class ExecuteHolder(
val executeKey: ExecuteKey,
val request: proto.ExecutePlanRequest,
val sessionHolder: SessionHolder)
extends Logging {

val session = sessionHolder.session

val operationId = if (request.hasOperationId) {
try {
UUID.fromString(request.getOperationId).toString
} catch {
case _: IllegalArgumentException =>
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.FORMAT",
messageParameters = Map("handle" -> request.getOperationId))
}
} else {
UUID.randomUUID().toString
}

/**
* Tag that is set for this execution on SparkContext, via SparkContext.addJobTag. Used
* (internally) for cancellation of the Spark Jobs ran by this execution.
*/
val jobTag = ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, operationId)
val jobTag =
ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, executeKey.operationId)

/**
* Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group
Expand Down Expand Up @@ -278,7 +265,7 @@ private[connect] class ExecuteHolder(
request = request,
userId = sessionHolder.userId,
sessionId = sessionHolder.sessionId,
operationId = operationId,
operationId = executeKey.operationId,
jobTag = jobTag,
sparkSessionTags = sparkSessionTags,
reattachable = reattachable,
Expand All @@ -289,7 +276,10 @@ private[connect] class ExecuteHolder(
}

/** Get key used by SparkConnectExecutionManager global tracker. */
def key: ExecuteKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId)
def key: ExecuteKey = executeKey

/** Get the operation ID. */
def operationId: String = key.operationId
}

/** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.sql.connect.service

import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
Expand All @@ -36,6 +37,24 @@ import org.apache.spark.util.ThreadUtils
// Unique key identifying execution by combination of user, session and operation id
case class ExecuteKey(userId: String, sessionId: String, operationId: String)

object ExecuteKey {
def apply(request: proto.ExecutePlanRequest, sessionHolder: SessionHolder): ExecuteKey = {
val operationId = if (request.hasOperationId) {
try {
UUID.fromString(request.getOperationId).toString
} catch {
case _: IllegalArgumentException =>
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.FORMAT",
messageParameters = Map("handle" -> request.getOperationId))
}
} else {
UUID.randomUUID().toString
}
ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId)
}
}

/**
* Global tracker of all ExecuteHolder executions.
*
Expand All @@ -44,10 +63,9 @@ case class ExecuteKey(userId: String, sessionId: String, operationId: String)
*/
private[connect] class SparkConnectExecutionManager() extends Logging {

/** Hash table containing all current executions. Guarded by executionsLock. */
@GuardedBy("executionsLock")
private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] =
new mutable.HashMap[ExecuteKey, ExecuteHolder]()
/** Concurrent hash table containing all the current executions. */
private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] =
new ConcurrentHashMap[ExecuteKey, ExecuteHolder]()
private val executionsLock = new Object

/** Graveyard of tombstones of executions that were abandoned and removed. */
Expand All @@ -61,6 +79,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis())

/** Executor for the periodic maintenance */
@GuardedBy("executionsLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None

/**
Expand All @@ -76,27 +95,35 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
request.getUserContext.getUserId,
request.getSessionId,
previousSessionId)
val executeHolder = new ExecuteHolder(request, sessionHolder)
val executeKey = ExecuteKey(request, sessionHolder)
val executeHolder = executions.compute(
executeKey,
(executeKey, oldExecuteHolder) => {
// Check if the operation already exists, either in the active execution map, or in the
// graveyard of tombstones of executions that have been abandoned. The latter is to prevent
// double executions when the client retries, thinking it never reached the server, but in
// fact it did, and already got removed as abandoned.
if (oldExecuteHolder != null) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
messageParameters = Map("handle" -> executeKey.operationId))
}
if (getAbandonedTombstone(executeKey).isDefined) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
messageParameters = Map("handle" -> executeKey.operationId))
}
new ExecuteHolder(executeKey, request, sessionHolder)
})

sessionHolder.addExecuteHolder(executeHolder)

executionsLock.synchronized {
// Check if the operation already exists, both in active executions, and in the graveyard
// of tombstones of executions that have been abandoned.
// The latter is to prevent double execution when a client retries execution, thinking it
// never reached the server, but in fact it did, and already got removed as abandoned.
if (executions.get(executeHolder.key).isDefined) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS",
messageParameters = Map("handle" -> executeHolder.operationId))
}
if (getAbandonedTombstone(executeHolder.key).isDefined) {
throw new SparkSQLException(
errorClass = "INVALID_HANDLE.OPERATION_ABANDONED",
messageParameters = Map("handle" -> executeHolder.operationId))
if (!executions.isEmpty()) {
lastExecutionTimeMs = None
}
sessionHolder.addExecuteHolder(executeHolder)
executions.put(executeHolder.key, executeHolder)
lastExecutionTimeMs = None
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")
}
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.")

schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started.

Expand All @@ -108,43 +135,50 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* execution if still running, free all resources.
*/
private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = {
var executeHolder: Option[ExecuteHolder] = None
val executeHolder = executions.get(key)

if (executeHolder == null) {
return
}

// Put into abandonedTombstones before removing it from executions, so that the client ends up
// getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry.
if (abandoned) {
abandonedTombstones.put(key, executeHolder.getExecuteInfo)
}

// Remove the execution from the map *after* putting it in abandonedTombstones.
executions.remove(key)
executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId)

executionsLock.synchronized {
executeHolder = executions.remove(key)
executeHolder.foreach { e =>
// Put into abandonedTombstones under lock, so that if it's accessed it will end up
// with INVALID_HANDLE.OPERATION_ABANDONED error.
if (abandoned) {
abandonedTombstones.put(key, e.getExecuteInfo)
}
e.sessionHolder.removeExecuteHolder(e.operationId)
}
if (executions.isEmpty) {
lastExecutionTimeMs = Some(System.currentTimeMillis())
}
logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")
}
// close the execution outside the lock
executeHolder.foreach { e =>
e.close()
if (abandoned) {
// Update in abandonedTombstones: above it wasn't yet updated with closedTime etc.
abandonedTombstones.put(key, e.getExecuteInfo)
}

logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.")

executeHolder.close()
if (abandoned) {
// Update in abandonedTombstones: above it wasn't yet updated with closedTime etc.
abandonedTombstones.put(key, executeHolder.getExecuteInfo)
}
}

private[connect] def getExecuteHolder(key: ExecuteKey): Option[ExecuteHolder] = {
executionsLock.synchronized {
executions.get(key)
}
Option(executions.get(key))
}

private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = {
val sessionExecutionHolders = executionsLock.synchronized {
executions.filter(_._2.sessionHolder.key == key)
}
sessionExecutionHolders.foreach { case (_, executeHolder) =>
var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]()
executions.forEach((_, executeHolder) => {
if (executeHolder.sessionHolder.key == key) {
sessionExecutionHolders += executeHolder
}
})

sessionExecutionHolders.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
logInfo(
log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.")
Expand All @@ -161,11 +195,11 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
*/
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionsLock.synchronized {
def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = {
if (executions.isEmpty) {
Left(lastExecutionTimeMs.get)
} else {
Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq)
Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq)
}
}

Expand All @@ -177,16 +211,22 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
abandonedTombstones.asMap.asScala.values.toSeq
}

private[connect] def shutdown(): Unit = executionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
private[connect] def shutdown(): Unit = {
executionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
scheduledExecutor = None
}
scheduledExecutor = None

// note: this does not cleanly shut down the executions, but the server is shutting down.
executions.clear()
abandonedTombstones.invalidateAll()
if (lastExecutionTimeMs.isEmpty) {
lastExecutionTimeMs = Some(System.currentTimeMillis())

executionsLock.synchronized {
if (lastExecutionTimeMs.isEmpty) {
lastExecutionTimeMs = Some(System.currentTimeMillis())
}
}
}

Expand Down Expand Up @@ -225,19 +265,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging {

// Find any detached executions that expired and should be removed.
val toRemove = new mutable.ArrayBuffer[ExecuteHolder]()
executionsLock.synchronized {
val nowMs = System.currentTimeMillis()
val nowMs = System.currentTimeMillis()

executions.values.foreach { executeHolder =>
executeHolder.lastAttachedRpcTimeMs match {
case Some(detached) =>
if (detached + timeout <= nowMs) {
toRemove += executeHolder
}
case _ => // execution is active
}
executions.forEach((_, executeHolder) => {
executeHolder.lastAttachedRpcTimeMs match {
case Some(detached) =>
if (detached + timeout <= nowMs) {
toRemove += executeHolder
}
case _ => // execution is active
}
}
})

// .. and remove them.
toRemove.foreach { executeHolder =>
val info = executeHolder.getExecuteInfo
Expand All @@ -250,16 +289,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging {
}

// For testing.
private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized {
executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
private[connect] def setAllRPCsDeadline(deadlineMs: Long) = {
executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs))
}

// For testing.
private[connect] def interruptAllRPCs() = executionsLock.synchronized {
executions.values.foreach(_.interruptGrpcResponseSenders())
private[connect] def interruptAllRPCs() = {
executions.values().asScala.foreach(_.interruptGrpcResponseSenders())
}

private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized {
executions.values.toSeq
private[connect] def listExecuteHolders: Seq[ExecuteHolder] = {
executions.values().asScala.toSeq
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ class ExecuteEventsManagerSuite
.setClientType(DEFAULT_CLIENT_TYPE)
.build()

val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder)
val executeKey = ExecuteKey(executePlanRequest, sessionHolder)
val executeHolder = new ExecuteHolder(executeKey, executePlanRequest, sessionHolder)

val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK)
eventsManager.status_(executeStatus)
Expand Down

0 comments on commit b466f32

Please sign in to comment.