Skip to content

Commit

Permalink
Merge pull request #1393 from RaasAhsan/fiber-local
Browse files Browse the repository at this point in the history
Fiber locals
  • Loading branch information
djspiewak authored Mar 26, 2021
2 parents 65aa8fb + dc7e11d commit 6093b4a
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 0 deletions.
69 changes: 69 additions & 0 deletions core/shared/src/main/scala/cats/effect/FiberLocal.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats.effect

trait FiberLocal[F[_], A] {

def get: F[A]

def set(value: A): F[Unit]

def clear: F[Unit]

def update(f: A => A): F[Unit]

def modify[B](f: A => (A, B)): F[B]

def getAndSet(value: A): F[A]

def getAndClear: F[A]

}

object FiberLocal {

def apply[A](default: A): IO[FiberLocal[IO, A]] =
IO {
new FiberLocal[IO, A] { self =>
override def get: IO[A] =
IO.Local(state => (state, state.get(self).map(_.asInstanceOf[A]).getOrElse(default)))

override def set(value: A): IO[Unit] =
IO.Local(state => (state + (self -> value), ()))

override def clear: IO[Unit] =
IO.Local(state => (state - self, ()))

override def update(f: A => A): IO[Unit] =
get.flatMap(a => set(f(a)))

override def modify[B](f: A => (A, B)): IO[B] =
get.flatMap { a =>
val (a2, b) = f(a)
set(a2).as(b)
}

override def getAndSet(value: A): IO[A] =
get <* set(value)

override def getAndClear: IO[A] =
get <* clear

}
}

}
6 changes: 6 additions & 0 deletions core/shared/src/main/scala/cats/effect/IO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ sealed abstract class IO[+A] private () extends IOPlatform[A] {

val fiber = new IOFiber[A](
0,
Map(),
oc =>
oc.fold(
canceled,
Expand Down Expand Up @@ -1438,6 +1439,11 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits {
def tag = 20
}

private[effect] final case class Local[+A](f: IOLocalState => (IOLocalState, A))
extends IO[A] {
def tag = 21
}

// INTERNAL, only created by the runloop itself as the terminal state of several operations
private[effect] case object EndFiber extends IO[Nothing] {
def tag = -1
Expand Down
11 changes: 11 additions & 0 deletions core/shared/src/main/scala/cats/effect/IOFiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import scala.util.control.NoStackTrace
*/
private final class IOFiber[A](
initMask: Int,
initLocalState: IOLocalState,
cb: OutcomeIO[A] => Unit,
startIO: IO[A],
startEC: ExecutionContext,
Expand Down Expand Up @@ -106,6 +107,8 @@ private final class IOFiber[A](

private[this] val callbacks = new CallbackStack[A](cb)

private[this] var localState: IOLocalState = initLocalState

@volatile
private[this] var outcome: OutcomeIO[A] = _

Expand Down Expand Up @@ -759,6 +762,7 @@ private final class IOFiber[A](
val ec = currentCtx
val fiber = new IOFiber[Any](
initMask2,
localState,
null,
cur.ioa,
ec,
Expand Down Expand Up @@ -811,6 +815,13 @@ private final class IOFiber[A](
} else {
runLoop(interruptibleImpl(cur, runtime.blocking), nextIteration)
}

case 21 =>
val cur = cur0.asInstanceOf[Local[Any]]

val (nextLocalState, value) = cur.f(localState)
localState = nextLocalState
runLoop(succeeded(value, 0), nextIteration)
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions core/shared/src/main/scala/cats/effect/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,6 @@ package object effect {

type Ref[F[_], A] = cekernel.Ref[F, A]
val Ref = cekernel.Ref

private[effect] type IOLocalState = scala.collection.immutable.Map[FiberLocal[IO, _], Any]
}
86 changes: 86 additions & 0 deletions tests/shared/src/test/scala/cats/effect/FiberLocalSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2020-2021 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package cats
package effect

class FiberLocalSpec extends BaseSpec {

"Local" should {
"return a default value" in ticked { implicit ticker =>
val io = FiberLocal(0).flatMap(_.get)

io must completeAs(0)
}

"set and get a value" in ticked { implicit ticker =>
val io = for {
local <- FiberLocal(0)
_ <- local.set(10)
value <- local.get
} yield value

io must completeAs(10)
}

"preserve locals across async boundaries" in ticked { implicit ticker =>
val io = for {
local <- FiberLocal(0)
_ <- local.set(10)
_ <- IO.cede
value <- local.get
} yield value

io must completeAs(10)
}

"copy locals to children fibers" in ticked { implicit ticker =>
val io = for {
local <- FiberLocal(0)
_ <- local.set(10)
f <- local.get.start
value <- f.joinWithNever
} yield value

io must completeAs(10)
}

"child local manipulation is invisible to parents" in ticked { implicit ticker =>
val io = for {
local <- FiberLocal(10)
f <- local.set(20).start
_ <- f.join
value <- local.get
} yield value

io must completeAs(10)
}

"parent local manipulation is invisible to children" in ticked { implicit ticker =>
val io = for {
local <- FiberLocal(0)
d1 <- Deferred[IO, Unit]
f <- (d1.get *> local.get).start
_ <- local.set(10)
_ <- d1.complete(())
value <- f.joinWithNever
} yield value

io must completeAs(0)
}
}

}

0 comments on commit 6093b4a

Please sign in to comment.