Skip to content

Commit

Permalink
Merge pull request #1417 from SystemFw/cyclic-barrier-2
Browse files Browse the repository at this point in the history
Cyclic Barrier follow up
  • Loading branch information
djspiewak authored Nov 14, 2020
2 parents 07eb53e + 2057722 commit c7f8046
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 153 deletions.
17 changes: 17 additions & 0 deletions core/shared/src/test/scala/cats/effect/Runners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ import scala.concurrent.{
TimeoutException
}
import scala.concurrent.duration._
import scala.reflect.ClassTag
import scala.util.Try

import java.io.{ByteArrayOutputStream, PrintStream}
Expand Down Expand Up @@ -272,6 +273,22 @@ trait Runners extends SpecificationLike with RunnersPlatform { outer =>
(oc eqv expected, s"${oc.show} !== ${expected.show}")
}

// useful for tests in the `real` context
implicit class Assertions[A](fa: IO[A]) {
def mustFailWith[E <: Throwable: ClassTag] =
fa.attempt.flatMap { res =>
IO {
res must beLike {
case Left(e) => e must haveClass[E]
}
}
}

def mustEqual(a: A) = fa.flatMap { res => IO(res must beEqualTo(a)) }
}



def unsafeRun[A](ioa: IO[A])(implicit ticker: Ticker): Outcome[Option, Throwable, A] =
try {
var results: Outcome[Option, Throwable, A] = Outcome.Succeeded(None)
Expand Down
139 changes: 42 additions & 97 deletions core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,14 @@
* limitations under the License.
*/

/*
* These tests have been inspired by and adapted from `monix-catnap`'s `ConcurrentQueueSuite`, available at
* https://github.com/monix/monix/blob/series/3.x/monix-catnap/shared/src/test/scala/monix/catnap/ConcurrentQueueSuite.scala.
*/

package cats.effect
package std

import cats.implicits._
import cats.arrow.FunctionK
import org.specs2.specification.core.Fragments

import scala.concurrent.duration._
import java.util.concurrent.TimeoutException

import org.specs2.specification.core.Fragments

class CyclicBarrierSpec extends BaseSpec {

Expand All @@ -40,104 +34,47 @@ class CyclicBarrierSpec extends BaseSpec {

private def cyclicBarrierTests(
name: String,
constructor: Int => IO[CyclicBarrier[IO]]): Fragments = {
newBarrier: Int => IO[CyclicBarrier[IO]]): Fragments = {
s"$name - raise an exception when constructed with a negative capacity" in real {
val test = IO.defer(constructor(-1)).attempt
test.flatMap { res =>
IO {
res must beLike {
case Left(e) => e must haveClass[IllegalArgumentException]
}
}
}
IO.defer(newBarrier(-1)).mustFailWith[IllegalArgumentException]
}

s"$name - raise an exception when constructed with zero capacity" in real {
val test = IO.defer(constructor(0)).attempt
test.flatMap { res =>
IO {
res must beLike {
case Left(e) => e must haveClass[IllegalArgumentException]
}
}
}
IO.defer(newBarrier(0)).mustFailWith[IllegalArgumentException]
}

s"$name - remaining when contructed" in real {
for {
cb <- constructor(5)
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
r <- cb.remaining
res <- IO(r must beEqualTo(5))
} yield res
s"$name - await is blocking" in ticked { implicit ticker =>
newBarrier(2).flatMap(_.await) must nonTerminate
}

s"$name - await releases all fibers" in real {
for {
cb <- constructor(2)
f1 <- cb.await.start
f2 <- cb.await.start
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
res <- IO(r must beEqualTo(((), ())))
} yield res
s"$name - await is cancelable" in ticked { implicit ticker =>
newBarrier(2).flatMap(_.await).timeoutTo(1.second, IO.unit) must completeAs(())
}

s"$name - await is blocking" in real {
for {
cb <- constructor(2)
r <- cb.await.timeout(5.millis).attempt
res <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
} yield res
s"$name - await releases all fibers" in real {
newBarrier(2).flatMap { barrier =>
(barrier.await, barrier.await).parTupled.void.mustEqual(())
}
}

s"$name - await is cancelable" in real {
for {
cb <- constructor(2)
f <- cb.await.start
_ <- IO.sleep(1.milli)
_ <- f.cancel
r <- f.join
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
res <- IO(r must beEqualTo(Outcome.Canceled()))
} yield res
s"$name - reset once full" in ticked { implicit ticker =>
newBarrier(2).flatMap { barrier =>
(barrier.await, barrier.await).parTupled >>
barrier.await
} must nonTerminate
}

s"$name - reset once full" in real {
for {
cb <- constructor(2)
f1 <- cb.await.start
f2 <- cb.await.start
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled
_ <- IO(r must beEqualTo(((), ())))
//Should have reset at this point
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
r <- cb.await.timeout(5.millis).attempt
res <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
} yield res
s"$name - clean up upon cancellation of await" in ticked { implicit ticker =>
newBarrier(2).flatMap { barrier =>
// This will time out, so count goes back to 2
barrier.await.timeoutTo(1.second, IO.unit) >>
// Therefore count goes only down to 1 when this awaits, and will block again
barrier.await
} must nonTerminate
}

s"$name - clean up upon cancellation of await" in real {
for {
cb <- constructor(2)
//This should time out and reduce the current capacity to 0 again
_ <- cb.await.timeout(5.millis).attempt
//Therefore the capacity should only be 1 when this awaits so will block again
r <- cb.await.timeout(5.millis).attempt
_ <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
awaiting <- cb.awaiting
res <- IO(awaiting must beEqualTo(0)) //
} yield res
s"$name - barrier of capacity 1 is a no op" in real {
newBarrier(1).flatMap(_.await).mustEqual(())
}

/*
Expand All @@ -148,13 +85,21 @@ class CyclicBarrierSpec extends BaseSpec {
s"$name - race fiber cancel and barrier full" in real {
val iterations = 100

val run = for {
cb <- constructor(2)
f <- cb.await.start
_ <- IO.race(cb.await, f.cancel)
awaiting <- cb.awaiting
res <- IO(awaiting must beGreaterThanOrEqualTo(0))
} yield res
val run = newBarrier(2)
.flatMap { barrier =>
barrier.await.start.flatMap { fiber =>
barrier.await.race(fiber.cancel).flatMap {
case Left(_) =>
// without the epoch check in CyclicBarrier,
// a late cancelation would increment the count
// after the barrier has already reset,
// causing this code to never terminate (test times out)
(barrier.await, barrier.await).parTupled.void
case Right(_) => IO.unit
}
}
}
.mustEqual(())

List.fill(iterations)(run).reduce(_ >> _)
}
Expand Down
90 changes: 34 additions & 56 deletions std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect.std

import cats.~>
import cats.effect.kernel.{Deferred, GenConcurrent, Ref}
import cats.effect.kernel.{Deferred, GenConcurrent}
import cats.effect.kernel.syntax.all._
import cats.syntax.all._

Expand All @@ -39,16 +39,6 @@ abstract class CyclicBarrier[F[_]] { self =>
*/
def await: F[Unit]

/*
* The number of fibers required to trip the barrier
*/
def remaining: F[Int]

/*
* The number of fibers currently awaiting
*/
def awaiting: F[Int]

/**
* Modifies the context in which this cyclic barrier is executed using the natural
* transformation `f`.
Expand All @@ -59,57 +49,45 @@ abstract class CyclicBarrier[F[_]] { self =>
def mapK[G[_]](f: F ~> G): CyclicBarrier[G] =
new CyclicBarrier[G] {
def await: G[Unit] = f(self.await)
def remaining: G[Int] = f(self.remaining)
def awaiting: G[Int] = f(self.awaiting)
}

}

object CyclicBarrier {
def apply[F[_]](n: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] =
if (n < 1)
def apply[F[_]](capacity: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = {
if (capacity < 1)
throw new IllegalArgumentException(
s"Cyclic barrier constructed with capacity $n. Must be > 0")
else
for {
state <- State.initial[F]
ref <- F.ref(state)
} yield new ConcurrentCyclicBarrier(n, ref)

private[std] class ConcurrentCyclicBarrier[F[_]](capacity: Int, state: Ref[F, State[F]])(
implicit F: GenConcurrent[F, _])
extends CyclicBarrier[F] {

val await: F[Unit] =
F.deferred[Unit].flatMap { newSignal =>
F.uncancelable { poll =>
state.modify {
case State(awaiting, epoch, signal) =>
if (awaiting < capacity - 1) {
val cleanup = state.update { s =>
if (epoch == s.epoch)
//The cyclic barrier hasn't been reset since the cancelled fiber start to await
s.copy(awaiting = s.awaiting - 1)
else s
}

val nextState = State(awaiting + 1, epoch, signal)
(nextState, poll(signal.get).onCancel(cleanup))
} else (State(0, epoch + 1, newSignal), signal.complete(()).void)
}.flatten
}
s"Cyclic barrier constructed with capacity $capacity. Must be > 0")

case class State(awaiting: Int, epoch: Long, unblock: Deferred[F, Unit])

F.deferred[Unit].map(State(capacity, 0, _)).flatMap(F.ref).map { state =>
new CyclicBarrier[F] {
val await: F[Unit] =
F.deferred[Unit].flatMap { gate =>
F.uncancelable { poll =>
state.modify {
case State(awaiting, epoch, unblock) =>
val awaitingNow = awaiting - 1

if (awaitingNow == 0)
State(capacity, epoch + 1, gate) -> unblock.complete(()).void
else {
val newState = State(awaitingNow, epoch, unblock)
// reincrement count if this await gets canceled,
// but only if the barrier hasn't reset in the meantime
val cleanup = state.update { s =>
if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1)
else s
}

newState -> poll(unblock.get).onCancel(cleanup)
}

}.flatten
}
}
}

val remaining: F[Int] = state.get.map(s => capacity - s.awaiting)

val awaiting: F[Int] = state.get.map(_.awaiting)

}

private[std] case class State[F[_]](awaiting: Int, epoch: Long, signal: Deferred[F, Unit])

private[std] object State {
def initial[F[_]](implicit F: GenConcurrent[F, _]): F[State[F]] =
F.deferred[Unit].map { signal => State(0, 0, signal) }
}
}
}

0 comments on commit c7f8046

Please sign in to comment.