Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create and use java.util.ConcurrentThreadLocalRandom on the same flatMap call #2784

Merged
merged 4 commits into from
Feb 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 108 additions & 64 deletions std/shared/src/main/scala/cats/effect/std/Random.scala
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,7 @@ object Random {
}

def javaUtilConcurrentThreadLocalRandom[F[_]: Sync]: Random[F] =
new ScalaRandom[F](
Sync[F].delay(new SRandom(java.util.concurrent.ThreadLocalRandom.current()))) {}
new ThreadLocalRandom[F] {}

def javaSecuritySecureRandom[F[_]: Sync](n: Int): F[Random[F]] =
for {
Expand All @@ -345,21 +344,40 @@ object Random {
def javaSecuritySecureRandom[F[_]: Sync]: F[Random[F]] =
Sync[F].delay(new java.security.SecureRandom).flatMap(r => javaUtilRandom(r))

private abstract class ScalaRandom[F[_]: Sync](f: F[SRandom]) extends Random[F] {
private sealed abstract class RandomCommon[F[_]: Sync] extends Random[F] {
def betweenDouble(minInclusive: Double, maxExclusive: Double): F[Double] =
for {
_ <- require(minInclusive < maxExclusive, "Invalid bounds")
d <- nextDouble
} yield {
val next = d * (maxExclusive - minInclusive) + minInclusive
if (next < maxExclusive) next
else Math.nextAfter(maxExclusive, Double.NegativeInfinity)
}

def betweenLong(minInclusive: Long, maxExclusive: Long): F[Long] =
def betweenFloat(minInclusive: Float, maxExclusive: Float): F[Float] =
for {
_ <- require(minInclusive < maxExclusive, "Invalid bounds")
f <- nextFloat
} yield {
val next = f * (maxExclusive - minInclusive) + minInclusive
if (next < maxExclusive) next
else Math.nextAfter(maxExclusive, Float.NegativeInfinity)
}

def betweenInt(minInclusive: Int, maxExclusive: Int): F[Int] =
require(minInclusive < maxExclusive, "Invalid bounds") *> {
val difference = maxExclusive - minInclusive
for {
out <-
if (difference >= 0) {
nextLongBounded(difference).map(_ + minInclusive)
nextIntBounded(difference).map(_ + minInclusive)
} else {
/* The interval size here is greater than Long.MaxValue,
/* The interval size here is greater than Int.MaxValue,
* so the loop will exit with a probability of at least 1/2.
*/
def loop(): F[Long] = {
nextLong.flatMap { n =>
def loop(): F[Int] = {
nextInt.flatMap { n =>
if (n >= minInclusive && n < maxExclusive) n.pure[F]
else loop()
}
Expand All @@ -369,19 +387,19 @@ object Random {
} yield out
}

def betweenInt(minInclusive: Int, maxExclusive: Int): F[Int] =
def betweenLong(minInclusive: Long, maxExclusive: Long): F[Long] =
require(minInclusive < maxExclusive, "Invalid bounds") *> {
val difference = maxExclusive - minInclusive
for {
out <-
if (difference >= 0) {
nextIntBounded(difference).map(_ + minInclusive)
nextLongBounded(difference).map(_ + minInclusive)
} else {
/* The interval size here is greater than Int.MaxValue,
/* The interval size here is greater than Long.MaxValue,
* so the loop will exit with a probability of at least 1/2.
*/
def loop(): F[Int] = {
nextInt.flatMap { n =>
def loop(): F[Long] = {
nextLong.flatMap { n =>
if (n >= minInclusive && n < maxExclusive) n.pure[F]
else loop()
}
Expand All @@ -391,31 +409,45 @@ object Random {
} yield out
}

def betweenFloat(minInclusive: Float, maxExclusive: Float): F[Float] =
for {
_ <- require(minInclusive < maxExclusive, "Invalid bounds")
f <- nextFloat
} yield {
val next = f * (maxExclusive - minInclusive) + minInclusive
if (next < maxExclusive) next
else Math.nextAfter(maxExclusive, Float.NegativeInfinity)
}

def betweenDouble(minInclusive: Double, maxExclusive: Double): F[Double] =
for {
_ <- require(minInclusive < maxExclusive, "Invalid bounds")
d <- nextDouble
} yield {
val next = d * (maxExclusive - minInclusive) + minInclusive
if (next < maxExclusive) next
else Math.nextAfter(maxExclusive, Double.NegativeInfinity)
}

def nextAlphaNumeric: F[Char] = {
val chars = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
nextIntBounded(chars.length()).map(chars.charAt(_))
}

def nextLongBounded(n: Long): F[Long] = {
/*
* Divide n by two until small enough for nextInt. On each
* iteration (at most 31 of them but usually much less),
* randomly choose both whether to include high bit in result
* (offset) and whether to continue with the lower vs upper
* half (which makes a difference only if odd).
*/
for {
_ <- require(n > 0, s"n must be positive, but was $n")
offset <- Ref[F].of(0L)
_n <- Ref[F].of(n)
_ <- Monad[F].whileM_(_n.get.map(_ >= Integer.MAX_VALUE))(
for {
bits <- nextIntBounded(2)
halfn <- _n.get.map(_ >>> 1)
nextN <- if ((bits & 2) == 0) halfn.pure[F] else _n.get.map(_ - halfn)
_ <-
if ((bits & 1) == 0) _n.get.flatMap(n => offset.update(_ + (n - nextN)))
else Applicative[F].unit
_ <- _n.set(nextN)
} yield ()
)
finalOffset <- offset.get
int <- _n.get.flatMap(l => nextIntBounded(l.toInt))
} yield finalOffset + int
}

private def require(condition: Boolean, errorMessage: => String): F[Unit] =
if (condition) ().pure[F]
else new IllegalArgumentException(errorMessage).raiseError[F, Unit]
}

private abstract class ScalaRandom[F[_] : Sync](f: F[SRandom]) extends RandomCommon[F] {
def nextBoolean: F[Boolean] =
for {
r <- f
Expand Down Expand Up @@ -465,34 +497,6 @@ object Random {
out <- Sync[F].delay(r.nextLong())
} yield out

def nextLongBounded(n: Long): F[Long] = {
/*
* Divide n by two until small enough for nextInt. On each
* iteration (at most 31 of them but usually much less),
* randomly choose both whether to include high bit in result
* (offset) and whether to continue with the lower vs upper
* half (which makes a difference only if odd).
*/
for {
_ <- require(n > 0, s"n must be positive, but was $n")
offset <- Ref[F].of(0L)
_n <- Ref[F].of(n)
_ <- Monad[F].whileM_(_n.get.map(_ >= Integer.MAX_VALUE))(
for {
bits <- nextIntBounded(2)
halfn <- _n.get.map(_ >>> 1)
nextN <- if ((bits & 2) == 0) halfn.pure[F] else _n.get.map(_ - halfn)
_ <-
if ((bits & 1) == 0) _n.get.flatMap(n => offset.update(_ + (n - nextN)))
else Applicative[F].unit
_ <- _n.set(nextN)
} yield ()
)
finalOffset <- offset.get
int <- _n.get.flatMap(l => nextIntBounded(l.toInt))
} yield finalOffset + int
}

def nextPrintableChar: F[Char] =
for {
r <- f
Expand All @@ -516,9 +520,49 @@ object Random {
r <- f
out <- Sync[F].delay(r.shuffle(v))
} yield out
}

private def require(condition: Boolean, errorMessage: => String): F[Unit] =
if (condition) ().pure[F]
else new IllegalArgumentException(errorMessage).raiseError[F, Unit]
private abstract class ThreadLocalRandom[F[_]: Sync] extends RandomCommon[F] {
def nextBoolean: F[Boolean] =
Sync[F].delay(localRandom().nextBoolean())

def nextBytes(n: Int): F[Array[Byte]] = {
val bytes = new Array[Byte](0 max n)
Sync[F]
.delay(localRandom().nextBytes(bytes))
.as(bytes)
}

def nextDouble: F[Double] =
Sync[F].delay(localRandom().nextDouble())

def nextFloat: F[Float] =
Sync[F].delay(localRandom().nextFloat())

def nextGaussian: F[Double] =
Sync[F].delay(localRandom().nextGaussian())

def nextInt: F[Int] =
Sync[F].delay(localRandom().nextInt())

def nextIntBounded(n: Int): F[Int] =
Sync[F].delay(localRandom().self.nextInt(n))

def nextLong: F[Long] =
Sync[F].delay(localRandom().nextLong())

def nextPrintableChar: F[Char] =
Sync[F].delay(localRandom().nextPrintableChar())

def nextString(length: Int): F[String] =
Sync[F].delay(localRandom().nextString(length))

def shuffleList[A](l: List[A]): F[List[A]] =
Sync[F].delay(localRandom().shuffle(l))

def shuffleVector[A](v: Vector[A]): F[Vector[A]] =
Sync[F].delay(localRandom().shuffle(v))
}

private[this] def localRandom() = new SRandom(java.util.concurrent.ThreadLocalRandom.current())
}