diff --git a/src/double.jl b/src/double.jl index d18755f..eb20f4a 100644 --- a/src/double.jl +++ b/src/double.jl @@ -62,12 +62,44 @@ end @inline trunclo(x::Float32) = reinterpret(Float32, reinterpret(UInt32, x) & 0xffff_f000) # clear lowest 12 bits (leave upper 12 bits) # @inline trunclo(x::VecProduct) = trunclo(Vec(data(x))) -@inline function trunclo(x::Vec{N,Float64}) where {N} +@inline function trunclo(x::AbstractSIMD{N,Float64}) where {N} reinterpret(Vec{N,Float64}, reinterpret(Vec{N,UInt64}, x) & vbroadcast(Val{N}(), 0xffff_ffff_f800_0000)) # clear lower 27 bits (leave upper 26 bits) end -@inline function trunclo(x::Vec{N,Float32}) where {N} +@inline function trunclo(x::AbstractSIMD{N,Float32}) where {N} reinterpret(Vec{N,Float32}, reinterpret(Vec{N,UInt32}, x) & vbroadcast(Val{N}(), 0xffff_f000)) # clear lowest 12 bits (leave upper 12 bits) end +for (op,f,ff) ∈ [("fadd",:add_ieee,:(+)),("fsub",:sub_ieee,:(-)),("fmul",:mul_ieee,:(*)),("fdiv",:fdiv_ieee,:(/)),("frem",:rem_ieee,:(%))] + @eval begin + @generated $f(v1::Vec{W,T}, v2::Vec{W,T}) where {W,T<:Union{Float32,Float64}} = VectorizationBase.binary_op($op, W, T) + @inline $f(s1::T, s2::T) where {T<:Union{Float32,Float64}} = $ff(s1,s2) + end +end +@inline add_ieee(a, b, c) = add_ieee(add_ieee(a, b), c) +function sub_ieee!(ex) + ex isa Expr || return + if ex.head === :call + _f = ex.args[1] + if _f isa Symbol + f::Symbol = _f + if f === :(+) + ex.args[1] = :add_ieee + elseif f === :(-) + ex.args[1] = :sub_ieee + elseif f === :(*) + ex.args[1] = :mul_ieee + elseif f === :(/) + ex.args[1] = :fdiv_ieee + elseif f === :(%) + ex.args[1] = :rem_ieee + end + end + end + foreach(sub_ieee!, ex.args) +end +macro ieee(ex) + sub_ieee!(ex) +end + @inline function splitprec(x::vIEEEFloat) hx = trunclo(x) @@ -200,8 +232,10 @@ end @inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False) hx, lx = splitprec(x) hy, ly = splitprec(y) - z = (x * y) - Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + @ieee begin + z = x * y + Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly)) + end end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True) z = (x.hi * y) @@ -210,8 +244,10 @@ end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False) hx, lx = splitprec(x.hi) hy, ly = splitprec(y) - z = x.hi * y - Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y) + @ieee begin + z = x.hi * y + Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y) + end end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) z = x.hi * y.hi @@ -220,8 +256,10 @@ end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) hx, lx = splitprec(x.hi) hy, ly = splitprec(y.hi) + @ieee begin z = x.hi * y.hi Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi) + end end @inline dmul(x::vIEEEFloat, y::Double{<:vIEEEFloat}) = dmul(y, x) @inline dmul(x, y) = dmul(x, y, fma_fast()) @@ -232,8 +270,10 @@ end end @inline function dsqu(x::T, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x) + @ieee begin z = x * x Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx) + end end @inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat} z = x.hi * x.hi @@ -241,8 +281,10 @@ end end @inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x.hi) + @ieee begin z = x.hi * x.hi Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo)) + end end @inline dsqu(x) = dsqu(x, fma_fast()) @@ -253,8 +295,8 @@ end end @inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat} c = _sqrt(x.hi) - u = dsqu(c) - Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c)) + u = dsqu(c, False()) + @ieee Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c)) end @inline dsqrt(x) = dsqrt(x, fma_fast()) @@ -265,10 +307,12 @@ end Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) end @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) - invy = inv(y.hi) - c = x.hi * invy - u = dmul(c, y.hi) - Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy) + @ieee begin + invy = one(y.hi) / y.hi + c = x.hi * invy + u = dmul(c, y.hi, False()) + Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy) + end end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True) ry = inv(y) @@ -276,11 +320,13 @@ end Double(r, (vfnmadd(r, y, x) * ry)) end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False) - ry = inv(y) + @ieee begin + ry = one(y) / y r = x * ry hx, lx = splitprec(r) hy, ly = splitprec(y) Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry) + end end @inline ddiv(x, y) = ddiv(x, y, fma_fast()) # 1/x @@ -289,9 +335,11 @@ end Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) end @inline function drec(x::vIEEEFloat, ::False) - c = inv(x) - u = dmul(c, x) + @ieee begin + c = one(x) / x + u = dmul(c, x, False()) Double(c, (one(eltype(u.hi)) - u.hi - u.lo) * c) + end end @inline function drec(x::Double{<:vIEEEFloat}, ::True) @@ -299,9 +347,10 @@ end Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) end @inline function drec(x::Double{<:vIEEEFloat}, ::False) + @ieee begin c = inv(x.hi) - u = dmul(c, x.hi) + u = dmul(c, x.hi, False()) Double(c, (one(eltype(u.hi)) - u.hi - u.lo - c * x.lo) * c) + end end @inline drec(x) = drec(x, fma_fast()) - diff --git a/test/accuracy.jl b/test/accuracy.jl index f325b18..1e31c5a 100644 --- a/test/accuracy.jl +++ b/test/accuracy.jl @@ -1,9 +1,4 @@ -MRANGE(::Type{Float64}) = 10000000 -MRANGE(::Type{Float32}) = 10000 -IntF(::Type{Float64}) = Int64 -IntF(::Type{Float32}) = Int32 - @testset "Accuracy (max error in ulp) for $T" for T in (Float32, Float64) println("Accuracy tests for $T") diff --git a/test/testsetup.jl b/test/testsetup.jl index b3216b2..6034188 100644 --- a/test/testsetup.jl +++ b/test/testsetup.jl @@ -89,6 +89,10 @@ end strip_module_name(f::Function) = last(split(string(f), '.')) # strip module name from function f +MRANGE(::Type{Float64}) = 10000000 +MRANGE(::Type{Float32}) = 10000 +IntF(::Type{Float64}) = Int64 +IntF(::Type{Float32}) = Int32 function tovector(u::VectorizationBase.VecUnroll{_N,W,T}) where {_N,W,T} N = _N + 1; i = 0