Skip to content

Commit

Permalink
add XoshiroSplit type
Browse files Browse the repository at this point in the history
  • Loading branch information
nhz2 committed Sep 11, 2023
1 parent 7b9fdf8 commit d086f21
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 105 deletions.
6 changes: 3 additions & 3 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ typeof_rng(::_GLOBAL_RNG) = TaskLocalRNG
"""
default_rng() -> rng
Return the default global random number generator (RNG).
Return the default task-local random number generator (RNG).
!!! note
What the default RNG is is an implementation detail. Across different versions of
Expand All @@ -346,8 +346,8 @@ Return the default global random number generator (RNG).
@inline default_rng() = TaskLocalRNG()
@inline default_rng(tid::Int) = TaskLocalRNG()

copy!(dst::Xoshiro, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::Xoshiro) = copy!(default_rng(), src)
copy!(dst::XoshiroSplit, ::_GLOBAL_RNG) = copy!(dst, default_rng())
copy!(::_GLOBAL_RNG, src::XoshiroSplit) = copy!(default_rng(), src)
copy(::_GLOBAL_RNG) = copy(default_rng())

GLOBAL_SEED = 0
Expand Down
211 changes: 146 additions & 65 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

## Xoshiro RNG
# Lots of implementation is shared with TaskLocalRNG
# Lots of implementation is shared with TaskLocalRNG and XoshiroSplit

"""
Xoshiro(seed)
Expand Down Expand Up @@ -53,14 +53,25 @@ mutable struct Xoshiro <: AbstractRNG
Xoshiro(seed=nothing) = seed!(new(), seed)
end

function setstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
# NON-PUBLIC
@inline function get_xoshiro_state(x::Xoshiro)
x.s0, x.s1, x.s2, x.s3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x
end

