Skip to content

Commit

Permalink
Updates for VectorizationBase 0.18
Browse files Browse the repository at this point in the history
  • Loading branch information
chriselrod committed Jan 31, 2021
1 parent a974ff5 commit f1b3e57
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 114 deletions.
2 changes: 1 addition & 1 deletion src/SLEEFPirates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Base.Math: uinttype, exponent_bias, exponent_mask, significand_bits, IEEEF
using Libdl, VectorizationBase

using VectorizationBase: vzero, AbstractSIMD, _Vec, fma_fast, data, VecUnroll, NativeTypes, FloatingTypes,
vfmadd, vfnmadd, vfmsub, vfnmsub
vfmadd, vfnmadd, vfmsub, vfnmsub, True, False

import IfElse: ifelse

Expand Down
202 changes: 96 additions & 106 deletions src/double.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,125 +193,115 @@ end
end

# two-prod-fma
@inline function dmul(x::vIEEEFloat, y::vIEEEFloat)
if fma_fast()
z = (x * y)
Double(z, vfmsub(x, y, z))
else
hx, lx = splitprec(x)
hy, ly = splitprec(y)
z = (x * y)
Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly)
end
end

@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat)
if fma_fast()
z = (x.hi * y)
Double(z, vfmsub(x.hi, y, z) + x.lo * y)
else
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y)
z = x.hi * y
Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y)
end
@inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::True)
z = (x * y)
Double(z, vfmsub(x, y, z))
end
@inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False)
hx, lx = splitprec(x)
hy, ly = splitprec(y)
z = (x * y)
Double(z, ((hx * hy - z) + lx * hy + hx * ly) + lx * ly)
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True)
z = (x.hi * y)
Double(z, vfmsub(x.hi, y, z) + x.lo * y)
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False)
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y)
z = x.hi * y
Double(z, (hx * hy - z) + lx * hy + hx * ly + lx * ly + x.lo * y)
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True)
z = x.hi * y.hi
Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi)
end
@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False)
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y.hi)
z = x.hi * y.hi
Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi)
end
@inline dmul(x::vIEEEFloat, y::Double{<:vIEEEFloat}) = dmul(y, x)

@inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat})
if fma_fast()
z = x.hi * y.hi
Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi)
else
hx, lx = splitprec(x.hi)
hy, ly = splitprec(y.hi)
z = x.hi * y.hi
Double(z, (((hx * hy - z) + lx * hy + hx * ly) + lx * ly) + x.hi * y.lo + x.lo * y.hi)
end
end
@inline dmul(x, y) = dmul(x, y, fma_fast())
# x^2
@inline function dsqu(x::T) where {T<:vIEEEFloat}
if fma_fast()
z = x * x
Double(z, vfmsub(x, x, z))
else
hx, lx = splitprec(x)
z = x * x
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx)
end
@inline function dsqu(x::T, ::True) where {T<:vIEEEFloat}
z = x * x
Double(z, vfmsub(x, x, z))
end

@inline function dsqu(x::Double{T}) where {T<:vIEEEFloat}
if fma_fast()
z = x.hi * x.hi
Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo)))
else
hx, lx = splitprec(x.hi)
z = x.hi * x.hi
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo))
end
@inline function dsqu(x::T, ::False) where {T<:vIEEEFloat}
hx, lx = splitprec(x)
z = x * x
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx)
end

# sqrt(x)
@inline function dsqrt(x::Double{T}) where {T<:vIEEEFloat}
if fma_fast()
zhi = _sqrt(x.hi)
Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi))
else
c = _sqrt(x.hi)
u = dsqu(c)
Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c))
end
@inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat}
z = x.hi * x.hi
Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo)))
end

# x/y
@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat})
if fma_fast()
invy = inv(y.hi)
zhi = (x.hi * invy)
Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy))
else
invy = 1 / y.hi
c = x.hi * invy
u = dmul(c, y.hi)
Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy)
end
@inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat}
hx, lx = splitprec(x.hi)
z = x.hi * x.hi
Double(z, (hx * hx - z) + lx * (hx + hx) + lx * lx + x.hi * (x.lo + x.lo))
end
@inline dsqu(x) = dsqu(x, fma_fast())

@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat)
if fma_fast()
ry = inv(y)
r = (x * ry)
Double(r, (vfnmadd(r, y, x) * ry))
else
ry = 1 / y
r = x * ry
hx, lx = splitprec(r)
hy, ly = splitprec(y)
Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry)
end
# sqrt(x)
@inline function dsqrt(x::Double{T}, ::True) where {T<:vIEEEFloat}
zhi = _sqrt(x.hi)
Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi))
end
@inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat}
c = _sqrt(x.hi)
u = dsqu(c)
Double(c, (x.hi - u.hi - u.lo + x.lo) / (c + c))
end
@inline dsqrt(x) = dsqrt(x, fma_fast())

