From a42c4fa6cf879186d4c724bb70a392f1cae8428e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milan=20Kl=C3=B6wer?= Date: Fri, 6 Oct 2023 05:11:59 -0700 Subject: [PATCH] Implement next/prevfloat(::BFloat16,::Integer) (#49) Co-authored-by: Tim Besard --- src/bfloat16.jl | 86 +++++++++++++++++++++++++++--------------------- test/runtests.jl | 22 +++++++++++-- 2 files changed, 68 insertions(+), 40 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 4c732b0..218d48a 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -6,7 +6,7 @@ import Base: isfinite, isnan, precision, iszero, eps, signbit, exponent, significand, frexp, ldexp, round, Int16, Int32, Int64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, - abs, abs2, sqrt, cbrt, + abs, abs2, uabs, sqrt, cbrt, exp, exp2, exp10, expm1, log, log2, log10, log1p, sin, cos, tan, csc, sec, cot, @@ -17,6 +17,9 @@ import Base: isfinite, isnan, precision, iszero, eps, primitive type BFloat16 <: AbstractFloat 16 end +Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) +Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x) + # Floating point property queries for f in (:sign_mask, :exponent_mask, :exponent_one, :exponent_half, :significand_mask) @@ -26,7 +29,7 @@ end Base.exponent_bias(::Type{BFloat16}) = 127 Base.exponent_bits(::Type{BFloat16}) = 8 Base.significand_bits(::Type{BFloat16}) = 7 -Base.signbit(x::BFloat16) = (reinterpret(UInt16, x) & 0x8000) !== 0x0000 +Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000 function Base.significand(x::BFloat16) result = abs_significand(x) @@ -34,13 +37,13 @@ function Base.significand(x::BFloat16) end @inline function abs_significand(x::BFloat16) - usig = Base.significand_mask(BFloat16) & reinterpret(UInt16, x) + usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x) isig = Int16(usig) 1 + isig / BFloat16(2)^7 end Base.exponent(x::BFloat16) = - ((reinterpret(UInt16, x) & Base.exponent_mask(BFloat16)) >> 7) - Base.exponent_bias(BFloat16) + ((reinterpret(Unsigned, x) & Base.exponent_mask(BFloat16)) >> 7) - Base.exponent_bias(BFloat16) function Base.frexp(x::BFloat16) xp = exponent(x) + 1 @@ -56,9 +59,9 @@ function Base.rem(x::BFloat16, ::Type{T}) where {T<:Integer} T(trunc(x)) end -iszero(x::BFloat16) = reinterpret(UInt16, x) & ~sign_mask(BFloat16) == 0x0000 -isfinite(x::BFloat16) = (reinterpret(UInt16,x) & exponent_mask(BFloat16)) != exponent_mask(BFloat16) -isnan(x::BFloat16) = (reinterpret(UInt16,x) & ~sign_mask(BFloat16)) > exponent_mask(BFloat16) +iszero(x::BFloat16) = reinterpret(Unsigned, x) & ~sign_mask(BFloat16) == 0x0000 +isfinite(x::BFloat16) = (reinterpret(Unsigned,x) & exponent_mask(BFloat16)) != exponent_mask(BFloat16) +isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent_mask(BFloat16) precision(::Type{BFloat16}) = 8 eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00) @@ -129,7 +132,7 @@ end # Expansion to Float32 function Base.Float32(x::BFloat16) - reinterpret(Float32, UInt32(reinterpret(UInt16, x)) << 16) + reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16) end # Expansion to Float64 @@ -145,7 +148,7 @@ Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) for f in (:+, :-, :*, :/, :^) @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) end --(x::BFloat16) = reinterpret(BFloat16, reinterpret(UInt16, x) ⊻ sign_mask(BFloat16)) +-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) ^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y)) const ZeroBFloat16 = BFloat16(0.0f0) @@ -157,8 +160,8 @@ inv(x::BFloat16) = one(BFloat16) / x # Floating point comparison function ==(x::BFloat16, y::BFloat16) - ix = reinterpret(UInt16, x) - iy = reinterpret(UInt16, y) + ix = reinterpret(Unsigned, x) + iy = reinterpret(Unsigned, y) # NaNs (isnan(x) || isnan(y)) if (ix|iy)&~sign_mask(BFloat16) > exponent_mask(BFloat16) return false @@ -205,37 +208,45 @@ randn(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randn(rng)) randexp(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randexp(rng)) # Bitstring -bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x)) +bitstring(x::BFloat16) = bitstring(reinterpret(Unsigned, x)) # next/prevfloat -function Base.nextfloat(x::BFloat16) - if isfinite(x) - ui = reinterpret(UInt16,x) - if ui < 0x8000 # positive numbers - return reinterpret(BFloat16,ui+0x0001) - elseif ui == 0x8000 # =-zero(T) - return reinterpret(BFloat16,0x0001) - else # negative numbers - return reinterpret(BFloat16,ui-0x0001) +function Base.nextfloat(f::BFloat16, d::Integer) + F = typeof(f) + fumax = reinterpret(Unsigned, F(Inf)) + U = typeof(fumax) + + isnan(f) && return f + fi = reinterpret(Signed, f) + fneg = fi < 0 + fu = unsigned(fi & typemax(fi)) + + dneg = d < 0 + da = uabs(d) + if da > typemax(U) + fneg = dneg + fu = fumax + else + du = da % U + if fneg ⊻ dneg + if du > fu + fu = min(fumax, du - fu) + fneg = !fneg + else + fu = fu - du + end + else + if fumax - fu < du + fu = fumax + else + fu = fu + du + end end - else # NaN / Inf case - return x end -end - -function Base.prevfloat(x::BFloat16) - if isfinite(x) - ui = reinterpret(UInt16,x) - if ui == 0x0000 # =zero(T) - return reinterpret(BFloat16,0x8001) - elseif ui < 0x8000 # positive numbers - return reinterpret(BFloat16,ui-0x0001) - else # negative numbers - return reinterpret(BFloat16,ui+0x0001) - end - else # NaN / Inf case - return x + if fneg + fu |= sign_mask(F) end + reinterpret(F, fu) end # math functions @@ -250,4 +261,3 @@ for F in (:abs, :abs2, :sqrt, :cbrt, Base.$F(x::BFloat16) = BFloat16($F(Float32(x))) end end - diff --git a/test/runtests.jl b/test/runtests.jl index dd3688a..44f4bcd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -128,9 +128,27 @@ end @test isinf(nextfloat(BFloat16s.InfB16)) @test isnan(prevfloat(BFloat16s.NaNB16)) - @test isinf(prevfloat(BFloat16s.InfB16)) end +@testset "Next/prevfloat(x,::Integer)" begin + + x = one(BFloat16) + @test x == prevfloat(nextfloat(x,100),100) + @test x == nextfloat(prevfloat(x,100),100) + + x = -one(BFloat16) + @test x == prevfloat(nextfloat(x,100),100) + @test x == nextfloat(prevfloat(x,100),100) + + x = one(BFloat16) + @test nextfloat(x,5) == prevfloat(x,-5) + @test prevfloat(x,-5) == nextfloat(x,5) + + @test isinf(nextfloat(floatmax(BFloat16),5)) + @test prevfloat(floatmin(BFloat16),2^8) < 0 + @test nextfloat(-floatmin(BFloat16),2^8) > 0 +end + + include("structure.jl") include("mathfuncs.jl") -