Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle self-cancellation when running fibers to avoid deadlocks #1484

Merged
merged 18 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/js/src/main/scala/cats/effect/IOApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package cats.effect

import scala.concurrent.CancellationException
import scala.concurrent.duration._
import scala.scalajs.js

Expand Down Expand Up @@ -46,7 +47,7 @@ trait IOApp {
.raceOutcome[ExitCode, Nothing](run(argList), keepAlive)
.flatMap {
case Left(Outcome.Canceled()) =>
IO.raiseError(new RuntimeException("IOApp main fiber canceled"))
IO.raiseError(new CancellationException("IOApp main fiber was canceled"))
case Left(Outcome.Errored(t)) => IO.raiseError(t)
case Left(Outcome.Succeeded(code)) => code
case Right(Outcome.Errored(t)) => IO.raiseError(t)
Expand Down
27 changes: 14 additions & 13 deletions core/jvm/src/main/scala/cats/effect/IOApp.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package cats.effect

import scala.concurrent.CancellationException

import java.util.concurrent.CountDownLatch

trait IOApp {
Expand All @@ -35,20 +37,19 @@ trait IOApp {
val ioa = run(args.toList)

val fiber =
ioa
.onCancel(IO {
error = new RuntimeException("IOApp main fiber canceled")
ioa.unsafeRunFiber(
{
error = new CancellationException("IOApp main fiber was canceled")
latch.countDown()
},
{ t =>
error = t
latch.countDown()
},
{ a =>
result = a
latch.countDown()
})
.unsafeRunFiber(
{ t =>
error = t
latch.countDown()
},
{ a =>
result = a
latch.countDown()
})(runtime)
})(runtime)

def handleShutdown(): Unit = {
if (latch.getCount() > 0) {
Expand Down
13 changes: 4 additions & 9 deletions core/jvm/src/main/scala/cats/effect/IOPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,10 @@ abstract private[effect] class IOPlatform[+A] { self: IO[A] =>
var results: Either[Throwable, A] = null
val latch = new CountDownLatch(1)

unsafeRunFiber(
{ t =>
results = Left(t)
latch.countDown()
},
{ a =>
results = Right(a)
latch.countDown()
})
unsafeRunAsync { r =>
results = r
latch.countDown()
}

if (latch.await(limit.toNanos, TimeUnit.NANOSECONDS)) {
results.fold(throw _, a => Some(a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ private[effect] final class WorkStealingThreadPool(
}

// `unsafeRunFiber(true)` will enqueue the fiber, no need to do it manually
IO(runnable.run()).unsafeRunFiber(reportFailure, _ => ())(self)
IO(runnable.run()).unsafeRunFiber((), reportFailure, _ => ())(self)
()
}
}
Expand Down
31 changes: 26 additions & 5 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package cats.effect
import cats.{
Applicative,
Eval,
Id,
Monoid,
Now,
Parallel,
Expand All @@ -33,7 +34,13 @@ import cats.effect.instances.spawn
import cats.effect.std.Console

import scala.annotation.unchecked.uncheckedVariance
import scala.concurrent.{ExecutionContext, Future, Promise, TimeoutException}
import scala.concurrent.{
CancellationException,
ExecutionContext,
Future,
Promise,
TimeoutException
}
import scala.concurrent.duration._
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -194,7 +201,19 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {

def unsafeRunAsync(cb: Either[Throwable, A] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
unsafeRunFiber(t => cb(Left(t)), a => cb(Right(a)))
unsafeRunFiber(
cb(Left(new CancellationException("Root fiber was canceled"))),
t => cb(Left(t)),
a => cb(Right(a)))
()
}

def unsafeRunAsyncOutcome(cb: Outcome[Id, Throwable, A @uncheckedVariance] => Unit)(
implicit runtime: unsafe.IORuntime): Unit = {
unsafeRunFiber(
cb(Outcome.canceled),
t => cb(Outcome.errored(t)),
a => cb(Outcome.succeeded(a: Id[A])))
()
}

Expand All @@ -212,14 +231,16 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {
p.future
}

private[effect] def unsafeRunFiber(failure: Throwable => Unit, success: A => Unit)(
implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {
private[effect] def unsafeRunFiber(
canceled: => Unit,
failure: Throwable => Unit,
success: A => Unit)(implicit runtime: unsafe.IORuntime): IOFiber[A @uncheckedVariance] = {

val fiber = new IOFiber[A](
0,
oc =>
oc.fold(
(),
canceled,
{ t =>
runtime.fiberErrorCbs.remove(failure)
failure(t)
Expand Down
18 changes: 9 additions & 9 deletions core/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit

"preserve monad right identity on uncancelable" in ticked { implicit ticker =>
val fa = IO.uncancelable(_ => IO.canceled)
fa.flatMap(IO.pure(_)) must nonTerminate
fa must nonTerminate
fa.flatMap(IO.pure(_)) must selfCancel
fa must selfCancel
}

}
Expand Down Expand Up @@ -621,11 +621,11 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit

"cancel flatMap continuations following a canceled uncancelable block" in ticked {
implicit ticker =>
IO.uncancelable(_ => IO.canceled).flatMap(_ => IO.pure(())) must nonTerminate
IO.uncancelable(_ => IO.canceled).flatMap(_ => IO.pure(())) must selfCancel
}

"cancel map continuations following a canceled uncancelable block" in ticked {
implicit ticker => IO.uncancelable(_ => IO.canceled).map(_ => ()) must nonTerminate
implicit ticker => IO.uncancelable(_ => IO.canceled).map(_ => ()) must selfCancel
}

"sequence onCancel when canceled before registration" in ticked { implicit ticker =>
Expand All @@ -634,7 +634,7 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit
IO.canceled >> poll(IO.unit).onCancel(IO { passed = true })
}

test must nonTerminate
test must selfCancel
passed must beTrue
}

Expand All @@ -644,15 +644,15 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit
IO.canceled >> poll(IO.unit) >> IO { passed = false }
}

test must nonTerminate
test must selfCancel
passed must beTrue
}

"not invoke onCancel when previously canceled within uncancelable" in ticked {
implicit ticker =>
var failed = false
IO.uncancelable(_ =>
IO.canceled >> IO.unit.onCancel(IO { failed = true })) must nonTerminate
IO.canceled >> IO.unit.onCancel(IO { failed = true })) must selfCancel
failed must beFalse
}

Expand Down Expand Up @@ -712,7 +712,7 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit
poll(poll(IO.unit) >> IO.canceled) >> IO { passed = false }
}

test must nonTerminate
test must selfCancel
passed must beTrue
}

Expand Down Expand Up @@ -834,7 +834,7 @@ class IOSpec extends IOPlatformSpecification with Discipline with ScalaCheck wit

IO.canceled
.guarantee(IO { inner = true })
.guarantee(IO { outer = true }) must nonTerminate
.guarantee(IO { outer = true }) must selfCancel

inner must beTrue
outer must beTrue
Expand Down
14 changes: 9 additions & 5 deletions core/shared/src/test/scala/cats/effect/Runners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package cats.effect

import cats.{Applicative, Eq, Id, Order, Show}
import cats.{~>, Applicative, Eq, Id, Order, Show}
import cats.effect.testkit.{
AsyncGenerators,
GenK,
Expand Down Expand Up @@ -274,6 +274,9 @@ trait Runners extends SpecificationLike with RunnersPlatform { outer =>
def nonTerminate(implicit ticker: Ticker): Matcher[IO[Unit]] =
tickTo[Unit](Outcome.Succeeded(None))

def selfCancel(implicit ticker: Ticker): Matcher[IO[Unit]] =
tickTo[Unit](Outcome.Canceled())

def beCanceledSync: Matcher[SyncIO[Unit]] =
(ioa: SyncIO[Unit]) => unsafeRunSync(ioa) eqv Outcome.canceled

Expand All @@ -297,14 +300,15 @@ trait Runners extends SpecificationLike with RunnersPlatform { outer =>
def mustEqual(a: A) = fa.flatMap { res => IO(res must beEqualTo(a)) }
}

private val someK: Id ~> Option =
new ~>[Id, Option] { def apply[A](a: A) = a.some }

def unsafeRun[A](ioa: IO[A])(implicit ticker: Ticker): Outcome[Option, Throwable, A] =
try {
var results: Outcome[Option, Throwable, A] = Outcome.Succeeded(None)

ioa.unsafeRunAsync {
case Left(t) => results = Outcome.Errored(t)
case Right(a) => results = Outcome.Succeeded(Some(a))
}(unsafe.IORuntime(ticker.ctx, ticker.ctx, scheduler, () => ()))
ioa.unsafeRunAsyncOutcome { oc => results = oc.mapK(someK) }(
unsafe.IORuntime(ticker.ctx, ticker.ctx, scheduler, () => ()))

ticker.ctx.tickAll(1.days)

Expand Down