diff --git a/src/device/opencl/math.jl b/src/device/opencl/math.jl index a7e9555b..02e82a86 100644 --- a/src/device/opencl/math.jl +++ b/src/device/opencl/math.jl @@ -182,3 +182,25 @@ end # TODO: half and native + +function _mulhi(a::Int64, b::Int64) + shift = sizeof(a) * 4 + mask = typemax(UInt32) + a1, a2 = (a >> shift), a & mask + b1, b2 = (b >> shift), b & mask + a1b1, a1b2, a2b1 = a1*b1, a1*b2, a2*b1 + t1 = a1b2 + _mulhi(a2 % UInt32, b2 % UInt32) + t2 = a2b1 + (t1 & mask) + a1b1 + (t1 >> shift) + (t2 >> shift) +end +@static if isdefined(Base.MultiplicativeInverses, :_mul_high) + _mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = Base.MultiplicativeInverses._mul_high(a, b) + @device_override Base.MultiplicativeInverses._mul_high(a::Int64, b::Int64) = _mulhi(a, b) +else + _mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = ((widen(a)*b) >>> (sizeof(a)*8)) % T + @device_override function Base.div(a::Int64, b::Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}) + x = _mulhi(a, b.multiplier) + x += (a*b.addmul) % Int64 + ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % Int64) + end +end