diff --git a/src/SLEEFPirates.jl b/src/SLEEFPirates.jl index d8e8a47..091174e 100644 --- a/src/SLEEFPirates.jl +++ b/src/SLEEFPirates.jl @@ -6,7 +6,7 @@ using Base.Math: uinttype, exponent_bias, exponent_mask, significand_bits, IEEEF using Libdl, VectorizationBase using VectorizationBase: vzero, AbstractSIMD, _Vec, fma_fast, data, VecUnroll, NativeTypes, FloatingTypes, - vfmadd, vfnmadd, vfmsub, vfnmsub + vfmadd, vfnmadd, vfmsub, vfnmsub, True, False import IfElse: ifelse diff --git a/src/double.jl b/src/double.jl index be0818e..6417e73 100644 --- a/src/double.jl +++ b/src/double.jl @@ -193,125 +193,115 @@ end end # two-prod-fma -@inline function dmul(x::vIEEEFloat, y::vIEEEFloat) - if fma_fast() - z = (x * y) - Double(z, vfmsub(x, y, z)) - else - hx, lx = splitprec(x) - hy, ly = splitprec(y) - z = (x * y) - Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly) - end -end - -@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat) - if fma_fast() - z = (x.hi * y) - Double(z, vfmsub(x.hi, y, z) + x.lo * y) - else - hx, lx = splitprec(x.hi) - hy, ly = splitprec(y) - z = x.hi * y - Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y) - end +@inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::True) + z = (x * y) + Double(z, vfmsub(x, y, z)) +end +@inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False) + hx, lx = splitprec(x) + hy, ly = splitprec(y) + z = (x * y) + Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly) +end +@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True) + z = (x.hi * y) + Double(z, vfmsub(x.hi, y, z) + x.lo * y) +end +@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False) + hx, lx = splitprec(x.hi) + hy, ly = splitprec(y) + z = x.hi * y + Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y) +end +@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) + z = x.hi * y.hi + Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi) +end +@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) + hx, lx = splitprec(x.hi) + hy, ly = splitprec(y.hi) + z = x.hi * y.hi + Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi) end @inline dmul(x::vIEEEFloat, y::Double{<:vIEEEFloat}) = dmul(y, x) - -@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) - if fma_fast() - z = x.hi * y.hi - Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi) - else - hx, lx = splitprec(x.hi) - hy, ly = splitprec(y.hi) - z = x.hi * y.hi - Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi) - end -end +@inline dmul(x, y) = dmul(x, y, fma_fast()) # x^2 -@inline function dsqu(x::T) where {T<:vIEEEFloat} - if fma_fast() - z = x * x - Double(z, vfmsub(x, x, z)) - else - hx, lx = splitprec(x) - z = x * x - Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx) - end +@inline function dsqu(x::T, ::True) where {T<:vIEEEFloat} + z = x * x + Double(z, vfmsub(x, x, z)) end - -@inline function dsqu(x::Double{T}) where {T<:vIEEEFloat} - if fma_fast() - z = x.hi * x.hi - Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo))) - else - hx, lx = splitprec(x.hi) - z = x.hi * x.hi - Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo)) - end +@inline function dsqu(x::T, ::False) where {T<:vIEEEFloat} + hx, lx = splitprec(x) + z = x * x + Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx) end - - # sqrt(x) -@inline function dsqrt(x::Double{T}) where {T<:vIEEEFloat} - if fma_fast() - zhi = _sqrt(x.hi) - Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi)) - else - c = _sqrt(x.hi) - u = dsqu(c) - Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c)) - end +@inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat} + z = x.hi * x.hi + Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo))) end - - # x/y -@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) - if fma_fast() - invy = inv(y.hi) - zhi = (x.hi * invy) - Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) - else - invy = 1 / y.hi - c = x.hi * invy - u = dmul(c, y.hi) - Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy) - end +@inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat} + hx, lx = splitprec(x.hi) + z = x.hi * x.hi + Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo)) end +@inline dsqu(x) = dsqu(x, fma_fast()) -@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat) - if fma_fast() - ry = inv(y) - r = (x * ry) - Double(r, (vfnmadd(r, y, x) * ry)) - else - ry = 1 / y - r = x * ry - hx, lx = splitprec(r) - hy, ly = splitprec(y) - Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry) - end + # sqrt(x) +@inline function dsqrt(x::Double{T}, ::True) where {T<:vIEEEFloat} + zhi = _sqrt(x.hi) + Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi)) +end +@inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat} + c = _sqrt(x.hi) + u = dsqu(c) + Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c)) end +@inline dsqrt(x) = dsqrt(x, fma_fast()) + # x/y +@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) + invy = inv(y.hi) + zhi = (x.hi * invy) + Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) +end +@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) + invy = 1 / y.hi + c = x.hi * invy + u = dmul(c, y.hi) + Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy) +end +@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True) + ry = inv(y) + r = (x * ry) + Double(r, (vfnmadd(r, y, x) * ry)) +end +@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False) + ry = 1 / y + r = x * ry + hx, lx = splitprec(r) + hy, ly = splitprec(y) + Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry) +end +@inline ddiv(x, y) = ddiv(x, y, fma_fast()) # 1/x -@inline function drec(x::vIEEEFloat) - if fma_fast() - zhi = inv(x) - Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) - else - c = 1 / x - u = dmul(c, x) - Double(c, (one(T) - u.hi - u.lo) * c) - end +@inline function drec(x::vIEEEFloat, ::True) + zhi = inv(x) + Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) +end +@inline function drec(x::vIEEEFloat, ::False) + c = 1 / x + u = dmul(c, x) + Double(c, (one(T) - u.hi - u.lo) * c) end -@inline function drec(x::Double{<:vIEEEFloat}) - if fma_fast() +@inline function drec(x::Double{<:vIEEEFloat}, ::True) zhi = inv(x.hi) Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) - - else - c = 1 / x.hi - u = dmul(c, x.hi) - Double(c, (one(T) - u.hi - u.lo - c * x.lo) * c) - end end +@inline function drec(x::Double{<:vIEEEFloat}, ::False) + c = 1 / x.hi + u = dmul(c, x.hi) + Double(c, (one(T) - u.hi - u.lo - c * x.lo) * c) +end +@inline drec(x) = drec(x, fma_fast()) + diff --git a/src/exp.jl b/src/exp.jl index 3470b4e..c2bc690 100644 --- a/src/exp.jl +++ b/src/exp.jl @@ -78,7 +78,7 @@ const J_TABLE= Float64[2.0^(big(j-1)/256) for j in 1:256]; @inline target_trunc(v) = target_trunc(v, VectorizationBase.has_feature(Val(:x86_64_avx512dq))) for (func, base) in (:exp2=>Val(2), :exp=>Val(ℯ), :exp10=>Val(10)) - Ndef1 = :(targetspecific_truncate(reinterpret(UInt64, N_float))) + Ndef1 = :(target_trunc(reinterpret(UInt64, N_float))) func_fast = Symbol(func, :_fast) @eval begin @inline function $func_fast(x::FloatType64) @@ -174,10 +174,14 @@ end return exthorner(x, (1.0f0, 0.5f0, hi_order)) end -@inline widest_supported_integer(::VectorizationBase.True) = Int64 -@inline widest_supported_integer(::VectorizationBase.False) = Int32 -@inline inttype(::Type{Float64}) = widest_supported_integer(VectorizationBase.has_feature(Val(:x86_64_avx512dq))) -@inline inttype(::Type{Float32}) = Int32 +if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686) + @inline widest_supported_integer(::VectorizationBase.True) = Int64 + @inline widest_supported_integer(::VectorizationBase.False) = Int32 + @inline inttype(::Type{Float64}) = widest_supported_integer(VectorizationBase.has_feature(Val(:x86_64_avx512dq))) + @inline inttype(::Type{Float32}) = Int32 +else + @inline inttype(_) = Int +end @inline function expm1_fast(x::FloatType) T = eltype(x) diff --git a/test/testsetup.jl b/test/testsetup.jl index 233c2f4..b3216b2 100644 --- a/test/testsetup.jl +++ b/test/testsetup.jl @@ -210,9 +210,9 @@ function test_acc(T, fun_table, xx, tol; debug = false, tol_debug = 5) # Vector test is mostly to make sure that they do not error # Results should either be the same as scalar # Or they're from another library (e.g., GLIBC), and may differ slighlty - test_vector(xfun, fun, VectorizationBase.pick_vector_width_val(T), first(xx), last(xx), tol) - test_vector(xfun, fun, Val(2), first(xx), last(xx), tol) W = VectorizationBase.pick_vector_width(T) + test_vector(xfun, fun, W, first(xx), last(xx), tol) + test_vector(xfun, fun, Val(2), first(xx), last(xx), tol) if W ≥ 4 test_vector(xfun, fun, Val(4), first(xx), last(xx), tol) end