From 7155a1dc7aa236221420f334bbf7a3136bc52d22 Mon Sep 17 00:00:00 2001 From: Chris Elrod Date: Sun, 10 Jan 2021 07:02:18 -0500 Subject: [PATCH] Update for VectorizationBase 0.15 --- Project.toml | 2 +- src/SLEEFPirates.jl | 4 +- src/double.jl | 125 +++++++++++++++++++++----------------------- src/exp.jl | 22 ++++---- src/priv.jl | 4 +- src/trig.jl | 10 ++-- 6 files changed, 82 insertions(+), 85 deletions(-) diff --git a/Project.toml b/Project.toml index 8135ab9..2a4a5df 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" [compat] IfElse = "0.1" -VectorizationBase = "0.13,0.14" +VectorizationBase = "0.15" julia = "1.5" [extras] diff --git a/src/SLEEFPirates.jl b/src/SLEEFPirates.jl index 7478d17..7a4cc15 100644 --- a/src/SLEEFPirates.jl +++ b/src/SLEEFPirates.jl @@ -5,8 +5,8 @@ using Base.Math: uinttype, exponent_bias, exponent_mask, significand_bits, IEEEF using Libdl, VectorizationBase -using VectorizationBase: vzero, AbstractSIMD, _Vec, FMA_FAST, data, vsub, vmul, VecUnroll, NativeTypes, FloatingTypes, - vadd, vmul, vsub +using VectorizationBase: vzero, AbstractSIMD, _Vec, FMA_FAST, data, VecUnroll, NativeTypes, FloatingTypes, + vfmadd, vfnmadd, vfmsub, vfnmsub import IfElse: ifelse diff --git a/src/double.jl b/src/double.jl index 136ab23..1c09441 100644 --- a/src/double.jl +++ b/src/double.jl @@ -75,17 +75,17 @@ end @inline function splitprec(x::vIEEEFloat) hx = trunclo(x) - hx, vsub(x, hx) + hx, x - hx end @inline function dnormalize(x::Double{T}) where {T} - r = vadd(x.hi, x.lo) - Double(r, vadd(vsub(x.hi, r), x.lo)) + r = x.hi + x.lo + Double(r, (x.hi - r) + x.lo) end @inline flipsign(x::Double{<:vIEEEFloat}, y::vIEEEFloat) = Double(flipsign(x.hi, y), flipsign(x.lo, y)) -@inline scale(x::Double{<:vIEEEFloat}, s::vIEEEFloat) = Double(vmul(s, x.hi), vmul(s, x.lo)) +@inline scale(x::Double{<:vIEEEFloat}, s::vIEEEFloat) = Double(s * x.hi, s * x.lo) @inline (-)(x::Double{T}) where {T<:vIEEEFloat} = Double(-x.hi, -x.lo) @@ -104,91 +104,91 @@ end # quick-two-sum x+y @inline function dadd(x::vIEEEFloat, y::vIEEEFloat) #WARNING |x| >= |y| - s = vadd(x, y) - Double(s, vadd(vsub(x, s), y)) + s = x + y + Double(s, ((x - s) + y)) end @inline function dadd(x::vIEEEFloat, y::Double{<:vIEEEFloat}) #WARNING |x| >= |y| - s = vadd(x, y.hi) - Double(s, vadd(vadd(vsub(x, s), y.hi), y.lo)) + s = x + y.hi + Double(s, (((x - s) + y.hi) + y.lo)) end @inline function dadd(x::Double{<:vIEEEFloat}, y::vIEEEFloat) #WARNING |x| >= |y| - s = vadd(x.hi, y) - Double(s, vadd(vadd(vsub(x.hi, s), y), x.lo)) + s = x.hi + y + Double(s, (((x.hi - s) + y) + x.lo)) end @inline function dadd(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) #WARNING |x| >= |y| - s = vadd(x.hi, y.hi) - Double(s, vadd(vadd(vadd(vsub(x.hi, s), y.hi), y.lo), x.lo)) + s = x.hi + y.hi + Double(s, ((((x.hi - s) + y.hi) + y.lo) + x.lo)) end @inline function dsub(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) #WARNING |x| >= |y| - s = vsub(x.hi, y.hi) - Double(s, vadd(vsub(vsub(vsub(x.hi, s), y.hi), y.lo), x.lo)) + s = x.hi - y.hi + Double(s, ((((x.hi - s) - y.hi) - y.lo) + x.lo)) end @inline function dsub(x::Double{<:vIEEEFloat}, y::vIEEEFloat) #WARNING |x| >= |y| - s = vsub(x.hi, y) - Double(s, vadd(vsub(vsub(x.hi, s), y), x.lo)) + s = x.hi - y + Double(s, (((x.hi - s) - y) + x.lo)) end @inline function dsub(x::vIEEEFloat, y::Double{<:vIEEEFloat}) #WARNING |x| >= |y| - s = vsub(x, y.hi) - Double(s, vsub(vsub(vsub(x, s), y.hi, y.lo))) + s = x - y.hi + Double(s, (((x - s) - y.hi - y.lo))) end @inline function dsub(x::vIEEEFloat, y::vIEEEFloat) #WARNING |x| >= |y| - s = vsub(x, y) - Double(s, vsub(vsub(x, s), y)) + s = x - y + Double(s, ((x - s) - y)) end # two-sum x+y NO BRANCH @inline function dadd2(x::vIEEEFloat, y::vIEEEFloat) - s = vadd(x, y) - v = vsub(s, x) - Double(s, vadd(vsub(x, vsub(s, v)), vsub(y, v))) + s = x + y + v = s - x + Double(s, ((x - (s - v)) + (y - v))) end @inline function dadd2(x::vIEEEFloat, y::Double{<:vIEEEFloat}) - s = vadd(x, y.hi) - v = vsub(s, x) - Double(s, vsub(x, vsub(s, v)) + vsub(y.hi, v) + y.lo) + s = x + y.hi + v = s - x + Double(s, (x - (s - v)) + (y.hi - v) + y.lo) end @inline dadd2(x::Double{<:vIEEEFloat}, y::vIEEEFloat) = dadd2(y, x) @inline function dadd2(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) - s = vadd(x.hi, y.hi) - v = vsub(s, x.hi) - smv = vsub(s, v) - yhimv = vsub(y.hi, v) - Double(s, vadd(vadd(vadd(vsub(x.hi, smv), yhimv), x.lo), y.lo)) + s = (x.hi + y.hi) + v = (s - x.hi) + smv = (s - v) + yhimv = (y.hi - v) + Double(s, ((((x.hi - smv) + yhimv) + x.lo) + y.lo)) end @inline function dsub2(x::vIEEEFloat, y::vIEEEFloat) - s = vsub(x, y) - v = vsub(s, x) - Double(s, vadd(vsub(x, vsub(s, v)), vsub(-y, v))) + s = x - y + v = s - x + Double(s, ((x - (s - v)) - (y + v))) end @inline function dsub2(x::vIEEEFloat, y::Double{<:vIEEEFloat}) - s = vsub(x, y.hi) - v = vsub(s, x) - Double(s, vsub(vadd(vsub(x, vsub(s, v)), vsub(-y.hi, v)), y.lo)) + s = (x - y.hi) + v = (s - x) + Double(s, (((x - (s - v)) - (y.hi + v)) - y.lo)) end @inline function dsub2(x::Double{<:vIEEEFloat}, y::vIEEEFloat) - s = vsub(x.hi, y) - v = vsub(s, x.hi) - Double(s, vadd(vadd(vsub(x.hi, vsub(s, v)), vsub(-y, v)), x.lo)) + s = x.hi - y + v = s - x.hi + Double(s, (((x.hi - (s - v)) - (y + v)) + x.lo)) end @inline function dsub2(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) - s = vsub(x.hi, y.hi) - v = vsub(s, x.hi) - Double(s, vsub(vadd(vadd(vsub(x.hi, vsub(s, v)), vsub(-y.hi, v)), x.lo), y.lo)) + s = x.hi - y.hi + v = s - x.hi + Double(s, ((((x.hi - (s - v)) - (y.hi + v)) + x.lo) - y.lo)) end @inline function ifelse(b::Mask{N}, x::Double{T1}, y::Double{T2}) where {N,T<:Union{Float32,Float64},T1<:Union{T,Vec{N,T}},T2<:Union{T,Vec{N,T}}} @@ -200,64 +200,61 @@ if FMA_FAST # two-prod-fma @inline function dmul(x::vIEEEFloat, y::vIEEEFloat) - z = vmul(x, y) - Double(z, fma(x, y, -z)) + z = (x * y) + Double(z, vfmsub(x, y, z)) end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat) - z = vmul(x.hi, y) - # Double(z, fma(x.hi, y, -z) + x.lo * y) - Double(z, vadd(fma(x.hi, y, -z), vmul(x.lo, y))) + z = (x.hi * y) + Double(z, vfmsub(x.hi, y, z) + x.lo * y) end @inline dmul(x::vIEEEFloat, y::Double{<:vIEEEFloat}) = dmul(y, x) @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) - z = vmul(x.hi, y.hi) - # Double(z, fma(x.hi, y.hi, -z) + x.hi * y.lo + x.lo * y.hi) - Double(z, vadd(vadd(fma(x.hi, y.hi, -z), vmul(x.hi, y.lo)), vmul(x.lo, y.hi))) + z = x.hi * y.hi + Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi) end # x^2 @inline function dsqu(x::T) where {T<:vIEEEFloat} - z = vmul(x, x) - Double(z, fma(x, x, -z)) + z = x * x + Double(z, vfmsub(x, x, z)) end @inline function dsqu(x::Double{T}) where {T<:vIEEEFloat} - z = vmul(x.hi, x.hi) - Double(z, fma(x.hi, x.hi, -z) + vmul(x.hi, vadd(x.lo, x.lo))) + z = x.hi * x.hi + Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo))) end # sqrt(x) @inline function dsqrt(x::Double{T}) where {T<:vIEEEFloat} zhi = _sqrt(x.hi) - Double(zhi, vadd(x.lo, fma(-zhi, zhi, x.hi)) / vadd(zhi, zhi)) + Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi)) end # x/y @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}) invy = inv(y.hi) - zhi = vmul(x.hi, invy) - Double(zhi, vmul((fma(-zhi, y.hi, x.hi) + fma(-zhi, y.lo, x.lo)), invy)) + zhi = (x.hi * invy) + Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat) ry = inv(y) - r = vmul(x, ry) - Double(r, vmul(vfnmadd(r, y, x), ry)) - # Double(r, vmul(fma(-r, y, x), ry)) + r = (x * ry) + Double(r, (vfnmadd(r, y, x) * ry)) end # 1/x @inline function drec(x::vIEEEFloat) zhi = inv(x) - Double(zhi, vmul(fma(-zhi, x, one(eltype(x))), zhi)) + Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) end @inline function drec(x::Double{<:vIEEEFloat}) zhi = inv(x.hi) - Double(zhi, vmul(vsub(fma(-zhi, x.hi, one(eltype(x))), vmul(zhi, x.lo)), zhi)) + Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) end else @@ -266,7 +263,7 @@ else @inline function dmul(x::vIEEEFloat, y::vIEEEFloat) hx, lx = splitprec(x) hy, ly = splitprec(y) - z = vmul(x, y) + z = (x * y) Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly) end diff --git a/src/exp.jl b/src/exp.jl index f1c3b16..5a5ea08 100644 --- a/src/exp.jl +++ b/src/exp.jl @@ -1,3 +1,4 @@ + # magic rounding constant: 1.5*2^52 Adding, then subtracting it from a float rounds it to an Int. MAGIC_ROUND_CONST(::Type{Float64}) = 6.755399441055744e15 MAGIC_ROUND_CONST(::Type{Float32}) = 1.2582912f7 @@ -75,7 +76,6 @@ const J_TABLE= Float64[2.0^(big(j-1)/256) for j in 1:256]; for (func, base) in (:exp2=>Val(2), :exp=>Val(ℯ), :exp10=>Val(10)) Ndef1 = :(reinterpret(UInt64, N_float)) Ndef1 = VectorizationBase.AVX512DQ ? Ndef1 : :($Ndef1 % UInt32) - # twopkpreshift = :(VectorizationBase.vadd(k, 0x0000000000000035)) twopkpreshift = VectorizationBase.AVX512DQ ? :k : :(k % UInt64) FF = VectorizationBase.AVX512DQ ? 0x00000000000000ff : 0x000000ff @eval begin @@ -85,14 +85,14 @@ for (func, base) in (:exp2=>Val(2), :exp=>Val(ℯ), :exp10=>Val(10)) N = $Ndef1 N_float = N_float - MAGIC_ROUND_CONST(Float64) # @show N_float - r = muladd(N_float, LogBo256U($base, Float64), x) - r = muladd(N_float, LogBo256L($base, Float64), r) + r = vfmadd(N_float, LogBo256U($base, Float64), x) + r = vfmadd(N_float, LogBo256L($base, Float64), r) # @show r (N & $FF) js = vload(VectorizationBase.zero_offsets(stridedpointer(J_TABLE)), (N & $FF,)) # @show N js k = N >>> 0x0000000000000008 - small_part = reinterpret(UInt64, muladd(js, expm1b_kernel($base, r), js)) + small_part = reinterpret(UInt64, vfmadd(js, expm1b_kernel($base, r), js)) twopk = $twopkpreshift << 0x0000000000000034 # @show k small_part twopk twopk + small_part r res = reinterpret(Float64, twopk + small_part) @@ -103,12 +103,12 @@ for (func, base) in (:exp2=>Val(2), :exp=>Val(ℯ), :exp10=>Val(10)) end @inline function ($func)(x::FloatType32) - N_float = muladd(x, LogBINV($base, Float32), MAGIC_ROUND_CONST(Float32)) + N_float = vfmadd(x, LogBINV($base, Float32), MAGIC_ROUND_CONST(Float32)) N = reinterpret(UInt32, N_float) - N_float = vsub(N_float, MAGIC_ROUND_CONST(Float32)) + N_float = (N_float - MAGIC_ROUND_CONST(Float32)) - r = muladd(N_float, LogBU($base, Float32), x) - r = muladd(N_float, LogBL($base, Float32), r) + r = vfmadd(N_float, LogBU($base, Float32), x) + r = vfmadd(N_float, LogBL($base, Float32), r) small_part = reinterpret(UInt32, expb_kernel($base, r)) twopk = N << 0x00000017 @@ -181,8 +181,8 @@ inttype(::Type{Float32}) = Int32 T = eltype(x) N_float = round(x*Ln2INV(T)) N = unsafe_trunc(inttype(T), N_float) - r = muladd(N_float, Ln2U(T), x) - r = muladd(N_float, Ln2L(T), r) + r = vfmadd(N_float, Ln2U(T), x) + r = vfmadd(N_float, Ln2L(T), r) hi, lo = expm1_kernel(r) small_part = r*hi small_round = fma(r, lo, fma(r, hi, -small_part)) @@ -193,7 +193,7 @@ inttype(::Type{Float32}) = Int32 # x > MAX_EXPM1(T) && return T(Inf) # x < MIN_EXPM1(T) && return T(-1) # if N == exponent_max(T) - # return muladd(small_part, T(2), T(2)) * T(2)^(exponent_max(T)-1) + # return vfmadd(small_part, T(2), T(2)) * T(2)^(exponent_max(T)-1) # end # end res = fma(twopk, small_round, fma(twopk, small_part, twopk-one(T))) diff --git a/src/priv.jl b/src/priv.jl index d908f27..239ed8b 100644 --- a/src/priv.jl +++ b/src/priv.jl @@ -49,8 +49,8 @@ end @inline function ldexp2k(x::FloatType, e::I) where {I <: IntegerType} eshift = e >> one(I) - vmul( - vmul(x, pow2i(eltype(x), eshift)), + ( + (x * pow2i(eltype(x), eshift)) * pow2i(eltype(x), e - eshift) ) end diff --git a/src/trig.jl b/src/trig.jl index 1c4b73d..c35e96a 100644 --- a/src/trig.jl +++ b/src/trig.jl @@ -31,7 +31,7 @@ end c3 = -0.0001981069071916863322258f0 c2 = 0.00833307858556509017944336f0 c1 = -0.166666597127914428710938f0 - return dadd(c1, vmul(x.hi, (@horner x.hi c2 c3 c4))) + return dadd(c1, (x.hi * (@horner x.hi c2 c3 c4))) end @inline function sin(d::V) where V <: FloatType64 @@ -69,10 +69,10 @@ end q = round(d * T(M_1_PI)) - s = dadd2(d, vmul(q, -PI_A(T))) - s = dadd2(s, vmul(q, -PI_B(T))) - s = dadd2(s, vmul(q, -PI_C(T))) - s = dadd2(s, vmul(q, -PI_D(T))) + s = dadd2(d, (q * -PI_A(T))) + s = dadd2(s, (q * -PI_B(T))) + s = dadd2(s, (q * -PI_C(T))) + s = dadd2(s, (q * -PI_D(T))) t = s s = dsqu(s)