Skip to content

Commit

Permalink
LinearAlgebra: improve type-inference in Symmetric/Hermitian matmul (…
Browse files Browse the repository at this point in the history
…#54303)

(cherry picked from commit c77671a315096b58f11cca911bb9c27ea816337b)
(cherry picked from commit 4d665a982a6957d8065a6b3e51c3f0c0abc2d96a)
  • Loading branch information
jishnub authored and KristofferC committed Dec 2, 2024
1 parent bbe4713 commit 2ee110e
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 46 deletions.
52 changes: 38 additions & 14 deletions src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,32 +516,56 @@ const ⋅ = dot
const × = cross
export , ×

# Separate the char corresponding to the wrapper from that corresponding to the uplo
# In most cases, the former may be constant-propagated, while the latter usually can't be.
# This improves type-inference in wrap for Symmetric/Hermitian matrices
# A WrapperChar is equivalent to `isuppertri ? uppercase(wrapperchar) : lowercase(wrapperchar)`
struct WrapperChar <: AbstractChar
wrapperchar :: Char
isuppertri :: Bool
end
function Base.Char(w::WrapperChar)
T = w.wrapperchar
if T ('N', 'T', 'C') # known cases where isuppertri is true
T
else
_isuppertri(w) ? uppercase(T) : lowercase(T)
end
end
Base.codepoint(w::WrapperChar) = codepoint(Char(w))
WrapperChar(n::UInt32) = WrapperChar(Char(n))
WrapperChar(c::Char) = WrapperChar(c, isuppercase(c))
# We extract the wrapperchar so that the result may be constant-propagated
# This doesn't return a value of the same type on purpose
Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar)
Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar)
_isuppertri(w::WrapperChar) = w.isuppertri
_isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation
_uplosym(x) = _isuppertri(x) ? (:U) : (:L)

