Skip to content

Commit

Permalink
Add sizeCompare, sizeIs and lengthIs
Browse files Browse the repository at this point in the history
  • Loading branch information
NthPortal committed Jan 18, 2020
1 parent 2badce8 commit c067ec2
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
package scala.collection

import scala.collection.generic.IsTraversableLike
import scala.{collection => c}

package object compat extends compat.PackageShared {
implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)(
implicit traversable: IsTraversableLike[Repr])
: TraversableLikeExtensionMethods[traversable.A, Repr] =
implicit traversable: IsTraversableLike[Repr])
: TraversableLikeExtensionMethods[traversable.A, Repr] =
new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self))

implicit def toSeqExtensionMethods[A](self: c.Seq[A]): SeqExtensionMethods[A] =
new SeqExtensionMethods[A](self)
}
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,76 @@ class TraversableOnceExtensionMethods[A](private val self: c.TraversableOnce[A])

class TraversableExtensionMethods[A](private val self: c.Traversable[A]) extends AnyVal {
def iterableFactory: GenericCompanion[Traversable] = self.companion

def sizeCompare(otherSize: Int): Int = SizeCompareImpl.sizeCompareInt(self)(otherSize)
def sizeIs: SizeCompareOps = new SizeCompareOps(self)
def sizeCompare(that: c.Traversable[_]): Int = SizeCompareImpl.sizeCompareColl(self)(that)
}

class SeqExtensionMethods[A](private val self: c.Seq[A]) extends AnyVal {
def lengthIs: SizeCompareOps = new SizeCompareOps(self)
}

class SizeCompareOps private[compat] (private val it: c.Traversable[_]) extends AnyVal {
import SizeCompareImpl._

/** Tests if the size of the collection is less than some value. */
@inline def <(size: Int): Boolean = sizeCompareInt(it)(size) < 0

/** Tests if the size of the collection is less than or equal to some value. */
@inline def <=(size: Int): Boolean = sizeCompareInt(it)(size) <= 0

/** Tests if the size of the collection is equal to some value. */
@inline def ==(size: Int): Boolean = sizeCompareInt(it)(size) == 0

/** Tests if the size of the collection is not equal to some value. */
@inline def !=(size: Int): Boolean = sizeCompareInt(it)(size) != 0

/** Tests if the size of the collection is greater than or equal to some value. */
@inline def >=(size: Int): Boolean = sizeCompareInt(it)(size) >= 0

/** Tests if the size of the collection is greater than some value. */
@inline def >(size: Int): Boolean = sizeCompareInt(it)(size) > 0
}

private object SizeCompareImpl {
def sizeCompareInt(self: c.Traversable[_])(otherSize: Int): Int =
self match {
case self: c.SeqLike[_, _] => self.lengthCompare(otherSize)
case _ =>
if (otherSize < 0) 1
else {
var i = 0
val it = self.toIterator
while (it.hasNext) {
if (i == otherSize) return 1
it.next()
i += 1
}
i - otherSize
}
}

// `IndexedSeq` is the only thing that we can safely say has a known size
def sizeCompareColl(self: c.Traversable[_])(that: c.Traversable[_]): Int =
that match {
case that: c.IndexedSeq[_] => sizeCompareInt(self)(that.length)
case _ =>
self match {
case self: c.IndexedSeq[_] =>
val res = sizeCompareInt(that)(self.length)
// can't just invert the result, because `-Int.MinValue == Int.MinValue`
if (res == Int.MinValue) 1 else -res
case _ =>
val thisIt = self.toIterator
val thatIt = that.toIterator
while (thisIt.hasNext && thatIt.hasNext) {
thisIt.next()
thatIt.next()
}
java.lang.Boolean.compare(thisIt.hasNext, thatIt.hasNext)
}
}
}

class TraversableLikeExtensionMethods[A, Repr](private val self: c.GenTraversableLike[A, Repr])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package scala.collection

import scala.collection.generic.IsTraversableLike
import scala.{collection => c}
import scala.collection.{mutable => m}

package object compat extends compat.PackageShared {
Expand All @@ -27,7 +28,10 @@ package object compat extends compat.PackageShared {
}

implicit def toTraversableLikeExtensionMethods[Repr](self: Repr)(
implicit traversable: IsTraversableLike[Repr])
: TraversableLikeExtensionMethods[traversable.A, Repr] =
implicit traversable: IsTraversableLike[Repr])
: TraversableLikeExtensionMethods[traversable.A, Repr] =
new TraversableLikeExtensionMethods[traversable.A, Repr](traversable.conversion(self))

implicit def toSeqExtensionMethods[A](self: c.Seq[A]): SeqExtensionMethods[A] =
new SeqExtensionMethods[A](self)
}
50 changes: 50 additions & 0 deletions compat/src/test/scala/test/scala/collection/CollectionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,54 @@ class CollectionTest {
stream.force
assertEquals(18, count)
}

@Test
def sizeCompare(): Unit = {
assertTrue(Set(1, 2, 3).sizeCompare(4) < 0)
assertTrue(Set(1, 2, 3).sizeCompare(2) > 0)
assertTrue(Set(1, 2, 3).sizeCompare(3) == 0)

assertTrue(List(1, 2, 3).sizeCompare(4) < 0)
assertTrue(List(1, 2, 3).sizeCompare(2) > 0)
assertTrue(List(1, 2, 3).sizeCompare(3) == 0)

assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2, 3, 4)) < 0)
assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2)) > 0)
assertTrue(Set(1, 2, 3).sizeCompare(List(1, 2, 3)) == 0)

assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2, 3, 4)) < 0)
assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2)) > 0)
assertTrue(Set(1, 2, 3).sizeCompare(Vector(1, 2, 3)) == 0)

assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2, 3, 4)) < 0)
assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2)) > 0)
assertTrue(Vector(1, 2, 3).sizeCompare(Set(1, 2, 3)) == 0)
}

@Test
def sizeIsLengthIs(): Unit = {
assertTrue(Set(1, 2, 3).sizeIs < 4)
assertTrue(Set(1, 2, 3).sizeIs <= 4)
assertTrue(Set(1, 2, 3).sizeIs <= 3)
assertTrue(Set(1, 2, 3).sizeIs == 3)
assertTrue(Set(1, 2, 3).sizeIs >= 3)
assertTrue(Set(1, 2, 3).sizeIs >= 2)
assertTrue(Set(1, 2, 3).sizeIs > 2)

assertTrue(List(1, 2, 3).sizeIs < 4)
assertTrue(List(1, 2, 3).sizeIs <= 4)
assertTrue(List(1, 2, 3).sizeIs <= 3)
assertTrue(List(1, 2, 3).sizeIs == 3)
assertTrue(List(1, 2, 3).sizeIs >= 3)
assertTrue(List(1, 2, 3).sizeIs >= 2)
assertTrue(List(1, 2, 3).sizeIs > 2)

assertTrue(List(1, 2, 3).lengthIs < 4)
assertTrue(List(1, 2, 3).lengthIs <= 4)
assertTrue(List(1, 2, 3).lengthIs <= 3)
assertTrue(List(1, 2, 3).lengthIs == 3)
assertTrue(List(1, 2, 3).lengthIs >= 3)
assertTrue(List(1, 2, 3).lengthIs >= 2)
assertTrue(List(1, 2, 3).lengthIs > 2)
}
}

0 comments on commit c067ec2

Please sign in to comment.