diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index 6ed272ab42f02..b8412fc361d3f 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -516,32 +516,56 @@ const ⋅ = dot const × = cross export ⋅, × +# Separate the char corresponding to the wrapper from that corresponding to the uplo +# In most cases, the former may be constant-propagated, while the latter usually can't be. +# This improves type-inference in wrap for Symmetric/Hermitian matrices +# A WrapperChar is equivalent to `isuppertri ? uppercase(wrapperchar) : lowercase(wrapperchar)` +struct WrapperChar <: AbstractChar + wrapperchar :: Char + isuppertri :: Bool +end +function Base.Char(w::WrapperChar) + T = w.wrapperchar + if T ∈ ('N', 'T', 'C') # known cases where isuppertri is true + T + else + _isuppertri(w) ? uppercase(T) : lowercase(T) + end +end +Base.codepoint(w::WrapperChar) = codepoint(Char(w)) +WrapperChar(n::UInt32) = WrapperChar(Char(n)) +WrapperChar(c::Char) = WrapperChar(c, isuppercase(c)) +# We extract the wrapperchar so that the result may be constant-propagated +# This doesn't return a value of the same type on purpose +Base.uppercase(w::WrapperChar) = uppercase(w.wrapperchar) +Base.lowercase(w::WrapperChar) = lowercase(w.wrapperchar) +_isuppertri(w::WrapperChar) = w.isuppertri +_isuppertri(x::AbstractChar) = isuppercase(x) # compatibility with earlier Char-based implementation +_uplosym(x) = _isuppertri(x) ? (:U) : (:L) + wrapper_char(::AbstractArray) = 'N' wrapper_char(::Adjoint) = 'C' wrapper_char(::Adjoint{<:Real}) = 'T' wrapper_char(::Transpose) = 'T' -wrapper_char(A::Hermitian) = A.uplo == 'U' ? 'H' : 'h' -wrapper_char(A::Hermitian{<:Real}) = A.uplo == 'U' ? 'S' : 's' -wrapper_char(A::Symmetric) = A.uplo == 'U' ? 'S' : 's' +wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U') +wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U') +wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U') Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar) # merge the result of this before return, so that we can type-assert the return such # that even if the tmerge is inaccurate, inference can still identify that the # `_generic_matmatmul` signature still matches and doesn't require missing backedges - B = if tA == 'N' + tA_uc = uppercase(tA) + B = if tA_uc == 'N' A - elseif tA == 'T' + elseif tA_uc == 'T' transpose(A) - elseif tA == 'C' + elseif tA_uc == 'C' adjoint(A) - elseif tA == 'H' - Hermitian(A, :U) - elseif tA == 'h' - Hermitian(A, :L) - elseif tA == 'S' - Symmetric(A, :U) - else # tA == 's' - Symmetric(A, :L) + elseif tA_uc == 'H' + Hermitian(A, _uplosym(tA)) + elseif tA_uc == 'S' + Symmetric(A, _uplosym(tA)) end return B::AbstractVecOrMat end diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 9ed8bd1b677aa..9c74addd6b69c 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -364,14 +364,18 @@ lmul!(A, B) # aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} - if all(in(('N', 'T', 'C')), (tA, tB)) - if tA == 'T' && tB == 'N' && A === B + # We convert the chars to uppercase to potentially unwrap a WrapperChar, + # and extract the char corresponding to the wrapper type + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) + if tA_uc == 'T' && tB_uc == 'N' && A === B return syrk_wrapper!(C, 'T', A, _add) - elseif tA == 'N' && tB == 'T' && A === B + elseif tA_uc == 'N' && tB_uc == 'T' && A === B return syrk_wrapper!(C, 'N', A, _add) - elseif tA == 'C' && tB == 'N' && A === B + elseif tA_uc == 'C' && tB_uc == 'N' && A === B return herk_wrapper!(C, 'C', A, _add) - elseif tA == 'N' && tB == 'C' && A === B + elseif tA_uc == 'N' && tB_uc == 'C' && A === B return herk_wrapper!(C, 'N', A, _add) else return gemm_wrapper!(C, tA, tB, A, B, _add) @@ -379,13 +383,13 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, end alpha, beta = promote(_add.alpha, _add.beta, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} - if (tA == 'S' || tA == 's') && tB == 'N' + if tA_uc == 'S' && tB_uc == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) - elseif (tB == 'S' || tB == 's') && tA == 'N' + elseif tA_uc == 'N' && tB_uc == 'S' return BLAS.symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C) - elseif (tA == 'H' || tA == 'h') && tB == 'N' + elseif tA_uc == 'H' && tB_uc == 'N' return BLAS.hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C) - elseif (tB == 'H' || tB == 'h') && tA == 'N' + elseif tA_uc == 'N' && tB_uc == 'H' return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end @@ -395,7 +399,11 @@ end # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul=MulAddMul()) where {T<:BlasReal} - if all(in(('N', 'T', 'C')), (tA, tB)) + # We convert the chars to uppercase to potentially unwrap a WrapperChar, + # and extract the char corresponding to the wrapper type + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) gemm_wrapper!(C, tA, tB, A, B, _add) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) @@ -434,18 +442,19 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && !iszero(stride(x, 1)) && # We only check input's stride here. - if tA in ('N', 'T', 'C') + if tA_uc in ('N', 'T', 'C') return BLAS.gemv!(tA, alpha, A, x, beta, y) - elseif tA in ('S', 's') + elseif tA_uc == 'S' return BLAS.symv!(tA == 'S' ? 'U' : 'L', alpha, A, x, beta, y) - elseif tA in ('H', 'h') + elseif tA_uc == 'H' return BLAS.hemv!(tA == 'H' ? 'U' : 'L', alpha, A, x, beta, y) end end - if tA in ('S', 's', 'H', 'h') + if tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) @@ -464,14 +473,15 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - stride(y, 1) == 1 && tA == 'N' && # reinterpret-based optimization is valid only for contiguous `y` + stride(y, 1) == 1 && tA_uc == 'N' && # reinterpret-based optimization is valid only for contiguous `y` !iszero(stride(x, 1)) BLAS.gemv!(tA, alpha, reinterpret(T, A), x, beta, reinterpret(T, y)) return y else - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -487,15 +497,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs mA == 0 && return y nA == 0 && return _rmul_or_fill!(y, β) alpha, beta = promote(α, β, zero(T)) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char @views if alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == 1 && abs(stride(A, 2)) >= size(A, 1) && - !iszero(stride(x, 1)) && tA in ('N', 'T', 'C') + !iszero(stride(x, 1)) && tA_uc in ('N', 'T', 'C') xfl = reinterpret(reshape, T, x) # Use reshape here. yfl = reinterpret(reshape, T, y) BLAS.gemv!(tA, alpha, A, xfl[1, :], beta, yfl[1, :]) BLAS.gemv!(tA, alpha, A, xfl[2, :], beta, yfl[2, :]) return y - elseif tA in ('S', 's', 'H', 'h') + elseif tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) @@ -504,10 +515,13 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs end end -function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, +# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +# to be concretely inferred +Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} nC = checksquare(C) - if tA == 'T' + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + if tA_uc == 'T' (nA, mA) = size(A,1), size(A,2) tAt = 'N' else @@ -542,10 +556,13 @@ function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat return gemm_wrapper!(C, tA, tAt, A, A, _add) end -function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, +# the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it +# to be concretely inferred +Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, _add = MulAddMul()) where {T<:BlasReal} nC = checksquare(C) - if tA == 'C' + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + if tA_uc == 'C' (nA, mA) = size(A,1), size(A,2) tAt = 'N' else @@ -581,20 +598,28 @@ function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA return gemm_wrapper!(C, tA, tAt, A, A, _add) end -function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred +Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) C = similar(B, T, mA, nB) - if all(in(('N', 'T', 'C')), (tA, tB)) + # We convert the chars to uppercase to potentially unwrap a WrapperChar, + # and extract the char corresponding to the wrapper type + tA_uc, tB_uc = uppercase(tA), uppercase(tB) + # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. + if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) gemm_wrapper!(C, tA, tB, A, B) else _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end end -function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) @@ -634,7 +659,9 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) end -function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, +# Aggressive constprop helps propagate the values of tA and tB into wrap, which +# makes the calls concretely inferred +Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add = MulAddMul()) where {T<:BlasReal} mA, nA = lapack_size(tA, A) @@ -664,13 +691,15 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + # Make-sure reinterpret-based optimization is BLAS-compatible. if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && stride(A, 2) >= size(A, 1) && stride(B, 2) >= size(B, 1) && - stride(C, 2) >= size(C, 1) && tA == 'N') + stride(C, 2) >= size(C, 1) && tA_uc == 'N') BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end @@ -703,9 +732,10 @@ parameters must satisfy `length(ir_dest) == length(ir_src)` and See also [`copy_transpose!`](@ref) and [`copy_adjoint!`](@ref). """ function copyto!(B::AbstractVecOrMat, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - if tM == 'N' + tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char + if tM_uc == 'N' copyto!(B, ir_dest, jr_dest, M, ir_src, jr_src) - elseif tM == 'T' + elseif tM_uc == 'T' copy_transpose!(B, ir_dest, jr_dest, M, jr_src, ir_src) else copy_adjoint!(B, ir_dest, jr_dest, M, jr_src, ir_src) @@ -734,11 +764,12 @@ range parameters must satisfy `length(ir_dest) == length(jr_src)` and See also [`copyto!`](@ref) and [`copy_adjoint!`](@ref). """ function copy_transpose!(B::AbstractMatrix, ir_dest::AbstractUnitRange{Int}, jr_dest::AbstractUnitRange{Int}, tM::AbstractChar, M::AbstractVecOrMat, ir_src::AbstractUnitRange{Int}, jr_src::AbstractUnitRange{Int}) - if tM == 'N' + tM_uc = uppercase(tM) # potentially convert a WrapperChar to a Char + if tM_uc == 'N' copy_transpose!(B, ir_dest, jr_dest, M, ir_src, jr_src) else copyto!(B, ir_dest, jr_dest, M, jr_src, ir_src) - tM == 'C' && conj!(@view B[ir_dest, jr_dest]) + tM_uc == 'C' && conj!(@view B[ir_dest, jr_dest]) end B end @@ -751,7 +782,8 @@ end @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) - Anew, ta = tA in ('S', 's', 'H', 'h') ? (wrap(A, tA), 'N') : (A, tA) + tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char + Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) return _generic_matvecmul!(C, ta, Anew, B, _add) end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index c760f1adeffdd..db61fbe0ab45a 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -30,6 +30,30 @@ mul_wrappers = [ h(A) = LinearAlgebra.wrap(LinearAlgebra._unwrap(A), LinearAlgebra.wrapper_char(A)) @test @inferred(h(transpose(A))) === transpose(A) @test @inferred(h(adjoint(A))) === transpose(A) + + M = rand(2,2) + for S in (Symmetric(M), Hermitian(M)) + @test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === Symmetric(M) + end + M = rand(ComplexF64,2,2) + for S in (Symmetric(M), Hermitian(M)) + @test @inferred((A -> LinearAlgebra.wrap(parent(A), LinearAlgebra.wrapper_char(A)))(S)) === S + end + + @testset "WrapperChar" begin + @test LinearAlgebra.WrapperChar('c') == 'c' + @test LinearAlgebra.WrapperChar('C') == 'C' + @testset "constant propagation in uppercase/lowercase" begin + v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('C'))))() + @test v isa Val{'C'} + v = @inferred (() -> Val(uppercase(LinearAlgebra.WrapperChar('s'))))() + @test v isa Val{'S'} + v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('C'))))() + @test v isa Val{'c'} + v = @inferred (() -> Val(lowercase(LinearAlgebra.WrapperChar('s'))))() + @test v isa Val{'s'} + end + end end @testset "matrices with zero dimensions" begin