Skip to content

Commit

Permalink
rand BFloat16 sampling without conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
milankl authored Apr 12, 2024
1 parent afd2da3 commit 9e8e997
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
12 changes: 11 additions & 1 deletion src/bfloat16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,17 @@ Printf.tofloat(x::BFloat16) = Float32(x)

# Random
import Random: rand, randn, randexp, AbstractRNG, Sampler
rand(rng::AbstractRNG, ::Sampler{BFloat16}) = convert(BFloat16, rand(rng))

"""Sample a BFloat16 from [0,1) by setting random mantissa
bits for one(BFloat16) to obtain [1,2) (where floats are uniformly
distributed) then subtract 1 for [0,1)."""
function rand(rng::AbstractRNG, ::Sampler{BFloat16})
u = reinterpret(UInt16, one(BFloat16))
# shift random bits into BFloat16 mantissa (1 sign + 8 exp bits = 9)
u |= rand(rng, UInt16) >> 9 # u in [1,2)
return reinterpret(BFloat16, u) - one(BFloat16) # -1 for [0,1)
end

randn(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randn(rng))
randexp(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randexp(rng))

Expand Down
15 changes: 14 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,5 +175,18 @@ end
@test a-1+1 == a # but -1 can
end

@testset "rand sampling" begin
Random.seed(123)
mi, ma = extrema(rand(BFloat16, 1_000_000))

# zero should be the lowest BFloat16 sampled
@test mi === zero(BFloat16)

# prevfloat(one(BFloat16)) cannot be sampled bc
# prevfloat(BFloat16(2)) - 1 is _two_ before one(BFloat16)
# (a statistical flaw of the [1,2)-1 sampling)
@test ma === prevfloat(one(BFloat16), 2)
end

include("structure.jl")
include("mathfuncs.jl")
include("mathfuncs.jl")

0 comments on commit 9e8e997

Please sign in to comment.