diff --git a/core/shared/src/test/scala/cats/effect/Runners.scala b/core/shared/src/test/scala/cats/effect/Runners.scala index ea223db83e..58055380b7 100644 --- a/core/shared/src/test/scala/cats/effect/Runners.scala +++ b/core/shared/src/test/scala/cats/effect/Runners.scala @@ -43,6 +43,7 @@ import scala.concurrent.{ TimeoutException } import scala.concurrent.duration._ +import scala.reflect.ClassTag import scala.util.Try import java.io.{ByteArrayOutputStream, PrintStream} @@ -272,6 +273,22 @@ trait Runners extends SpecificationLike with RunnersPlatform { outer => (oc eqv expected, s"${oc.show} !== ${expected.show}") } + // useful for tests in the `real` context + implicit class Assertions[A](fa: IO[A]) { + def mustFailWith[E <: Throwable: ClassTag] = + fa.attempt.flatMap { res => + IO { + res must beLike { + case Left(e) => e must haveClass[E] + } + } + } + + def mustEqual(a: A) = fa.flatMap { res => IO(res must beEqualTo(a)) } + } + + + def unsafeRun[A](ioa: IO[A])(implicit ticker: Ticker): Outcome[Option, Throwable, A] = try { var results: Outcome[Option, Throwable, A] = Outcome.Succeeded(None) diff --git a/core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala b/core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala index 8b516d476b..244d7f19ec 100644 --- a/core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala +++ b/core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala @@ -14,20 +14,14 @@ * limitations under the License. */ -/* - * These tests have been inspired by and adapted from `monix-catnap`'s `ConcurrentQueueSuite`, available at - * https://github.com/monix/monix/blob/series/3.x/monix-catnap/shared/src/test/scala/monix/catnap/ConcurrentQueueSuite.scala. - */ - package cats.effect package std import cats.implicits._ import cats.arrow.FunctionK -import org.specs2.specification.core.Fragments - import scala.concurrent.duration._ -import java.util.concurrent.TimeoutException + +import org.specs2.specification.core.Fragments class CyclicBarrierSpec extends BaseSpec { @@ -40,104 +34,47 @@ class CyclicBarrierSpec extends BaseSpec { private def cyclicBarrierTests( name: String, - constructor: Int => IO[CyclicBarrier[IO]]): Fragments = { + newBarrier: Int => IO[CyclicBarrier[IO]]): Fragments = { s"$name - raise an exception when constructed with a negative capacity" in real { - val test = IO.defer(constructor(-1)).attempt - test.flatMap { res => - IO { - res must beLike { - case Left(e) => e must haveClass[IllegalArgumentException] - } - } - } + IO.defer(newBarrier(-1)).mustFailWith[IllegalArgumentException] } s"$name - raise an exception when constructed with zero capacity" in real { - val test = IO.defer(constructor(0)).attempt - test.flatMap { res => - IO { - res must beLike { - case Left(e) => e must haveClass[IllegalArgumentException] - } - } - } + IO.defer(newBarrier(0)).mustFailWith[IllegalArgumentException] } - s"$name - remaining when contructed" in real { - for { - cb <- constructor(5) - awaiting <- cb.awaiting - _ <- IO(awaiting must beEqualTo(0)) - r <- cb.remaining - res <- IO(r must beEqualTo(5)) - } yield res + s"$name - await is blocking" in ticked { implicit ticker => + newBarrier(2).flatMap(_.await) must nonTerminate } - s"$name - await releases all fibers" in real { - for { - cb <- constructor(2) - f1 <- cb.await.start - f2 <- cb.await.start - r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled - awaiting <- cb.awaiting - _ <- IO(awaiting must beEqualTo(0)) - res <- IO(r must beEqualTo(((), ()))) - } yield res + s"$name - await is cancelable" in ticked { implicit ticker => + newBarrier(2).flatMap(_.await).timeoutTo(1.second, IO.unit) must completeAs(()) } - s"$name - await is blocking" in real { - for { - cb <- constructor(2) - r <- cb.await.timeout(5.millis).attempt - res <- IO(r must beLike { - case Left(e) => e must haveClass[TimeoutException] - }) - } yield res + s"$name - await releases all fibers" in real { + newBarrier(2).flatMap { barrier => + (barrier.await, barrier.await).parTupled.void.mustEqual(()) + } } - s"$name - await is cancelable" in real { - for { - cb <- constructor(2) - f <- cb.await.start - _ <- IO.sleep(1.milli) - _ <- f.cancel - r <- f.join - awaiting <- cb.awaiting - _ <- IO(awaiting must beEqualTo(0)) - res <- IO(r must beEqualTo(Outcome.Canceled())) - } yield res + s"$name - reset once full" in ticked { implicit ticker => + newBarrier(2).flatMap { barrier => + (barrier.await, barrier.await).parTupled >> + barrier.await + } must nonTerminate } - s"$name - reset once full" in real { - for { - cb <- constructor(2) - f1 <- cb.await.start - f2 <- cb.await.start - r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled - _ <- IO(r must beEqualTo(((), ()))) - //Should have reset at this point - awaiting <- cb.awaiting - _ <- IO(awaiting must beEqualTo(0)) - r <- cb.await.timeout(5.millis).attempt - res <- IO(r must beLike { - case Left(e) => e must haveClass[TimeoutException] - }) - } yield res + s"$name - clean up upon cancellation of await" in ticked { implicit ticker => + newBarrier(2).flatMap { barrier => + // This will time out, so count goes back to 2 + barrier.await.timeoutTo(1.second, IO.unit) >> + // Therefore count goes only down to 1 when this awaits, and will block again + barrier.await + } must nonTerminate } - s"$name - clean up upon cancellation of await" in real { - for { - cb <- constructor(2) - //This should time out and reduce the current capacity to 0 again - _ <- cb.await.timeout(5.millis).attempt - //Therefore the capacity should only be 1 when this awaits so will block again - r <- cb.await.timeout(5.millis).attempt - _ <- IO(r must beLike { - case Left(e) => e must haveClass[TimeoutException] - }) - awaiting <- cb.awaiting - res <- IO(awaiting must beEqualTo(0)) // - } yield res + s"$name - barrier of capacity 1 is a no op" in real { + newBarrier(1).flatMap(_.await).mustEqual(()) } /* @@ -148,13 +85,21 @@ class CyclicBarrierSpec extends BaseSpec { s"$name - race fiber cancel and barrier full" in real { val iterations = 100 - val run = for { - cb <- constructor(2) - f <- cb.await.start - _ <- IO.race(cb.await, f.cancel) - awaiting <- cb.awaiting - res <- IO(awaiting must beGreaterThanOrEqualTo(0)) - } yield res + val run = newBarrier(2) + .flatMap { barrier => + barrier.await.start.flatMap { fiber => + barrier.await.race(fiber.cancel).flatMap { + case Left(_) => + // without the epoch check in CyclicBarrier, + // a late cancelation would increment the count + // after the barrier has already reset, + // causing this code to never terminate (test times out) + (barrier.await, barrier.await).parTupled.void + case Right(_) => IO.unit + } + } + } + .mustEqual(()) List.fill(iterations)(run).reduce(_ >> _) } diff --git a/std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala b/std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala index 4aed746723..bb20826240 100644 --- a/std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala +++ b/std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala @@ -17,7 +17,7 @@ package cats.effect.std import cats.~> -import cats.effect.kernel.{Deferred, GenConcurrent, Ref} +import cats.effect.kernel.{Deferred, GenConcurrent} import cats.effect.kernel.syntax.all._ import cats.syntax.all._ @@ -39,16 +39,6 @@ abstract class CyclicBarrier[F[_]] { self => */ def await: F[Unit] - /* - * The number of fibers required to trip the barrier - */ - def remaining: F[Int] - - /* - * The number of fibers currently awaiting - */ - def awaiting: F[Int] - /** * Modifies the context in which this cyclic barrier is executed using the natural * transformation `f`. @@ -59,57 +49,45 @@ abstract class CyclicBarrier[F[_]] { self => def mapK[G[_]](f: F ~> G): CyclicBarrier[G] = new CyclicBarrier[G] { def await: G[Unit] = f(self.await) - def remaining: G[Int] = f(self.remaining) - def awaiting: G[Int] = f(self.awaiting) } } object CyclicBarrier { - def apply[F[_]](n: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = - if (n < 1) + def apply[F[_]](capacity: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = { + if (capacity < 1) throw new IllegalArgumentException( - s"Cyclic barrier constructed with capacity $n. Must be > 0") - else - for { - state <- State.initial[F] - ref <- F.ref(state) - } yield new ConcurrentCyclicBarrier(n, ref) - - private[std] class ConcurrentCyclicBarrier[F[_]](capacity: Int, state: Ref[F, State[F]])( - implicit F: GenConcurrent[F, _]) - extends CyclicBarrier[F] { - - val await: F[Unit] = - F.deferred[Unit].flatMap { newSignal => - F.uncancelable { poll => - state.modify { - case State(awaiting, epoch, signal) => - if (awaiting < capacity - 1) { - val cleanup = state.update { s => - if (epoch == s.epoch) - //The cyclic barrier hasn't been reset since the cancelled fiber start to await - s.copy(awaiting = s.awaiting - 1) - else s - } - - val nextState = State(awaiting + 1, epoch, signal) - (nextState, poll(signal.get).onCancel(cleanup)) - } else (State(0, epoch + 1, newSignal), signal.complete(()).void) - }.flatten - } + s"Cyclic barrier constructed with capacity $capacity. Must be > 0") + + case class State(awaiting: Int, epoch: Long, unblock: Deferred[F, Unit]) + + F.deferred[Unit].map(State(capacity, 0, _)).flatMap(F.ref).map { state => + new CyclicBarrier[F] { + val await: F[Unit] = + F.deferred[Unit].flatMap { gate => + F.uncancelable { poll => + state.modify { + case State(awaiting, epoch, unblock) => + val awaitingNow = awaiting - 1 + + if (awaitingNow == 0) + State(capacity, epoch + 1, gate) -> unblock.complete(()).void + else { + val newState = State(awaitingNow, epoch, unblock) + // reincrement count if this await gets canceled, + // but only if the barrier hasn't reset in the meantime + val cleanup = state.update { s => + if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1) + else s + } + + newState -> poll(unblock.get).onCancel(cleanup) + } + + }.flatten + } + } } - - val remaining: F[Int] = state.get.map(s => capacity - s.awaiting) - - val awaiting: F[Int] = state.get.map(_.awaiting) - - } - - private[std] case class State[F[_]](awaiting: Int, epoch: Long, signal: Deferred[F, Unit]) - - private[std] object State { - def initial[F[_]](implicit F: GenConcurrent[F, _]): F[State[F]] = - F.deferred[Unit].map { signal => State(0, 0, signal) } + } } }