Skip to content

Commit

Permalink
Merge pull request #603 from tkoolen/tk/dot-simdloop
Browse files Browse the repository at this point in the history
Use `@simd` in `_vecdot`
  • Loading branch information
tkoolen authored May 1, 2019
2 parents ec26f0a + 90ffc02 commit 407c65f
Showing 1 changed file with 13 additions and 28 deletions.
41 changes: 13 additions & 28 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,38 +214,23 @@ end
@inbounds return similar_type(a, typeof(Signed(a[2]*b[3])-Signed(a[3]*b[2])))(((Signed(a[2]*b[3])-Signed(a[3]*b[2]), Signed(a[3]*b[1])-Signed(a[1]*b[3]), Signed(a[1]*b[2])-Signed(a[2]*b[1]))))
end

@inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b)
@generated function _vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
if prod(S) == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(adjoint(a[1]) * b[1])
for j = 2:prod(S)
expr = :($expr + adjoint(a[$j]) * b[$j])
end
@inline dot(a::StaticVector, b::StaticVector) = _vecdot(same_size(a, b), a, b, dot)
@inline bilinear_vecdot(a::StaticArray, b::StaticArray) = _vecdot(same_size(a, b), a, b, *)

return quote
@_inline_meta
@inbounds return $expr
end
end

@inline bilinear_vecdot(a::StaticArray, b::StaticArray) = _bilinear_vecdot(same_size(a, b), a, b)
@generated function _bilinear_vecdot(::Size{S}, a::StaticArray, b::StaticArray) where {S}
@inline function _vecdot(::Size{S}, a::StaticArray, b::StaticArray, product) where {S}
if prod(S) == 0
return :(zero(promote_op(*, eltype(a), eltype(b))))
end

expr = :(a[1] * b[1])
for j = 2:prod(S)
expr = :($expr + a[$j] * b[$j])
za, zb = zero(eltype(a)), zero(eltype(b))
else
# Use an actual element if there is one, to support e.g. Vector{<:Number}
# element types for which runtime size information is required to construct
# a zero element.
za, zb = zero(a[1]), zero(b[1])
end

return quote
@_inline_meta
@inbounds return $expr
ret = product(za, zb) + product(za, zb)
@inbounds @simd for j = 1 : prod(S)
ret += product(a[j], b[j])
end
return ret
end

@inline LinearAlgebra.norm_sqr(v::StaticVector) = mapreduce(abs2, +, v; init=zero(real(eltype(v))))
Expand Down

0 comments on commit 407c65f

Please sign in to comment.