# NON-PUBLIC
function seedstate!(x::Xoshiro, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
set_xoshiro_state!(x, s0, s1, s2, s3)
end

copy(rng::Xoshiro) = Xoshiro(rng.s0, rng.s1, rng.s2, rng.s3)

function copy!(dst::Xoshiro, src::Xoshiro)
Expand All @@ -72,21 +83,66 @@ function ==(a::Xoshiro, b::Xoshiro)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3
end

rng_native_52(::Xoshiro) = UInt64

@inline function rand(rng::Xoshiro, ::SamplerType{UInt64})
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
rng.s0, rng.s1, rng.s2, rng.s3 = s0, s1, s2, s3
res
"""
XoshiroSplit(seed)
XoshiroSplit()
Creates the same stream as Xoshiro, but has an additional splitting ability.
For more discussion, cf rng_split in task.c
This is the type currently returned by `copy(default_rng())`.
!!! note
What the default RNG is is an implementation detail. Across different versions of
Julia, you should not expect the default RNG to be always the same, nor that it will
return the same stream of random numbers for a given seed.
"""
mutable struct XoshiroSplit <: AbstractRNG
s0::UInt64
s1::UInt64
s2::UInt64
s3::UInt64
s4::UInt64

XoshiroSplit(
s0::Integer, s1::Integer, s2::Integer, s3::Integer, # xoshiro256 state
s4::Integer, # internal splitmix state
) = new(s0, s1, s2, s3, s4)
XoshiroSplit(seed=nothing) = seed!(new(), seed)
end

# NON-PUBLIC
@inline function get_xoshiro_state(x::XoshiroSplit)
x.s0, x.s1, x.s2, x.s3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::XoshiroSplit, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
x.s0 = s0
x.s1 = s1
x.s2 = s2
x.s3 = s3
x
end

# NON-PUBLIC
function seedstate!(x::XoshiroSplit, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
set_xoshiro_state!(x, s0, s1, s2, s3)
x.s4 = 1s0 + 3s1 + 5s2 + 7s3
x
end

copy(rng::XoshiroSplit) = XoshiroSplit(rng.s0, rng.s1, rng.s2, rng.s3, rng.s4)

function copy!(dst::XoshiroSplit, src::XoshiroSplit)
dst.s0, dst.s1, dst.s2, dst.s3, dst.s4 = src.s0, src.s1, src.s2, src.s3, src.s4
dst
end

function ==(a::XoshiroSplit, b::XoshiroSplit)
a.s0 == b.s0 && a.s1 == b.s1 && a.s2 == b.s2 && a.s3 == b.s3 && a.s4 == b.s4
end


Expand All @@ -111,25 +167,74 @@ is undefined behavior: it will work most of the time, and may sometimes fail sil
"""
struct TaskLocalRNG <: AbstractRNG end
TaskLocalRNG(::Nothing) = TaskLocalRNG()
rng_native_52(::TaskLocalRNG) = UInt64

function setstate!(
x::TaskLocalRNG,
s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64, # xoshiro256 state
s4::UInt64 = 1s0 + 3s1 + 5s2 + 7s3, # internal splitmix state
)
# NON-PUBLIC
@inline function get_xoshiro_state(x::TaskLocalRNG)
t = current_task()
t.rngState0, t.rngState1, t.rngState2, t.rngState3
end

# NON-PUBLIC
@inline function set_xoshiro_state!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = s4
x
end

@inline function rand(::TaskLocalRNG, ::SamplerType{UInt64})
task = current_task()
s0, s1, s2, s3 = task.rngState0, task.rngState1, task.rngState2, task.rngState3
# NON-PUBLIC
function seedstate!(x::TaskLocalRNG, s0::UInt64, s1::UInt64, s2::UInt64, s3::UInt64)
t = current_task()
t.rngState0 = s0
t.rngState1 = s1
t.rngState2 = s2
t.rngState3 = s3
t.rngState4 = 1s0 + 3s1 + 5s2 + 7s3
x
end

function copy(rng::TaskLocalRNG)
t = current_task()
XoshiroSplit(t.rngState0, t.rngState1, t.rngState2, t.rngState3, t.rngState4)
end

function copy!(dst::TaskLocalRNG, src::XoshiroSplit)
t = current_task()
t.rngState0 = src.s0
t.rngState1 = src.s1
t.rngState2 = src.s2
t.rngState3 = src.s3
t.rngState4 = src.s4
return dst
end

function copy!(dst::XoshiroSplit, src::TaskLocalRNG)
t = current_task()
dst.s0 = t.rngState0
dst.s1 = t.rngState1
dst.s2 = t.rngState2
dst.s3 = t.rngState3
dst.s4 = t.rngState4
return dst
end

function ==(a::XoshiroSplit, b::TaskLocalRNG)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3 && a.s4 == t.rngState4
end

==(a::TaskLocalRNG, b::XoshiroSplit) = b == a

# Shared implementation between Xoshiro, XoshiroSplit, and TaskLocalRNG

const XoshiroLike = Union{TaskLocalRNG, Xoshiro, XoshiroSplit}

rng_native_52(::XoshiroLike) = UInt64

@inline function rand(rng::XoshiroLike, ::SamplerType{UInt64})
s0, s1, s2, s3 = get_xoshiro_state(rng)
tmp = s0 + s3
res = ((tmp << 23) | (tmp >> 41)) + s0
t = s1 << 17
Expand All @@ -139,78 +244,54 @@ end
s0 ⊻= s3
s2 ⊻= t
s3 = s3 << 45 | s3 >> 19
task.rngState0, task.rngState1, task.rngState2, task.rngState3 = s0, s1, s2, s3
set_xoshiro_state!(rng, s0, s1, s2, s3)
res
end

# Shared implementation between Xoshiro and TaskLocalRNG -- seeding

function seed!(rng::Union{TaskLocalRNG,Xoshiro})
function seed!(rng::XoshiroLike)
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
setstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
seedstate!(rng, rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64), rand(rd, UInt64))
end

function seed!(rng::Union{TaskLocalRNG,Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
function seed!(rng::XoshiroLike, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
setstate!(rng, s0, s1, s2, s3)
seedstate!(rng, s0, s1, s2, s3)
end

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
seed!(rng::XoshiroLike, seed::Integer) = seed!(rng, make_seed(seed))


@inline function rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt128})
@inline function rand(rng::XoshiroLike, ::SamplerType{UInt128})
first = rand(rng, UInt64)
second = rand(rng,UInt64)
second + UInt128(first) << 64
end

@inline rand(rng::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128
@inline rand(rng::XoshiroLike, ::SamplerType{Int128}) = rand(rng, UInt128) % Int128

@inline function rand(rng::Union{TaskLocalRNG, Xoshiro},
@inline function rand(rng::XoshiroLike,
T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64))
S = T[]
# use upper bits
(rand(rng, UInt64) >>> (64 - 8*sizeof(S))) % S
end

function copy(rng::TaskLocalRNG)
t = current_task()
Xoshiro(t.rngState0, t.rngState1, t.rngState2, t.rngState3)
end

function copy!(dst::TaskLocalRNG, src::Xoshiro)
t = current_task()
setstate!(dst, src.s0, src.s1, src.s2, src.s3)
return dst
end

function copy!(dst::Xoshiro, src::TaskLocalRNG)
t = current_task()
setstate!(dst, t.rngState0, t.rngState1, t.rngState2, t.rngState3)
return dst
end

function ==(a::Xoshiro, b::TaskLocalRNG)
t = current_task()
a.s0 == t.rngState0 && a.s1 == t.rngState1 && a.s2 == t.rngState2 && a.s3 == t.rngState3
end

==(a::TaskLocalRNG, b::Xoshiro) = b == a

# for partial words, use upper bits from Xoshiro

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52Raw{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())
rand(r::XoshiroLike, ::SamplerTrivial{UInt52Raw{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::XoshiroLike, ::SamplerTrivial{UInt52{UInt64}}) = rand(r, UInt64) >>> 12
rand(r::XoshiroLike, ::SamplerTrivial{UInt104{UInt128}}) = rand(r, UInt104Raw())

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float16}}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01{Float16}}) =
Float16(Float32(rand(r, UInt16) >>> 5) * Float32(0x1.0p-11))

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01{Float32}}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01{Float32}}) =
Float32(rand(r, UInt32) >>> 8) * Float32(0x1.0p-24)

rand(r::Union{TaskLocalRNG, Xoshiro}, ::SamplerTrivial{CloseOpen01_64}) =
rand(r::XoshiroLike, ::SamplerTrivial{CloseOpen01_64}) =
Float64(rand(r, UInt64) >>> 11) * 0x1.0p-53
Loading

0 comments on commit d086f21

Please sign in to comment.