Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cyclic Barrier follow up #1417

Merged
merged 16 commits into from
Nov 14, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 54 additions & 95 deletions core/shared/src/test/scala/cats/effect/std/CyclicBarrierSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,15 @@
* limitations under the License.
*/

/*
* These tests have been inspired by and adapted from `monix-catnap`'s `ConcurrentQueueSuite`, available at
* https://github.com/monix/monix/blob/series/3.x/monix-catnap/shared/src/test/scala/monix/catnap/ConcurrentQueueSuite.scala.
*/

package cats.effect
package std

import cats.implicits._
import cats.arrow.FunctionK
import org.specs2.specification.core.Fragments

import scala.concurrent.duration._
import java.util.concurrent.TimeoutException

import org.specs2.specification.core.Fragments
import scala.reflect.ClassTag

class CyclicBarrierSpec extends BaseSpec {

Expand All @@ -38,106 +33,62 @@ class CyclicBarrierSpec extends BaseSpec {
CyclicBarrier.apply[IO](_).map(_.mapK(FunctionK.id)))
}

private def cyclicBarrierTests(
name: String,
constructor: Int => IO[CyclicBarrier[IO]]): Fragments = {
s"$name - raise an exception when constructed with a negative capacity" in real {
val test = IO.defer(constructor(-1)).attempt
test.flatMap { res =>
implicit class Assertions[A](fa: IO[A]) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would be super useful in BaseSpec

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move it

def mustFailWith[E <: Throwable: ClassTag] =
fa.attempt.flatMap { res =>
IO {
res must beLike {
case Left(e) => e must haveClass[IllegalArgumentException]
case Left(e) => e must haveClass[E]
}
}
}

def mustEqual(a: A) = fa.flatMap { res => IO(res must beEqualTo(a)) }
}

private def cyclicBarrierTests(
name: String,
newBarrier: Int => IO[CyclicBarrier[IO]]): Fragments = {
s"$name - raise an exception when constructed with a negative capacity" in real {
IO.defer(newBarrier(-1)).mustFailWith[IllegalArgumentException]
}

s"$name - raise an exception when constructed with zero capacity" in real {
val test = IO.defer(constructor(0)).attempt
test.flatMap { res =>
IO {
res must beLike {
case Left(e) => e must haveClass[IllegalArgumentException]
}
}
}
IO.defer(newBarrier(0)).mustFailWith[IllegalArgumentException]
}

s"$name - remaining when contructed" in real {
for {
cb <- constructor(5)
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
r <- cb.remaining
res <- IO(r must beEqualTo(5))
} yield res
s"$name - await is blocking" in ticked { implicit ticker =>
newBarrier(2).flatMap(_.await) must nonTerminate
}

s"$name - await releases all fibers" in real {
for {
cb <- constructor(2)
f1 <- cb.await.start
f2 <- cb.await.start
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
res <- IO(r must beEqualTo(((), ())))
} yield res
s"$name - await is cancelable" in ticked { implicit ticker =>
newBarrier(2).flatMap(_.await).timeoutTo(1.second, IO.unit) must completeAs(())
}

s"$name - await is blocking" in real {
for {
cb <- constructor(2)
r <- cb.await.timeout(5.millis).attempt
res <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
} yield res
s"$name - await releases all fibers" in real {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some tests use real and others use ticked, but i can't really tell why :)

Copy link
Contributor Author

@SystemFw SystemFw Nov 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ones that use ticked need to assert nontermination or cancelation, and ticked let's you do that much more easily. I used real whenever I could though

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nontermination makes sense, but how come cancellation is an issue?

s"$name - await is cancelable" in ticked { implicit ticker =>
      for {	      newBarrier(2).flatMap(_.await).timeoutTo(1.second, IO.unit) must completeAs(())

i thought this was a good candidate for real but i'm probably missing something about how TC works

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, is it because cancellation is more or less nondeterministic on the sleep when using real?

Copy link
Contributor Author

@SystemFw SystemFw Nov 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, and ticked is deterministic and fast. With real you tradeoff slow test vs predictable test. Also these tests don't have any of the things that make ticked deadlock-prone (lots of fibers depending on each other) and untrustworthy (real memory barrier things, it's just Ref + Deferred + uncancelable)

newBarrier(2).flatMap { barrier =>
(barrier.await, barrier.await).parTupled.void.mustEqual(())
}
}

s"$name - await is cancelable" in real {
for {
cb <- constructor(2)
f <- cb.await.start
_ <- IO.sleep(1.milli)
_ <- f.cancel
r <- f.join
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
res <- IO(r must beEqualTo(Outcome.Canceled()))
} yield res
s"$name - reset once full" in ticked { implicit ticker =>
newBarrier(2).flatMap { barrier =>
(barrier.await, barrier.await).parTupled >>
barrier.await
} must nonTerminate
}

s"$name - reset once full" in real {
for {
cb <- constructor(2)
f1 <- cb.await.start
f2 <- cb.await.start
r <- (f1.joinAndEmbedNever, f2.joinAndEmbedNever).tupled
_ <- IO(r must beEqualTo(((), ())))
//Should have reset at this point
awaiting <- cb.awaiting
_ <- IO(awaiting must beEqualTo(0))
r <- cb.await.timeout(5.millis).attempt
res <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
} yield res
s"$name - clean up upon cancellation of await" in ticked { implicit ticker =>
newBarrier(2).flatMap { barrier =>
// This will time out, so count goes back to 2
barrier.await.timeoutTo(1.second, IO.unit) >>
// Therefore count goes only down to 1 when this awaits, and will block again
barrier.await
} must nonTerminate
}

s"$name - clean up upon cancellation of await" in real {
for {
cb <- constructor(2)
//This should time out and reduce the current capacity to 0 again
_ <- cb.await.timeout(5.millis).attempt
//Therefore the capacity should only be 1 when this awaits so will block again
r <- cb.await.timeout(5.millis).attempt
_ <- IO(r must beLike {
case Left(e) => e must haveClass[TimeoutException]
})
awaiting <- cb.awaiting
res <- IO(awaiting must beEqualTo(0)) //
} yield res
s"$name - barrier of capacity 1 is a no op" in real {
newBarrier(1).flatMap(_.await).mustEqual(())
}

/*
Expand All @@ -148,13 +99,21 @@ class CyclicBarrierSpec extends BaseSpec {
s"$name - race fiber cancel and barrier full" in real {
val iterations = 100

val run = for {
cb <- constructor(2)
f <- cb.await.start
_ <- IO.race(cb.await, f.cancel)
awaiting <- cb.awaiting
res <- IO(awaiting must beGreaterThanOrEqualTo(0))
} yield res
val run = newBarrier(2)
.flatMap { barrier =>
barrier.await.start.flatMap { fiber =>
barrier.await.race(fiber.cancel).flatMap {
case Left(_) =>
// without the epoch check in CyclicBarrier,
// a late cancelation would increment the count
// after the barrier has already reset,
// causing this code to never terminate (test times out)
(barrier.await, barrier.await).parTupled.void
case Right(_) => IO.unit
}
}
}
.mustEqual(())

List.fill(iterations)(run).reduce(_ >> _)
}
Expand Down
91 changes: 36 additions & 55 deletions std/shared/src/main/scala/cats/effect/std/CyclicBarrier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect.std

import cats.~>
import cats.effect.kernel.{Deferred, GenConcurrent, Ref}
import cats.effect.kernel.{Deferred, GenConcurrent}
import cats.effect.kernel.syntax.all._
import cats.syntax.all._

Expand All @@ -39,16 +39,6 @@ abstract class CyclicBarrier[F[_]] { self =>
*/
def await: F[Unit]

/*
* The number of fibers required to trip the barrier
*/
def remaining: F[Int]

/*
* The number of fibers currently awaiting
*/
def awaiting: F[Int]

/**
* Modifies the context in which this cyclic barrier is executed using the natural
* transformation `f`.
Expand All @@ -59,57 +49,48 @@ abstract class CyclicBarrier[F[_]] { self =>
def mapK[G[_]](f: F ~> G): CyclicBarrier[G] =
new CyclicBarrier[G] {
def await: G[Unit] = f(self.await)
def remaining: G[Int] = f(self.remaining)
def awaiting: G[Int] = f(self.awaiting)
}

}

object CyclicBarrier {
def apply[F[_]](n: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] =
if (n < 1)
def apply[F[_]](capacity: Int)(implicit F: GenConcurrent[F, _]): F[CyclicBarrier[F]] = {
if (capacity < 1)
throw new IllegalArgumentException(
s"Cyclic barrier constructed with capacity $n. Must be > 0")
else
for {
state <- State.initial[F]
ref <- F.ref(state)
} yield new ConcurrentCyclicBarrier(n, ref)

private[std] class ConcurrentCyclicBarrier[F[_]](capacity: Int, state: Ref[F, State[F]])(
implicit F: GenConcurrent[F, _])
extends CyclicBarrier[F] {

val await: F[Unit] =
F.deferred[Unit].flatMap { newSignal =>
F.uncancelable { poll =>
state.modify {
case State(awaiting, epoch, signal) =>
if (awaiting < capacity - 1) {
val cleanup = state.update { s =>
if (epoch == s.epoch)
//The cyclic barrier hasn't been reset since the cancelled fiber start to await
s.copy(awaiting = s.awaiting - 1)
else s
}

val nextState = State(awaiting + 1, epoch, signal)
(nextState, poll(signal.get).onCancel(cleanup))
} else (State(0, epoch + 1, newSignal), signal.complete(()).void)
}.flatten
s"Cyclic barrier constructed with capacity $capacity. Must be > 0")

case class State[F[_]](awaiting: Int, epoch: Long, unblock: Deferred[F, Unit])

F.deferred[Unit]
.map(gate => State(capacity,0, gate))
.flatMap(F.ref)
.map { state =>
new CyclicBarrier[F] {
val await: F[Unit] =
F.deferred[Unit].flatMap { gate =>
F.uncancelable { poll =>
state.modify {
case State(awaiting, epoch, unblock) =>
val awaitingNow = awaiting - 1

if (awaitingNow == 0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, this invariant is much clearer

State(capacity, epoch + 1, gate) -> unblock.complete(()).void
else {
val newState = State(awaitingNow, epoch, unblock)
// reincrement count if this await gets canceled,
// but only if the barrier hasn't reset in the meantime
val cleanup = state.update { s =>
if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1)
else s
}

newState -> poll(unblock.get).onCancel(cleanup)
}

}.flatten
}
}
}
}

val remaining: F[Int] = state.get.map(s => capacity - s.awaiting)

val awaiting: F[Int] = state.get.map(_.awaiting)

}

private[std] case class State[F[_]](awaiting: Int, epoch: Long, signal: Deferred[F, Unit])

private[std] object State {
def initial[F[_]](implicit F: GenConcurrent[F, _]): F[State[F]] =
F.deferred[Unit].map { signal => State(0, 0, signal) }
}
}