Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve traverseViaChain API #3535

Merged
merged 5 commits into from
Aug 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions alleycats-core/src/main/scala/alleycats/std/map.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package std

import cats._
import cats.data.Chain
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq

object map extends MapInstances

Expand All @@ -15,7 +16,11 @@ trait MapInstances {
def traverse[G[_], A, B](fa: Map[K, A])(f: A => G[B])(implicit G: Applicative[G]): G[Map[K, B]] =
if (fa.isEmpty) G.pure(Map.empty[K, B])
else
G.map(Chain.traverseViaChain(fa.iterator) {
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) => G.map(f(a))((k, _))
}) { chain => chain.foldLeft(Map.empty[K, B]) { case (m, (k, b)) => m.updated(k, b) } }

Expand Down Expand Up @@ -62,7 +67,11 @@ trait MapInstances {
def traverseFilter[G[_], A, B](fa: Map[K, A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Map[K, B]] =
if (fa.isEmpty) G.pure(Map.empty[K, B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator) {
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) =>
G.map(f(a)) { optB =>
if (optB.isDefined) Some((k, optB.get))
Expand Down
42 changes: 21 additions & 21 deletions core/src/main/scala/cats/data/Chain.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import Chain._
import cats.kernel.instances.StaticMethods

import scala.annotation.tailrec
import scala.collection.immutable.{SortedMap, TreeSet}
import scala.collection.immutable.{SortedMap, TreeSet, IndexedSeq => ImIndexedSeq}
import scala.collection.mutable.ListBuffer

/**
Expand Down Expand Up @@ -382,14 +382,6 @@ sealed abstract class Chain[+A] {
go(this, Chain.nil)
}

/**
* Applies the supplied function to each element, left to right.
*/
final private def foreach(f: A => Unit): Unit =
foreachUntil { a =>
f(a); false
}

/**
* Applies the supplied function to each element, left to right, but stops when true is returned
*/
Expand Down Expand Up @@ -508,11 +500,11 @@ sealed abstract class Chain[+A] {
val builder = new StringBuilder("Chain(")
var first = true

foreach { a =>
foreachUntil { a =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just miscellaneous clean-up, right, since foreach isn't used anywhere else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. foreach is used exactly once and is private. Just use foreachUntil (or iterator).

if (first) {
builder ++= AA.show(a); first = false
} else builder ++= ", " + AA.show(a)
()
false
}
builder += ')'
builder.result
Expand Down Expand Up @@ -629,13 +621,13 @@ object Chain extends ChainInstances {
def apply[A](as: A*): Chain[A] =
fromSeq(as)

def traverseViaChain[G[_], A, B](iter: Iterator[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (!iter.hasNext) G.pure(Chain.nil)
def traverseViaChain[G[_], A, B](
as: ImIndexedSeq[A]
)(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (as.isEmpty) G.pure(Chain.nil)
else {
// we branch out by this factor
val width = 128
val as = collection.mutable.Buffer[A]()
as ++= iter
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
Expand Down Expand Up @@ -676,14 +668,12 @@ object Chain extends ChainInstances {
}

def traverseFilterViaChain[G[_], A, B](
iter: Iterator[A]
as: ImIndexedSeq[A]
)(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (!iter.hasNext) G.pure(Chain.nil)
if (as.isEmpty) G.pure(Chain.nil)
else {
// we branch out by this factor
val width = 128
val as = collection.mutable.Buffer[A]()
as ++= iter
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
Expand Down Expand Up @@ -862,7 +852,12 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {

def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else traverseViaChain(fa.iterator)(f)
else
traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
StaticMethods.wrapMutableIndexedSeq(as)
}(f)

def empty[A]: Chain[A] = Chain.nil
def combineK[A](c: Chain[A], c2: Chain[A]): Chain[A] = Chain.concat(c, c2)
Expand Down Expand Up @@ -963,7 +958,12 @@ sealed abstract private[data] class ChainInstances extends ChainInstances1 {

def traverseFilter[G[_], A, B](fa: Chain[A])(f: A => G[Option[B]])(implicit G: Applicative[G]): G[Chain[B]] =
if (fa.isEmpty) G.pure(Chain.nil)
else traverseFilterViaChain(fa.iterator)(f)
else
traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa.iterator
StaticMethods.wrapMutableIndexedSeq(as)
}(f)

override def filterA[G[_], A](fa: Chain[A])(f: A => G[Boolean])(implicit G: Applicative[G]): G[Chain[A]] =
traverse
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/instances/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package instances

import cats.data.{Chain, ZipList}
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq
import cats.syntax.show._

import scala.annotation.tailrec
Expand Down Expand Up @@ -87,7 +88,12 @@ trait ListInstances extends cats.kernel.instances.ListInstances {

def traverse[G[_], A, B](fa: List[A])(f: A => G[B])(implicit G: Applicative[G]): G[List[B]] =
if (fa.isEmpty) G.pure(Nil)
else G.map(Chain.traverseViaChain(fa.iterator)(f))(_.toList)
else
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f))(_.toList)

def functor: Functor[List] = this

Expand Down Expand Up @@ -212,7 +218,12 @@ private[instances] trait ListInstancesBinCompat0 {

def traverseFilter[G[_], A, B](fa: List[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[List[B]] =
if (fa.isEmpty) G.pure(Nil)
else G.map(Chain.traverseFilterViaChain(fa.iterator)(f))(_.toList)
else
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f))(_.toList)

override def filterA[G[_], A](fa: List[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[List[A]] =
traverse
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/cats/instances/queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cats
package instances

import cats.data.Chain
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq
import cats.syntax.show._
import scala.annotation.tailrec
import scala.collection.immutable.Queue
Expand Down Expand Up @@ -82,7 +83,11 @@ trait QueueInstances extends cats.kernel.instances.QueueInstances {
def traverse[G[_], A, B](fa: Queue[A])(f: A => G[B])(implicit G: Applicative[G]): G[Queue[B]] =
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G.map(Chain.traverseViaChain(fa.iterator)(f)) { chain =>
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f)) { chain =>
chain.foldLeft(Queue.empty[B])(_ :+ _)
}

Expand Down Expand Up @@ -177,7 +182,11 @@ private object QueueInstances {
def traverseFilter[G[_], A, B](fa: Queue[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Queue[B]] =
if (fa.isEmpty) G.pure(Queue.empty[B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator)(f)) { chain =>
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[A]()
as ++= fa
wrapMutableIndexedSeq(as)
}(f)) { chain =>
chain.foldLeft(Queue.empty[B])(_ :+ _)
}

Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/cats/instances/sortedMap.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package cats.instances
import cats._
import cats.data.{Chain, Ior}
import cats.kernel.{CommutativeMonoid, CommutativeSemigroup}
import cats.kernel.instances.StaticMethods.wrapMutableIndexedSeq

import scala.annotation.tailrec
import scala.collection.immutable.SortedMap
Expand Down Expand Up @@ -35,7 +36,11 @@ trait SortedMapInstances extends SortedMapInstances2 {
implicit val ordering: Ordering[K] = fa.ordering
if (fa.isEmpty) G.pure(SortedMap.empty[K, B])
else
G.map(Chain.traverseViaChain(fa.iterator) {
G.map(Chain.traverseViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) => G.map(f(a))((k, _))
}) { chain => chain.foldLeft(SortedMap.empty[K, B]) { case (m, (k, b)) => m.updated(k, b) } }
}
Expand Down Expand Up @@ -194,7 +199,11 @@ private[instances] trait SortedMapInstancesBinCompat0 {
implicit val ordering: Ordering[K] = fa.ordering
if (fa.isEmpty) G.pure(SortedMap.empty[K, B])
else
G.map(Chain.traverseFilterViaChain(fa.iterator) {
G.map(Chain.traverseFilterViaChain {
val as = collection.mutable.ArrayBuffer[(K, A)]()
as ++= fa
wrapMutableIndexedSeq(as)
} {
case (k, a) =>
G.map(f(a)) { optB =>
if (optB.isDefined) Some((k, optB.get))
Expand Down
97 changes: 2 additions & 95 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,52 +93,7 @@ trait VectorInstances extends cats.kernel.instances.VectorInstances {
}

final override def traverse[G[_], A, B](fa: Vector[A])(f: A => G[B])(implicit G: Applicative[G]): G[Vector[B]] =
if (fa.isEmpty) G.pure(empty)
else {
// this is a specialized version of Chain.traverseViaChain since
// we don't need to materialize the Vector first

// we branch out by this factor
val width = 128
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
def loop(start: Int, end: Int): Eval[G[Chain[B]]] =
if (end - start <= width) {
// Here we are at the leafs of the trees
// we don't use map2Eval since it is always
// at most width in size.
val aend = fa(end - 1)
var flist = Eval.later(G.map(f(aend))(_ :: Nil))
var idx = end - 2
while (start <= idx) {
val a = fa(idx)
// don't capture a var in the defer
val right = flist
flist = Eval.defer(G.map2Eval(f(a), right)(_ :: _))
idx = idx - 1
}
flist.map { glist => G.map(glist)(Chain.fromSeq(_)) }
} else {
// we have width + 1 or more nodes left
val step = (end - start) / width

var fchain = Eval.defer(loop(start, start + step))
var start0 = start + step
var end0 = start0 + step

while (start0 < end) {
val end1 = math.min(end, end0)
val right = loop(start0, end1)
fchain = fchain.flatMap(G.map2Eval(_, right)(_.concat(_)))
start0 = start0 + step
end0 = end0 + step
}
fchain
}

G.map(loop(0, fa.size).value)(_.toVector)
}
G.map(Chain.traverseViaChain(fa)(f))(_.toVector)

override def mapWithIndex[A, B](fa: Vector[A])(f: (A, Int) => B): Vector[B] =
fa.iterator.zipWithIndex.map(ai => f(ai._1, ai._2)).toVector
Expand Down Expand Up @@ -225,55 +180,7 @@ private[instances] trait VectorInstancesBinCompat0 {
override def flattenOption[A](fa: Vector[Option[A]]): Vector[A] = fa.flatten

def traverseFilter[G[_], A, B](fa: Vector[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Vector[B]] =
if (fa.isEmpty) G.pure(Vector.empty[B])
else {
// we branch out by this factor
val width = 128
// By making a tree here we don't blow the stack
// even if the List is very long
// by construction, this is never called with start == end
def loop(start: Int, end: Int): Eval[G[Chain[B]]] =
if (end - start <= width) {
// Here we are at the leafs of the trees
// we don't use map2Eval since it is always
// at most width in size.
val aend = fa(end - 1)
var flist = Eval.later(G.map(f(aend)) { optB =>
if (optB.isDefined) optB.get :: Nil
else Nil
})
var idx = end - 2
while (start <= idx) {
val a = fa(idx)
// don't capture a var in the defer
val right = flist
flist = Eval.defer(G.map2Eval(f(a), right) { (optB, tail) =>
if (optB.isDefined) optB.get :: tail
else tail
})
idx = idx - 1
}
flist.map { glist => G.map(glist)(Chain.fromSeq(_)) }
} else {
// we have width + 1 or more nodes left
val step = (end - start) / width

var fchain = Eval.defer(loop(start, start + step))
var start0 = start + step
var end0 = start0 + step

while (start0 < end) {
val end1 = math.min(end, end0)
val right = loop(start0, end1)
fchain = fchain.flatMap(G.map2Eval(_, right)(_.concat(_)))
start0 = start0 + step
end0 = end0 + step
}
fchain
}

G.map(loop(0, fa.size).value)(_.toVector)
}
G.map(Chain.traverseFilterViaChain(fa)(f))(_.toVector)

override def filterA[G[_], A](fa: Vector[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Vector[A]] =
traverse
Expand Down
18 changes: 18 additions & 0 deletions kernel/src/main/scala/cats/kernel/instances/StaticMethods.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package cats
package kernel
package instances

import scala.collection.immutable.{IndexedSeq => ImIndexedSeq}
import scala.collection.mutable
import compat.scalaVersionSpecific._

@suppressUnusedImportWarningForScalaVersionSpecific
object StaticMethods extends cats.kernel.compat.HashCompat {

Expand All @@ -17,6 +19,22 @@ object StaticMethods extends cats.kernel.compat.HashCompat {
def iterator: Iterator[(K, V)] = m.iterator
}

/**
* When you "own" this m, and will not mutate it again, this
* is safe to call. It is unsafe to call this, then mutate
* the original collection.
*
* You are giving up ownership when calling this method
*/
def wrapMutableIndexedSeq[A](m: mutable.IndexedSeq[A]): ImIndexedSeq[A] =
new WrappedIndexedSeq(m)

private[kernel] class WrappedIndexedSeq[A](m: mutable.IndexedSeq[A]) extends ImIndexedSeq[A] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can never remember exactly how this works inside objects, but is there any reason to make this package-private (which is just public in the JVM API) instead of just private (which I think should be properly JVM-private)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I know of, I just copied the idiom below. Should I fix that up?

override def length: Int = m.length
override def apply(i: Int): A = m(i)
override def iterator: Iterator[A] = m.iterator
}

// scalastyle:off return
def iteratorCompare[A](xs: Iterator[A], ys: Iterator[A])(implicit ev: Order[A]): Int = {
while (true) {
Expand Down