From 7a39103d83a4d89c19cbc69094f49fa0ae1435f2 Mon Sep 17 00:00:00 2001 From: NthPortal Date: Fri, 10 Jan 2020 05:03:49 -0500 Subject: [PATCH] Add groupMap and groupMapReduce extensions --- .../scala/collection/compat/package.scala | 9 +++++- .../collection/compat/PackageShared.scala | 30 +++++++++++++++++++ .../scala/collection/compat/package.scala | 6 ++++ .../scala/collection/CollectionTest.scala | 14 +++++++++ 4 files changed, 58 insertions(+), 1 deletion(-) diff --git a/compat/src/main/scala-2.11/scala/collection/compat/package.scala b/compat/src/main/scala-2.11/scala/collection/compat/package.scala index c60b579f..5f7309a1 100644 --- a/compat/src/main/scala-2.11/scala/collection/compat/package.scala +++ b/compat/src/main/scala-2.11/scala/collection/compat/package.scala @@ -12,4 +12,11 @@ package scala.collection -package object compat extends compat.PackageShared +import scala.collection.generic.IsTraversableLike + +package object compat extends compat.PackageShared { + implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)( + implicit traversable: IsTraversableLike[Repr]) + : TraversableLikeExtensionMethods[traversable.A, Repr] = + new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self)) +} diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala index 76382afa..d319e912 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala @@ -243,6 +243,36 @@ class TraversableExtensionMethods[A](private val self: c.Traversable[A]) extends def iterableFactory: GenericCompanion[Traversable] = self.companion } +class TraversableLikeExtensionMethods[A, Repr](private val self: c.GenTraversableLike[A, Repr]) + extends AnyVal { + + def groupMap[K, B, That](key: A => K)(f: A => B)( + implicit bf: CanBuildFrom[Repr, B, That]): Map[K, That] = { + val map = m.Map.empty[K, m.Builder[B, That]] + for (elem <- self) { + val k = key(elem) + val bldr = map.getOrElseUpdate(k, bf(self.repr)) + bldr += f(elem) + } + val res = Map.newBuilder[K, That] + for ((k, bldr) <- map) res += ((k, bldr.result())) + res.result() + } + + def groupMapReduce[K, B](key: A => K)(f: A => B)(reduce: (B, B) => B): Map[K, B] = { + val map = m.Map.empty[K, B] + for (elem <- self) { + val k = key(elem) + val v = map.get(k) match { + case Some(b) => reduce(b, f(elem)) + case None => f(elem) + } + map.put(k, v) + } + map.toMap + } +} + class MapViewExtensionMethods[K, V, C <: scala.collection.Map[K, V]]( private val self: IterableView[(K, V), C]) extends AnyVal { diff --git a/compat/src/main/scala-2.12/scala/collection/compat/package.scala b/compat/src/main/scala-2.12/scala/collection/compat/package.scala index 987e35b4..ec34c117 100644 --- a/compat/src/main/scala-2.12/scala/collection/compat/package.scala +++ b/compat/src/main/scala-2.12/scala/collection/compat/package.scala @@ -12,6 +12,7 @@ package scala.collection +import scala.collection.generic.IsTraversableLike import scala.collection.{mutable => m} package object compat extends compat.PackageShared { @@ -24,4 +25,9 @@ package object compat extends compat.PackageShared { def from[K: Ordering, V](source: TraversableOnce[(K, V)]): m.SortedMap[K, V] = build(m.SortedMap.newBuilder[K, V], source) } + + implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)( + implicit traversable: IsTraversableLike[Repr]) + : TraversableLikeExtensionMethods[traversable.A, Repr] = + new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self)) } diff --git a/compat/src/test/scala/test/scala/collection/CollectionTest.scala b/compat/src/test/scala/test/scala/collection/CollectionTest.scala index 3401d5ed..755366af 100644 --- a/compat/src/test/scala/test/scala/collection/CollectionTest.scala +++ b/compat/src/test/scala/test/scala/collection/CollectionTest.scala @@ -81,4 +81,18 @@ class CollectionTest { assertFalse(it1.iterator.sameElements(it2)) assertTrue(it2.iterator.sameElements(it3)) } + + @Test + def groupMap(): Unit = { + val res = Seq("foo", "test", "bar", "baz") + .groupMap(_.length)(_.toUpperCase()) + assertEquals(Map(3 -> Seq("FOO", "BAR", "BAZ"), 4 -> Seq("TEST")), res) + } + + @Test + def groupMapReduce(): Unit = { + val res = Seq("foo", "test", "bar", "baz") + .groupMapReduce(_.length)(_ => 1)(_ + _) + assertEquals(Map(3 -> 3, 4 -> 1), res) + } }