diff --git a/core/shared/src/main/scala/cats/effect/concurrent/MVar.scala b/core/shared/src/main/scala/cats/effect/concurrent/MVar.scala index a23f5c2815..12c9142c0d 100644 --- a/core/shared/src/main/scala/cats/effect/concurrent/MVar.scala +++ b/core/shared/src/main/scala/cats/effect/concurrent/MVar.scala @@ -124,6 +124,7 @@ abstract class MVar[F[_], A] extends MVarDocumentation { /** * Modify the context `F` using transformation `f`. */ + @deprecated("`mapK` is deprecated in favor of `imapK` which supports the new invariant `MVar2` interface", "2.2.0") def mapK[G[_]](f: F ~> G): MVar[G, A] = new TransformedMVar(this, f) } @@ -139,26 +140,64 @@ abstract class MVar2[F[_], A] extends MVar[F, A] { /** * Replaces a value in MVar and returns the old value. - + * + * @note This operation is only safe from deadlocks if there are no other producers for this `MVar`. + * * @param newValue is a new value * @return the value taken */ def swap(newValue: A): F[A] /** - * Returns the value without waiting or modifying. - * - * This operation is atomic. + * A non-blocking version of [[read]]. * * @return an Option holding the current value, None means it was empty */ def tryRead: F[Option[A]] /** - * Modify the context `F` using transformation `f`. + * Applies the effectful function `f` on the contents of this `MVar`. In case of failure, it sets the contents of the + * `MVar` to the original value. + * + * @note This operation is only safe from deadlocks if there are no other producers for this `MVar`. + * + * @param f effectful function that operates on the contents of this `MVar` + * @return the value produced by applying `f` to the contents of this `MVar` + */ + def use[B](f: A => F[B]): F[B] + + /** + * Modifies the contents of the `MVar` using the effectful function `f`, but also allows for returning a value derived + * from the original contents of the `MVar`. Like [[use]], in case of failure, it sets the contents of the `MVar` to + * the original value. + * + * @note This operation is only safe from deadlocks if there are no other producers for this `MVar`. + * + * @param f effectful function that operates on the contents of this `MVar` + * @return the second value produced by applying `f` to the contents of this `MVar` + */ + def modify[B](f: A => F[(A, B)]): F[B] + + /** + * Modifies the contents of the `MVar` using the effectful function `f`. Like [[use]], in case of failure, it sets the + * contents of the `MVar` to the original value. + * + * @note This operation is only safe from deadlocks if there are no other producers for this `MVar`. + * + * @param f effectful function that operates on the contents of this `MVar` + * @return no useful value. Executed only for the effects. + */ + def modify_(f: A => F[A]): F[Unit] + + /** + * Modify the context `F` using natural isomorphism `f` with `g`. + * + * @param f functor transformation from `F` to `G` + * @param g functor transformation from `G` to `F` + * @return `MVar2` with a modified context `G` derived using a natural isomorphism from `F` */ - override def mapK[G[_]](f: F ~> G): MVar2[G, A] = - new TransformedMVar2(this, f) + def imapK[G[_]](f: F ~> G, g: G ~> F): MVar2[G, A] = + new TransformedMVar2(this, f, g) } /** Builders for [[MVar]]. */ @@ -304,7 +343,9 @@ object MVar { override def read: G[A] = trans(underlying.read) } - final private[concurrent] class TransformedMVar2[F[_], G[_], A](underlying: MVar2[F, A], trans: F ~> G) + final private[concurrent] class TransformedMVar2[F[_], G[_], A](underlying: MVar2[F, A], + trans: F ~> G, + inverse: G ~> F) extends MVar2[G, A] { override def isEmpty: G[Boolean] = trans(underlying.isEmpty) override def put(a: A): G[Unit] = trans(underlying.put(a)) @@ -314,5 +355,8 @@ object MVar { override def read: G[A] = trans(underlying.read) override def tryRead: G[Option[A]] = trans(underlying.tryRead) override def swap(newValue: A): G[A] = trans(underlying.swap(newValue)) + override def use[B](f: A => G[B]): G[B] = trans(underlying.use(a => inverse(f(a)))) + override def modify[B](f: A => G[(A, B)]): G[B] = trans(underlying.modify(a => inverse(f(a)))) + override def modify_(f: A => G[A]): G[Unit] = trans(underlying.modify_(a => inverse(f(a)))) } } diff --git a/core/shared/src/main/scala/cats/effect/internals/MVarAsync.scala b/core/shared/src/main/scala/cats/effect/internals/MVarAsync.scala index 5459b697b9..c636e6b9ec 100644 --- a/core/shared/src/main/scala/cats/effect/internals/MVarAsync.scala +++ b/core/shared/src/main/scala/cats/effect/internals/MVarAsync.scala @@ -77,6 +77,20 @@ final private[effect] class MVarAsync[F[_], A] private (initial: MVarAsync.State } } + def use[B](f: A => F[B]): F[B] = + modify(a => F.map(f(a))((a, _))) + + def modify[B](f: A => F[(A, B)]): F[B] = + F.flatMap(take) { a => + F.flatMap(F.onError(f(a)) { case _ => put(a) }) { + case (newA, b) => + F.as(put(newA), b) + } + } + + def modify_(f: A => F[A]): F[Unit] = + modify(a => F.map(f(a))((_, ()))) + @tailrec private def unsafeTryPut(a: A): F[Boolean] = stateRef.get match { diff --git a/core/shared/src/main/scala/cats/effect/internals/MVarConcurrent.scala b/core/shared/src/main/scala/cats/effect/internals/MVarConcurrent.scala index 0662c0160a..9b83c8b5a0 100644 --- a/core/shared/src/main/scala/cats/effect/internals/MVarConcurrent.scala +++ b/core/shared/src/main/scala/cats/effect/internals/MVarConcurrent.scala @@ -19,7 +19,7 @@ package internals import java.util.concurrent.atomic.AtomicReference -import cats.effect.concurrent.MVar2 +import cats.effect.concurrent.{MVar2, Ref} import cats.effect.internals.Callback.rightUnit import scala.annotation.tailrec @@ -77,10 +77,35 @@ final private[effect] class MVarConcurrent[F[_], A] private (initial: MVarConcur } def swap(newValue: A): F[A] = - F.flatMap(take) { oldValue => - F.map(put(newValue))(_ => oldValue) + F.continual(take) { + case Left(t) => F.raiseError(t) + case Right(oldValue) => F.as(put(newValue), oldValue) } + def use[B](f: A => F[B]): F[B] = + modify(a => F.map(f(a))((a, _))) + + def modify[B](f: A => F[(A, B)]): F[B] = + F.bracket(Ref[F].of[Option[A]](None)) { signal => + F.flatMap(F.continual[A, A](take) { + case Left(t) => F.raiseError(t) + case Right(a) => F.as(signal.set(Some(a)), a) + }) { a => + F.continual[(A, B), B](f(a)) { + case Left(t) => F.raiseError(t) + case Right((newA, b)) => F.as(signal.set(Some(newA)), b) + } + } + } { signal => + F.flatMap(signal.get) { + case Some(a) => put(a) + case None => F.unit + } + } + + def modify_(f: A => F[A]): F[Unit] = + modify(a => F.map(f(a))((_, ()))) + @tailrec private def unsafeTryPut(a: A): F[Boolean] = stateRef.get match { case WaitForTake(_, _) => F.pure(false) diff --git a/core/shared/src/test/scala/cats/effect/concurrent/MVarTests.scala b/core/shared/src/test/scala/cats/effect/concurrent/MVarTests.scala index ad27071218..451c9b5a3e 100644 --- a/core/shared/src/test/scala/cats/effect/concurrent/MVarTests.scala +++ b/core/shared/src/test/scala/cats/effect/concurrent/MVarTests.scala @@ -84,6 +84,55 @@ class MVarConcurrentTests extends BaseMVarTests { r shouldBe Right(0) } } + + test("swap is cancelable on take") { + val task = for { + mVar <- empty[Int] + finished <- Deferred.uncancelable[IO, Int] + fiber <- mVar.swap(20).flatMap(finished.complete).start + _ <- fiber.cancel + _ <- mVar.put(10) + fallback = IO.sleep(100.millis) *> mVar.take + v <- IO.race(finished.get, fallback) + } yield v + + for (r <- task.unsafeToFuture()) yield { + r shouldBe Right(10) + } + } + + test("modify is cancelable on take") { + val task = for { + mVar <- empty[Int] + finished <- Deferred.uncancelable[IO, String] + fiber <- mVar.modify(n => IO.pure((n * 2, n.show))).flatMap(finished.complete).start + _ <- fiber.cancel + _ <- mVar.put(10) + fallback = IO.sleep(100.millis) *> mVar.take + v <- IO.race(finished.get, fallback) + } yield v + + for (r <- task.unsafeToFuture()) yield { + r shouldBe Right(10) + } + } + + test("modify is cancelable on f") { + val task = for { + mVar <- empty[Int] + finished <- Deferred.uncancelable[IO, String] + fiber <- mVar.modify(n => IO.never *> IO.pure((n * 2, n.show))).flatMap(finished.complete).start + _ <- mVar.put(10) + _ <- IO.sleep(10.millis) + _ <- fiber.cancel + fallback = IO.sleep(100.millis) *> mVar.take + v <- IO.race(finished.get, fallback) + } yield v + + for (r <- task.unsafeToFuture()) yield { + r shouldBe Right(10) + } + } } class MVarAsyncTests extends BaseMVarTests { @@ -424,4 +473,36 @@ abstract class BaseMVarTests extends AsyncFunSuite with Matchers { r shouldBe count * 2 } } + + test("put; take; modify; put") { + val task = for { + mVar <- empty[Int] + _ <- mVar.put(10) + _ <- mVar.take + fiber <- mVar.modify(n => IO.pure((n * 2, n.toString))).start + _ <- mVar.put(20) + s <- fiber.join + v <- mVar.take + } yield (s, v) + + for (r <- task.unsafeToFuture()) yield { + r shouldBe ("20" -> 40) + } + } + + test("modify replaces the original value of the mvar on error") { + val error = new Exception("Boom!") + val task = for { + mVar <- empty[Int] + _ <- mVar.put(10) + finished <- Deferred.uncancelable[IO, String] + e <- mVar.modify(_ => IO.raiseError(error)).attempt + fallback = IO.sleep(100.millis) *> mVar.take + v <- IO.race(finished.get, fallback) + } yield (e, v) + + for (r <- task.unsafeToFuture()) yield { + r shouldBe (Left(error) -> Right(10)) + } + } }