From c067ec2731be14ba8255bd827b83e081e9e01dce Mon Sep 17 00:00:00 2001 From: NthPortal Date: Sat, 11 Jan 2020 04:51:00 -0500 Subject: [PATCH] Add sizeCompare, sizeIs and lengthIs --- .../scala/collection/compat/package.scala | 8 ++- .../collection/compat/PackageShared.scala | 70 +++++++++++++++++++ .../scala/collection/compat/package.scala | 8 ++- .../scala/collection/CollectionTest.scala | 50 +++++++++++++ 4 files changed, 132 insertions(+), 4 deletions(-) diff --git a/compat/src/main/scala-2.11/scala/collection/compat/package.scala b/compat/src/main/scala-2.11/scala/collection/compat/package.scala index 5f7309a1..39dda6c4 100644 --- a/compat/src/main/scala-2.11/scala/collection/compat/package.scala +++ b/compat/src/main/scala-2.11/scala/collection/compat/package.scala @@ -13,10 +13,14 @@ package scala.collection import scala.collection.generic.IsTraversableLike +import scala.{collection => c} package object compat extends compat.PackageShared { implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)( - implicit traversable: IsTraversableLike[Repr]) - : TraversableLikeExtensionMethods[traversable.A, Repr] = + implicit traversable: IsTraversableLike[Repr]) + : TraversableLikeExtensionMethods[traversable.A, Repr] = new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self)) + + implicit def toSeqExtensionMethods[A](self: c.Seq[A]): SeqExtensionMethods[A] = + new SeqExtensionMethods[A](self) } diff --git a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala index 09d8fcbc..785fbcf4 100644 --- a/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala +++ b/compat/src/main/scala-2.11_2.12/scala/collection/compat/PackageShared.scala @@ -242,6 +242,76 @@ class TraversableOnceExtensionMethods[A](private val self: c.TraversableOnce[A]) class TraversableExtensionMethods[A](private val self: c.Traversable[A]) extends AnyVal { def iterableFactory: GenericCompanion[Traversable] = self.companion + + def sizeCompare(otherSize: Int): Int = SizeCompareImpl.sizeCompareInt(self)(otherSize) + def sizeIs: SizeCompareOps = new SizeCompareOps(self) + def sizeCompare(that: c.Traversable[_]): Int = SizeCompareImpl.sizeCompareColl(self)(that) +} + +class SeqExtensionMethods[A](private val self: c.Seq[A]) extends AnyVal { + def lengthIs: SizeCompareOps = new SizeCompareOps(self) +} + +class SizeCompareOps private[compat] (private val it: c.Traversable[_]) extends AnyVal { + import SizeCompareImpl._ + + /** Tests if the size of the collection is less than some value. */ + @inline def <(size: Int): Boolean = sizeCompareInt(it)(size) < 0 + + /** Tests if the size of the collection is less than or equal to some value. */ + @inline def <=(size: Int): Boolean = sizeCompareInt(it)(size) <= 0 + + /** Tests if the size of the collection is equal to some value. */ + @inline def ==(size: Int): Boolean = sizeCompareInt(it)(size) == 0 + + /** Tests if the size of the collection is not equal to some value. */ + @inline def !=(size: Int): Boolean = sizeCompareInt(it)(size) != 0 + + /** Tests if the size of the collection is greater than or equal to some value. */ + @inline def >=(size: Int): Boolean = sizeCompareInt(it)(size) >= 0 + + /** Tests if the size of the collection is greater than some value. */ + @inline def >(size: Int): Boolean = sizeCompareInt(it)(size) > 0 +} + +private object SizeCompareImpl { + def sizeCompareInt(self: c.Traversable[_])(otherSize: Int): Int = + self match { + case self: c.SeqLike[_, _] => self.lengthCompare(otherSize) + case _ => + if (otherSize < 0) 1 + else { + var i = 0 + val it = self.toIterator + while (it.hasNext) { + if (i == otherSize) return 1 + it.next() + i += 1 + } + i - otherSize + } + } + + // `IndexedSeq` is the only thing that we can safely say has a known size + def sizeCompareColl(self: c.Traversable[_])(that: c.Traversable[_]): Int = + that match { + case that: c.IndexedSeq[_] => sizeCompareInt(self)(that.length) + case _ => + self match { + case self: c.IndexedSeq[_] => + val res = sizeCompareInt(that)(self.length) + // can't just invert the result, because `-Int.MinValue == Int.MinValue` + if (res == Int.MinValue) 1 else -res + case _ => + val thisIt = self.toIterator + val thatIt = that.toIterator + while (thisIt.hasNext && thatIt.hasNext) { + thisIt.next() + thatIt.next() + } + java.lang.Boolean.compare(thisIt.hasNext, thatIt.hasNext) + } + } } class TraversableLikeExtensionMethods[A, Repr](private val self: c.GenTraversableLike[A, Repr]) diff --git a/compat/src/main/scala-2.12/scala/collection/compat/package.scala b/compat/src/main/scala-2.12/scala/collection/compat/package.scala index ec34c117..346f437b 100644 --- a/compat/src/main/scala-2.12/scala/collection/compat/package.scala +++ b/compat/src/main/scala-2.12/scala/collection/compat/package.scala @@ -13,6 +13,7 @@ package scala.collection import scala.collection.generic.IsTraversableLike +import scala.{collection => c} import scala.collection.{mutable => m} package object compat extends compat.PackageShared { @@ -27,7 +28,10 @@ package object compat extends compat.PackageShared { } implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)( - implicit traversable: IsTraversableLike[Repr]) - : TraversableLikeExtensionMethods[traversable.A, Repr] = + implicit traversable: IsTraversableLike[Repr]) + : TraversableLikeExtensionMethods[traversable.A, Repr] = new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self)) + + implicit def toSeqExtensionMethods[A](self: c.Seq[A]): SeqExtensionMethods[A] = + new SeqExtensionMethods[A](self) } diff --git a/compat/src/test/scala/test/scala/collection/CollectionTest.scala b/compat/src/test/scala/test/scala/collection/CollectionTest.scala index c7ed505e..730a7386 100644 --- a/compat/src/test/scala/test/scala/collection/CollectionTest.scala +++ b/compat/src/test/scala/test/scala/collection/CollectionTest.scala @@ -110,4 +110,54 @@ class CollectionTest { stream.force assertEquals(18, count) } + + @Test + def sizeCompare(): Unit = { + assertTrue(Set(1, 2, 3).sizeCompare(4) < 0) + assertTrue(Set(1, 2, 3).sizeCompare(2) > 0) + assertTrue(Set(1, 2, 3).sizeCompare(3) == 0) + + assertTrue(List(1, 2, 3).sizeCompare(4) < 0) + assertTrue(List(1, 2, 3).sizeCompare(2) > 0) + assertTrue(List(1, 2, 3).sizeCompare(3) == 0) + + assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2, 3, 4)) < 0) + assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2)) > 0) + assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2, 3)) == 0) + + assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2, 3, 4)) < 0) + assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2)) > 0) + assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2, 3)) == 0) + + assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2, 3, 4)) < 0) + assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2)) > 0) + assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2, 3)) == 0) + } + + @Test + def sizeIsLengthIs(): Unit = { + assertTrue(Set(1, 2, 3).sizeIs < 4) + assertTrue(Set(1, 2, 3).sizeIs <= 4) + assertTrue(Set(1, 2, 3).sizeIs <= 3) + assertTrue(Set(1, 2, 3).sizeIs == 3) + assertTrue(Set(1, 2, 3).sizeIs >= 3) + assertTrue(Set(1, 2, 3).sizeIs >= 2) + assertTrue(Set(1, 2, 3).sizeIs > 2) + + assertTrue(List(1, 2, 3).sizeIs < 4) + assertTrue(List(1, 2, 3).sizeIs <= 4) + assertTrue(List(1, 2, 3).sizeIs <= 3) + assertTrue(List(1, 2, 3).sizeIs == 3) + assertTrue(List(1, 2, 3).sizeIs >= 3) + assertTrue(List(1, 2, 3).sizeIs >= 2) + assertTrue(List(1, 2, 3).sizeIs > 2) + + assertTrue(List(1, 2, 3).lengthIs < 4) + assertTrue(List(1, 2, 3).lengthIs <= 4) + assertTrue(List(1, 2, 3).lengthIs <= 3) + assertTrue(List(1, 2, 3).lengthIs == 3) + assertTrue(List(1, 2, 3).lengthIs >= 3) + assertTrue(List(1, 2, 3).lengthIs >= 2) + assertTrue(List(1, 2, 3).lengthIs > 2) + } }