diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala index ec7ebbe92d72e..dc349c3e33251 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.connect.service -import java.util.UUID - import scala.collection.mutable import scala.jdk.CollectionConverters._ -import org.apache.spark.{SparkEnv, SparkSQLException} +import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.internal.Logging import org.apache.spark.sql.Observation @@ -35,30 +33,19 @@ import org.apache.spark.util.SystemClock * Object used to hold the Spark Connect execution state. */ private[connect] class ExecuteHolder( + val executeKey: ExecuteKey, val request: proto.ExecutePlanRequest, val sessionHolder: SessionHolder) extends Logging { val session = sessionHolder.session - val operationId = if (request.hasOperationId) { - try { - UUID.fromString(request.getOperationId).toString - } catch { - case _: IllegalArgumentException => - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.FORMAT", - messageParameters = Map("handle" -> request.getOperationId)) - } - } else { - UUID.randomUUID().toString - } - /** * Tag that is set for this execution on SparkContext, via SparkContext.addJobTag. Used * (internally) for cancellation of the Spark Jobs ran by this execution. */ - val jobTag = ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, operationId) + val jobTag = + ExecuteJobTag(sessionHolder.userId, sessionHolder.sessionId, executeKey.operationId) /** * Tags set by Spark Connect client users via SparkSession.addTag. Used to identify and group @@ -278,7 +265,7 @@ private[connect] class ExecuteHolder( request = request, userId = sessionHolder.userId, sessionId = sessionHolder.sessionId, - operationId = operationId, + operationId = executeKey.operationId, jobTag = jobTag, sparkSessionTags = sparkSessionTags, reattachable = reattachable, @@ -289,7 +276,10 @@ private[connect] class ExecuteHolder( } /** Get key used by SparkConnectExecutionManager global tracker. */ - def key: ExecuteKey = ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId) + def key: ExecuteKey = executeKey + + /** Get the operation ID. */ + def operationId: String = key.operationId } /** Used to identify ExecuteHolder jobTag among SparkContext.SPARK_JOB_TAGS. */ diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala index 6681a5f509c6e..61b41f932199e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectExecutionManager.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit} +import java.util.UUID +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -36,6 +37,24 @@ import org.apache.spark.util.ThreadUtils // Unique key identifying execution by combination of user, session and operation id case class ExecuteKey(userId: String, sessionId: String, operationId: String) +object ExecuteKey { + def apply(request: proto.ExecutePlanRequest, sessionHolder: SessionHolder): ExecuteKey = { + val operationId = if (request.hasOperationId) { + try { + UUID.fromString(request.getOperationId).toString + } catch { + case _: IllegalArgumentException => + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.FORMAT", + messageParameters = Map("handle" -> request.getOperationId)) + } + } else { + UUID.randomUUID().toString + } + ExecuteKey(sessionHolder.userId, sessionHolder.sessionId, operationId) + } +} + /** * Global tracker of all ExecuteHolder executions. * @@ -44,10 +63,9 @@ case class ExecuteKey(userId: String, sessionId: String, operationId: String) */ private[connect] class SparkConnectExecutionManager() extends Logging { - /** Hash table containing all current executions. Guarded by executionsLock. */ - @GuardedBy("executionsLock") - private val executions: mutable.HashMap[ExecuteKey, ExecuteHolder] = - new mutable.HashMap[ExecuteKey, ExecuteHolder]() + /** Concurrent hash table containing all the current executions. */ + private val executions: ConcurrentMap[ExecuteKey, ExecuteHolder] = + new ConcurrentHashMap[ExecuteKey, ExecuteHolder]() private val executionsLock = new Object /** Graveyard of tombstones of executions that were abandoned and removed. */ @@ -61,6 +79,7 @@ private[connect] class SparkConnectExecutionManager() extends Logging { private var lastExecutionTimeMs: Option[Long] = Some(System.currentTimeMillis()) /** Executor for the periodic maintenance */ + @GuardedBy("executionsLock") private var scheduledExecutor: Option[ScheduledExecutorService] = None /** @@ -76,27 +95,35 @@ private[connect] class SparkConnectExecutionManager() extends Logging { request.getUserContext.getUserId, request.getSessionId, previousSessionId) - val executeHolder = new ExecuteHolder(request, sessionHolder) + val executeKey = ExecuteKey(request, sessionHolder) + val executeHolder = executions.compute( + executeKey, + (executeKey, oldExecuteHolder) => { + // Check if the operation already exists, either in the active execution map, or in the + // graveyard of tombstones of executions that have been abandoned. The latter is to prevent + // double executions when the client retries, thinking it never reached the server, but in + // fact it did, and already got removed as abandoned. + if (oldExecuteHolder != null) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", + messageParameters = Map("handle" -> executeKey.operationId)) + } + if (getAbandonedTombstone(executeKey).isDefined) { + throw new SparkSQLException( + errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", + messageParameters = Map("handle" -> executeKey.operationId)) + } + new ExecuteHolder(executeKey, request, sessionHolder) + }) + + sessionHolder.addExecuteHolder(executeHolder) + executionsLock.synchronized { - // Check if the operation already exists, both in active executions, and in the graveyard - // of tombstones of executions that have been abandoned. - // The latter is to prevent double execution when a client retries execution, thinking it - // never reached the server, but in fact it did, and already got removed as abandoned. - if (executions.get(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ALREADY_EXISTS", - messageParameters = Map("handle" -> executeHolder.operationId)) - } - if (getAbandonedTombstone(executeHolder.key).isDefined) { - throw new SparkSQLException( - errorClass = "INVALID_HANDLE.OPERATION_ABANDONED", - messageParameters = Map("handle" -> executeHolder.operationId)) + if (!executions.isEmpty()) { + lastExecutionTimeMs = None } - sessionHolder.addExecuteHolder(executeHolder) - executions.put(executeHolder.key, executeHolder) - lastExecutionTimeMs = None - logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") } + logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, executeHolder.key)} is created.") schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started. @@ -108,43 +135,50 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * execution if still running, free all resources. */ private[connect] def removeExecuteHolder(key: ExecuteKey, abandoned: Boolean = false): Unit = { - var executeHolder: Option[ExecuteHolder] = None + val executeHolder = executions.get(key) + + if (executeHolder == null) { + return + } + + // Put into abandonedTombstones before removing it from executions, so that the client ends up + // getting an INVALID_HANDLE.OPERATION_ABANDONED error on a retry. + if (abandoned) { + abandonedTombstones.put(key, executeHolder.getExecuteInfo) + } + + // Remove the execution from the map *after* putting it in abandonedTombstones. + executions.remove(key) + executeHolder.sessionHolder.removeExecuteHolder(executeHolder.operationId) + executionsLock.synchronized { - executeHolder = executions.remove(key) - executeHolder.foreach { e => - // Put into abandonedTombstones under lock, so that if it's accessed it will end up - // with INVALID_HANDLE.OPERATION_ABANDONED error. - if (abandoned) { - abandonedTombstones.put(key, e.getExecuteInfo) - } - e.sessionHolder.removeExecuteHolder(e.operationId) - } if (executions.isEmpty) { lastExecutionTimeMs = Some(System.currentTimeMillis()) } - logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") } - // close the execution outside the lock - executeHolder.foreach { e => - e.close() - if (abandoned) { - // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc. - abandonedTombstones.put(key, e.getExecuteInfo) - } + + logInfo(log"ExecuteHolder ${MDC(LogKeys.EXECUTE_KEY, key)} is removed.") + + executeHolder.close() + if (abandoned) { + // Update in abandonedTombstones: above it wasn't yet updated with closedTime etc. + abandonedTombstones.put(key, executeHolder.getExecuteInfo) } } private[connect] def getExecuteHolder(key: ExecuteKey): Option[ExecuteHolder] = { - executionsLock.synchronized { - executions.get(key) - } + Option(executions.get(key)) } private[connect] def removeAllExecutionsForSession(key: SessionKey): Unit = { - val sessionExecutionHolders = executionsLock.synchronized { - executions.filter(_._2.sessionHolder.key == key) - } - sessionExecutionHolders.foreach { case (_, executeHolder) => + var sessionExecutionHolders = mutable.ArrayBuffer[ExecuteHolder]() + executions.forEach((_, executeHolder) => { + if (executeHolder.sessionHolder.key == key) { + sessionExecutionHolders += executeHolder + } + }) + + sessionExecutionHolders.foreach { executeHolder => val info = executeHolder.getExecuteInfo logInfo( log"Execution ${MDC(LogKeys.EXECUTE_INFO, info)} removed in removeSessionExecutions.") @@ -161,11 +195,11 @@ private[connect] class SparkConnectExecutionManager() extends Logging { * If there are no executions, return Left with System.currentTimeMillis of last active * execution. Otherwise return Right with list of ExecuteInfo of all executions. */ - def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = executionsLock.synchronized { + def listActiveExecutions: Either[Long, Seq[ExecuteInfo]] = { if (executions.isEmpty) { Left(lastExecutionTimeMs.get) } else { - Right(executions.values.map(_.getExecuteInfo).toBuffer.toSeq) + Right(executions.values().asScala.map(_.getExecuteInfo).toBuffer.toSeq) } } @@ -177,16 +211,22 @@ private[connect] class SparkConnectExecutionManager() extends Logging { abandonedTombstones.asMap.asScala.values.toSeq } - private[connect] def shutdown(): Unit = executionsLock.synchronized { - scheduledExecutor.foreach { executor => - ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) + private[connect] def shutdown(): Unit = { + executionsLock.synchronized { + scheduledExecutor.foreach { executor => + ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) + } + scheduledExecutor = None } - scheduledExecutor = None + // note: this does not cleanly shut down the executions, but the server is shutting down. executions.clear() abandonedTombstones.invalidateAll() - if (lastExecutionTimeMs.isEmpty) { - lastExecutionTimeMs = Some(System.currentTimeMillis()) + + executionsLock.synchronized { + if (lastExecutionTimeMs.isEmpty) { + lastExecutionTimeMs = Some(System.currentTimeMillis()) + } } } @@ -225,19 +265,18 @@ private[connect] class SparkConnectExecutionManager() extends Logging { // Find any detached executions that expired and should be removed. val toRemove = new mutable.ArrayBuffer[ExecuteHolder]() - executionsLock.synchronized { - val nowMs = System.currentTimeMillis() + val nowMs = System.currentTimeMillis() - executions.values.foreach { executeHolder => - executeHolder.lastAttachedRpcTimeMs match { - case Some(detached) => - if (detached + timeout <= nowMs) { - toRemove += executeHolder - } - case _ => // execution is active - } + executions.forEach((_, executeHolder) => { + executeHolder.lastAttachedRpcTimeMs match { + case Some(detached) => + if (detached + timeout <= nowMs) { + toRemove += executeHolder + } + case _ => // execution is active } - } + }) + // .. and remove them. toRemove.foreach { executeHolder => val info = executeHolder.getExecuteInfo @@ -250,16 +289,16 @@ private[connect] class SparkConnectExecutionManager() extends Logging { } // For testing. - private[connect] def setAllRPCsDeadline(deadlineMs: Long) = executionsLock.synchronized { - executions.values.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) + private[connect] def setAllRPCsDeadline(deadlineMs: Long) = { + executions.values().asScala.foreach(_.setGrpcResponseSendersDeadline(deadlineMs)) } // For testing. - private[connect] def interruptAllRPCs() = executionsLock.synchronized { - executions.values.foreach(_.interruptGrpcResponseSenders()) + private[connect] def interruptAllRPCs() = { + executions.values().asScala.foreach(_.interruptGrpcResponseSenders()) } - private[connect] def listExecuteHolders: Seq[ExecuteHolder] = executionsLock.synchronized { - executions.values.toSeq + private[connect] def listExecuteHolders: Seq[ExecuteHolder] = { + executions.values().asScala.toSeq } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala index dbe8420eab03d..a9843e261fff8 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ExecuteEventsManagerSuite.scala @@ -374,7 +374,8 @@ class ExecuteEventsManagerSuite .setClientType(DEFAULT_CLIENT_TYPE) .build() - val executeHolder = new ExecuteHolder(executePlanRequest, sessionHolder) + val executeKey = ExecuteKey(executePlanRequest, sessionHolder) + val executeHolder = new ExecuteHolder(executeKey, executePlanRequest, sessionHolder) val eventsManager = ExecuteEventsManager(executeHolder, DEFAULT_CLOCK) eventsManager.status_(executeStatus)