From 6655f49908c7a3d7ac16f2a1abcbeb75ed7c465f Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sat, 13 May 2023 16:49:00 +0200 Subject: [PATCH 1/2] Simplify `mul!` dispatch --- stdlib/LinearAlgebra/src/adjtrans.jl | 1 + stdlib/LinearAlgebra/src/matmul.jl | 74 +++++++--------------------- 2 files changed, 18 insertions(+), 57 deletions(-) diff --git a/stdlib/LinearAlgebra/src/adjtrans.jl b/stdlib/LinearAlgebra/src/adjtrans.jl index f6d8b33eb7639..2f5c5508e0ee3 100644 --- a/stdlib/LinearAlgebra/src/adjtrans.jl +++ b/stdlib/LinearAlgebra/src/adjtrans.jl @@ -97,6 +97,7 @@ inplace_adj_or_trans(::Type{<:Transpose}) = transpose! adj_or_trans_char(::T) where {T<:AbstractArray} = adj_or_trans_char(T) adj_or_trans_char(::Type{<:AbstractArray}) = 'N' adj_or_trans_char(::Type{<:Adjoint}) = 'C' +adj_or_trans_char(::Type{<:Adjoint{<:Real}}) = 'T' adj_or_trans_char(::Type{<:Transpose}) = 'T' Base.dataids(A::Union{Adjoint, Transpose}) = Base.dataids(A.parent) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 7a7f615e6f3e3..da259d70f1132 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -341,66 +341,26 @@ julia> lmul!(F.Q, B) """ lmul!(A, B) -# generic case -@inline mul!(C::StridedMatrix{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, B::StridedMaybeAdjOrTransVecOrMat{T}, - alpha::Number, beta::Number) where {T<:BlasFloat} = - gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta)) - -# AtB & ABt (including B === A) -@inline function mul!(C::StridedMatrix{T}, tA::Transpose{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}, - alpha::Number, beta::Number) where {T<:BlasFloat} - A = tA.parent - if A === B - return syrk_wrapper!(C, 'T', A, MulAddMul(alpha, beta)) - else - return gemm_wrapper!(C, 'T', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, tB::Transpose{<:Any,<:StridedVecOrMat{T}}, - alpha::Number, beta::Number) where {T<:BlasFloat} - B = tB.parent - if A === B - return syrk_wrapper!(C, 'N', A, MulAddMul(alpha, beta)) - else - return gemm_wrapper!(C, 'N', 'T', A, B, MulAddMul(alpha, beta)) - end -end -# real adjoint cases, also needed for disambiguation -@inline mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}, - alpha::Number, beta::Number) where {T<:BlasReal} = - mul!(C, A, transpose(adjB.parent), alpha, beta) -@inline mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}, - alpha::Real, beta::Real) where {T<:BlasReal} = - mul!(C, transpose(adjA.parent), B, alpha, beta) - -# AcB & ABc (including B === A) -@inline function mul!(C::StridedMatrix{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, B::StridedVecOrMat{T}, - alpha::Number, beta::Number) where {T<:BlasComplex} - A = adjA.parent - if A === B - return herk_wrapper!(C, 'C', A, MulAddMul(alpha, beta)) - else - return gemm_wrapper!(C, 'C', 'N', A, B, MulAddMul(alpha, beta)) - end -end -@inline function mul!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}, - alpha::Number, beta::Number) where {T<:BlasComplex} - B = adjB.parent - if A === B - return herk_wrapper!(C, 'N', A, MulAddMul(alpha, beta)) +@inline function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} + if tA == 'T' && tB == 'N' && A === B + return syrk_wrapper!(C, 'T', A, _add) + elseif tA == 'N' && tB == 'T' && A === B + return syrk_wrapper!(C, 'N', A, _add) + elseif tA == 'C' && tB == 'N' && A === B + return herk_wrapper!(C, 'C', A, _add) + elseif tA == 'N' && tB == 'C' && A === B + return herk_wrapper!(C, 'N', A, _add) else - return gemm_wrapper!(C, 'N', 'C', A, B, MulAddMul(alpha, beta)) + return gemm_wrapper!(C, tA, tB, A, B, _add) end end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. -@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedMaybeAdjOrTransVecOrMat{Complex{T}}, B::StridedMaybeAdjOrTransVecOrMat{T}, - alpha::Number, beta::Number) where {T<:BlasReal} = - gemm_wrapper!(C, adj_or_trans_char(A), adj_or_trans_char(B), _parent(A), _parent(B), MulAddMul(alpha, beta)) -# catch the real adjoint case and interpret it as a transpose -@inline mul!(C::StridedMatrix{Complex{T}}, A::StridedVecOrMat{Complex{T}}, adjB::Adjoint{<:Any,<:StridedVecOrMat{T}}, - alpha::Number, beta::Number) where {T<:BlasReal} = - mul!(C, A, transpose(adjB.parent), alpha, beta) +@inline function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + _add::MulAddMul=MulAddMul()) where {T<:BlasReal} + gemm_wrapper!(C, tA, tB, A, B, _add) +end # Supporting functions for matrix multiplication @@ -609,7 +569,7 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, tA, tB, A, B, _add) end function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, @@ -652,7 +612,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - generic_matmatmul!(C, tA, tB, A, B, _add) + _generic_matmatmul!(C, tA, tB, A, B, _add) end # blas.jl defines matmul for floats; other integer and mixed precision From 63111d850bf9f9dcc643c73f305e8faff36882ea Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Sun, 14 May 2023 13:55:11 +0200 Subject: [PATCH 2/2] apply same trick to matvec mul --- stdlib/LinearAlgebra/src/matmul.jl | 43 ++++++++++++++++-------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index da259d70f1132..170aacee6682f 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -70,23 +70,22 @@ end alpha::Number, beta::Number) = generic_matvecmul!(y, adj_or_trans_char(A), _parent(A), x, MulAddMul(alpha, beta)) # BLAS cases -@inline mul!(y::StridedVector{T}, A::StridedMaybeAdjOrTransVecOrMat{T}, x::StridedVector{T}, - alpha::Number, beta::Number) where {T<:BlasFloat} = - gemv!(y, adj_or_trans_char(A), _parent(A), x, alpha, beta) -# catch the real adjoint case and rewrap to transpose -@inline mul!(y::StridedVector{T}, adjA::Adjoint{<:Any,<:StridedVecOrMat{T}}, x::StridedVector{T}, - alpha::Number, beta::Number) where {T<:BlasReal} = - mul!(y, transpose(adjA.parent), x, alpha, beta) +# equal eltypes +@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, + _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} = + gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) +# Real (possibly transposed) matrix times complex vector. +# Multiply the matrix with the real and imaginary parts separately +@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, + _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = + gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) # Complex matrix times real vector. # Reinterpret the matrix as a real matrix and do real matvec computation. -@inline mul!(y::StridedVector{Complex{T}}, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, - alpha::Number, beta::Number) where {T<:BlasReal} = - gemv!(y, 'N', A, x, alpha, beta) -# Real matrix times complex vector. -# Multiply the matrix with the real and imaginary parts separately -@inline mul!(y::StridedVector{Complex{T}}, A::StridedMaybeAdjOrTransMat{T}, x::StridedVector{Complex{T}}, - alpha::Number, beta::Number) where {T<:BlasReal} = - gemv!(y, A isa StridedArray ? 'N' : 'T', _parent(A), x, alpha, beta) +# works only in cooperation with BLAS when A is untransposed (tA == 'N') +# but that check is included in gemv! anyway +@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, + _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = + gemv!(y, tA, _parent(A), x, _add.alpha, _add.beta) # Vector-Matrix multiplication (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' @@ -398,7 +397,7 @@ function gemv!(y::StridedVector{T}, tA::AbstractChar, A::StridedVecOrMat{T}, x:: !iszero(stride(x, 1)) # We only check input's stride here. return BLAS.gemv!(tA, alpha, A, x, beta, y) else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -419,7 +418,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -442,7 +441,7 @@ function gemv!(y::StridedVector{Complex{T}}, tA::AbstractChar, A::StridedVecOrMa BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y else - return generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -646,8 +645,12 @@ end # NOTE: the generic version is also called as fallback for # strides != 1 cases -function generic_matvecmul!(C::AbstractVector{R}, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul = MulAddMul()) where R +generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, + _add::MulAddMul = MulAddMul()) = + _generic_matvecmul!(C, tA, A, B, _add) + +function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, + _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) mB = length(B) mA, nA = lapack_size(tA, A)