Skip to content

Commit

Permalink
Merge pull request #2264 from vasilmkd/batched-starvation
Browse files Browse the repository at this point in the history
Fix batched queue starvation
  • Loading branch information
djspiewak authored Aug 27, 2021
2 parents 0a9f113 + 701eb9f commit 0b12f95
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 11 deletions.
74 changes: 74 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,80 @@ private final class LocalQueue {
null
}

/**
* Steals a batch of enqueued fibers and transfers the whole batch to the
* batched queue.
*
* This method is called by the runtime to restore fairness guarantees between
* fibers in the local queue compared to fibers on the overflow and batched
* queues. Every few iterations, the overflow and batched queues are checked
* for fibers and those fibers are executed. In the case of the batched queue,
* a batch of fibers might be obtained, which cannot fully fit into the local
* queue due to insufficient capacity. In that case, this method is called to
* drain one full batch of fibers, which in turn creates space for the fibers
* arriving from the batched queue.
*
* Conceptually, this method is identical to [[LocalQueue#dequeue]], with the
* main difference being that the `head` of the queue is moved forward by as
* many places as there are in a batch, thus securing all those fibers and
* transferring them to the batched queue.
*
* @note Can '''only''' be correctly called by the owner [[WorkerThread]].
*
* @param batched the batched queue to transfer a batch of fibers into
* @param random a reference to an uncontended source of randomness, to be
* passed along to the striped concurrent queues when executing
* their enqueue operations
*/
def drainBatch(batched: ScalQueue[Array[IOFiber[_]]], random: ThreadLocalRandom): Unit = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

while (true) {
// A load of the head of the queue using `acquire` semantics.
val hd = head.get()

val real = lsb(hd)

if (tl == real) {
// The tail and the "real" value of the head are equal. The queue is
// empty. There is nothing more to be done.
return
}

// Move the "real" value of the head by the size of a batch.
val newReal = unsignedShortAddition(real, OverflowBatchSize)

// Make sure to preserve the "steal" tag in the presence of a concurrent
// stealer. Otherwise, move the "steal" tag along with the "real" value.
val steal = msb(hd)
val newHd = if (steal == real) pack(newReal, newReal) else pack(steal, newReal)

if (head.compareAndSet(hd, newHd)) {
// The head has been successfully moved forward and a batch of fibers
// secured. Proceed to null out the references to the fibers and
// transfer them to the batch.
val batch = new Array[IOFiber[_]](OverflowBatchSize)
var i = 0

while (i < OverflowBatchSize) {
val idx = index(real + i)
val f = buffer(idx)
buffer(idx) = null
batch(i) = f
i += 1
}

// The fibers have been transferred, enqueue the whole batch on the
// batched queue.
batchedSpilloverCount += OverflowBatchSize
tailPublisher.lazySet(tl)
batched.offer(batch, random)
return
}
}
}

/**
* Steals all enqueued fibers and transfers them to the provided array.
*
Expand Down
49 changes: 42 additions & 7 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ private final class WorkerThread(
*
* 0: To increase the fairness towards fibers scheduled by threads which
* are external to the `WorkStealingThreadPool`, every
* `OverflowQueueTicks` number of iterations, the overflow queue takes
* precedence over the local queue.
* `OverflowQueueTicks` number of iterations, the overflow and batched
* queues take precedence over the local queue.
*
* If a fiber is successfully dequeued from the overflow queue, it will
* be executed. The `WorkerThread` unconditionally transitions to
Expand Down Expand Up @@ -232,6 +232,8 @@ private final class WorkerThread(
*/
var state = 0

var fairness: Int = 0

def parkLoop(): Unit = {
var cont = true
while (cont && !isInterrupted()) {
Expand All @@ -246,12 +248,45 @@ private final class WorkerThread(
while (!isInterrupted()) {
((state & OverflowQueueTicksMask): @switch) match {
case 0 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
// Alternate between checking the overflow and batched queues with a
// 2:1 ration in favor of the overflow queue, for now.
(fairness: @switch) match {
case 0 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
}
fairness = 1

case 1 =>
// Look into the batched queue for a batch of fibers.
val batch = batched.poll(rnd)
if (batch ne null) {
if (queue.size() > HalfLocalQueueCapacity) {
// Make room for the batch if the local queue cannot
// accommodate the batch as is.
queue.drainBatch(batched, rnd)
}

// Enqueue the batch at the back of the local queue and execute
// the first fiber.
val fiber = queue.enqueueBatch(batch)
fiber.run()
}
fairness = 2

case 2 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
}
fairness = 0
}

// Transition to executing fibers from the local queue.
state = 7

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect
package unsafe

import cats.syntax.parallel._
import cats.syntax.traverse._

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration._
Expand Down Expand Up @@ -52,7 +52,7 @@ class StripedHashtableSpec extends BaseSpec with Runners {
IO(new CountDownLatch(iterations)).flatMap { counter =>
(0 until iterations)
.toList
.parTraverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.traverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.flatMap { _ => IO.fromFuture(IO.delay(counter.await())) }
.flatMap { _ =>
IO.blocking {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect
package unsafe

import cats.syntax.parallel._
import cats.syntax.traverse._

import scala.concurrent.duration._

Expand Down Expand Up @@ -72,7 +72,7 @@ class StripedHashtableSpec extends BaseSpec with Runners {
IO(new CountDownLatch(iterations)).flatMap { counter =>
(0 until iterations)
.toList
.parTraverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.traverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.flatMap { _ => IO.blocking(counter.await()) }
.flatMap { _ =>
IO.blocking {
Expand Down
17 changes: 17 additions & 0 deletions tests/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,23 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification {
}
}

"fiber repeated yielding test" in real {
def yieldUntil(ref: Ref[IO, Boolean]): IO[Unit] =
ref.get.flatMap(b => if (b) IO.unit else IO.cede *> yieldUntil(ref))

for {
n <- IO(java.lang.Runtime.getRuntime.availableProcessors)
done <- Ref.of[IO, Boolean](false)
fibers <- List.range(0, n - 1).traverse(_ => yieldUntil(done).start)
_ <- IO.unit.start.replicateA(200)
_ <- done.set(true).start
_ <- IO.unit.start.replicateA(1000)
_ <- yieldUntil(done)
_ <- fibers.traverse(_.join)
res <- IO(ok)
} yield res
}

platformSpecs
}

Expand Down

0 comments on commit 0b12f95

Please sign in to comment.