wrapper_char(::AbstractArray) = 'N'
wrapper_char(::Adjoint) = 'C'
wrapper_char(::Adjoint{<:Real}) = 'T'
wrapper_char(::Transpose) = 'T'
wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h'
wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's'
wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's'
wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U')
wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U')
wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U')

Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar)
# merge the result of this before return, so that we can type-assert the return such
# that even if the tmerge is inaccurate, inference can still identify that the
# `_generic_matmatmul` signature still matches and doesn't require missing backedges
B = if tA == 'N'
tA_uc = uppercase(tA)
B = if tA_uc == 'N'
A
elseif tA == 'T'
elseif tA_uc == 'T'
transpose(A)
elseif tA == 'C'
elseif tA_uc == 'C'
adjoint(A)
elseif tA == 'H'
Hermitian(A, :U)
elseif tA == 'h'
Hermitian(A, :L)
elseif tA == 'S'
Symmetric(A, :U)
else # tA == 's'
Symmetric(A, :L)
elseif tA_uc == 'H'
Hermitian(A, _uplosym(tA))
elseif tA_uc == 'S'
Symmetric(A, _uplosym(tA))
end
return B::AbstractVecOrMat
end
Expand Down
96 changes: 64 additions & 32 deletions src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,28 +364,32 @@ lmul!(A, B)
# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasFloat}
if all(in(('N', 'T', 'C')), (tA, tB))
if tA == 'T' && tB == 'N' && A === B
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, _add)
elseif tA == 'N' && tB == 'T' && A === B
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, _add)
elseif tA == 'C' && tB == 'N' && A === B
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, _add)
elseif tA == 'N' && tB == 'C' && A === B
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, _add)
else
return gemm_wrapper!(C, tA, tB, A, B, _add)
end
end
alpha, beta = promote(_add.alpha, _add.beta, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
if (tA == 'S' || tA == 's') && tB == 'N'
if tA_uc == 'S' && tB_uc == 'N'
return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
elseif (tB == 'S' || tB == 's') && tA == 'N'
elseif tA_uc == 'N' && tB_uc == 'S'
return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
elseif (tA == 'H' || tA == 'h') && tB == 'N'
elseif tA_uc == 'H' && tB_uc == 'N'
return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
elseif (tB == 'H' || tB == 'h') && tA == 'N'
elseif tA_uc == 'N' && tB_uc == 'H'
return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
end
end
Expand All @@ -395,7 +399,11 @@ end
# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add::MulAddMul=MulAddMul()) where {T<:BlasReal}
if all(in(('N', 'T', 'C')), (tA, tB))
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, _add)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
Expand Down Expand Up @@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) && # We only check input's stride here.
if tA in ('N', 'T', 'C')
if tA_uc in ('N', 'T', 'C')
return BLAS.gemv!(tA, alpha, A, x, beta, y)
elseif tA in ('S', 's')
elseif tA_uc == 'S'
return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y)
elseif tA in ('H', 'h')
elseif tA_uc == 'H'
return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y)
end
end
if tA in ('S', 's', 'H', 'h')
if tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
Expand All @@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y`
!iszero(stride(x, 1))
BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y))
return y
else
Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
end
end
Expand All @@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
mA == 0 && return y
nA == 0 && return _rmul_or_fill!(y, β)
alpha, beta = promote(α, β, zero(T))
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
@views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} &&
stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) &&
!iszero(stride(x, 1)) && tA in ('N', 'T', 'C')
!iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C')
xfl = reinterpret(reshape, T, x) # Use reshape here.
yfl = reinterpret(reshape, T, y)
BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :])
BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :])
return y
elseif tA in ('S', 's', 'H', 'h')
elseif tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
Expand All @@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
end
end

function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
# to be concretely inferred
Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
nC = checksquare(C)
if tA == 'T'
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'T'
(nA, mA) = size(A,1), size(A,2)
tAt = 'N'
else
Expand Down Expand Up @@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat
return gemm_wrapper!(C, tA, tAt, A, A, _add)
end

function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it
# to be concretely inferred
Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}},
_add = MulAddMul()) where {T<:BlasReal}
nC = checksquare(C)
if tA == 'C'
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
if tA_uc == 'C'
(nA, mA) = size(A,1), size(A,2)
tAt = 'N'
else
Expand Down Expand Up @@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA
return gemm_wrapper!(C, tA, tAt, A, A, _add)
end

function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
# makes the calls concretely inferred
Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T},
B::StridedVecOrMat{T}) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
C = similar(B, T, mA, nB)
if all(in(('N', 'T', 'C')), (tA, tB))
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end
end

function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
# makes the calls concretely inferred
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add)
end

function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
# Aggressive constprop helps propagate the values of tA and tB into wrap, which
# makes the calls concretely inferred
Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar,
A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
_add = MulAddMul()) where {T<:BlasReal}
mA, nA = lapack_size(tA, A)
Expand Down Expand Up @@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs

alpha, beta = promote(_add.alpha, _add.beta, zero(T))

tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char

# Make-sure reinterpret-based optimization is BLAS-compatible.
if (alpha isa Union{Bool,T} &&
beta isa Union{Bool,T} &&
stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 &&
stride(A, 2) >= size(A, 1) &&
stride(B, 2) >= size(B, 1) &&
stride(C, 2) >= size(C, 1) && tA == 'N')
stride(C, 2) >= size(C, 1) && tA_uc == 'N')
BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C))
return C
end
Expand Down Expand Up @@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and
See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref).
"""
function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
if tM == 'N'
tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char
if tM_uc == 'N'
copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src)
elseif tM == 'T'
elseif tM_uc == 'T'
copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src)
else
copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src)
Expand Down Expand Up @@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and
See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref).
"""
function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int})
if tM == 'N'
tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char
if tM_uc == 'N'
copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src)
else
copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src)
tM == 'C' && conj!(@view B[ir_dest, jr_dest])
tM_uc == 'C' && conj!(@view B[ir_dest, jr_dest])
end
B
end
Expand All @@ -751,7 +782,8 @@ end

@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA)
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(C, ta, Anew, B, _add)
end

Expand Down
24 changes: 24 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,30 @@ mul_wrappers = [
h(A) = LinearAlgebra.wrap(LinearAlgebra._unwrap(A), LinearAlgebra.wrapper_char(A))
@test @inferred(h(transpose(A))) === transpose(A)
@test @inferred(h(adjoint(A))) === transpose(A)

M = rand(2,2)
for S in (Symmetric(M), Hermitian(M))
@test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === Symmetric(M)
end
M = rand(ComplexF64,2,2)
for S in (Symmetric(M), Hermitian(M))
@test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === S
end

@testset "WrapperChar" begin
@test LinearAlgebra.WrapperChar('c') == 'c'
@test LinearAlgebra.WrapperChar('C') == 'C'
@testset "constant propagation in uppercase/lowercase" begin
v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('C'))))()
@test v isa Val{'C'}
v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('s'))))()
@test v isa Val{'S'}
v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('C'))))()
@test v isa Val{'c'}
v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('s'))))()
@test v isa Val{'s'}
end
end
end

@testset "matrices with zero dimensions" begin
Expand Down

0 comments on commit 2ee110e

Please sign in to comment.