diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala index a306856efa33c..b0b74a36e187b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManager.scala @@ -289,8 +289,10 @@ class SparkConnectSessionManager extends Logging { * Used for testing */ private[connect] def invalidateAllSessions(): Unit = { - periodicMaintenance(defaultInactiveTimeoutMs = 0L, ignoreCustomTimeout = true) - assert(sessionStore.isEmpty) + sessionStore.forEach((key, sessionHolder) => { + removeSessionHolder(key) + shutdownSessionHolder(sessionHolder) + }) closedSessionsCache.invalidateAll() } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index d6d137e6d91aa..5e88725691656 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -919,7 +919,8 @@ class SparkConnectServiceSuite } class MockSparkListener() extends SparkListener { val semaphoreStarted = new Semaphore(0) - var executeHolder = Option.empty[ExecuteHolder] + // Accessed by multiple threads in parallel. + @volatile var executeHolder = Option.empty[ExecuteHolder] override def onOtherEvent(event: SparkListenerEvent): Unit = { event match { case e: SparkListenerConnectOperationStarted =>