Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete the MVar2 interface #912

Merged
merged 1 commit into from
Jul 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions core/shared/src/main/scala/cats/effect/concurrent/MVar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ abstract class MVar[F[_], A] extends MVarDocumentation {
/**
* Modify the context `F` using transformation `f`.
*/
@deprecated("`mapK` is deprecated in favor of `imapK` which supports the new invariant `MVar2` interface", "2.2.0")
def mapK[G[_]](f: F ~> G): MVar[G, A] =
new TransformedMVar(this, f)
}
Expand All @@ -139,26 +140,64 @@ abstract class MVar2[F[_], A] extends MVar[F, A] {

/**
* Replaces a value in MVar and returns the old value.

*
* @note This operation is only safe from deadlocks if there are no other producers for this `MVar`.
*
* @param newValue is a new value
* @return the value taken
*/
def swap(newValue: A): F[A]

/**
* Returns the value without waiting or modifying.
*
* This operation is atomic.
* A non-blocking version of [[read]].
*
* @return an Option holding the current value, None means it was empty
*/
def tryRead: F[Option[A]]

/**
* Modify the context `F` using transformation `f`.
* Applies the effectful function `f` on the contents of this `MVar`. In case of failure, it sets the contents of the
* `MVar` to the original value.
*
* @note This operation is only safe from deadlocks if there are no other producers for this `MVar`.
*
* @param f effectful function that operates on the contents of this `MVar`
* @return the value produced by applying `f` to the contents of this `MVar`
*/
def use[B](f: A => F[B]): F[B]

/**
* Modifies the contents of the `MVar` using the effectful function `f`, but also allows for returning a value derived
* from the original contents of the `MVar`. Like [[use]], in case of failure, it sets the contents of the `MVar` to
* the original value.
*
* @note This operation is only safe from deadlocks if there are no other producers for this `MVar`.
*
* @param f effectful function that operates on the contents of this `MVar`
* @return the second value produced by applying `f` to the contents of this `MVar`
*/
def modify[B](f: A => F[(A, B)]): F[B]

/**
* Modifies the contents of the `MVar` using the effectful function `f`. Like [[use]], in case of failure, it sets the
* contents of the `MVar` to the original value.
*
* @note This operation is only safe from deadlocks if there are no other producers for this `MVar`.
*
* @param f effectful function that operates on the contents of this `MVar`
* @return no useful value. Executed only for the effects.
*/
def modify_(f: A => F[A]): F[Unit]

/**
* Modify the context `F` using natural isomorphism `f` with `g`.
*
* @param f functor transformation from `F` to `G`
* @param g functor transformation from `G` to `F`
* @return `MVar2` with a modified context `G` derived using a natural isomorphism from `F`
*/
override def mapK[G[_]](f: F ~> G): MVar2[G, A] =
new TransformedMVar2(this, f)
def imapK[G[_]](f: F ~> G, g: G ~> F): MVar2[G, A] =
new TransformedMVar2(this, f, g)
}

/** Builders for [[MVar]]. */
Expand Down Expand Up @@ -304,7 +343,9 @@ object MVar {
override def read: G[A] = trans(underlying.read)
}

final private[concurrent] class TransformedMVar2[F[_], G[_], A](underlying: MVar2[F, A], trans: F ~> G)
final private[concurrent] class TransformedMVar2[F[_], G[_], A](underlying: MVar2[F, A],
trans: F ~> G,
inverse: G ~> F)
extends MVar2[G, A] {
override def isEmpty: G[Boolean] = trans(underlying.isEmpty)
override def put(a: A): G[Unit] = trans(underlying.put(a))
Expand All @@ -314,5 +355,8 @@ object MVar {
override def read: G[A] = trans(underlying.read)
override def tryRead: G[Option[A]] = trans(underlying.tryRead)
override def swap(newValue: A): G[A] = trans(underlying.swap(newValue))
override def use[B](f: A => G[B]): G[B] = trans(underlying.use(a => inverse(f(a))))
override def modify[B](f: A => G[(A, B)]): G[B] = trans(underlying.modify(a => inverse(f(a))))
override def modify_(f: A => G[A]): G[Unit] = trans(underlying.modify_(a => inverse(f(a))))
}
}
14 changes: 14 additions & 0 deletions core/shared/src/main/scala/cats/effect/internals/MVarAsync.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,20 @@ final private[effect] class MVarAsync[F[_], A] private (initial: MVarAsync.State
}
}

def use[B](f: A => F[B]): F[B] =
modify(a => F.map(f(a))((a, _)))

def modify[B](f: A => F[(A, B)]): F[B] =
F.flatMap(take) { a =>
F.flatMap(F.onError(f(a)) { case _ => put(a) }) {
case (newA, b) =>
F.as(put(newA), b)
}
}

