Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify mul! dispatch #49806

Merged
merged 2 commits into from
May 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/src/adjtrans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ inplace_adj_or_trans(::Type{<:Transpose}) = transpose!
adj_or_trans_char(::T) where {T<:AbstractArray} = adj_or_trans_char(T)
adj_or_trans_char(::Type{<:AbstractArray}) = 'N'
adj_or_trans_char(::Type{<:Adjoint}) = 'C'
adj_or_trans_char(::Type{<:Adjoint{<:Real}}) = 'T'
adj_or_trans_char(::Type{<:Transpose}) = 'T'

Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent)
Expand Down
117 changes: 40 additions & 77 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,22 @@ end
alpha::Number, beta::Number) =
generic_matvecmul!(y, adj_or_trans_char(A), _parent(A), x, MulAddMul(alpha, beta))
# BLAS cases
@inline mul!(y::StridedVector{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, adj_or_trans_char(A), _parent(A), x, alpha, beta)
# catch the real adjoint case and rewrap to transpose
@inline mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(y, transpose(adjA.parent), x, alpha, beta)
# equal eltypes
@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
# Real (possibly transposed) matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)
# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec computation.
@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, 'N', A, x, alpha, beta)
# Real matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, A isa StridedArray ? 'N' : 'T', _parent(A), x, alpha, beta)
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
# but that check is included in gemv! anyway
@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta)

# Vector-Matrix multiplication
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
Expand Down Expand Up @@ -341,66 +340,26 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# generic case
@inline mul!(C::StridedMatrix{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, B::StridedMaybeAdjOrTransVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))

# AtB & ABt (including B === A)
@inline function mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasFloat}
A = tA.parent
if A === B
return syrk_wrapper!(C, 'T', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, tB::Transpose{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasFloat}
B = tB.parent
if A === B
return syrk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta))
end
end
# real adjoint cases, also needed for disambiguation
@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(C, A, transpose(adjB.parent), alpha, beta)
@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Real, beta::Real) where {T<:BlasReal} =
mul!(C, transpose(adjA.parent), B, alpha, beta)

# AcB & ABc (including B === A)
@inline function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasComplex}
A = adjA.parent
if A === B
return herk_wrapper!(C, 'C', A, MulAddMul(alpha, beta))
@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if tA == 'T' && tB == 'N' && A === B
return syrk_wrapper!(C, 'T', A, _add)
elseif tA == 'N' && tB == 'T' && A === B
return syrk_wrapper!(C, 'N', A, _add)
elseif tA == 'C' && tB == 'N' && A === B
return herk_wrapper!(C, 'C', A, _add)
elseif tA == 'N' && tB == 'C' && A === B
return herk_wrapper!(C, 'N', A, _add)
else
return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta))
end
end
@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasComplex}
B = adjB.parent
if A === B
return herk_wrapper!(C, 'N', A, MulAddMul(alpha, beta))
else
return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta))
return gemm_wrapper!(C, tA, tB, A, B, _add)
end
end

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedMaybeAdjOrTransVecOrMat{Complex{T}}, B::StridedMaybeAdjOrTransVecOrMat{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta))
# catch the real adjoint case and interpret it as a transpose
@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
mul!(C, A, transpose(adjB.parent), alpha, beta)
@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
gemm_wrapper!(C, tA, tB, A, B, _add)
end


# Supporting functions for matrix multiplication
Expand Down Expand Up @@ -438,7 +397,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x::
!iszero(stride(x, 1)) # We only check input's stride here.
return BLAS.gemv!(tA, alpha, A, x, beta, y)
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand All @@ -459,7 +418,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand All @@ -482,7 +441,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
else
return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
end
end

Expand Down Expand Up @@ -609,7 +568,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
stride(C, 2) >= size(C, 1))
return BLAS.gemm!(tA, tB, alpha, A, B, beta, C)
end
generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
Expand Down Expand Up @@ -652,7 +611,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
generic_matmatmul!(C, tA, tB, A, B, _add)
_generic_matmatmul!(C, tA, tB, A, B, _add)
end

# blas.jl defines matmul for floats; other integer and mixed precision
Expand Down Expand Up @@ -686,8 +645,12 @@ end
# NOTE: the generic version is also called as fallback for
# strides != 1 cases

function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul()) where R
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul()) =
_generic_matvecmul!(C, tA, A, B, _add)

function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
mB = length(B)
mA, nA = lapack_size(tA, A)
Expand Down