Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #1316: Add the MonadDefer type-class #1552

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,12 @@ object Eval extends EvalInstances {

private[cats] trait EvalInstances extends EvalInstances0 {

implicit val catsBimonadForEval: Bimonad[Eval] with MonadError[Eval, Throwable] =
new Bimonad[Eval] with MonadError[Eval, Throwable] {
implicit val catsBimonadForEval: Bimonad[Eval] with MonadError[Eval, Throwable] with MonadDefer[Eval] =
new Bimonad[Eval] with MonadError[Eval, Throwable] with MonadDefer[Eval] {
override def map[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa.map(f)
def pure[A](a: A): Eval[A] = Now(a)
override def delay[A](a: => A): Eval[A] = Always(a)
override def defer[A](fa: => Eval[A]): Eval[A] = Eval.defer(fa)
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
def extract[A](la: Eval[A]): A = la.value
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
Expand Down
28 changes: 28 additions & 0 deletions core/src/main/scala/cats/MonadDefer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package cats

import simulacrum.typeclass

/**
* A [[Monad monad]] that allows for arbitrarily delaying the
* evaluation of an operation, triggering its execution on each run.
*
* Instances of this type-class have the following properties:
*
* - suspend any side-effects for later, until evaluated
* - suspension has `always` semantics, meaning that on each
* evaluation of `F[_]` the evaluation, along with any
* side-effects, get repeated
* - the `flatMap` operation is stack safe and can be
* used in recursive loops
*/
@typeclass trait MonadDefer[F[_]] extends Monad[F] {
/**
* Returns an `F[A]` that evaluates the provided by-name `fa`
* parameter on each run. In essence it builds an `F[A]` factory.
*/
def defer[A](fa: => F[A]): F[A]

/** Lifts the given by-name value in the `F[_]` context. */
def delay[A](a: => A): F[A] =
defer(pure(a))
}
61 changes: 61 additions & 0 deletions laws/src/main/scala/cats/laws/MonadDeferLaws.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package cats
package laws

import cats.syntax.all._
import cats.laws.MonadDeferLaws.StatefulBox

trait MonadDeferLaws[F[_]] extends MonadLaws[F] {
implicit override def F: MonadDefer[F]

def delayEquivalenceWithPure[A](a: A): IsEq[F[A]] =
F.delay(a) <-> F.pure(a)

def delayEquivalenceWithDefer[A, B](a: A, f: A => B): IsEq[F[B]] =
F.delay(f(a)) <-> F.defer(F.pure(f(a)))

def delayRepeatsSideEffects[A, B](a: A, b: B, f: (A, B) => A): IsEq[F[A]] = {
val state = new StatefulBox(a)
val fa = F.delay(state.transform(a => f(a, b)))
fa.flatMap(_ => fa) <-> F.pure(f(f(a, b), b))
}

def deferRepeatsSideEffects[A, B](a: A, b: B, f: (A, B) => A): IsEq[F[A]] = {
val state = new StatefulBox(a)
val fa = F.defer(F.pure(state.transform(a => f(a, b))))
fa.flatMap(_ => fa) <-> F.pure(f(f(a, b), b))
}

lazy val flatMapStackSafety: IsEq[F[Int]] = {
// tailRecM expressed with flatMap
def loop[A, B](a: A)(f: A => F[Either[A, B]]): F[B] =
F.flatMap(f(a)) {
case Right(b) =>
F.pure(b)
case Left(nextA) =>
loop(nextA)(f)
}

val n = 50000
val res = loop(0)(i => F.pure(if (i < n) Either.left(i + 1) else Either.right(i)))
res <-> F.pure(n)
}
}

object MonadDeferLaws {
def apply[F[_]](implicit ev: MonadDefer[F]): MonadDeferLaws[F] =
new MonadDeferLaws[F] { def F: MonadDefer[F] = ev }

/**
* A boxed and synchronized variable to use for
* testing deferred side effects.
*/
final class StatefulBox[A](initial: A) {
private[this] var state = initial

def get: A =
synchronized(state)

def transform(f: A => A): A =
synchronized{ state = f(state); state }
}
}
48 changes: 48 additions & 0 deletions laws/src/main/scala/cats/laws/discipline/MonadDeferTests.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package cats
package laws
package discipline

import catalysts.Platform
import cats.laws.discipline.CartesianTests.Isomorphisms
import org.scalacheck.{Arbitrary, Cogen, Prop}
import org.scalacheck.Prop.forAll

trait MonadDeferTests[F[_]] extends MonadTests[F] {
def laws: MonadDeferLaws[F]

def monadDefer[A: Arbitrary: Eq, B: Arbitrary: Eq, C: Arbitrary: Eq](implicit
ArbFA: Arbitrary[F[A]],
ArbFB: Arbitrary[F[B]],
ArbFC: Arbitrary[F[C]],
ArbFAtoB: Arbitrary[F[A => B]],
ArbFBtoC: Arbitrary[F[B => C]],
CogenA: Cogen[A],
CogenB: Cogen[B],
CogenC: Cogen[C],
EqFA: Eq[F[A]],
EqFB: Eq[F[B]],
EqFC: Eq[F[C]],
EqFABC: Eq[F[(A, B, C)]],
EqFInt: Eq[F[Int]],
iso: Isomorphisms[F]
): RuleSet = {
new RuleSet {
def name: String = "monadDefer"
def bases: Seq[(String, RuleSet)] = Nil
def parents: Seq[RuleSet] = Seq(monad[A, B, C])
def props: Seq[(String, Prop)] = Seq(
"delay equivalence with pure" -> forAll(laws.delayEquivalenceWithPure[A] _),
"delay equivalence with defer" -> forAll(laws.delayEquivalenceWithDefer[A, B] _),
"delay repeats side effects" -> forAll(laws.delayRepeatsSideEffects[A, B] _),
"defer repeats side effects" -> forAll(laws.deferRepeatsSideEffects[A, B] _)
) ++ (if (Platform.isJvm) Seq[(String, Prop)]("flatMap stack safety" -> Prop.lzy(laws.flatMapStackSafety)) else Seq.empty)
}
}
}

object MonadDeferTests {
def apply[F[_]: MonadDefer]: MonadDeferTests[F] =
new MonadDeferTests[F] {
def laws: MonadDeferLaws[F] = MonadDeferLaws[F]
}
}
5 changes: 4 additions & 1 deletion tests/src/test/scala/cats/tests/EvalTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package tests
import scala.math.min
import scala.util.Try
import cats.laws.ComonadLaws
import cats.laws.discipline.{BimonadTests, CartesianTests, MonadErrorTests, ReducibleTests, SerializableTests}
import cats.laws.discipline.{BimonadTests, CartesianTests, MonadErrorTests, ReducibleTests, SerializableTests, MonadDeferTests}
import cats.laws.discipline.arbitrary._
import cats.kernel.laws.{GroupLaws, OrderLaws}

Expand Down Expand Up @@ -125,6 +125,7 @@ class EvalTests extends CatsSuite {
{
implicit val iso = CartesianTests.Isomorphisms.invariant[Eval]
checkAll("Eval[Int]", BimonadTests[Eval].bimonad[Int, Int, Int])
checkAll("Eval[Int]", MonadDeferTests[Eval].monadDefer[Int, Int, Int])

{
// we need exceptions which occur during .value calls to be
Expand All @@ -135,7 +136,9 @@ class EvalTests extends CatsSuite {
checkAll("Eval[Int]", MonadErrorTests[Eval, Throwable].monadError[Int, Int, Int])
}
}

checkAll("Bimonad[Eval]", SerializableTests.serializable(Bimonad[Eval]))
checkAll("MonadDefer[Eval]", SerializableTests.serializable(MonadDefer[Eval]))
checkAll("MonadError[Eval, Throwable]", SerializableTests.serializable(MonadError[Eval, Throwable]))

checkAll("Eval[Int]", ReducibleTests[Eval].reducible[Option, Int, Int])
Expand Down
1 change: 0 additions & 1 deletion tests/src/test/scala/cats/tests/TryTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package tests
import cats.laws.{ApplicativeLaws, CoflatMapLaws, FlatMapLaws, MonadLaws}
import cats.laws.discipline._
import cats.laws.discipline.arbitrary._

import scala.util.{Success, Try}

class TryTests extends CatsSuite {
Expand Down