From c8776b7b33380f9cb93c2b23a77437d9687cc100 Mon Sep 17 00:00:00 2001 From: NthPortal Date: Sun, 12 Jan 2020 04:36:10 -0500 Subject: [PATCH] Add tapEach --- .../scala/collection/compat/PackageShared.scala | 3 +++ .../test/scala/collection/CollectionTest.scala | 15 +++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala index d319e912..09d8fcbc 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala @@ -233,6 +233,7 @@ class IteratorExtensionMethods[A](private val self: c.Iterator[A]) extends AnyVa self.sameElements(that.iterator) } def concat[B >: A](that: c.TraversableOnce[B]): c.TraversableOnce[B] = self ++ that + def tapEach[U](f: A => U): c.Iterator[A] = self.map(a => { f(a); a }) } class TraversableOnceExtensionMethods[A](private val self: c.TraversableOnce[A]) extends AnyVal { @@ -245,6 +246,8 @@ class TraversableExtensionMethods[A](private val self: c.Traversable[A]) extends class TraversableLikeExtensionMethods[A, Repr](private val self: c.GenTraversableLike[A, Repr]) extends AnyVal { + def tapEach[U](f: A => U)(implicit bf: CanBuildFrom[Repr, A, Repr]): Repr = + self.map(a => { f(a); a }) def groupMap[K, B, That](key: A => K)(f: A => B)( implicit bf: CanBuildFrom[Repr, B, That]): Map[K, That] = { diff --git a/compat/src/test/scala/test/scala/collection/CollectionTest.scala b/compat/src/test/scala/test/scala/collection/CollectionTest.scala index 755366af..c7ed505e 100644 --- a/compat/src/test/scala/test/scala/collection/CollectionTest.scala +++ b/compat/src/test/scala/test/scala/collection/CollectionTest.scala @@ -95,4 +95,19 @@ class CollectionTest { .groupMapReduce(_.length)(_ => 1)(_ + _) assertEquals(Map(3 -> 3, 4 -> 1), res) } + + @Test + def tapEach(): Unit = { + var count = 0 + val it = Iterator(1, 2, 3).tapEach(count += _) + assertEquals(0, count) + it.foreach(_ => ()) + assertEquals(6, count) + List(1, 2, 3).tapEach(count += _) + assertEquals(12, count) + val stream = Stream(1, 2, 3).tapEach(count += _) + assertEquals(13, count) + stream.force + assertEquals(18, count) + } }