Skip to content

Commit

Permalink
Reduce stack depth in StateT (#1466)
Browse files Browse the repository at this point in the history
* Reduce stack depth in StateT

* only use FlatMap for product

* Lower required typeclasses

* Add more methods
  • Loading branch information
johnynek authored Dec 31, 2016
1 parent f8c766f commit a6efd9d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 30 deletions.
100 changes: 73 additions & 27 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,26 +11,46 @@ import cats.syntax.either._
*/
final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable {

def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
def flatMap[B](fas: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
fas(a).run(s)
}
})
}
})

def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
F.map(faf(a))((s, _))
}
def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
)
})

def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] =
def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform { case (s, a) => (s, f(a)) }

def map2[B, Z](sb: StateT[F, S, B])(fn: (A, B) => Z)(implicit F: FlatMap[F]): StateT[F, S, Z] =
StateT.applyF(F.map2(runF, sb.runF) { (ssa, ssb) =>
ssa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) }
}
}
})

def map2Eval[B, Z](sb: Eval[StateT[F, S, B]])(fn: (A, B) => Z)(implicit F: FlatMap[F]): Eval[StateT[F, S, Z]] =
F.map2Eval(runF, sb.map(_.runF)) { (ssa, ssb) =>
ssa.andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
F.map(ssb(s)) { case (s, b) => (s, fn(a, b)) }
}
}
}.map(StateT.applyF)

def product[B](sb: StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, (A, B)] =
map2(sb)((_, _))

/**
* Run with the provided initial state value
*/
Expand Down Expand Up @@ -69,10 +89,13 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
/**
* Like [[map]], but also allows the state (`S`) value to be modified.
*/
def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] =
transformF { fsa =>
F.map(fsa){ case (s, a) => f(s, a) }
}
def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] =
StateT.applyF(
F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
F.map(fsa) { case (s, a) => f(s, a) }
}
})

/**
* Like [[transform]], but allows the context to change from `F` to `G`.
Expand All @@ -98,31 +121,31 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
* res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0))
* }}}
*/
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] =
StateT { r =>
F.flatMap(runF) { ff =>
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] =
StateT.applyF(F.map(runF) { sfsa =>
{ r: R =>
val s = f(r)
val nextState = ff(s)
F.map(nextState) { case (s, a) => (g(r, s), a) }
val fsa = sfsa(s)
F.map(fsa) { case (s, a) => (g(r, s), a) }
}
}
})

/**
* Modify the state (`S`) component.
*/
def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] =
def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] =
transform((s, a) => (f(s), a))

/**
* Inspect a value from the input state, without modifying the state.
*/
def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] =
def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform((s, _) => (s, f(s)))

/**
* Get the input state, without modifying the state.
*/
def get(implicit F: Monad[F]): StateT[F, S, S] =
def get(implicit F: Functor[F]): StateT[F, S, S] =
inspect(identity)
}

Expand Down Expand Up @@ -182,11 +205,16 @@ private[data] sealed trait StateTInstances2 extends StateTInstances3 {
new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 }
}

private[data] sealed trait StateTInstances3 {
private[data] sealed trait StateTInstances3 extends StateTInstances4 {
implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] =
new StateTMonad[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances4 {
implicit def catsDataFunctorForStateT[F[_], S](implicit F0: Functor[F]): Functor[StateT[F, S, ?]] =
new StateTFunctor[F, S] { implicit def F = F0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {
Expand Down Expand Up @@ -220,6 +248,12 @@ private[data] abstract class StateFunctions {
def set[S](s: S): State[S, Unit] = State(_ => (s, ()))
}

private[data] sealed trait StateTFunctor[F[_], S] extends Functor[StateT[F, S, ?]] {
implicit def F: Functor[F]

def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)
}

private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
implicit def F: Monad[F]

Expand All @@ -229,8 +263,20 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

override def ap[A, B](ff: StateT[F, S, A => B])(fa: StateT[F, S, A]): StateT[F, S, B] =
ff.map2(fa) { case (f, a) => f(a) }

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)

override def map2[A, B, Z](fa: StateT[F, S, A], fb: StateT[F, S, B])(fn: (A, B) => Z): StateT[F, S, Z] =
fa.map2(fb)(fn)

override def map2Eval[A, B, Z](fa: StateT[F, S, A], fb: Eval[StateT[F, S, B]])(fn: (A, B) => Z): Eval[StateT[F, S, Z]] =
fa.map2Eval(fb)(fn)

override def product[A, B](fa: StateT[F, S, A], fb: StateT[F, S, B]): StateT[F, S, (A, B)] =
fa.product(fb)

def tailRecM[A, B](a: A)(f: A => StateT[F, S, Either[A, B]]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
Expand Down
24 changes: 21 additions & 3 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class StateTTests extends CatsSuite {
}

test("State.get and StateT.get are consistent") {
forAll{ (s: String) =>
forAll{ (s: String) =>
val state: State[String, String] = State.get
val stateT: State[String, String] = StateT.get
state.run(s) should === (stateT.run(s))
Expand Down Expand Up @@ -195,7 +195,25 @@ class StateTTests extends CatsSuite {
}


implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataMonadForStateT(ListWrapper.monad))
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataFunctorForStateT(ListWrapper.monad))

{
// F has a Functor
implicit val F: Functor[ListWrapper] = ListWrapper.monad
// We only need a Functor on F to find a Functor on StateT
Functor[StateT[ListWrapper, Int, ?]]
}

{
// F needs a Monad to do Eq on StateT
implicit val F: Monad[ListWrapper] = ListWrapper.monad
implicit val FS: Functor[StateT[ListWrapper, Int, ?]] = StateT.catsDataFunctorForStateT

checkAll("StateT[ListWrapper, Int, Int]", FunctorTests[StateT[ListWrapper, Int, ?]].functor[Int, Int, Int])
checkAll("Functor[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Functor[StateT[ListWrapper, Int, ?]]))

Functor[StateT[ListWrapper, Int, ?]]
}

{
// F has a Monad
Expand Down Expand Up @@ -265,7 +283,7 @@ class StateTTests extends CatsSuite {
// F has a MonadError
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]]
implicit val eqEitherTFA: Eq[EitherT[StateT[Option, Int , ?], Unit, Int]] = EitherT.catsDataEqForEitherT[StateT[Option, Int , ?], Unit, Int]

checkAll("StateT[Option, Int, Int]", MonadErrorTests[StateT[Option, Int , ?], Unit].monadError[Int, Int, Int])
checkAll("MonadError[StateT[Option, Int , ?], Unit]", SerializableTests.serializable(MonadError[StateT[Option, Int , ?], Unit]))
}
Expand Down

0 comments on commit a6efd9d

Please sign in to comment.