From 9e8e9971c9ce2b9e42465597abfcf2158fd8e0a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milan=20Kl=C3=B6wer?= Date: Fri, 12 Apr 2024 14:01:23 -0400 Subject: [PATCH 1/2] rand BFloat16 sampling without conversion --- src/bfloat16.jl | 12 +++++++++++- test/runtests.jl | 15 ++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index a643df9..9a6f7d2 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -350,7 +350,17 @@ Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler -rand(rng::AbstractRNG, ::Sampler{BFloat16}) = convert(BFloat16, rand(rng)) + +"""Sample a BFloat16 from [0,1) by setting random mantissa +bits for one(BFloat16) to obtain [1,2) (where floats are uniformly +distributed) then subtract 1 for [0,1).""" +function rand(rng::AbstractRNG, ::Sampler{BFloat16}) + u = reinterpret(UInt16, one(BFloat16)) + # shift random bits into BFloat16 mantissa (1 sign + 8 exp bits = 9) + u |= rand(rng, UInt16) >> 9 # u in [1,2) + return reinterpret(BFloat16, u) - one(BFloat16) # -1 for [0,1) +end + randn(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randn(rng)) randexp(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randexp(rng)) diff --git a/test/runtests.jl b/test/runtests.jl index 908ddf6..0a1c344 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -175,5 +175,18 @@ end @test a-1+1 == a # but -1 can end +@testset "rand sampling" begin + Random.seed(123) + mi, ma = extrema(rand(BFloat16, 1_000_000)) + + # zero should be the lowest BFloat16 sampled + @test mi === zero(BFloat16) + + # prevfloat(one(BFloat16)) cannot be sampled bc + # prevfloat(BFloat16(2)) - 1 is _two_ before one(BFloat16) + # (a statistical flaw of the [1,2)-1 sampling) + @test ma === prevfloat(one(BFloat16), 2) +end + include("structure.jl") -include("mathfuncs.jl") +include("mathfuncs.jl") \ No newline at end of file From 4837ccc272da04257345f650ad769160e4e66c05 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milan=20Kl=C3=B6wer?= Date: Fri, 12 Apr 2024 14:16:02 -0400 Subject: [PATCH 2/2] Random.seed! not seed --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 0a1c344..ca14594 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -176,7 +176,7 @@ end end @testset "rand sampling" begin - Random.seed(123) + Random.seed!(123) mi, ma = extrema(rand(BFloat16, 1_000_000)) # zero should be the lowest BFloat16 sampled