diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index f872b0689afbb..2b9b358e0e227 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -440,11 +440,16 @@ const BandedMatrix = Union{Bidiagonal,Diagonal,Tridiagonal,SymTridiagonal} # or const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal} const TriSym = Union{Tridiagonal,SymTridiagonal} const BiTri = Union{Bidiagonal,Tridiagonal} -@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline _mul!(C::AbstractVector, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractVector, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractMatrix, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BandedMatrix, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline _mul!(C::AbstractMatrix, A::BandedMatrix, B::BandedMatrix, alpha::Number, beta::Number) = + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) lmul!(A::Bidiagonal, B::AbstractVecOrMat) = @inline _mul!(B, A, B, MulAddMul()) rmul!(B::AbstractMatrix, A::Bidiagonal) = @inline _mul!(B, B, A, MulAddMul()) diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index 35014cd520630..c98f3e7190bef 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -49,6 +49,74 @@ end end end +""" + @stable_muladdmul + +Replaces a function call, that has a `MulAddMul(alpha, beta)` constructor as an +argument, with a branch over possible values of `isone(alpha)` and `iszero(beta)` +and constructs `MulAddMul{isone(alpha), iszero(beta)}` explicitly in each branch. +For example, 'f(x, y, MulAddMul(alpha, beta))` is transformed into +``` +if isone(alpha) + if iszero(beta) + f(x, y, MulAddMul{true, true, typeof(alpha), typeof(beta)}(alpha, beta)) + else + f(x, y, MulAddMul{true, false, typeof(alpha), typeof(beta)}(alpha, beta)) + end +else + if iszero(beta) + f(x, y, MulAddMul{false, true, typeof(alpha), typeof(beta)}(alpha, beta)) + else + f(x, y, MulAddMul{false, false, typeof(alpha), typeof(beta)}(alpha, beta)) + end +end +``` +This avoids the type instability of the `MulAddMul(alpha, beta)` constructor, +which causes runtime dispatch in case alpha and zero are not constants. +""" +macro stable_muladdmul(expr) + expr.head == :call || throw(ArgumentError("Can only handle function calls.")) + for (i, e) in enumerate(expr.args) + e isa Expr || continue + if e.head == :call && e.args[1] == :MulAddMul && length(e.args) == 3 + e.args[2] isa Symbol || continue + e.args[3] isa Symbol || continue + local asym = e.args[2] + local bsym = e.args[3] + + local e_sub11 = copy(expr) + e_sub11.args[i] = :(MulAddMul{true, true, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub10 = copy(expr) + e_sub10.args[i] = :(MulAddMul{true, false, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub01 = copy(expr) + e_sub01.args[i] = :(MulAddMul{false, true, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_sub00 = copy(expr) + e_sub00.args[i] = :(MulAddMul{false, false, typeof($asym), typeof($bsym)}($asym, $bsym)) + + local e_out = quote + if isone($asym) + if iszero($bsym) + $e_sub11 + else + $e_sub10 + end + else + if iszero($bsym) + $e_sub01 + else + $e_sub00 + end + end + end + return esc(e_out) + end + end + throw(ArgumentError("No valid MulAddMul expression found.")) +end + MulAddMul() = MulAddMul{true,true,Bool,Bool}(true, false) @inline (::MulAddMul{true})(x) = x diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 9c74addd6b69c..ad8d8e91af299 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -69,26 +69,34 @@ end @inline mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, alpha::Number, beta::Number) = _mul!(y, A, x, alpha, beta) -@inline _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, +_mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, alpha::Number, beta::Number) = - generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, MulAddMul(alpha, beta)) - + generic_matvecmul!(y, wrapper_char(A), _unwrap(A), x, alpha, beta) # BLAS cases # equal eltypes -@inline generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, - _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} = +generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, + alpha::Number, beta::Number) where {T<:BlasFloat} = + gemv!(y, tA, A, x, alpha, beta) +generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, + _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = gemv!(y, tA, A, x, _add.alpha, _add.beta) # Real (possibly transposed) matrix times complex vector. # Multiply the matrix with the real and imaginary parts separately -@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, - _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = +generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, + alpha::Number, beta::Number) where {T<:BlasReal} = + gemv!(y, tA, A, x, alpha, beta) +generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, + _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = gemv!(y, tA, A, x, _add.alpha, _add.beta) # Complex matrix times real vector. # Reinterpret the matrix as a real matrix and do real matvec computation. # works only in cooperation with BLAS when A is untransposed (tA == 'N') # but that check is included in gemv! anyway -@inline generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, - _add::MulAddMul=MulAddMul()) where {T<:BlasReal} = +generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, + alpha::Number, beta::Number) where {T<:BlasReal} = + gemv!(y, tA, A, x, alpha, beta) +generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, + _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = gemv!(y, tA, A, x, _add.alpha, _add.beta) # Vector-Matrix multiplication @@ -291,7 +299,7 @@ true wrapper_char(B), _unwrap(A), _unwrap(B), - MulAddMul(α, β) + α, β ) """ @@ -361,27 +369,40 @@ julia> lmul!(F.Q, B) lmul!(A, B) # THE one big BLAS dispatch -# aggressive constant propagation makes mul!(C, A, B) invoke gemm_wrapper! directly Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - _add::MulAddMul=MulAddMul()) where {T<:BlasFloat} + α::Number, β::Number) where {T<:BlasFloat} + mA, nA = lapack_size(tA, A) + mB, nB = lapack_size(tB, B) + if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α) + if size(C) != (mA, nB) + throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) + end + return _rmul_or_fill!(C, β) + end + if size(C) == size(A) == size(B) == (2,2) + return @stable_muladdmul matmul2x2!(C, tA, tB, A, B, MulAddMul(α, β)) + end + if size(C) == size(A) == size(B) == (3,3) + return @stable_muladdmul matmul3x3!(C, tA, tB, A, B, MulAddMul(α, β)) + end # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type tA_uc, tB_uc = uppercase(tA), uppercase(tB) # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) if tA_uc == 'T' && tB_uc == 'N' && A === B - return syrk_wrapper!(C, 'T', A, _add) + return syrk_wrapper!(C, 'T', A, α, β) elseif tA_uc == 'N' && tB_uc == 'T' && A === B - return syrk_wrapper!(C, 'N', A, _add) + return syrk_wrapper!(C, 'N', A, α, β) elseif tA_uc == 'C' && tB_uc == 'N' && A === B - return herk_wrapper!(C, 'C', A, _add) + return herk_wrapper!(C, 'C', A, α, β) elseif tA_uc == 'N' && tB_uc == 'C' && A === B - return herk_wrapper!(C, 'N', A, _add) + return herk_wrapper!(C, 'N', A, α, β) else - return gemm_wrapper!(C, tA, tB, A, B, _add) + return gemm_wrapper!(C, tA, tB, A, B, α, β) end end - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + alpha, beta = promote(α, β, zero(T)) if alpha isa Union{Bool,T} && beta isa Union{Bool,T} if tA_uc == 'S' && tB_uc == 'N' return BLAS.symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C) @@ -393,23 +414,30 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, return BLAS.hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C) end end - return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) + return _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) end +# legacy method +Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, + _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = + generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) # Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency. Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - _add::MulAddMul=MulAddMul()) where {T<:BlasReal} + α::Number, β::Number) where {T<:BlasReal} # We convert the chars to uppercase to potentially unwrap a WrapperChar, # and extract the char corresponding to the wrapper type tA_uc, tB_uc = uppercase(tA), uppercase(tB) # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) - gemm_wrapper!(C, tA, tB, A, B, _add) + gemm_wrapper!(C, tA, tB, A, B, α, β) else - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) end end - +# legacy method +Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, + _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = + generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) # Supporting functions for matrix multiplication @@ -457,9 +485,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar if tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else - return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end @@ -482,7 +510,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs return y else Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) - return _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) + return @stable_muladdmul _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) end end @@ -509,16 +537,16 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs elseif tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) else - return _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) end end # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it # to be concretely inferred Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasFloat} + alpha::Number, beta::Number) where {T<:BlasFloat} nC = checksquare(C) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'T' @@ -531,20 +559,11 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) end - if mA == 0 || nA == 0 || iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) - end - if mA == 2 && nA == 2 - return matmul2x2!(C, tA, tAt, A, A, _add) - end - if mA == 3 && nA == 3 - return matmul3x3!(C, tA, tAt, A, A, _add) - end # BLAS.syrk! only updates symmetric C # alternatively, make non-zero β a show-stopper for BLAS.syrk! - if iszero(_add.beta) || issymmetric(C) - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + if iszero(beta) || issymmetric(C) + α, β = promote(alpha, beta, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && @@ -553,13 +572,16 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst return copytri!(BLAS.syrk!('U', tA, alpha, A, beta, C), 'U') end end - return gemm_wrapper!(C, tA, tAt, A, A, _add) + return gemm_wrapper!(C, tA, tAt, A, A, alpha, beta) end +# legacy method +syrk_wrapper!(C::StridedMatrix{T}, tA::AbstractChar, A::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = + syrk_wrapper!(C, tA, A, _add.alpha, _add.beta) # the aggressive constprop pushes tA and tB into gemm_wrapper!, which is needed for wrap calls within it # to be concretely inferred Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, - _add = MulAddMul()) where {T<:BlasReal} + α::Number, β::Number) where {T<:BlasReal} nC = checksquare(C) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char if tA_uc == 'C' @@ -572,21 +594,12 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St if nC != mA throw(DimensionMismatch(lazy"output matrix has size: $(nC), but should have size $(mA)")) end - if mA == 0 || nA == 0 || iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) - end - if mA == 2 && nA == 2 - return matmul2x2!(C, tA, tAt, A, A, _add) - end - if mA == 3 && nA == 3 - return matmul3x3!(C, tA, tAt, A, A, _add) - end # Result array does not need to be initialized as long as beta==0 # C = Matrix{T}(undef, mA, mA) - if iszero(_add.beta) || issymmetric(C) - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + if iszero(β) || issymmetric(C) + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(C, 1) == 1 && @@ -595,8 +608,12 @@ Base.@constprop :aggressive function herk_wrapper!(C::Union{StridedMatrix{T}, St return copytri!(BLAS.herk!('U', tA, alpha, A, beta, C), 'U', true) end end - return gemm_wrapper!(C, tA, tAt, A, A, _add) + return gemm_wrapper!(C, tA, tAt, A, A, α, β) end +# legacy method +herk_wrapper!(C::Union{StridedMatrix{T}, StridedMatrix{Complex{T}}}, tA::AbstractChar, A::Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}}, + _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = + herk_wrapper!(C, tA, A, _add.alpha, _add.beta) # Aggressive constprop helps propagate the values of tA and tB into wrap, which # makes the calls concretely inferred @@ -611,9 +628,9 @@ Base.@constprop :aggressive function gemm_wrapper(tA::AbstractChar, tB::Abstract tA_uc, tB_uc = uppercase(tA), uppercase(tB) # the map in all ensures constprop by acting on tA and tB individually, instead of looping over them. if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc))) - gemm_wrapper!(C, tA, tB, A, B) + gemm_wrapper!(C, tA, tB, A, B, true, false) else - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul()) end end @@ -621,7 +638,7 @@ end # makes the calls concretely inferred Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasFloat} + α::Number, β::Number) where {T<:BlasFloat} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) @@ -633,21 +650,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) - if size(C) != (mA, nB) - throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) - end - return _rmul_or_fill!(C, _add.beta) - end - - if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) - end - if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) - end - - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + alpha, beta = promote(α, β, zero(T)) if (alpha isa Union{Bool,T} && beta isa Union{Bool,T} && stride(A, 1) == stride(B, 1) == stride(C, 1) == 1 && @@ -656,14 +659,18 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab stride(C, 2) >= size(C, 1)) return BLAS.gemm!(tA, tB, alpha, A, B, beta, C) end - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) end +# legacy method +gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{T}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = + gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta) # Aggressive constprop helps propagate the values of tA and tB into wrap, which # makes the calls concretely inferred Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, - _add = MulAddMul()) where {T<:BlasReal} + α::Number, β::Number) where {T<:BlasReal} mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) @@ -675,21 +682,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T} throw(ArgumentError("output matrix must not be aliased with input matrix")) end - if mA == 0 || nA == 0 || nB == 0 || iszero(_add.alpha) - if size(C) != (mA, nB) - throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)")) - end - return _rmul_or_fill!(C, _add.beta) - end - - if mA == 2 && nA == 2 && nB == 2 - return matmul2x2!(C, tA, tB, A, B, _add) - end - if mA == 3 && nA == 3 && nB == 3 - return matmul3x3!(C, tA, tB, A, B, _add) - end - - alpha, beta = promote(_add.alpha, _add.beta, zero(T)) + alpha, beta = promote(α, β, zero(T)) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char @@ -703,8 +696,12 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{Complex{T} BLAS.gemm!(tA, tB, alpha, reinterpret(T, A), B, beta, reinterpret(T, C)) return C end - _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) end +# legacy method +gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::AbstractChar, + A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T}, _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = + gemm_wrapper!(C, tA, tB, A, B, _add.alpha, _add.beta) # blas.jl defines matmul for floats; other integer and mixed precision # cases are handled here @@ -780,6 +777,8 @@ end # NOTE: the generic version is also called as fallback for # strides != 1 cases +generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) = + @stable_muladdmul generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta)) @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char @@ -861,12 +860,15 @@ function generic_matmatmul(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S}) mA, nA = lapack_size(tA, A) mB, nB = lapack_size(tB, B) C = similar(B, promote_op(matprod, T, S), mA, nB) - generic_matmatmul!(C, tA, tB, A, B) + generic_matmatmul!(C, tA, tB, A, B, true, false) end # aggressive const prop makes mixed eltype mul!(C, A, B) invoke _generic_matmatmul! directly -Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul) = +# legacy method +Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) = _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add) +Base.@constprop :aggressive generic_matmatmul!(C::AbstractVecOrMat, tA, tB, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = + _generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β)) @noinline function _generic_matmatmul!(C::AbstractVecOrMat{R}, A::AbstractVecOrMat{T}, B::AbstractVecOrMat{S}, _add::MulAddMul) where {T,S,R} @@ -935,6 +937,9 @@ end function matmul2x2!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) + if C === A || B === C + throw(ArgumentError("output matrix must not be aliased with input matrix")) + end if !(size(A) == size(B) == size(C) == (2,2)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end @@ -1002,6 +1007,9 @@ end function matmul3x3!(C::AbstractMatrix, tA, tB, A::AbstractMatrix, B::AbstractMatrix, _add::MulAddMul = MulAddMul()) require_one_based_indexing(C, A, B) + if C === A || B === C + throw(ArgumentError("output matrix must not be aliased with input matrix")) + end if !(size(A) == size(B) == size(C) == (3,3)) throw(DimensionMismatch(lazy"A has size $(size(A)), B has size $(size(B)), C has size $(size(C))")) end diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 092dacdd11846..3c62517f20659 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -109,9 +109,9 @@ end # disambiguation between triangular and banded matrices, banded ones "dominate" _mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) = - _mul!(C, A, B, MulAddMul(alpha, beta)) + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) _mul!(C::AbstractMatrix, A::BandedMatrix, B::AbstractTriangular, alpha::Number, beta::Number) = - _mul!(C, A, B, MulAddMul(alpha, beta)) + @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) function *(H::UpperHessenberg, B::Bidiagonal) T = promote_op(matprod, eltype(H), eltype(B)) diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index 922bd9a6bd91a..d7ca9f9854775 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -545,9 +545,9 @@ function copyto!(A::T, B::T) where {T<:Union{LowerTriangular,UnitLowerTriangular end @inline _rscale_add!(A::AbstractTriangular, B::AbstractTriangular, C::Number, alpha::Number, beta::Number) = - _triscale!(A, B, C, MulAddMul(alpha, beta)) + @stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta)) @inline _lscale_add!(A::AbstractTriangular, B::Number, C::AbstractTriangular, alpha::Number, beta::Number) = - _triscale!(A, B, C, MulAddMul(alpha, beta)) + @stable_muladdmul _triscale!(A, B, C, MulAddMul(alpha, beta)) function checksize1(A, B) szA, szB = size(A), size(B) @@ -891,7 +891,7 @@ for TC in (:AbstractVector, :AbstractMatrix) if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else - return generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta)) + return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta)) end end end @@ -903,7 +903,7 @@ for (TA, TB) in ((:AbstractTriangular, :AbstractMatrix), if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else - return generic_matmatmul!(C, 'N', 'N', A, B, MulAddMul(alpha, beta)) + return generic_matmatmul!(C, 'N', 'N', A, B, alpha, beta) end end end diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index db61fbe0ab45a..d64e0ee67f3ba 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -1094,7 +1094,8 @@ end end end -@testset "Issue #46865: mul!() with non-const alpha, beta" begin +#46865 +@testset "mul!() with non-const alpha, beta" begin f!(C,A,B,alphas,betas) = mul!(C, A, B, alphas[1], betas[1]) alphas = [1.0] betas = [0.5] @@ -1103,7 +1104,7 @@ end B = copy(A) C = copy(A) f!(C, A, B, alphas, betas) - @test_broken (@allocated f!(C, A, B, alphas, betas)) == 0 + @test (@allocated f!(C, A, B, alphas, betas)) == 0 end end