-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cyclic Barrier follow up #1417
Cyclic Barrier follow up #1417
Changes from 13 commits
f84cde7
f20a2f6
efd03c5
3185b67
fa4b5b2
a6d1d7f
5c9a86a
ef66621
09fbc59
7e3cfb6
08067ce
7cf1d7c
f4102e4
3a0feac
501d568
2057722
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,20 +14,15 @@ | |
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* These tests have been inspired by and adapted from `monix-catnap`'s `ConcurrentQueueSuite`, available at | ||
* https://github.com/monix/monix/blob/series/3.x/monix-catnap/shared/src/test/scala/monix/catnap/ConcurrentQueueSuite.scala. | ||
*/ | ||
|
||
package cats.effect | ||
package std | ||
|
||
import cats.implicits._ | ||
import cats.arrow.FunctionK | ||
import org.specs2.specification.core.Fragments | ||
|
||
import scala.concurrent.duration._ | ||
import java.util.concurrent.TimeoutException | ||
|
||
import org.specs2.specification.core.Fragments | ||
import scala.reflect.ClassTag | ||
|
||
class CyclicBarrierSpec extends BaseSpec { | ||
|
||
|
@@ -38,106 +33,62 @@ class CyclicBarrierSpec extends BaseSpec { | |
CyclicBarrier.apply[IO](_).map(_.mapK(FunctionK.id))) | ||
} | ||
|
||
private def cyclicBarrierTests( | ||
name: String, | ||
constructor: Int => IO[CyclicBarrier[IO]]): Fragments = { | ||
s"$name - raise an exception when constructed with a negative capacity" in real { | ||
val test = IO.defer(constructor(-1)).attempt | ||
test.flatMap { res => | ||
implicit class Assertions[A](fa: IO[A]) { | ||
def mustFailWith[E <: Throwable: ClassTag] = | ||
fa.attempt.flatMap { res => | ||
IO { | ||
res must beLike { | ||
case Left(e) => e must haveClass[IllegalArgumentException] | ||
case Left(e) => e must haveClass[E] | ||
} | ||
} | ||
} | ||
|
||
def mustEqual(a: A) = fa.flatMap { res => IO(res must beEqualTo(a)) } | ||
} | ||
|
||
private def cyclicBarrierTests( | ||
name: String, | ||
newBarrier: Int => IO[CyclicBarrier[IO]]): Fragments = { | ||
s"$name - raise an exception when constructed with a negative capacity" in real { | ||
IO.defer(newBarrier(-1)).mustFailWith[IllegalArgumentException] | ||
} | ||
|
||
s"$name - raise an exception when constructed with zero capacity" in real { | ||
val test = IO.defer(constructor(0)).attempt | ||
test.flatMap { res => | ||
IO { | ||
res must beLike { | ||
case Left(e) => e must haveClass[IllegalArgumentException] | ||
} | ||
} | ||
} | ||
IO.defer(newBarrier(0)).mustFailWith[IllegalArgumentException] | ||
} | ||
|
||
s"$name - remaining when contructed" in real { | ||
for { | ||
cb <- constructor(5) | ||
awaiting <- cb.awaiting | ||
_ <- IO(awaiting must beEqualTo(0)) | ||
r <- cb.remaining | ||
res <- IO(r must beEqualTo(5)) | ||
} yield res | ||
s"$name - await is blocking" in ticked { implicit ticker => | ||
newBarrier(2).flatMap(_.await) must nonTerminate | ||
} | ||
|
||
s"$name - await releases all fibers" in real { | ||
for { | ||
cb <- constructor(2) | ||
f1 <- cb.await.start | ||
f2 <- cb.await.start | ||
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled | ||
awaiting <- cb.awaiting | ||
_ <- IO(awaiting must beEqualTo(0)) | ||
res <- IO(r must beEqualTo(((), ()))) | ||
} yield res | ||
s"$name - await is cancelable" in ticked { implicit ticker => | ||
newBarrier(2).flatMap(_.await).timeoutTo(1.second, IO.unit) must completeAs(()) | ||
} | ||
|
||
s"$name - await is blocking" in real { | ||
for { | ||
cb <- constructor(2) | ||
r <- cb.await.timeout(5.millis).attempt | ||
res <- IO(r must beLike { | ||
case Left(e) => e must haveClass[TimeoutException] | ||
}) | ||
} yield res | ||
s"$name - await releases all fibers" in real { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. some tests use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ones that use ticked need to assert nontermination or cancelation, and ticked let's you do that much more easily. I used real whenever I could though There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nontermination makes sense, but how come cancellation is an issue?
i thought this was a good candidate for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, is it because cancellation is more or less nondeterministic on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, and |
||
newBarrier(2).flatMap { barrier => | ||
(barrier.await, barrier.await).parTupled.void.mustEqual(()) | ||
} | ||
} | ||
|
||
s"$name - await is cancelable" in real { | ||
for { | ||
cb <- constructor(2) | ||
f <- cb.await.start | ||
_ <- IO.sleep(1.milli) | ||
_ <- f.cancel | ||
r <- f.join | ||
awaiting <- cb.awaiting | ||
_ <- IO(awaiting must beEqualTo(0)) | ||
res <- IO(r must beEqualTo(Outcome.Canceled())) | ||
} yield res | ||
s"$name - reset once full" in ticked { implicit ticker => | ||
newBarrier(2).flatMap { barrier => | ||
(barrier.await, barrier.await).parTupled >> | ||
barrier.await | ||
} must nonTerminate | ||
} | ||
|
||
s"$name - reset once full" in real { | ||
for { | ||
cb <- constructor(2) | ||
f1 <- cb.await.start | ||
f2 <- cb.await.start | ||
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled | ||
_ <- IO(r must beEqualTo(((), ()))) | ||
//Should have reset at this point | ||
awaiting <- cb.awaiting | ||
_ <- IO(awaiting must beEqualTo(0)) | ||
r <- cb.await.timeout(5.millis).attempt | ||
res <- IO(r must beLike { | ||
case Left(e) => e must haveClass[TimeoutException] | ||
}) | ||
} yield res | ||
s"$name - clean up upon cancellation of await" in ticked { implicit ticker => | ||
newBarrier(2).flatMap { barrier => | ||
// This will time out, so count goes back to 2 | ||
barrier.await.timeoutTo(1.second, IO.unit) >> | ||
// Therefore count goes only down to 1 when this awaits, and will block again | ||
barrier.await | ||
} must nonTerminate | ||
} | ||
|
||
s"$name - clean up upon cancellation of await" in real { | ||
for { | ||
cb <- constructor(2) | ||
//This should time out and reduce the current capacity to 0 again | ||
_ <- cb.await.timeout(5.millis).attempt | ||
//Therefore the capacity should only be 1 when this awaits so will block again | ||
r <- cb.await.timeout(5.millis).attempt | ||
_ <- IO(r must beLike { | ||
case Left(e) => e must haveClass[TimeoutException] | ||
}) | ||
awaiting <- cb.awaiting | ||
res <- IO(awaiting must beEqualTo(0)) // | ||
} yield res | ||
s"$name - barrier of capacity 1 is a no op" in real { | ||
newBarrier(1).flatMap(_.await).mustEqual(()) | ||
} | ||
|
||
/* | ||
|
@@ -148,13 +99,21 @@ class CyclicBarrierSpec extends BaseSpec { | |
s"$name - race fiber cancel and barrier full" in real { | ||
val iterations = 100 | ||
|
||
val run = for { | ||
cb <- constructor(2) | ||
f <- cb.await.start | ||
_ <- IO.race(cb.await, f.cancel) | ||
awaiting <- cb.awaiting | ||
res <- IO(awaiting must beGreaterThanOrEqualTo(0)) | ||
} yield res | ||
val run = newBarrier(2) | ||
.flatMap { barrier => | ||
barrier.await.start.flatMap { fiber => | ||
barrier.await.race(fiber.cancel).flatMap { | ||
case Left(_) => | ||
// without the epoch check in CyclicBarrier, | ||
// a late cancelation would increment the count | ||
// after the barrier has already reset, | ||
// causing this code to never terminate (test times out) | ||
(barrier.await, barrier.await).parTupled.void | ||
case Right(_) => IO.unit | ||
} | ||
} | ||
} | ||
.mustEqual(()) | ||
|
||
List.fill(iterations)(run).reduce(_ >> _) | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
package cats.effect.std | ||
|
||
import cats.~> | ||
import cats.effect.kernel.{Deferred, GenConcurrent, Ref} | ||
import cats.effect.kernel.{Deferred, GenConcurrent} | ||
import cats.effect.kernel.syntax.all._ | ||
import cats.syntax.all._ | ||
|
||
|
@@ -39,16 +39,6 @@ abstract class CyclicBarrier[F[_]] { self => | |
*/ | ||
def await: F[Unit] | ||
|
||
/* | ||
* The number of fibers required to trip the barrier | ||
*/ | ||
def remaining: F[Int] | ||
|
||
/* | ||
* The number of fibers currently awaiting | ||
*/ | ||
def awaiting: F[Int] | ||
|
||
/** | ||
* Modifies the context in which this cyclic barrier is executed using the natural | ||
* transformation `f`. | ||
|
@@ -59,57 +49,48 @@ abstract class CyclicBarrier[F[_]] { self => | |
def mapK[G[_]](f: F ~> G): CyclicBarrier[G] = | ||
new CyclicBarrier[G] { | ||
def await: G[Unit] = f(self.await) | ||
def remaining: G[Int] = f(self.remaining) | ||
def awaiting: G[Int] = f(self.awaiting) | ||
} | ||
|
||
} | ||
|
||
object CyclicBarrier { | ||
def apply[F[_]](n: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = | ||
if (n < 1) | ||
def apply[F[_]](capacity: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = { | ||
if (capacity < 1) | ||
throw new IllegalArgumentException( | ||
s"Cyclic barrier constructed with capacity $n. Must be > 0") | ||
else | ||
for { | ||
state <- State.initial[F] | ||
ref <- F.ref(state) | ||
} yield new ConcurrentCyclicBarrier(n, ref) | ||
|
||
private[std] class ConcurrentCyclicBarrier[F[_]](capacity: Int, state: Ref[F, State[F]])( | ||
implicit F: GenConcurrent[F, _]) | ||
extends CyclicBarrier[F] { | ||
|
||
val await: F[Unit] = | ||
F.deferred[Unit].flatMap { newSignal => | ||
F.uncancelable { poll => | ||
state.modify { | ||
case State(awaiting, epoch, signal) => | ||
if (awaiting < capacity - 1) { | ||
val cleanup = state.update { s => | ||
if (epoch == s.epoch) | ||
//The cyclic barrier hasn't been reset since the cancelled fiber start to await | ||
s.copy(awaiting = s.awaiting - 1) | ||
else s | ||
} | ||
|
||
val nextState = State(awaiting + 1, epoch, signal) | ||
(nextState, poll(signal.get).onCancel(cleanup)) | ||
} else (State(0, epoch + 1, newSignal), signal.complete(()).void) | ||
}.flatten | ||
s"Cyclic barrier constructed with capacity $capacity. Must be > 0") | ||
|
||
case class State[F[_]](awaiting: Int, epoch: Long, unblock: Deferred[F, Unit]) | ||
|
||
F.deferred[Unit] | ||
.map(gate => State(capacity,0, gate)) | ||
.flatMap(F.ref) | ||
.map { state => | ||
new CyclicBarrier[F] { | ||
val await: F[Unit] = | ||
F.deferred[Unit].flatMap { gate => | ||
F.uncancelable { poll => | ||
state.modify { | ||
case State(awaiting, epoch, unblock) => | ||
val awaitingNow = awaiting - 1 | ||
|
||
if (awaitingNow == 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice, this invariant is much clearer |
||
State(capacity, epoch + 1, gate) -> unblock.complete(()).void | ||
else { | ||
val newState = State(awaitingNow, epoch, unblock) | ||
// reincrement count if this await gets canceled, | ||
// but only if the barrier hasn't reset in the meantime | ||
val cleanup = state.update { s => | ||
if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1) | ||
else s | ||
} | ||
|
||
newState -> poll(unblock.get).onCancel(cleanup) | ||
} | ||
|
||
}.flatten | ||
} | ||
} | ||
} | ||
} | ||
|
||
val remaining: F[Int] = state.get.map(s => capacity - s.awaiting) | ||
|
||
val awaiting: F[Int] = state.get.map(_.awaiting) | ||
|
||
} | ||
|
||
private[std] case class State[F[_]](awaiting: Int, epoch: Long, signal: Deferred[F, Unit]) | ||
|
||
private[std] object State { | ||
def initial[F[_]](implicit F: GenConcurrent[F, _]): F[State[F]] = | ||
F.deferred[Unit].map { signal => State(0, 0, signal) } | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would be super useful in
BaseSpec
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll move it