Skip to content

Commit

Permalink
Merge pull request #3425 from armanbilge/bug/completable-future-cance…
Browse files Browse the repository at this point in the history
…llation-leak
  • Loading branch information
djspiewak authored Feb 17, 2023
2 parents 6820f01 + 8d80978 commit 33269d3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,13 @@ private[effect] sealed class FiberMonitor(
private[this] val compute: WorkStealingThreadPool
) extends FiberMonitorShared {

private[this] final val Bags = FiberMonitor.Bags
private[this] final val BagReferences = FiberMonitor.BagReferences
private[this] final val BagReferences =
new ConcurrentLinkedQueue[WeakReference[WeakBag[Runnable]]]
private[this] final val Bags = ThreadLocal.withInitial { () =>
val bag = new WeakBag[Runnable]()
BagReferences.offer(new WeakReference(bag))
bag
}

private[this] val justFibers: PartialFunction[(Runnable, Trace), (IOFiber[_], Trace)] = {
case (fiber: IOFiber[_], trace) => fiber -> trace
Expand Down Expand Up @@ -214,16 +219,4 @@ private[effect] final class NoOpFiberMonitor extends FiberMonitor(null) {
override def liveFiberSnapshot(print: String => Unit): Unit = {}
}

private[effect] object FiberMonitor extends FiberMonitorCompanionPlatform {

private[FiberMonitor] final val BagReferences
: ConcurrentLinkedQueue[WeakReference[WeakBag[Runnable]]] =
new ConcurrentLinkedQueue()

private[FiberMonitor] final val Bags: ThreadLocal[WeakBag[Runnable]] =
ThreadLocal.withInitial { () =>
val bag = new WeakBag[Runnable]()
BagReferences.offer(new WeakReference(bag))
bag
}
}
private[effect] object FiberMonitor extends FiberMonitorCompanionPlatform
18 changes: 9 additions & 9 deletions kernel/jvm/src/main/scala/cats/effect/kernel/AsyncPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F]
* @param fut
* The `java.util.concurrent.CompletableFuture` to suspend in `F[_]`
*/
def fromCompletableFuture[A](fut: F[CompletableFuture[A]]): F[A] = flatMap(fut) { cf =>
cont {
new Cont[F, A, A] {
def apply[G[_]](
implicit
G: MonadCancelThrow[G]): (Either[Throwable, A] => Unit, G[A], F ~> G) => G[A] = {
(resume, get, lift) =>
G.uncancelable { poll =>
def fromCompletableFuture[A](fut: F[CompletableFuture[A]]): F[A] = cont {
new Cont[F, A, A] {
def apply[G[_]](
implicit
G: MonadCancelThrow[G]): (Either[Throwable, A] => Unit, G[A], F ~> G) => G[A] = {
(resume, get, lift) =>
G.uncancelable { poll =>
G.flatMap(poll(lift(fut))) { cf =>
val go = delay {
cf.handle[Unit] {
case (a, null) => resume(Right(a))
Expand All @@ -57,7 +57,7 @@ private[kernel] trait AsyncPlatform[F[_]] extends Serializable { this: Async[F]

G.productR(lift(go))(await)
}
}
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
package cats.effect.kernel

import cats.effect.{BaseSpec, IO}
import cats.effect.testkit.TestControl
import cats.effect.unsafe.IORuntimeConfig

import scala.concurrent.duration._

import java.util.concurrent.{CancellationException, CompletableFuture}
import java.util.concurrent.atomic.AtomicBoolean

class AsyncPlatformSpec extends BaseSpec {

Expand All @@ -40,19 +43,30 @@ class AsyncPlatformSpec extends BaseSpec {
} yield ok
}

"backpressure on CompletableFuture cancelation" in ticked { implicit ticker =>
"backpressure on CompletableFuture cancelation" in real {
// a non-cancelable, never-completing CompletableFuture
def cf = new CompletableFuture[Unit] {
def mkcf() = new CompletableFuture[Unit] {
override def cancel(mayInterruptIfRunning: Boolean) = false
}

val io = for {
fiber <- IO.fromCompletableFuture(IO(cf)).start
_ <- smallDelay // time for the callback to be set-up
def go = for {
started <- IO(new AtomicBoolean)
fiber <- IO.fromCompletableFuture {
IO {
started.set(true)
mkcf()
}
}.start
_ <- IO.cede.whileM_(IO(!started.get))
_ <- fiber.cancel
} yield ()

io must nonTerminate
TestControl
.executeEmbed(go, IORuntimeConfig(1, 2))
.as(false)
.recover { case _: TestControl.NonTerminationException => true }
.replicateA(1000)
.map(_.forall(identity(_)))
}
}
}

0 comments on commit 33269d3

Please sign in to comment.