Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fromCompletableFuture cancelation leak #3425

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ private[effect] sealed class FiberMonitor(
private[this] val compute: WorkStealingThreadPool
) extends FiberMonitorShared {

private[this] final val Bags = FiberMonitor.Bags
private[this] final val BagReferences = FiberMonitor.BagReferences
private[this] final val BagReferences =
new ConcurrentLinkedQueue[WeakReference[WeakBag[Runnable]]]
private[this] final val Bags = ThreadLocal.withInitial { () =>
val bag = new WeakBag[Runnable]()
BagReferences.offer(new WeakReference(bag))
bag
}

private[this] val justFibers: PartialFunction[(Runnable, Trace), (IOFiber[_], Trace)] = {
case (fiber: IOFiber[_], trace) => fiber -> trace
Expand Down Expand Up @@ -214,16 +219,4 @@ private[effect] final class NoOpFiberMonitor extends FiberMonitor(null) {
override def liveFiberSnapshot(print: String => Unit): Unit = {}
}

private[effect] object FiberMonitor extends FiberMonitorCompanionPlatform {

private[FiberMonitor] final val BagReferences
: ConcurrentLinkedQueue[WeakReference[WeakBag[Runnable]]] =
new ConcurrentLinkedQueue()

private[FiberMonitor] final val Bags: ThreadLocal[WeakBag[Runnable]] =
ThreadLocal.withInitial { () =>
val bag = new WeakBag[Runnable]()
BagReferences.offer(new WeakReference(bag))
bag
}
}
private[effect] object FiberMonitor extends FiberMonitorCompanionPlatform
18 changes: 9 additions & 9 deletions kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F]
* @param fut
* The `java.util.concurrent.CompletableFuture` to suspend in `F[_]`
*/
def fromCompletableFuture[A](fut: F[CompletableFuture[A]]): F[A] = flatMap(fut) { cf =>
cont {
new Cont[F, A, A] {
def apply[G[_]](
implicit
G: MonadCancelThrow[G]): (Either[Throwable, A] => Unit, G[A], F ~> G) => G[A] = {
(resume, get, lift) =>
G.uncancelable { poll =>
def fromCompletableFuture[A](fut: F[CompletableFuture[A]]): F[A] = cont {
new Cont[F, A, A] {
def apply[G[_]](
implicit
G: MonadCancelThrow[G]): (Either[Throwable, A] => Unit, G[A], F ~> G) => G[A] = {
(resume, get, lift) =>
G.uncancelable { poll =>
G.flatMap(poll(lift(fut))) { cf =>
val go = delay {
cf.handle[Unit] {
case (a, null) => resume(Right(a))
Expand All @@ -57,7 +57,7 @@ private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F]

G.productR(lift(go))(await)
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package cats.effect.kernel

import cats.effect.{BaseSpec, IO}
import cats.effect.testkit.TestControl
import cats.effect.unsafe.IORuntimeConfig

import scala.concurrent.duration._

import java.util.concurrent.{CancellationException, CompletableFuture}
import java.util.concurrent.atomic.AtomicBoolean

class AsyncPlatformSpec extends BaseSpec {

Expand All @@ -40,19 +43,30 @@ class AsyncPlatformSpec extends BaseSpec {
} yield ok
}

"backpressure on CompletableFuture cancelation" in ticked { implicit ticker =>
"backpressure on CompletableFuture cancelation" in real {
// a non-cancelable, never-completing CompletableFuture
def cf = new CompletableFuture[Unit] {
def mkcf() = new CompletableFuture[Unit] {
override def cancel(mayInterruptIfRunning: Boolean) = false
}

val io = for {
fiber <- IO.fromCompletableFuture(IO(cf)).start
_ <- smallDelay // time for the callback to be set-up
def go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromCompletableFuture {
IO {
started.set(true)
mkcf()
}
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

io must nonTerminate
TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}
}
}