Skip to content

Commit

Permalink
Avoid ReshapedArray using Int128 in kernel (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgymnich authored Jul 18, 2024
1 parent b06f0fe commit 76eacff
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/device/opencl/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,25 @@ end


# TODO: half and native

function _mulhi(a::Int64, b::Int64)
shift = sizeof(a) * 4
mask = typemax(UInt32)
a1, a2 = (a >> shift), a & mask
b1, b2 = (b >> shift), b & mask
a1b1, a1b2, a2b1 = a1*b1, a1*b2, a2*b1
t1 = a1b2 + _mulhi(a2 % UInt32, b2 % UInt32)
t2 = a2b1 + (t1 & mask)
a1b1 + (t1 >> shift) + (t2 >> shift)
end
@static if isdefined(Base.MultiplicativeInverses, :_mul_high)
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = Base.MultiplicativeInverses._mul_high(a, b)
@device_override Base.MultiplicativeInverses._mul_high(a::Int64, b::Int64) = _mulhi(a, b)
else
_mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = ((widen(a)*b) >>> (sizeof(a)*8)) % T
@device_override function Base.div(a::Int64, b::Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64})
x = _mulhi(a, b.multiplier)
x += (a*b.addmul) % Int64
ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % Int64)
end
end

0 comments on commit 76eacff

Please sign in to comment.