Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Join worker threads on pool shutdown #3794

Merged
merged 3 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ class WorkStealingBenchmark {
"io-blocker",
60.seconds,
false,
1.second,
SleepSystem,
_.printStackTrace())

Expand Down
3 changes: 2 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf
scalacOptions ~= { _.filterNot(_.startsWith("-P:scalajs:mapSourceURI")) }
)
.jvmSettings(
Test / fork := true
fork := true,
Test / javaOptions += s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}"
)
.nativeSettings(
Compile / mainClass := Some("catseffect.examples.NativeRunner")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
runtimeBlockingExpiration,
reportFailure,
false,
1.second,
SleepSystem
)
(pool, shutdown)
Expand All @@ -77,6 +78,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
runtimeBlockingExpiration: Duration = 60.seconds,
reportFailure: Throwable => Unit = _.printStackTrace(),
blockedThreadDetectionEnabled: Boolean = false,
shutdownTimeout: Duration = 1.second,
pollingSystem: PollingSystem = SelectorSystem())
: (WorkStealingThreadPool[_], pollingSystem.Api, () => Unit) = {
val threadPool =
Expand All @@ -86,8 +88,10 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type
blockerThreadPrefix,
runtimeBlockingExpiration,
blockedThreadDetectionEnabled && (threads > 1),
shutdownTimeout,
pollingSystem,
reportFailure)
reportFailure
)

val unregisterMBeans =
if (isStackTracing) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ private[effect] final class WorkStealingThreadPool[P](
private[unsafe] val blockerThreadPrefix: String, // prefix for the name of worker threads currently in a blocking region
private[unsafe] val runtimeBlockingExpiration: Duration,
private[unsafe] val blockedThreadDetectionEnabled: Boolean,
shutdownTimeout: Duration,
system: PollingSystem.WithPoller[P],
reportFailure0: Throwable => Unit
) extends ExecutionContextExecutor
Expand Down Expand Up @@ -691,39 +692,72 @@ private[effect] final class WorkStealingThreadPool[P](
def shutdown(): Unit = {
// Clear the interrupt flag.
val interruptCalling = Thread.interrupted()
val currentThread = Thread.currentThread()

// Execute the shutdown logic only once.
if (done.compareAndSet(false, true)) {
// Send an interrupt signal to each of the worker threads.
workerThreadPublisher.get()

// Note: while loops and mutable variables are used throughout this method
// to avoid allocations of objects, since this method is expected to be
// executed mostly in situations where the thread pool is shutting down in
// the face of unhandled exceptions or as part of the whole JVM exiting.

workerThreadPublisher.get()

// Send an interrupt signal to each of the worker threads.
var i = 0
while (i < threadCount) {
workerThreads(i).interrupt()
system.closePoller(pollers(i))
val workerThread = workerThreads(i)
if (workerThread ne currentThread) {
workerThread.interrupt()
}
i += 1
}

system.close()
i = 0
var joinTimeout = shutdownTimeout match {
case Duration.Inf => Long.MaxValue
case d => d.toNanos
}
while (i < threadCount && joinTimeout > 0) {
val workerThread = workerThreads(i)
if (workerThread ne currentThread) {
val now = System.nanoTime()
workerThread.join(joinTimeout / 1000000, (joinTimeout % 1000000).toInt)
val elapsed = System.nanoTime() - now
joinTimeout -= elapsed
}
i += 1
}

// Clear the interrupt flag.
Thread.interrupted()
i = 0
var allClosed = true
while (i < threadCount) {
val workerThread = workerThreads(i)
// only close the poller if it is safe to do so, leak otherwise ...
if ((workerThread eq currentThread) || !workerThread.isAlive()) {
system.closePoller(pollers(i))
} else {
allClosed = false
}
i += 1
}

if (allClosed) {
system.close()
}

var t: WorkerThread[P] = null
while ({
t = cachedThreads.pollFirst()
t ne null
}) {
t.interrupt()
// don't bother joining, cached threads are not doing anything interesting
}

// Drain the external queue.
externalQueue.clear()
if (interruptCalling) Thread.currentThread().interrupt()
if (interruptCalling) currentThread.interrupt()
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/jvm/src/main/scala/catseffect/examplesplatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ package examples {
super.runtimeConfig.copy(shutdownHookTimeout = Duration.Zero)

val run: IO[Unit] =
IO(System.exit(0)).uncancelable
IO.blocking(System.exit(0)).uncancelable
Comment on lines 38 to +39
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an example of a deadlock that came up in our test suite. Actually, System.exit(...) is technically blocking, because it blocks waiting for shutdown hooks to complete.

If we run it on the compute pool, then it gets stuck waiting for the runtime to shutdown, which is stuck waiting for this thread to complete, ... deadlock.

}

object FatalErrorUnsafeRun extends IOApp {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck =>
runtimeBlockingExpiration = 3.seconds,
reportFailure0 = _.printStackTrace(),
blockedThreadDetectionEnabled = false,
shutdownTimeout = 1.second,
system = SleepSystem
)

Expand Down
Loading