Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New AsyncMutex implementation #3562

Merged
merged 14 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,7 @@ class MutexBenchmark {

@Benchmark
def happyPathConcurrent(): Unit = {
happyPathImpl(mutex = Mutex.concurrent)
}

@Benchmark
def happyPathAsync(): Unit = {
happyPathImpl(mutex = Mutex.async)
happyPathImpl(mutex = Mutex.apply)
}

private def highContentionImpl(mutex: IO[Mutex[IO]]): Unit = {
Expand All @@ -73,12 +68,7 @@ class MutexBenchmark {

@Benchmark
def highContentionConcurrent(): Unit = {
highContentionImpl(mutex = Mutex.concurrent)
}

@Benchmark
def highContentionAsync(): Unit = {
highContentionImpl(mutex = Mutex.async)
highContentionImpl(mutex = Mutex.apply)
}

private def cancellationImpl(mutex: IO[Mutex[IO]]): Unit = {
Expand All @@ -94,11 +84,6 @@ class MutexBenchmark {

@Benchmark
def cancellationConcurrent(): Unit = {
cancellationImpl(mutex = Mutex.concurrent)
}

@Benchmark
def cancellationAsync(): Unit = {
cancellationImpl(mutex = Mutex.async)
cancellationImpl(mutex = Mutex.apply)
}
}
4 changes: 2 additions & 2 deletions std/shared/src/main/scala/cats/effect/std/AtomicCell.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ object AtomicCell {
}

private[effect] def async[F[_], A](init: A)(implicit F: Async[F]): F[AtomicCell[F, A]] =
Mutex.async[F].map(mutex => new AsyncImpl(init, mutex))
Mutex.apply[F].map(mutex => new AsyncImpl(init, mutex))

private[effect] def concurrent[F[_], A](init: A)(
implicit F: Concurrent[F]): F[AtomicCell[F, A]] =
(Ref.of[F, A](init), Mutex.concurrent[F]).mapN { (ref, m) => new ConcurrentImpl(ref, m) }
(Ref.of[F, A](init), Mutex.apply[F]).mapN { (ref, m) => new ConcurrentImpl(ref, m) }

private final class ConcurrentImpl[F[_], A](
ref: Ref[F, A],
Expand Down
167 changes: 69 additions & 98 deletions std/shared/src/main/scala/cats/effect/std/Mutex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ package std
import cats.effect.kernel._
import cats.syntax.all._

import scala.annotation.tailrec

import java.util.concurrent.atomic.AtomicBoolean

/**
* A purely functional mutex.
*
Expand Down Expand Up @@ -68,114 +64,90 @@ object Mutex {
* Creates a new `Mutex`.
*/
def apply[F[_]](implicit F: Concurrent[F]): F[Mutex[F]] =
F match {
case ff: Async[F] =>
async[F](ff)

case _ =>
concurrent[F](F)
}

private[effect] def async[F[_]](implicit F: Async[F]): F[Mutex[F]] =
in[F, F](F, F)

private[effect] def concurrent[F[_]](implicit F: Concurrent[F]): F[Mutex[F]] =
Semaphore[F](n = 1).map(sem => new ConcurrentImpl[F](sem))
Ref
.of[F, ConcurrentImpl.LockQueue](
// Initialize the state with an already completed cell.
ConcurrentImpl.Empty
)
.map(state => new ConcurrentImpl[F](state))

/**
* Creates a new `Mutex`. Like `apply` but initializes state using another effect constructor.
*/
def in[F[_], G[_]](implicit F: Sync[F], G: Async[G]): F[Mutex[G]] =
F.delay(new AsyncImpl[G])

private final class ConcurrentImpl[F[_]](sem: Semaphore[F]) extends Mutex[F] {
override final val lock: Resource[F, Unit] =
sem.permit

override def mapK[G[_]](f: F ~> G)(implicit G: MonadCancel[G, _]): Mutex[G] =
new ConcurrentImpl(sem.mapK(f))
}

private final class AsyncImpl[F[_]](implicit F: Async[F]) extends Mutex[F] {
import AsyncImpl._

private[this] val locked = new AtomicBoolean(false)
private[this] val waiters = new UnsafeUnbounded[Either[Throwable, Boolean] => Unit]
private[this] val FailureSignal = cats.effect.std.FailureSignal // prefetch

private[this] val acquire: F[Unit] = F.uncancelable { poll =>
F.onCancel(
poll(F.asyncCheckAttempt[Boolean] { cb =>
F.delay {
if (locked.compareAndSet(false, true)) { // acquired
RightTrue
} else {
val cancel = waiters.put(cb)
if (locked.compareAndSet(false, true)) { // try again
cancel()
RightTrue
} else {
Left(Some(F.delay(cancel())))
}
}
}
}).flatMap { acquired =>
if (acquired) F.unit // home free
else poll(acquire) // wokened, but need to acquire
},
// If we were cancelled, that could mean
// a lost wakeup from `release`, so we
// wake up someone else instead of us;
// the worst that could happen is an
// unnecessary wakeup, which causes a
// waiter to go to the end of the queue:
F.delay(notifyOne())
Ref
.in[F, G, ConcurrentImpl.LockQueue](
// Initialize the state with an already completed cell.
ConcurrentImpl.Empty
)
}
.map(state => new ConcurrentImpl[G](state))

private final class ConcurrentImpl[F[_]](
state: Ref[F, ConcurrentImpl.LockQueue]
)(
implicit F: Concurrent[F]
) extends Mutex[F] {
// Acquires the Mutex.
private def acquire(poll: Poll[F]): F[ConcurrentImpl.Next[F]] =
ConcurrentImpl.LockQueueCell[F].flatMap { ourCell =>
// Atomically get the last cell in the queue,
// and put ourselves as the last one.
state.getAndSet(ourCell).flatMap { lastCell =>
// Then we check what was the current cell is.
// There are two options:
// + Empty: Signaling that the mutex is free.
// + Next(cell): Which means there is someone ahead of us in the queue.
// Thus, wait for that cell to complete; and check again.
//
// Only the waiting process is cancelable.
// If we are cancelled while waiting,
// we then notify our waiter to wait for the cell ahead of us instead.
def loop(currentCell: ConcurrentImpl.LockQueue): F[ConcurrentImpl.Next[F]] =
if (currentCell eq ConcurrentImpl.Empty) F.pure(ourCell)
else {
F.onCancel(
poll(currentCell.asInstanceOf[ConcurrentImpl.Next[F]].get),
ourCell.complete(currentCell).void
).flatMap { nextCell => loop(currentCell = nextCell) }
}

private[this] val _release: F[Unit] = F.delay {
locked.set(false) // release
notifyOne() // try to wake someone
}
loop(currentCell = lastCell)
}
}

private[this] val release: Unit => F[Unit] = _ => _release
// Releases the Mutex.
private def release(thisCell: ConcurrentImpl.Next[F]): F[Unit] =
state.access.flatMap {
// If the current last cell in the queue is ours,
// then that means nobody is waiting for us.
// Thus, we can just reset the state to the Empty cell.
// Otherwise, we awake whoever is waiting for us.
case (lastCell, setter) =>
if (lastCell eq thisCell) setter(ConcurrentImpl.Empty)
else F.pure(false)
} flatMap {
case false => thisCell.complete(ConcurrentImpl.Empty).void
case true => F.unit
}

override final val lock: Resource[F, Unit] =
Resource.makeFull[F, Unit](poll => poll(acquire))(release)
Resource.makeFull[F, ConcurrentImpl.Next[F]](acquire)(release).void

override def mapK[G[_]](f: F ~> G)(implicit G: MonadCancel[G, _]): Mutex[G] =
new Mutex.TransformedMutex(this, f)

@tailrec
private[this] final def notifyOne(): Unit = {
val retry =
try {
val waiter = waiters.take()
if (waiter ne null) {
// wake the waiter; we don't
// pass `true`, so it has to
// go through the normal acquire,
// so its cancellation is handled
// properly:
waiter(RightFalse)
false
} else {
true
}
} catch {
case FailureSignal =>
false
}

if (retry) {
notifyOne()
}
}
}

private object AsyncImpl {
private val RightTrue = Right(true)
private val RightFalse = Right(false)
private object ConcurrentImpl {
// Represents a queue of waiters for the mutex.
private[Mutex] final type LockQueue = AnyRef
// Represents the first cell of the queue.
private[Mutex] final type Empty = LockQueue
private[Mutex] final val Empty: Empty = null
// Represents a cell in the queue of waiters.
private[Mutex] final type Next[F[_]] = Deferred[F, LockQueue]

private[Mutex] def LockQueueCell[F[_]](implicit F: Concurrent[F]): F[Next[F]] =
Deferred[F, LockQueue]
}

private final class TransformedMutex[F[_], G[_]](
Expand All @@ -189,5 +161,4 @@ object Mutex {
override def mapK[H[_]](f: G ~> H)(implicit H: MonadCancel[H, _]): Mutex[H] =
new Mutex.TransformedMutex(this, f)
}

}
72 changes: 67 additions & 5 deletions tests/shared/src/test/scala/cats/effect/std/MutexSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,7 @@ final class MutexSpec extends BaseSpec with DetectPlatform {
final override def executionTimeout = 2.minutes

"ConcurrentMutex" should {
tests(Mutex.concurrent[IO])
}

"AsyncMutex" should {
tests(Mutex.async[IO])
tests(Mutex.apply[IO])
}

"Mutex with dual constructors" should {
Expand Down Expand Up @@ -147,5 +143,71 @@ final class MutexSpec extends BaseSpec with DetectPlatform {

t.timeoutTo(executionTimeout - 1.second, IO(ko)) mustEqual (())
}

"handle multiple concurrent cancels during release" in real {
val t = mutex.flatMap { m =>
val task = for {
f1 <- m.lock.allocated
(_, f1Release) = f1
f2 <- m.lock.use_.start
_ <- IO.sleep(5.millis)
f3 <- m.lock.use_.start
_ <- IO.sleep(5.millis)
f4 <- m.lock.use_.start
_ <- IO.sleep(5.millis)
_ <- (f1Release, f2.cancel, f3.cancel).parTupled
_ <- f4.join
} yield ()

task.replicateA_(if (isJS || isNative) 5 else 1000)
}

t.timeoutTo(executionTimeout - 1.second, IO(ko)) mustEqual (())
}

"preserve waiters order (FIFO) on a non-race cancellation" in ticked { implicit ticker =>
val numbers = List.range(1, 10)
val p = (mutex, IO.ref(List.empty[Int])).flatMapN {
case (m, ref) =>
for {
f1 <- m.lock.allocated
(_, f1Release) = f1
f2 <- m.lock.use_.start
_ <- IO.sleep(1.millis)
t <- numbers.parTraverse_ { i =>
IO.sleep(i.millis) >>
m.lock.surround(ref.update(acc => i :: acc))
}.start
_ <- IO.sleep(100.millis)
_ <- f2.cancel
_ <- f1Release
_ <- t.join
r <- ref.get
} yield r.reverse
}

p must completeAs(numbers)
}

"cancellation must not corrupt Mutex" in ticked { implicit ticker =>
val p = mutex.flatMap { m =>
for {
f1 <- m.lock.allocated
(_, f1Release) = f1
f2 <- m.lock.use_.start
_ <- IO.sleep(1.millis)
f3 <- m.lock.use_.start
_ <- IO.sleep(1.millis)
f4 <- m.lock.use_.start
_ <- IO.sleep(1.millis)
_ <- f2.cancel
_ <- f3.cancel
_ <- f4.join
_ <- f1Release
} yield ()
}

p must nonTerminate
}
}
}