Skip to content

Commit

Permalink
Split generic_matmul for strided matrices into two halves (#54552)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored May 25, 2024
1 parent 2a5d553 commit b54dce2
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 28 deletions.
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,10 @@ wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U')
wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U')
wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U')

wrapper_char_NTC(A::AbstractArray) = uppercase(wrapper_char(A)) == 'N'
wrapper_char_NTC(A::Union{StridedArray, Adjoint, Transpose}) = true
wrapper_char_NTC(A::Union{Symmetric, Hermitian}) = false

Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar)
# merge the result of this before return, so that we can type-assert the return such
# that even if the tmerge is inaccurate, inference can still identify that the
Expand Down
78 changes: 50 additions & 28 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,24 @@ true
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
generic_matmatmul!(
generic_matmatmul_wrapper!(
C,
wrapper_char(A),
wrapper_char(B),
_unwrap(A),
_unwrap(B),
α, β
α, β,
Val(wrapper_char_NTC(A) & wrapper_char_NTC(B))
)

# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
# even though the wrappers are stripped off in mul!
# By default, we ignore the wrapper info and forward the arguments to generic_matmatmul!
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, @nospecialize(val))
generic_matmatmul!(C, tA, tB, A, B, α, β)
end


"""
rmul!(A, B)
Expand Down Expand Up @@ -368,9 +377,9 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# THE one big BLAS dispatch
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:BlasFloat}
# THE one big BLAS dispatch. This is split into two methods to improve latency
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{true}) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
Expand All @@ -389,19 +398,37 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, α, β)
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, α, β)
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
end
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
if size(C) != (mA, nB)
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, β)
end
if size(C) == size(A) == size(B) == (2,2)
return matmul2x2!(C, tA, tB, A, B, α, β)
end
if size(C) == size(A) == size(B) == (3,3)
return matmul3x3!(C, tA, tB, A, B, α, β)
end
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
if tA_uc == 'S' && tB_uc == 'N'
Expand All @@ -421,18 +448,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:BlasReal}
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, α, β)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
end
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
gemm_wrapper!(C, tA, tB, A, B, α, β)
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
end
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Expand Down

0 comments on commit b54dce2

Please sign in to comment.