def modify_(f: A => F[A]): F[Unit] =
modify(a => F.map(f(a))((_, ())))

@tailrec
private def unsafeTryPut(a: A): F[Boolean] =
stateRef.get match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package internals

import java.util.concurrent.atomic.AtomicReference

import cats.effect.concurrent.MVar2
import cats.effect.concurrent.{MVar2, Ref}
import cats.effect.internals.Callback.rightUnit

import scala.annotation.tailrec
Expand Down Expand Up @@ -77,10 +77,35 @@ final private[effect] class MVarConcurrent[F[_], A] private (initial: MVarConcur
}

def swap(newValue: A): F[A] =
F.flatMap(take) { oldValue =>
F.map(put(newValue))(_ => oldValue)
F.continual(take) {
case Left(t) => F.raiseError(t)
case Right(oldValue) => F.as(put(newValue), oldValue)
}

def use[B](f: A => F[B]): F[B] =
modify(a => F.map(f(a))((a, _)))

def modify[B](f: A => F[(A, B)]): F[B] =
F.bracket(Ref[F].of[Option[A]](None)) { signal =>
F.flatMap(F.continual[A, A](take) {
case Left(t) => F.raiseError(t)
case Right(a) => F.as(signal.set(Some(a)), a)
}) { a =>
F.continual[(A, B), B](f(a)) {
case Left(t) => F.raiseError(t)
case Right((newA, b)) => F.as(signal.set(Some(newA)), b)
}
}
} { signal =>
F.flatMap(signal.get) {
case Some(a) => put(a)
case None => F.unit
}
}

def modify_(f: A => F[A]): F[Unit] =
modify(a => F.map(f(a))((_, ())))

@tailrec private def unsafeTryPut(a: A): F[Boolean] =
stateRef.get match {
case WaitForTake(_, _) => F.pure(false)
Expand Down
81 changes: 81 additions & 0 deletions core/shared/src/test/scala/cats/effect/concurrent/MVarTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,55 @@ class MVarConcurrentTests extends BaseMVarTests {
r shouldBe Right(0)
}
}

test("swap is cancelable on take") {
val task = for {
mVar <- empty[Int]
finished <- Deferred.uncancelable[IO, Int]
fiber <- mVar.swap(20).flatMap(finished.complete).start
_ <- fiber.cancel
_ <- mVar.put(10)
fallback = IO.sleep(100.millis) *> mVar.take
v <- IO.race(finished.get, fallback)
} yield v

for (r <- task.unsafeToFuture()) yield {
r shouldBe Right(10)
}
}

test("modify is cancelable on take") {
val task = for {
mVar <- empty[Int]
finished <- Deferred.uncancelable[IO, String]
fiber <- mVar.modify(n => IO.pure((n * 2, n.show))).flatMap(finished.complete).start
_ <- fiber.cancel
_ <- mVar.put(10)
fallback = IO.sleep(100.millis) *> mVar.take
v <- IO.race(finished.get, fallback)
} yield v

for (r <- task.unsafeToFuture()) yield {
r shouldBe Right(10)
}
}

test("modify is cancelable on f") {
val task = for {
mVar <- empty[Int]
finished <- Deferred.uncancelable[IO, String]
fiber <- mVar.modify(n => IO.never *> IO.pure((n * 2, n.show))).flatMap(finished.complete).start
_ <- mVar.put(10)
_ <- IO.sleep(10.millis)
_ <- fiber.cancel
fallback = IO.sleep(100.millis) *> mVar.take
v <- IO.race(finished.get, fallback)
} yield v

for (r <- task.unsafeToFuture()) yield {
r shouldBe Right(10)
}
}
}

class MVarAsyncTests extends BaseMVarTests {
Expand Down Expand Up @@ -424,4 +473,36 @@ abstract class BaseMVarTests extends AsyncFunSuite with Matchers {
r shouldBe count * 2
}
}

test("put; take; modify; put") {
val task = for {
mVar <- empty[Int]
_ <- mVar.put(10)
_ <- mVar.take
fiber <- mVar.modify(n => IO.pure((n * 2, n.toString))).start
_ <- mVar.put(20)
s <- fiber.join
v <- mVar.take
} yield (s, v)

for (r <- task.unsafeToFuture()) yield {
r shouldBe ("20" -> 40)
}
}

test("modify replaces the original value of the mvar on error") {
val error = new Exception("Boom!")
val task = for {
mVar <- empty[Int]
_ <- mVar.put(10)
finished <- Deferred.uncancelable[IO, String]
e <- mVar.modify(_ => IO.raiseError(error)).attempt
fallback = IO.sleep(100.millis) *> mVar.take
v <- IO.race(finished.get, fallback)
} yield (e, v)

for (r <- task.unsafeToFuture()) yield {
r shouldBe (Left(error) -> Right(10))
}
}
}