Skip to content

Commit

Permalink
Use a global ThreadLocal[IOLocalState]
Browse files Browse the repository at this point in the history
  • Loading branch information
armanbilge committed Dec 13, 2024
1 parent a01fba0 commit 5803aa9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
23 changes: 3 additions & 20 deletions core/jvm/src/main/scala/cats/effect/IOLocalPlatform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,33 +30,16 @@ private[effect] trait IOLocalPlatform[A] { self: IOLocal[A] =>
*/
def unsafeThreadLocal(): ThreadLocal[A] = if (ioLocalPropagation)
new ThreadLocal[A] {
override def initialValue(): A = self.getOrDefault(IOLocalState.empty)

override def get(): A = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
self.getOrDefault(fiber.getLocalState())
} else {
super.get()
}
self.getOrDefault(IOLocal.getThreadLocalState())
}

override def set(value: A): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.set(fiber.getLocalState(), value))
} else {
super.set(value)
}
IOLocal.setThreadLocalState(self.set(IOLocal.getThreadLocalState(), value))
}

override def remove(): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) {
fiber.setLocalState(self.reset(fiber.getLocalState()))
} else {
super.remove()
}
IOLocal.setThreadLocalState(self.reset(IOLocal.getThreadLocalState()))
}
}
else
Expand Down
15 changes: 11 additions & 4 deletions core/shared/src/main/scala/cats/effect/IOLocal.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,14 +273,21 @@ object IOLocal {
*/
def isPropagating: Boolean = IOFiberConstants.ioLocalPropagation

private[effect] def getThreadLocalState() = {
private[effect] val threadLocal = new ThreadLocal[IOLocalState] {
override def initialValue() = IOLocalState.empty
}

private[effect] def getThreadLocalState(): IOLocalState = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) fiber.getLocalState() else IOLocalState.empty
if (fiber ne null) fiber.getLocalState() else threadLocal.get()
}

private[effect] def setThreadLocalState(state: IOLocalState) = {
private[effect] def setThreadLocalState(state: IOLocalState): Unit = {
val fiber = IOFiber.currentIOFiber()
if (fiber ne null) fiber.setLocalState(state)
if (fiber ne null)
fiber.setLocalState(state)
else
threadLocal.set(state)
}

private final class IOLocalImpl[A](default: A) extends IOLocal[A] {
Expand Down

0 comments on commit 5803aa9

Please sign in to comment.