Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix batched queue starvation #2264

Merged
merged 8 commits into from
Aug 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,80 @@ private final class LocalQueue {
null
}

/**
* Steals a batch of enqueued fibers and transfers the whole batch to the
* batched queue.
*
* This method is called by the runtime to restore fairness guarantees between
* fibers in the local queue compared to fibers on the overflow and batched
* queues. Every few iterations, the overflow and batched queues are checked
* for fibers and those fibers are executed. In the case of the batched queue,
* a batch of fibers might be obtained, which cannot fully fit into the local
* queue due to insufficient capacity. In that case, this method is called to
* drain one full batch of fibers, which in turn creates space for the fibers
* arriving from the batched queue.
*
* Conceptually, this method is identical to [[LocalQueue#dequeue]], with the
* main difference being that the `head` of the queue is moved forward by as
* many places as there are in a batch, thus securing all those fibers and
* transferring them to the batched queue.
*
* @note Can '''only''' be correctly called by the owner [[WorkerThread]].
*
* @param batched the batched queue to transfer a batch of fibers into
* @param random a reference to an uncontended source of randomness, to be
* passed along to the striped concurrent queues when executing
* their enqueue operations
*/
def drainBatch(batched: ScalQueue[Array[IOFiber[_]]], random: ThreadLocalRandom): Unit = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

while (true) {
// A load of the head of the queue using `acquire` semantics.
val hd = head.get()

val real = lsb(hd)

if (tl == real) {
// The tail and the "real" value of the head are equal. The queue is
// empty. There is nothing more to be done.
return
}

// Move the "real" value of the head by the size of a batch.
val newReal = unsignedShortAddition(real, OverflowBatchSize)

// Make sure to preserve the "steal" tag in the presence of a concurrent
// stealer. Otherwise, move the "steal" tag along with the "real" value.
val steal = msb(hd)
val newHd = if (steal == real) pack(newReal, newReal) else pack(steal, newReal)

if (head.compareAndSet(hd, newHd)) {
// The head has been successfully moved forward and a batch of fibers
// secured. Proceed to null out the references to the fibers and
// transfer them to the batch.
val batch = new Array[IOFiber[_]](OverflowBatchSize)
var i = 0

while (i < OverflowBatchSize) {
val idx = index(real + i)
val f = buffer(idx)
buffer(idx) = null
batch(i) = f
i += 1
}

// The fibers have been transferred, enqueue the whole batch on the
// batched queue.
batchedSpilloverCount += OverflowBatchSize
tailPublisher.lazySet(tl)
batched.offer(batch, random)
return
}
}
}

/**
* Steals all enqueued fibers and transfers them to the provided array.
*
Expand Down
49 changes: 42 additions & 7 deletions core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ private final class WorkerThread(
*
* 0: To increase the fairness towards fibers scheduled by threads which
* are external to the `WorkStealingThreadPool`, every
* `OverflowQueueTicks` number of iterations, the overflow queue takes
* precedence over the local queue.
* `OverflowQueueTicks` number of iterations, the overflow and batched
* queues take precedence over the local queue.
*
* If a fiber is successfully dequeued from the overflow queue, it will
* be executed. The `WorkerThread` unconditionally transitions to
Expand Down Expand Up @@ -232,6 +232,8 @@ private final class WorkerThread(
*/
var state = 0

var fairness: Int = 0

def parkLoop(): Unit = {
var cont = true
while (cont && !isInterrupted()) {
Expand All @@ -246,12 +248,45 @@ private final class WorkerThread(
while (!isInterrupted()) {
((state & OverflowQueueTicksMask): @switch) match {
case 0 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
// Alternate between checking the overflow and batched queues with a
// 2:1 ration in favor of the overflow queue, for now.
(fairness: @switch) match {
case 0 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
}
fairness = 1

case 1 =>
// Look into the batched queue for a batch of fibers.
val batch = batched.poll(rnd)
if (batch ne null) {
if (queue.size() > HalfLocalQueueCapacity) {
// Make room for the batch if the local queue cannot
// accommodate the batch as is.
queue.drainBatch(batched, rnd)
}

// Enqueue the batch at the back of the local queue and execute
// the first fiber.
val fiber = queue.enqueueBatch(batch)
fiber.run()
}
fairness = 2

case 2 =>
// Dequeue a fiber from the overflow queue.
val fiber = overflow.poll(rnd)
if (fiber ne null) {
// Run the fiber.
fiber.run()
}
fairness = 0
}

// Transition to executing fibers from the local queue.
state = 7

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect
package unsafe

import cats.syntax.parallel._
import cats.syntax.traverse._

import scala.concurrent.{Future, Promise}
import scala.concurrent.duration._
Expand Down Expand Up @@ -52,7 +52,7 @@ class StripedHashtableSpec extends BaseSpec with Runners {
IO(new CountDownLatch(iterations)).flatMap { counter =>
(0 until iterations)
.toList
.parTraverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.traverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.flatMap { _ => IO.fromFuture(IO.delay(counter.await())) }
.flatMap { _ =>
IO.blocking {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package cats.effect
package unsafe

import cats.syntax.parallel._
import cats.syntax.traverse._

import scala.concurrent.duration._

Expand Down Expand Up @@ -72,7 +72,7 @@ class StripedHashtableSpec extends BaseSpec with Runners {
IO(new CountDownLatch(iterations)).flatMap { counter =>
(0 until iterations)
.toList
.parTraverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.traverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.flatMap { _ => IO.blocking(counter.await()) }
.flatMap { _ =>
IO.blocking {
Expand Down
17 changes: 17 additions & 0 deletions tests/shared/src/test/scala/cats/effect/IOSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,23 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification {
}
}

"fiber repeated yielding test" in real {
def yieldUntil(ref: Ref[IO, Boolean]): IO[Unit] =
ref.get.flatMap(b => if (b) IO.unit else IO.cede *> yieldUntil(ref))

for {
n <- IO(java.lang.Runtime.getRuntime.availableProcessors)
done <- Ref.of[IO, Boolean](false)
fibers <- List.range(0, n - 1).traverse(_ => yieldUntil(done).start)
_ <- IO.unit.start.replicateA(200)
_ <- done.set(true).start
_ <- IO.unit.start.replicateA(1000)
_ <- yieldUntil(done)
_ <- fibers.traverse(_.join)
res <- IO(ok)
} yield res
}

platformSpecs
}

Expand Down