Skip to content

Commit

Permalink
Add @ieee macro.
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Feb 5, 2021
1 parent 0050a39 commit e4afa71
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 22 deletions.
83 changes: 66 additions & 17 deletions src/double.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,44 @@ end
@inline trunclo(x::Float32) = reinterpret(Float32, reinterpret(UInt32, x) & 0xffff_f000) # clear lowest 12 bits (leave upper 12 bits)

# @inline trunclo(x::VecProduct) = trunclo(Vec(data(x)))
@inline function trunclo(x::Vec{N,Float64}) where {N}
@inline function trunclo(x::AbstractSIMD{N,Float64}) where {N}
reinterpret(Vec{N,Float64}, reinterpret(Vec{N,UInt64}, x) & vbroadcast(Val{N}(), 0xffff_ffff_f800_0000)) # clear lower 27 bits (leave upper 26 bits)
end
@inline function trunclo(x::Vec{N,Float32}) where {N}
@inline function trunclo(x::AbstractSIMD{N,Float32}) where {N}
reinterpret(Vec{N,Float32}, reinterpret(Vec{N,UInt32}, x) & vbroadcast(Val{N}(), 0xffff_f000)) # clear lowest 12 bits (leave upper 12 bits)
end
for (op,f,ff) [("fadd",:add_ieee,:(+)),("fsub",:sub_ieee,:(-)),("fmul",:mul_ieee,:(*)),("fdiv",:fdiv_ieee,:(/)),("frem",:rem_ieee,:(%))]
@eval begin
@generated $f(v1::Vec{W,T}, v2::Vec{W,T}) where {W,T<:Union{Float32,Float64}} = VectorizationBase.binary_op($op, W, T)
@inline $f(s1::T, s2::T) where {T<:Union{Float32,Float64}} = $ff(s1,s2)
end
end
@inline add_ieee(a, b, c) = add_ieee(add_ieee(a, b), c)
function sub_ieee!(ex)
ex isa Expr || return
if ex.head === :call
_f = ex.args[1]
if _f isa Symbol
f::Symbol = _f
if f === :(+)
ex.args[1] = :add_ieee
elseif f === :(-)
ex.args[1] = :sub_ieee
elseif f === :(*)
ex.args[1] = :mul_ieee
elseif f === :(/)
ex.args[1] = :fdiv_ieee
elseif f === :(%)
ex.args[1] = :rem_ieee
end
end
end
foreach(sub_ieee!, ex.args)
end
macro ieee(ex)
sub_ieee!(ex)
end


@inline function splitprec(x::vIEEEFloat)
hx = trunclo(x)
Expand Down Expand Up @@ -200,8 +232,10 @@ end
@inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False)
hx, lx = splitprec(x)
hy, ly = splitprec(y)
z = (x * y)
Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly)
@ieee begin
z = x * y
Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly))
end
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True)
z = (x.hi * y)
Expand All @@ -210,8 +244,10 @@ end
@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False)
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y)
z = x.hi * y
Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y)
@ieee begin
z = x.hi * y
Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y)
end
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True)
z = x.hi * y.hi
Expand All @@ -220,8 +256,10 @@ end
@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False)
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y.hi)
@ieee begin
z = x.hi * y.hi
Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi)
end
end
@inline dmul(x::vIEEEFloat, y::Double{<:vIEEEFloat}) = dmul(y, x)
@inline dmul(x, y) = dmul(x, y, fma_fast())
Expand All @@ -232,17 +270,21 @@ end
end
@inline function dsqu(x::T, ::False) where {T<:vIEEEFloat}
hx, lx = splitprec(x)
@ieee begin
z = x * x
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx)
end
end
@inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat}
z = x.hi * x.hi
Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo)))
end
@inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat}
hx, lx = splitprec(x.hi)
@ieee begin
z = x.hi * x.hi
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo))
end
end
@inline dsqu(x) = dsqu(x, fma_fast())

Expand All @@ -253,8 +295,8 @@ end
end
@inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat}
c = _sqrt(x.hi)
u = dsqu(c)
Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c))
u = dsqu(c, False())
@ieee Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c))
end
@inline dsqrt(x) = dsqrt(x, fma_fast())

Expand All @@ -265,22 +307,26 @@ end
Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy))
end
@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False)
invy = inv(y.hi)
c = x.hi * invy
u = dmul(c, y.hi)
Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy)
@ieee begin
invy = one(y.hi) / y.hi
c = x.hi * invy
u = dmul(c, y.hi, False())
Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy)
end
end
@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True)
ry = inv(y)
r = (x * ry)
Double(r, (vfnmadd(r, y, x) * ry))
end
@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False)
ry = inv(y)
@ieee begin
ry = one(y) / y
r = x * ry
hx, lx = splitprec(r)
hy, ly = splitprec(y)
Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry)
end
end
@inline ddiv(x, y) = ddiv(x, y, fma_fast())
# 1/x
Expand All @@ -289,19 +335,22 @@ end
Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi))
end
@inline function drec(x::vIEEEFloat, ::False)
c = inv(x)
u = dmul(c, x)
@ieee begin
c = one(x) / x
u = dmul(c, x, False())
Double(c, (one(eltype(u.hi)) - u.hi - u.lo) * c)
end
end

@inline function drec(x::Double{<:vIEEEFloat}, ::True)
zhi = inv(x.hi)
Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi))
end
@inline function drec(x::Double{<:vIEEEFloat}, ::False)
@ieee begin
c = inv(x.hi)
u = dmul(c, x.hi)
u = dmul(c, x.hi, False())
Double(c, (one(eltype(u.hi)) - u.hi - u.lo - c * x.lo) * c)
end
end
@inline drec(x) = drec(x, fma_fast())

5 changes: 0 additions & 5 deletions test/accuracy.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@

MRANGE(::Type{Float64}) = 10000000
MRANGE(::Type{Float32}) = 10000
IntF(::Type{Float64}) = Int64
IntF(::Type{Float32}) = Int32


@testset "Accuracy (max error in ulp) for $T" for T in (Float32, Float64)
println("Accuracy tests for $T")
Expand Down
4 changes: 4 additions & 0 deletions test/testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ end

strip_module_name(f::Function) = last(split(string(f), '.')) # strip module name from function f

MRANGE(::Type{Float64}) = 10000000
MRANGE(::Type{Float32}) = 10000
IntF(::Type{Float64}) = Int64
IntF(::Type{Float32}) = Int32

function tovector(u::VectorizationBase.VecUnroll{_N,W,T}) where {_N,W,T}
N = _N + 1; i = 0
Expand Down

0 comments on commit e4afa71

Please sign in to comment.