# x/y
@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True)
invy = inv(y.hi)
zhi = (x.hi * invy)
Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy))
end
@inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False)
invy = 1 / y.hi
c = x.hi * invy
u = dmul(c, y.hi)
Double(c, ((((x.hi - u.hi) - u.lo) + x.lo) - c * y.lo) * invy)
end
@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True)
ry = inv(y)
r = (x * ry)
Double(r, (vfnmadd(r, y, x) * ry))
end
@inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False)
ry = 1 / y
r = x * ry
hx, lx = splitprec(r)
hy, ly = splitprec(y)
Double(r, (((-hx * hy + r * y) - lx * hy - hx * ly) - lx * ly) * ry)
end
@inline ddiv(x, y) = ddiv(x, y, fma_fast())
# 1/x
@inline function drec(x::vIEEEFloat)
if fma_fast()
zhi = inv(x)
Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi))
else
c = 1 / x
u = dmul(c, x)
Double(c, (one(T) - u.hi - u.lo) * c)
end
@inline function drec(x::vIEEEFloat, ::True)
zhi = inv(x)
Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi))
end
@inline function drec(x::vIEEEFloat, ::False)
c = 1 / x
u = dmul(c, x)
Double(c, (one(T) - u.hi - u.lo) * c)
end

@inline function drec(x::Double{<:vIEEEFloat})
if fma_fast()
@inline function drec(x::Double{<:vIEEEFloat}, ::True)
zhi = inv(x.hi)
Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi))

else
c = 1 / x.hi
u = dmul(c, x.hi)
Double(c, (one(T) - u.hi - u.lo - c * x.lo) * c)
end
end
@inline function drec(x::Double{<:vIEEEFloat}, ::False)
c = 1 / x.hi
u = dmul(c, x.hi)
Double(c, (one(T) - u.hi - u.lo - c * x.lo) * c)
end
@inline drec(x) = drec(x, fma_fast())

14 changes: 9 additions & 5 deletions src/exp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ const J_TABLE= Float64[2.0^(big(j-1)/256) for j in 1:256];
@inline target_trunc(v) = target_trunc(v, VectorizationBase.has_feature(Val(:x86_64_avx512dq)))

for (func, base) in (:exp2=>Val(2), :exp=>Val(ℯ), :exp10=>Val(10))
Ndef1 = :(targetspecific_truncate(reinterpret(UInt64, N_float)))
Ndef1 = :(target_trunc(reinterpret(UInt64, N_float)))
func_fast = Symbol(func, :_fast)
@eval begin
@inline function $func_fast(x::FloatType64)
Expand Down Expand Up @@ -174,10 +174,14 @@ end
return exthorner(x, (1.0f0, 0.5f0, hi_order))
end

@inline widest_supported_integer(::VectorizationBase.True) = Int64
@inline widest_supported_integer(::VectorizationBase.False) = Int32
@inline inttype(::Type{Float64}) = widest_supported_integer(VectorizationBase.has_feature(Val(:x86_64_avx512dq)))
@inline inttype(::Type{Float32}) = Int32
if (Sys.ARCH === :x86_64) || (Sys.ARCH === :i686)
@inline widest_supported_integer(::VectorizationBase.True) = Int64
@inline widest_supported_integer(::VectorizationBase.False) = Int32
@inline inttype(::Type{Float64}) = widest_supported_integer(VectorizationBase.has_feature(Val(:x86_64_avx512dq)))
@inline inttype(::Type{Float32}) = Int32
else
@inline inttype(_) = Int
end

@inline function expm1_fast(x::FloatType)
T = eltype(x)
Expand Down
4 changes: 2 additions & 2 deletions test/testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ function test_acc(T, fun_table, xx, tol; debug = false, tol_debug = 5)
# Vector test is mostly to make sure that they do not error
# Results should either be the same as scalar
# Or they're from another library (e.g., GLIBC), and may differ slighlty
test_vector(xfun, fun, VectorizationBase.pick_vector_width_val(T), first(xx), last(xx), tol)
test_vector(xfun, fun, Val(2), first(xx), last(xx), tol)
W = VectorizationBase.pick_vector_width(T)
test_vector(xfun, fun, W, first(xx), last(xx), tol)
test_vector(xfun, fun, Val(2), first(xx), last(xx), tol)
if W 4
test_vector(xfun, fun, Val(4), first(xx), last(xx), tol)
end
Expand Down

0 comments on commit f1b3e57

Please sign in to comment.