Skip to content

Commit

Permalink
Merge pull request #1913 from vasilmkd/striped
Browse files Browse the repository at this point in the history
Striped fiber callback hashtable
  • Loading branch information
djspiewak authored Apr 24, 2021
2 parents 490949f + 8378a42 commit 4e9c92a
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 84 deletions.
7 changes: 6 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,12 @@ lazy val core = crossProject(JSPlatform, JVMPlatform)
mimaBinaryIssueFilters ++= Seq(
// introduced by #1837, removal of package private class
ProblemFilters.exclude[MissingClassProblem]("cats.effect.AsyncPropagateCancelation"),
ProblemFilters.exclude[MissingClassProblem]("cats.effect.AsyncPropagateCancelation$")
ProblemFilters.exclude[MissingClassProblem]("cats.effect.AsyncPropagateCancelation$"),
// introduced by #1913, striped fiber callback hashtable, changes to package private code
ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.FiberErrorHashtable"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("cats.effect.unsafe.IORuntime.fiberErrorCbs"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("cats.effect.unsafe.IORuntime.this"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("cats.effect.unsafe.IORuntime.<init>$default$6")
)
)
.jvmSettings(
Expand Down
27 changes: 18 additions & 9 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,18 +137,27 @@ private final class IOFiber[A](
case t: Throwable =>
Thread.interrupted()
currentCtx.reportFailure(t)
runtime.fiberErrorCbs.synchronized {
var idx = 0
val len = runtime.fiberErrorCbs.hashtable.length
while (idx < len) {
val cb = runtime.fiberErrorCbs.hashtable(idx)
if (cb != null) {
cb(t)
runtime.shutdown()

var idx = 0
val tables = runtime.fiberErrorCbs.tables
val numTables = runtime.fiberErrorCbs.numTables
while (idx < numTables) {
val table = tables(idx).hashtable
val len = table.length
table.synchronized {
var i = 0
while (i < len) {
val cb = table(i)
if (cb ne null) {
cb(t)
}
i += 1
}
idx += 1
}
idx += 1
}
runtime.shutdown()

Thread.currentThread().interrupt()
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ final class IORuntime private[effect] (
val scheduler: Scheduler,
val shutdown: () => Unit,
val config: IORuntimeConfig,
private[effect] val fiberErrorCbs: FiberErrorHashtable = new FiberErrorHashtable(16)
private[effect] val fiberErrorCbs: StripedHashtable = new StripedHashtable()
) {
override def toString: String = s"IORuntime($compute, $scheduler, $config)"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect.unsafe

/**
* A conceptual hash table which balances between several
* [[ThreadSafeHashtable]]s, in order to reduce the contention on the single
* lock by spreading it to several different locks controlling parts of the
* hash table.
*/
private[effect] final class StripedHashtable {
val numTables: Int = {
val cpus = Runtime.getRuntime().availableProcessors()
// Bit twiddling hacks.
// http://graphics.stanford.edu/~seander/bithacks.html#RoundUpPowerOf2
var value = cpus - 1
value |= value >> 1
value |= value >> 2
value |= value >> 4
value |= value >> 8
value |= value >> 16
value + 1
}

private[this] val mask: Int = numTables - 1

private[this] val initialCapacity: Int = 8

val tables: Array[ThreadSafeHashtable] = {
val array = new Array[ThreadSafeHashtable](numTables)
var i = 0
while (i < numTables) {
array(i) = new ThreadSafeHashtable(initialCapacity)
i += 1
}
array
}

def put(cb: Throwable => Unit): Unit = {
val hash = System.identityHashCode(cb)
val idx = hash & mask
tables(idx).put(cb, hash)
}

def remove(cb: Throwable => Unit): Unit = {
val hash = System.identityHashCode(cb)
val idx = hash & mask
tables(idx).remove(cb, hash)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package unsafe

/**
* A primitive thread safe hash table implementation specialized for a single
* purpose, to hold references to the error callbacks of fibers. The hashing
* function is [[System.identityHashCode]] simply because the callbacks are
* functions and therefore have no defined notion of [[Object#hashCode]]. The
* thread safety is achieved by pessimistically locking the whole structure.
* This is fine in practice because this data structure is only accessed when
* running [[cats.effect.IO#unsafeRunFiber]], which is not expected to be
* executed often in a realistic system.
*
* @param initialCapacity the initial capacity of the hashtable, ''must'' be a
* power of 2
*/
private[effect] final class ThreadSafeHashtable(initialCapacity: Int) {
var hashtable: Array[Throwable => Unit] = new Array(initialCapacity)
private[this] var size = 0
private[this] var mask = initialCapacity - 1
private[this] var capacity = initialCapacity

def put(cb: Throwable => Unit, hash: Int): Unit = this.synchronized {
val cap = capacity
if (size == cap) {
val newCap = cap * 2
val newHashtable = new Array[Throwable => Unit](newCap)
System.arraycopy(hashtable, 0, newHashtable, 0, cap)
hashtable = newHashtable
mask = newCap - 1
capacity = newCap
}

var idx = hash & mask
while (true) {
if (hashtable(idx) == null) {
hashtable(idx) = cb
size += 1
return
} else {
idx += 1
idx &= mask
}
}
}

def remove(cb: Throwable => Unit, hash: Int): Unit = this.synchronized {
val init = hash & mask
var idx = init
while (true) {
if (cb eq hashtable(idx)) {
hashtable(idx) = null
size -= 1
return
} else {
idx += 1
idx &= mask
if (idx == init) {
return
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect
package unsafe

import cats.syntax.parallel._

import scala.concurrent.duration._

import java.util.concurrent.CountDownLatch

class StripedHashtableSpec extends BaseSpec with Runners {

override def executionTimeout: FiniteDuration = 30.seconds

def hashtableRuntime(): IORuntime = {
lazy val rt: IORuntime = {
val (blocking, blockDown) =
IORuntime.createDefaultBlockingExecutionContext(threadPrefix =
s"io-blocking-${getClass.getName}")
val (scheduler, schedDown) =
IORuntime.createDefaultScheduler(threadPrefix = s"io-scheduler-${getClass.getName}")
val (compute, compDown) =
IORuntime.createDefaultComputeThreadPool(
rt,
threadPrefix = s"io-compute-${getClass.getName}")

new IORuntime(
compute,
blocking,
scheduler,
{ () =>
compDown()
blockDown()
schedDown()
},
IORuntimeConfig()
)
}

rt
}

"StripedHashtable" should {
"work correctly in the presence of many unsafeRuns" in real {
val iterations = 1000000

object Boom extends RuntimeException("Boom!")

def io(n: Int): IO[Unit] =
(n % 3) match {
case 0 => IO.unit
case 1 => IO.canceled
case 2 => IO.raiseError[Unit](Boom)
}

Resource.make(IO(hashtableRuntime()))(rt => IO(rt.shutdown())).use { rt =>
IO(new CountDownLatch(iterations)).flatMap { counter =>
(0 until iterations)
.toList
.parTraverse { n => IO(io(n).unsafeRunAsync { _ => counter.countDown() }(rt)) }
.flatMap { _ => IO.blocking(counter.await()) }
.flatMap { _ =>
IO.blocking {
rt.fiberErrorCbs.synchronized {
rt.fiberErrorCbs.tables.forall(_.hashtable.forall(_ eq null)) mustEqual true
}
}
}
}
}
}
}
}

0 comments on commit 4e9c92a

Please sign in to comment.