Skip to content

Commit

Permalink
[SPARK-49548][CONNECT] Replace coarse-locking in SparkConnectSessionM…
Browse files Browse the repository at this point in the history
…anager with ConcurrentMap

### What changes were proposed in this pull request?

Replace the coarse-locking in SparkConnectSessionManager with ConcurrentMap in order to minimise lock contention when there are many sessions.

### Why are the changes needed?

It is a spin-off from apache#48034 where apache#48034 addresses many-execution cases whereas this addresses many-session situations.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing test cases.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#48036 from changgyoopark-db/SPARK-49548.

Authored-by: Changgyoo Park <changgyoo.park@databricks.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
changgyoopark-db authored and HyukjinKwon committed Sep 11, 2024
1 parent 07f5b2c commit 3cb8d6e
Showing 1 changed file with 49 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.connect.service

import java.util.UUID
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit}
import javax.annotation.concurrent.GuardedBy

import scala.collection.mutable
Expand All @@ -42,8 +42,8 @@ class SparkConnectSessionManager extends Logging {

private val sessionsLock = new Object

@GuardedBy("sessionsLock")
private val sessionStore = mutable.HashMap[SessionKey, SessionHolder]()
private val sessionStore: ConcurrentMap[SessionKey, SessionHolder] =
new ConcurrentHashMap[SessionKey, SessionHolder]()

private val closedSessionsCache =
CacheBuilder
Expand All @@ -52,6 +52,7 @@ class SparkConnectSessionManager extends Logging {
.build[SessionKey, SessionHolderInfo]()

/** Executor for the periodic maintenance */
@GuardedBy("sessionsLock")
private var scheduledExecutor: Option[ScheduledExecutorService] = None

private def validateSessionId(
Expand Down Expand Up @@ -121,43 +122,39 @@ class SparkConnectSessionManager extends Logging {
private def getSession(key: SessionKey, default: Option[() => SessionHolder]): SessionHolder = {
schedulePeriodicChecks() // Starts the maintenance thread if it hasn't started yet.

sessionsLock.synchronized {
// try to get existing session from store
val sessionOpt = sessionStore.get(key)
// create using default if missing
val session = sessionOpt match {
case Some(s) => s
case None =>
default match {
case Some(callable) =>
val session = callable()
sessionStore.put(key, session)
session
case None =>
null
}
}
// record access time before returning
session match {
case null =>
null
case s: SessionHolder =>
s.updateAccessTime()
s
}
// Get the existing session from the store or create a new one.
val session = default match {
case Some(callable) =>
sessionStore.computeIfAbsent(key, _ => callable())
case None =>
sessionStore.get(key)
}

// Record the access time before returning the session holder.
if (session != null) {
session.updateAccessTime()
}

session
}

// Removes session from sessionStore and returns it.
private def removeSessionHolder(key: SessionKey): Option[SessionHolder] = {
var sessionHolder: Option[SessionHolder] = None
sessionsLock.synchronized {
sessionHolder = sessionStore.remove(key)
sessionHolder.foreach { s =>
// Put into closedSessionsCache, so that it cannot get accidentally recreated
// by getOrCreateIsolatedSession.
closedSessionsCache.put(s.key, s.getSessionHolderInfo)
}

// The session holder should remain in the session store until it is added to the closed session
// cache, because of a subtle data race: a new session with the same key can be created if the
// closed session cache does not contain the key right after the key has been removed from the
// session store.
sessionHolder = Option(sessionStore.get(key))

sessionHolder.foreach { s =>
// Put into closedSessionsCache to prevent the same session from being recreated by
// getOrCreateIsolatedSession.
closedSessionsCache.put(s.key, s.getSessionHolderInfo)

// Then, remove the session holder from the session store.
sessionStore.remove(key)
}
sessionHolder
}
Expand All @@ -176,21 +173,24 @@ class SparkConnectSessionManager extends Logging {
sessionHolder.foreach(shutdownSessionHolder(_))
}

private[connect] def shutdown(): Unit = sessionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
private[connect] def shutdown(): Unit = {
sessionsLock.synchronized {
scheduledExecutor.foreach { executor =>
ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES))
}
scheduledExecutor = None
}
scheduledExecutor = None

// note: this does not cleanly shut down the sessions, but the server is shutting down.
sessionStore.clear()
closedSessionsCache.invalidateAll()
}

def listActiveSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
sessionStore.values.map(_.getSessionHolderInfo).toSeq
def listActiveSessions: Seq[SessionHolderInfo] = {
sessionStore.values().asScala.map(_.getSessionHolderInfo).toSeq
}

def listClosedSessions: Seq[SessionHolderInfo] = sessionsLock.synchronized {
def listClosedSessions: Seq[SessionHolderInfo] = {
closedSessionsCache.asMap.asScala.values.toSeq
}

Expand Down Expand Up @@ -246,18 +246,17 @@ class SparkConnectSessionManager extends Logging {
timeoutMs != -1 && info.lastAccessTimeMs + timeoutMs <= nowMs
}

sessionsLock.synchronized {
val nowMs = System.currentTimeMillis()
sessionStore.values.foreach { sessionHolder =>
if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
toRemove += sessionHolder
}
val nowMs = System.currentTimeMillis()
sessionStore.forEach((_, sessionHolder) => {
if (shouldExpire(sessionHolder.getSessionHolderInfo, nowMs)) {
toRemove += sessionHolder
}
}
})

// .. and remove them.
toRemove.foreach { sessionHolder =>
// This doesn't use closeSession to be able to do the extra last chance check under lock.
val removedSession = sessionsLock.synchronized {
val removedSession = {
// Last chance - check expiration time and remove under lock if expired.
val info = sessionHolder.getSessionHolderInfo
if (shouldExpire(info, System.currentTimeMillis())) {
Expand Down Expand Up @@ -309,7 +308,7 @@ class SparkConnectSessionManager extends Logging {
/**
* Used for testing
*/
private[connect] def invalidateAllSessions(): Unit = sessionsLock.synchronized {
private[connect] def invalidateAllSessions(): Unit = {
periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = true)
assert(sessionStore.isEmpty)
closedSessionsCache.invalidateAll()
Expand Down

0 comments on commit 3cb8d6e

Please sign in to comment.