Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Even faster async mutex #3409

Merged
merged 10 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ class MutexBenchmark {
var iterations: Int = _

private def happyPathImpl(mutex: IO[Mutex[IO]]): Unit = {
mutex
.flatMap { m => m.lock.use_.replicateA_(fibers) }
.replicateA_(iterations)
.unsafeRunSync()
mutex.flatMap { m => m.lock.use_.replicateA_(fibers * iterations) }.unsafeRunSync()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what is the rationale for this change? And why only to the happy path?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I misunderstood the purpose of the benchmark, but we want to replicate many acquire/releases of the same mutex from the same fiber—we don't need to allocate a new mutex in each iteration, and "fibers" is not really accurate term. It's just iterations in the end.

Actually you are right, we can probably make a similar change to the other benchmarks.

}

@Benchmark
Expand Down
9 changes: 6 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -906,11 +906,14 @@ lazy val std = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.std.Queue#DroppingQueue.onOfferNoCapacity"),
// introduced by #3346
// private stuff
ProblemFilters.exclude[MissingClassProblem](
"cats.effect.std.Mutex$Impl"),
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.Mutex$Impl"),
// introduced by #3347
// private stuff
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.AtomicCell$Impl")
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.AtomicCell$Impl"),
// introduced by #3409
// extracted UnsafeUnbounded private data structure
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.Queue$UnsafeUnbounded"),
ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.Queue$UnsafeUnbounded$Cell")
)
)
.jsSettings(
Expand Down
120 changes: 43 additions & 77 deletions std/shared/src/main/scala/cats/effect/std/Mutex.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package std
import cats.effect.kernel._
import cats.syntax.all._

import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.atomic.AtomicBoolean

/**
* A purely functional mutex.
Expand Down Expand Up @@ -84,9 +84,7 @@ object Mutex {
* Creates a new `Mutex`. Like `apply` but initializes state using another effect constructor.
*/
def in[F[_], G[_]](implicit F: Sync[F], G: Async[G]): F[Mutex[G]] =
F.delay(
new AtomicReference[LockCell]()
).map(state => new AsyncImpl[G](state)(G))
F.delay(new AsyncImpl[G])

private final class ConcurrentImpl[F[_]](sem: Semaphore[F]) extends Mutex[F] {
override final val lock: Resource[F, Unit] =
Expand All @@ -96,92 +94,65 @@ object Mutex {
new ConcurrentImpl(sem.mapK(f))
}

private final class AsyncImpl[F[_]](state: AtomicReference[LockCell])(implicit F: Async[F])
extends Mutex[F] {
// Cancels a Fiber waiting for the Mutex.
private def cancel(thisCB: CB, thisCell: LockCell, previousCell: LockCell): F[Unit] =
F.delay {
// If we are canceled.
// First, we check if the state still contains ourselves,
// if that is the case, we swap it with the previousCell.
// This ensures any consequent attempt to acquire the Mutex
// will register its callback on the appropriate cell.
// Additionally, that confirms there is no Fiber
// currently waiting for us.
if (!state.compareAndSet(thisCell, previousCell)) {
// Otherwise,
// it means we have a Fiber waiting for us.
// Thus, we need to tell the previous cell
// to awake that Fiber instead.
var nextCB = thisCell.get()
while (nextCB eq null) {
// There is a tiny fraction of time when
// the next cell has acquired ourselves,
// but hasn't registered itself yet.
// Thus, we spin loop until that happens
nextCB = thisCell.get()
}
if (!previousCell.compareAndSet(thisCB, nextCB)) {
// However, in case the previous cell had already completed,
// then the Mutex is free and we can awake our waiting fiber.
if (nextCB ne null) nextCB.apply(Either.unit)
}
}
}
private final class AsyncImpl[F[_]](implicit F: Async[F]) extends Mutex[F] {
import AsyncImpl._

// Awaits until the Mutex is free.
private def await(thisCell: LockCell): F[Unit] =
F.asyncCheckAttempt[Unit] { thisCB =>
F.delay {
val previousCell = state.getAndSet(thisCell)
private[this] val locked = new AtomicBoolean(false)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we swap this AtomicBoolean for an AtomicInteger then do we basically have an AsyncSemaphore?

private[this] val waiters = new UnsafeUnbounded[Either[Throwable, Boolean] => Unit]

if (previousCell eq null) {
// If the previous cell was null,
// then the Mutex is free.
Either.unit
private[this] val acquire: F[Unit] = F
.asyncCheckAttempt[Boolean] { cb =>
F.delay {
if (locked.compareAndSet(false, true)) { // acquired
RightTrue
} else {
// Otherwise,
// we check again that the previous cell haven't been completed yet,
// if not we tell the previous cell to awake us when they finish.
if (!previousCell.compareAndSet(null, thisCB)) {
// If it was already completed,
// then the Mutex is free.
Either.unit
val cancel = waiters.put(cb)
if (locked.compareAndSet(false, true)) { // try again
cancel()
RightTrue
} else {
Left(Some(cancel(thisCB, thisCell, previousCell)))
Left(Some(F.delay(cancel())))
}
}
}
}

// Acquires the Mutex.
private def acquire(poll: Poll[F]): F[LockCell] =
F.delay(new AtomicReference[CB]()).flatMap { thisCell =>
poll(await(thisCell).map(_ => thisCell))
.flatMap { acquired =>
if (acquired) F.unit // home free
else acquire // wokened, but need to acquire
}

// Releases the Mutex.
private def release(thisCell: LockCell): F[Unit] =
F.delay {
// If the state still contains our own cell,
// then it means nobody was waiting for the Mutex,
// and thus it can be put on a free state again.
if (!state.compareAndSet(thisCell, null)) {
// Otherwise,
// our cell is probably not empty,
// we must awake whatever Fiber is waiting for us.
val nextCB = thisCell.getAndSet(Sentinel)
if (nextCB ne null) nextCB.apply(Either.unit)
}
private[this] val _release: F[Unit] = F.delay {
try { // look for a waiter
var waiter = waiters.take()
while (waiter eq null) waiter = waiters.take()
waiter(RightTrue) // pass the buck
} catch { // no waiter found
case FailureSignal =>
locked.set(false) // release
try {
var waiter = waiters.take()
while (waiter eq null) waiter = waiters.take()
waiter(RightFalse) // waken any new waiters
} catch {
Comment on lines +130 to +137
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some fairness corruption here under contention, where an acquirer may cut-in-line of an acquirer that had placed itself in the queue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the FIFO semantics of the current Mutex something we would like to preserve? If so, should we be louder about it on the docs?

BTW, does the Semaphore based one guarantee that as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it's still FIFO (or should be, we should add a test if we don't have already). It's just under contention it can be slightly corrupted—but under contention, who was really "first" anyway?

I think this was the long-running debate Daniel and Fabio had for the async queue :)

case FailureSignal => // do nothing
}
}
}

private[this] val release: Unit => F[Unit] = _ => _release

override final val lock: Resource[F, Unit] =
Resource.makeFull[F, LockCell](acquire)(release).void
Resource.makeFull[F, Unit](poll => poll(acquire))(release)

override def mapK[G[_]](f: F ~> G)(implicit G: MonadCancel[G, _]): Mutex[G] =
new Mutex.TransformedMutex(this, f)
}

private object AsyncImpl {
private val RightTrue = Right(true)
private val RightFalse = Right(false)
}

private final class TransformedMutex[F[_], G[_]](
underlying: Mutex[F],
f: F ~> G
Expand All @@ -194,9 +165,4 @@ object Mutex {
new Mutex.TransformedMutex(this, f)
}

private type CB = Either[Throwable, Unit] => Unit

private final val Sentinel: CB = _ => ()

private type LockCell = AtomicReference[CB]
}
95 changes: 1 addition & 94 deletions std/shared/src/main/scala/cats/effect/std/Queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.annotation.tailrec
import scala.collection.immutable.{Queue => ScalaQueue}
import scala.collection.mutable.ListBuffer

import java.util.concurrent.atomic.{AtomicLong, AtomicLongArray, AtomicReference}
import java.util.concurrent.atomic.{AtomicLong, AtomicLongArray}

/**
* A purely functional, concurrent data structure which allows insertion and retrieval of
Expand Down Expand Up @@ -492,8 +492,6 @@ object Queue {
}

private val EitherUnit: Either[Nothing, Unit] = Right(())
private val FailureSignal: Throwable = new RuntimeException
with scala.util.control.NoStackTrace

/*
* Does not correctly handle bound = 0 because take waiters are async[Unit]
Expand Down Expand Up @@ -1039,97 +1037,6 @@ object Queue {
((idx & Int.MaxValue) % bound).toInt
}

final class UnsafeUnbounded[A] {
private[this] val first = new AtomicReference[Cell]
private[this] val last = new AtomicReference[Cell]

def size(): Int = {
var current = first.get()
var count = 0
while (current != null) {
count += 1
current = current.get()
}
count
}

def put(data: A): () => Unit = {
val cell = new Cell(data)

val prevLast = last.getAndSet(cell)

if (prevLast eq null)
first.set(cell)
else
prevLast.set(cell)

cell
}

@tailrec
def take(): A = {
val taken = first.get()
if (taken ne null) {
val next = taken.get()
if (first.compareAndSet(taken, next)) { // WINNING
if ((next eq null) && !last.compareAndSet(taken, null)) {
// we emptied the first, but someone put at the same time
// in this case, they might have seen taken in the last slot
// at which point they would *not* fix up the first pointer
// instead of fixing first, they would have written into taken
// so we fix first for them. but we might be ahead, so we loop
// on taken.get() to wait for them to make it not-null

var next2 = taken.get()
while (next2 eq null) {
next2 = taken.get()
}

first.set(next2)
}

val ret = taken.data()
taken() // Attempt to clear out data we've consumed
ret
} else {
take() // We lost, try again
}
} else {
if (last.get() ne null) {
take() // Waiting for prevLast.set(cell), so recurse
} else {
throw FailureSignal
}
}
}

def debug(): String = {
val f = first.get()

if (f == null) {
"[]"
} else {
f.debug()
}
}

private final class Cell(private[this] final var _data: A)
extends AtomicReference[Cell]
with (() => Unit) {

def data(): A = _data

final override def apply(): Unit = {
_data = null.asInstanceOf[A] // You want a lazySet here
}

def debug(): String = {
val tail = get()
s"${_data} -> ${if (tail == null) "[]" else tail.debug()}"
}
}
}

implicit def catsInvariantForQueue[F[_]: Functor]: Invariant[Queue[F, *]] =
new Invariant[Queue[F, *]] {
override def imap[A, B](fa: Queue[F, A])(f: A => B)(g: B => A): Queue[F, B] =
Expand Down
Loading