Skip to content

Commit

Permalink
Implement next/prevfloat(::BFloat16,::Integer) (#49)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <tim.besard@gmail.com>
  • Loading branch information
milankl and maleadt authored Oct 6, 2023
1 parent 2165157 commit a42c4fa
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 40 deletions.
86 changes: 48 additions & 38 deletions src/bfloat16.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: isfinite, isnan, precision, iszero, eps,
signbit, exponent, significand, frexp, ldexp,
round, Int16, Int32, Int64,
+, -, *, /, ^, ==, <, <=, >=, >, !=, inv,
abs, abs2, sqrt, cbrt,
abs, abs2, uabs, sqrt, cbrt,
exp, exp2, exp10, expm1,
log, log2, log10, log1p,
sin, cos, tan, csc, sec, cot,
Expand All @@ -17,6 +17,9 @@ import Base: isfinite, isnan, precision, iszero, eps,

primitive type BFloat16 <: AbstractFloat 16 end

Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x)
Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x)

# Floating point property queries
for f in (:sign_mask, :exponent_mask, :exponent_one,
:exponent_half, :significand_mask)
Expand All @@ -26,21 +29,21 @@ end
Base.exponent_bias(::Type{BFloat16}) = 127
Base.exponent_bits(::Type{BFloat16}) = 8
Base.significand_bits(::Type{BFloat16}) = 7
Base.signbit(x::BFloat16) = (reinterpret(UInt16, x) & 0x8000) !== 0x0000
Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000

function Base.significand(x::BFloat16)
result = abs_significand(x)
ifelse(signbit(x), -result, result)
end

@inline function abs_significand(x::BFloat16)
usig = Base.significand_mask(BFloat16) & reinterpret(UInt16, x)
usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x)
isig = Int16(usig)
1 + isig / BFloat16(2)^7
end

Base.exponent(x::BFloat16) =
((reinterpret(UInt16, x) & Base.exponent_mask(BFloat16)) >> 7) - Base.exponent_bias(BFloat16)
((reinterpret(Unsigned, x) & Base.exponent_mask(BFloat16)) >> 7) - Base.exponent_bias(BFloat16)

function Base.frexp(x::BFloat16)
xp = exponent(x) + 1
Expand All @@ -56,9 +59,9 @@ function Base.rem(x::BFloat16, ::Type{T}) where {T<:Integer}
T(trunc(x))
end

iszero(x::BFloat16) = reinterpret(UInt16, x) & ~sign_mask(BFloat16) == 0x0000
isfinite(x::BFloat16) = (reinterpret(UInt16,x) & exponent_mask(BFloat16)) != exponent_mask(BFloat16)
isnan(x::BFloat16) = (reinterpret(UInt16,x) & ~sign_mask(BFloat16)) > exponent_mask(BFloat16)
iszero(x::BFloat16) = reinterpret(Unsigned, x) & ~sign_mask(BFloat16) == 0x0000
isfinite(x::BFloat16) = (reinterpret(Unsigned,x) & exponent_mask(BFloat16)) != exponent_mask(BFloat16)
isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent_mask(BFloat16)
precision(::Type{BFloat16}) = 8
eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00)

Expand Down Expand Up @@ -129,7 +132,7 @@ end

# Expansion to Float32
function Base.Float32(x::BFloat16)
reinterpret(Float32, UInt32(reinterpret(UInt16, x)) << 16)
reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16)
end

# Expansion to Float64
Expand All @@ -145,7 +148,7 @@ Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x))
for f in (:+, :-, :*, :/, :^)
@eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y)))
end
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(UInt16, x) sign_mask(BFloat16))
-(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) sign_mask(BFloat16))
^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y))

const ZeroBFloat16 = BFloat16(0.0f0)
Expand All @@ -157,8 +160,8 @@ inv(x::BFloat16) = one(BFloat16) / x

# Floating point comparison
function ==(x::BFloat16, y::BFloat16)
ix = reinterpret(UInt16, x)
iy = reinterpret(UInt16, y)
ix = reinterpret(Unsigned, x)
iy = reinterpret(Unsigned, y)
# NaNs (isnan(x) || isnan(y))
if (ix|iy)&~sign_mask(BFloat16) > exponent_mask(BFloat16)
return false
Expand Down Expand Up @@ -205,37 +208,45 @@ randn(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randn(rng))
randexp(rng::AbstractRNG, ::Type{BFloat16}) = convert(BFloat16, randexp(rng))

# Bitstring
bitstring(x::BFloat16) = bitstring(reinterpret(UInt16, x))
bitstring(x::BFloat16) = bitstring(reinterpret(Unsigned, x))

# next/prevfloat
function Base.nextfloat(x::BFloat16)
if isfinite(x)
ui = reinterpret(UInt16,x)
if ui < 0x8000 # positive numbers
return reinterpret(BFloat16,ui+0x0001)
elseif ui == 0x8000 # =-zero(T)
return reinterpret(BFloat16,0x0001)
else # negative numbers
return reinterpret(BFloat16,ui-0x0001)
function Base.nextfloat(f::BFloat16, d::Integer)
F = typeof(f)
fumax = reinterpret(Unsigned, F(Inf))
U = typeof(fumax)

isnan(f) && return f
fi = reinterpret(Signed, f)
fneg = fi < 0
fu = unsigned(fi & typemax(fi))

dneg = d < 0
da = uabs(d)
if da > typemax(U)
fneg = dneg
fu = fumax
else
du = da % U
if fneg dneg
if du > fu
fu = min(fumax, du - fu)
fneg = !fneg
else
fu = fu - du
end
else
if fumax - fu < du
fu = fumax
else
fu = fu + du
end
end
else # NaN / Inf case
return x
end
end

function Base.prevfloat(x::BFloat16)
if isfinite(x)
ui = reinterpret(UInt16,x)
if ui == 0x0000 # =zero(T)
return reinterpret(BFloat16,0x8001)
elseif ui < 0x8000 # positive numbers
return reinterpret(BFloat16,ui-0x0001)
else # negative numbers
return reinterpret(BFloat16,ui+0x0001)
end
else # NaN / Inf case
return x
if fneg
fu |= sign_mask(F)
end
reinterpret(F, fu)
end

# math functions
Expand All @@ -250,4 +261,3 @@ for F in (:abs, :abs2, :sqrt, :cbrt,
Base.$F(x::BFloat16) = BFloat16($F(Float32(x)))
end
end

22 changes: 20 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,27 @@ end
@test isinf(nextfloat(BFloat16s.InfB16))

@test isnan(prevfloat(BFloat16s.NaNB16))
@test isinf(prevfloat(BFloat16s.InfB16))
end

@testset "Next/prevfloat(x,::Integer)" begin

x = one(BFloat16)
@test x == prevfloat(nextfloat(x,100),100)
@test x == nextfloat(prevfloat(x,100),100)

x = -one(BFloat16)
@test x == prevfloat(nextfloat(x,100),100)
@test x == nextfloat(prevfloat(x,100),100)

x = one(BFloat16)
@test nextfloat(x,5) == prevfloat(x,-5)
@test prevfloat(x,-5) == nextfloat(x,5)

@test isinf(nextfloat(floatmax(BFloat16),5))
@test prevfloat(floatmin(BFloat16),2^8) < 0
@test nextfloat(-floatmin(BFloat16),2^8) > 0
end


include("structure.jl")
include("mathfuncs.jl")

0 comments on commit a42c4fa

Please sign in to comment.