diff --git a/core/shared/src/main/scala/cats/effect/FiberRef.scala b/core/shared/src/main/scala/cats/effect/FiberRef.scala new file mode 100644 index 0000000000..a92149b128 --- /dev/null +++ b/core/shared/src/main/scala/cats/effect/FiberRef.scala @@ -0,0 +1,77 @@ +/* + * Copyright 2020-2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +import cats.syntax.all._ + +trait FiberRef[F[_], A] extends Ref[F, A] { + + /** + * Divorces the current reference from parent fiber and + * sets a new reference for the duration of `fa` evaluation. + */ + def locally(fa: F[A]): F[A] + +} + +object FiberRef { + + def apply[A](default: A): IO[FiberRef[IO, A]] = + for { + ref <- Ref.of[IO, A](default) + local <- FiberLocal[Ref[IO, A]](ref) + } yield new FiberRef[IO, A] { + override def locally(fa: IO[A]): IO[A] = { + val acquire = local.get.product(Ref.of[IO, A](default)).flatTap { + case (_, nextRef) => + local.set(nextRef) + } + def release(oldRef: Ref[IO, A]): IO[Unit] = + local.set(oldRef) + + acquire.bracket(_ => fa) { case (oldRef, _) => release(oldRef) } + } + + override def get: IO[A] = + local.get.flatMap(_.get) + + override def set(a: A): IO[Unit] = + local.get.flatMap(_.set(a)) + + override def access: IO[(A, A => IO[Boolean])] = + local.get.flatMap(_.access) + + override def tryUpdate(f: A => A): IO[Boolean] = + local.get.flatMap(_.tryUpdate(f)) + + override def tryModify[B](f: A => (A, B)): IO[Option[B]] = + local.get.flatMap(_.tryModify(f)) + + override def update(f: A => A): IO[Unit] = + local.get.flatMap(_.update(f)) + + override def modify[B](f: A => (A, B)): IO[B] = + local.get.flatMap(_.modify(f)) + + override def tryModifyState[B](state: cats.data.State[A, B]): IO[Option[B]] = + local.get.flatMap(_.tryModifyState(state)) + + override def modifyState[B](state: cats.data.State[A, B]): IO[B] = + local.get.flatMap(_.modifyState(state)) + } + +} diff --git a/tests/shared/src/test/scala/cats/effect/FiberRefSpec.scala b/tests/shared/src/test/scala/cats/effect/FiberRefSpec.scala new file mode 100644 index 0000000000..62fc4d90fa --- /dev/null +++ b/tests/shared/src/test/scala/cats/effect/FiberRefSpec.scala @@ -0,0 +1,97 @@ +/* + * Copyright 2020-2021 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats +package effect + +class FiberRefSpec extends BaseSpec { + + "FiberRef" should { + "return a default value" in ticked { implicit ticker => + val io = FiberRef(0).flatMap(_.get) + + io must completeAs(0) + } + + "set and get a value" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + _ <- local.set(10) + value <- local.get + } yield value + + io must completeAs(10) + } + + "preserve locals across async boundaries" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + _ <- local.set(10) + _ <- IO.cede + value <- local.get + } yield value + + io must completeAs(10) + } + + "children fibers can read locals" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + _ <- local.set(10) + f <- local.get.start + value <- f.joinWithNever + } yield value + + io must completeAs(10) + } + + "child local manipulation is visible to parents" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + f <- local.set(20).start + _ <- f.join + value <- local.get + } yield value + + io must completeAs(20) + } + + "parent local manipulation is visible to children" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + d1 <- Deferred[IO, Unit] + f <- (d1.get *> local.get).start + _ <- local.set(10) + _ <- d1.complete(()) + value <- f.joinWithNever + } yield value + + io must completeAs(10) + } + + "locally" in ticked { implicit ticker => + val io = for { + local <- FiberRef(0) + f <- (local.locally(local.set(1) >> local.get)).start + v1 <- f.joinWithNever + v2 <- local.get + } yield (v1, v2) + + io must completeAs((1, 0)) + } + } + +}