Skip to content

Commit

Permalink
faster randn by separating out unlikely branch in a function
Browse files Browse the repository at this point in the history
All credits to @ViralBShah (cf. #8941 and #9126).
This change probably allows better inlining.
  • Loading branch information
rfourquet committed Nov 24, 2014
1 parent 3ccaf3c commit 616e1d7
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type Close1Open2 <: FloatInterval end

@inline rand_ui52_raw_inbounds(r::MersenneTwister) = reinterpret(UInt64, rand_inbounds(r, Close1Open2))
@inline rand_ui52_raw(r::MersenneTwister) = (reserve_1(r); rand_ui52_raw_inbounds(r))
@inline rand_ui52(r::MersenneTwister) = rand_ui52_raw(r) & 0x000fffffffffffff
@inline rand_ui2x52_raw(r::MersenneTwister) = rand_ui52_raw(r) % UInt128 << 64 | rand_ui52_raw(r)

function srand(r::MersenneTwister, seed::Vector{UInt32})
Expand Down Expand Up @@ -943,29 +944,30 @@ const ziggurat_nor_r = 3.6541528853610087963519472518
const ziggurat_nor_inv_r = inv(ziggurat_nor_r)
const ziggurat_exp_r = 7.6971174701310497140446280481

@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand(rng, Close1Open2)) & 0x000fffffffffffff

function randmtzig_randn(rng::MersenneTwister=GLOBAL_RNG)
@inbounds begin
r = rand_ui52(rng)
rabs = int64(r>>1) # One bit for the sign
idx = rabs & 0xFF
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
rabs < ki[idx+1] && return x # 99.3% of the time we return here 1st try
return randmtzig_randn_unlikely(rng, idx, rabs, x)
end
end

# this unlikely branch is put in a separate function for better efficiency
function randmtzig_randn_unlikely(rng, idx, rabs, x)
@inbounds if idx == 0
while true
r = randi(rng)
rabs = int64(r>>1) # One bit for the sign
idx = rabs & 0xFF
x = ifelse(r % Bool, -rabs, rabs)*wi[idx+1]
if rabs < ki[idx+1]
return x # 99.3% of the time we return here 1st try
elseif idx == 0
while true
xx = -ziggurat_nor_inv_r*log(rand(rng))
yy = -log(rand(rng))
if yy+yy > xx*xx
return (rabs & 0x100) != 0x000000000 ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
end
end
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
return x # return from the triangular area
end
xx = -ziggurat_nor_inv_r*log(rand(rng))
yy = -log(rand(rng))
yy+yy > xx*xx && return (rabs >> 8) % Bool ? -ziggurat_nor_r-xx : ziggurat_nor_r+xx
end
elseif (fi[idx] - fi[idx+1])*rand(rng) + fi[idx+1] < exp(-0.5*x*x)
return x # return from the triangular area
else
return randmtzig_randn(rng)
end
end

Expand Down

0 comments on commit 616e1d7

Please sign in to comment.