From ee54dd528ea12ff0007c3cd85fff3483449ff45e Mon Sep 17 00:00:00 2001 From: Daniel Karrasch Date: Wed, 30 Nov 2022 14:13:41 +0100 Subject: [PATCH] reduce ambiguity in *diagonal multiplication code (#47683) --- stdlib/LinearAlgebra/src/bidiag.jl | 106 ++++++++++--------------- stdlib/LinearAlgebra/src/diagonal.jl | 6 +- stdlib/LinearAlgebra/src/hessenberg.jl | 4 +- stdlib/LinearAlgebra/src/tridiag.jl | 54 +------------ 4 files changed, 52 insertions(+), 118 deletions(-) diff --git a/stdlib/LinearAlgebra/src/bidiag.jl b/stdlib/LinearAlgebra/src/bidiag.jl index 218fd67a1b9d2..90f5c03f7fcfb 100644 --- a/stdlib/LinearAlgebra/src/bidiag.jl +++ b/stdlib/LinearAlgebra/src/bidiag.jl @@ -206,7 +206,7 @@ AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A) convert(::Type{T}, m::AbstractMatrix) where {T<:Bidiagonal} = m isa T ? m : T(m)::T similar(B::Bidiagonal, ::Type{T}) where {T} = Bidiagonal(similar(B.dv, T), similar(B.ev, T), B.uplo) -similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...) +similar(B::Bidiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(B.dv, T, dims) tr(B::Bidiagonal) = sum(B.dv) @@ -407,38 +407,32 @@ end const BiTriSym = Union{Bidiagonal,Tridiagonal,SymTridiagonal} const BiTri = Union{Bidiagonal,Tridiagonal} -@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTri, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -# for A::SymTridiagonal see tridiagonal.jl -@inline mul!(C::AbstractVector, A::BiTri, B::AbstractVector, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTri, B::AbstractMatrix, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTri, B::Diagonal, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::Diagonal, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTri, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::BiTri, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = A_mul_B_td!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractVector, A::BiTriSym, B::AbstractVector, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::AbstractMatrix, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Transpose{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Adjoint{<:Any,<:AbstractVecOrMat}, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) +@inline mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, alpha::Number, beta::Number) = _mul!(C, A, B, MulAddMul(alpha, beta)) function check_A_mul_B!_sizes(C, A, B) - require_one_based_indexing(C) - require_one_based_indexing(A) - require_one_based_indexing(B) - nA, mA = size(A) - nB, mB = size(B) - nC, mC = size(C) - if nA != nC - throw(DimensionMismatch("sizes size(A)=$(size(A)) and size(C) = $(size(C)) must match at first entry.")) - elseif mA != nB - throw(DimensionMismatch("second entry of size(A)=$(size(A)) and first entry of size(B) = $(size(B)) must match.")) - elseif mB != mC - throw(DimensionMismatch("sizes size(B)=$(size(B)) and size(C) = $(size(C)) must match at first second entry.")) + mA, nA = size(A) + mB, nB = size(B) + mC, nC = size(C) + if mA != mC + throw(DimensionMismatch("first dimension of A, $mA, and first dimension of output C, $mC, must match")) + elseif nA != mB + throw(DimensionMismatch("second dimension of A, $nA, and first dimension of B, $mB, must match")) + elseif nB != nC + throw(DimensionMismatch("second dimension of output C, $nC, and second dimension of B, $nB, must match")) end end # function to get the internally stored vectors for Bidiagonal and [Sym]Tridiagonal -# to avoid allocations in A_mul_B_td! below (#24324, #24578) +# to avoid allocations in _mul! below (#24324, #24578) _diag(A::Tridiagonal, k) = k == -1 ? A.dl : k == 0 ? A.d : A.du _diag(A::SymTridiagonal, k) = k == 0 ? A.dv : A.ev function _diag(A::Bidiagonal, k) @@ -451,8 +445,7 @@ function _diag(A::Bidiagonal, k) end end -function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, - _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, _add::MulAddMul = MulAddMul()) check_A_mul_B!_sizes(C, A, B) n = size(A,1) n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) @@ -509,10 +502,11 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::BiTriSym, C end -function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, - _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, _add::MulAddMul = MulAddMul()) + require_one_based_indexing(C) check_A_mul_B!_sizes(C, A, B) n = size(A,1) + iszero(n) && return C n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) _rmul_or_fill!(C, _add.beta) # see the same use above iszero(_add.alpha) && return C @@ -544,10 +538,8 @@ function A_mul_B_td!(C::AbstractMatrix, A::BiTriSym, B::Diagonal, C end -function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, - _add::MulAddMul = MulAddMul()) - require_one_based_indexing(C) - require_one_based_indexing(B) +function _mul!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, _add::MulAddMul = MulAddMul()) + require_one_based_indexing(C, B) nA = size(A,1) nB = size(B,2) if !(size(C,1) == size(B,1) == nA) @@ -556,6 +548,7 @@ function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, if size(C,2) != nB throw(DimensionMismatch("A has second dimension $nA, B has $(size(B,2)), C has $(size(C,2)) but all must match")) end + iszero(nA) && return C iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) nA <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) l = _diag(A, -1) @@ -575,8 +568,8 @@ function A_mul_B_td!(C::AbstractVecOrMat, A::BiTriSym, B::AbstractVecOrMat, C end -function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, - _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, _add::MulAddMul = MulAddMul()) + require_one_based_indexing(C, A) check_A_mul_B!_sizes(C, A, B) iszero(_add.alpha) && return _rmul_or_fill!(C, _add.beta) n = size(A,1) @@ -610,8 +603,8 @@ function A_mul_B_td!(C::AbstractMatrix, A::AbstractMatrix, B::BiTriSym, C end -function A_mul_B_td!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, - _add::MulAddMul = MulAddMul()) +function _mul!(C::AbstractMatrix, A::Diagonal, B::BiTriSym, _add::MulAddMul = MulAddMul()) + require_one_based_indexing(C) check_A_mul_B!_sizes(C, A, B) n = size(A,1) n <= 3 && return mul!(C, Array(A), Array(B), _add.alpha, _add.beta) @@ -648,51 +641,38 @@ end function *(A::UpperOrUnitUpperTriangular, B::Bidiagonal) TS = promote_op(matprod, eltype(A), eltype(B)) - C = A_mul_B_td!(zeros(TS, size(A)), A, B) + C = mul!(similar(A, TS, size(A)), A, B) return B.uplo == 'U' ? UpperTriangular(C) : C end function *(A::LowerOrUnitLowerTriangular, B::Bidiagonal) TS = promote_op(matprod, eltype(A), eltype(B)) - C = A_mul_B_td!(zeros(TS, size(A)), A, B) + C = mul!(similar(A, TS, size(A)), A, B) return B.uplo == 'L' ? LowerTriangular(C) : C end function *(A::Bidiagonal, B::UpperOrUnitUpperTriangular) TS = promote_op(matprod, eltype(A), eltype(B)) - C = A_mul_B_td!(zeros(TS, size(A)), A, B) + C = mul!(similar(B, TS, size(B)), A, B) return A.uplo == 'U' ? UpperTriangular(C) : C end function *(A::Bidiagonal, B::LowerOrUnitLowerTriangular) TS = promote_op(matprod, eltype(A), eltype(B)) - C = A_mul_B_td!(zeros(TS, size(A)), A, B) + C = mul!(similar(B, TS, size(B)), A, B) return A.uplo == 'L' ? LowerTriangular(C) : C end -function *(A::BiTri, B::Diagonal) - TS = promote_op(matprod, eltype(A), eltype(B)) - A_mul_B_td!(similar(A, TS), A, B) -end - -function *(A::Diagonal, B::BiTri) - TS = promote_op(matprod, eltype(A), eltype(B)) - A_mul_B_td!(similar(B, TS), A, B) -end - function *(A::Diagonal, B::SymTridiagonal) - TS = promote_op(matprod, eltype(A), eltype(B)) - A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B) + TS = promote_op(*, eltype(A), eltype(B)) + out = Tridiagonal(similar(A, TS, size(A, 1)-1), similar(A, TS, size(A, 1)), similar(A, TS, size(A, 1)-1)) + mul!(out, A, B) end function *(A::SymTridiagonal, B::Diagonal) - TS = promote_op(matprod, eltype(A), eltype(B)) - A_mul_B_td!(Tridiagonal(zeros(TS, size(A, 1)-1), zeros(TS, size(A, 1)), zeros(TS, size(A, 1)-1)), A, B) -end - -function *(A::BiTriSym, B::BiTriSym) - TS = promote_op(matprod, eltype(A), eltype(B)) - mul!(similar(A, TS, size(A)), A, B) + TS = promote_op(*, eltype(A), eltype(B)) + out = Tridiagonal(similar(A, TS, size(A, 1)-1), similar(A, TS, size(A, 1)), similar(A, TS, size(A, 1)-1)) + mul!(out, A, B) end function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector) @@ -924,7 +904,7 @@ function inv(B::Bidiagonal{T}) where T end # Eigensystems -eigvals(M::Bidiagonal) = M.dv +eigvals(M::Bidiagonal) = copy(M.dv) function eigvecs(M::Bidiagonal{T}) where T n = length(M.dv) Q = Matrix{T}(undef, n,n) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1cb4240722ce2..03c5a2bbdeba4 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -121,7 +121,7 @@ Construct an uninitialized `Diagonal{T}` of length `n`. See `undef`. Diagonal{T}(::UndefInitializer, n::Integer) where T = Diagonal(Vector{T}(undef, n)) similar(D::Diagonal, ::Type{T}) where {T} = Diagonal(similar(D.diag, T)) -similar(::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...) +similar(D::Diagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(D.diag, T, dims) copyto!(D1::Diagonal, D2::Diagonal) = (copyto!(D1.diag, D2.diag); D1) @@ -270,8 +270,12 @@ function (*)(D::Diagonal, V::AbstractVector) end (*)(A::AbstractMatrix, D::Diagonal) = + mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), A, D) +(*)(A::HermOrSym, D::Diagonal) = mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D) (*)(D::Diagonal, A::AbstractMatrix) = + mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag))), D, A) +(*)(D::Diagonal, A::HermOrSym) = mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A) rmul!(A::AbstractMatrix, D::Diagonal) = @inline mul!(A, A, D) diff --git a/stdlib/LinearAlgebra/src/hessenberg.jl b/stdlib/LinearAlgebra/src/hessenberg.jl index eb7c3642ae7b4..c7c4c3a50a239 100644 --- a/stdlib/LinearAlgebra/src/hessenberg.jl +++ b/stdlib/LinearAlgebra/src/hessenberg.jl @@ -164,12 +164,12 @@ end function *(H::UpperHessenberg, B::Bidiagonal) TS = promote_op(matprod, eltype(H), eltype(B)) - A = A_mul_B_td!(zeros(TS, size(H)), H, B) + A = mul!(similar(H, TS, size(H)), H, B) return B.uplo == 'U' ? UpperHessenberg(A) : A end function *(B::Bidiagonal, H::UpperHessenberg) TS = promote_op(matprod, eltype(B), eltype(H)) - A = A_mul_B_td!(zeros(TS, size(H)), B, H) + A = mul!(similar(H, TS, size(H)), B, H) return B.uplo == 'U' ? UpperHessenberg(A) : A end diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index e43e9e699e3a9..8821dc064a960 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -149,7 +149,7 @@ function size(A::SymTridiagonal, d::Integer) end similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T)) -similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...) +similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(S.dv, T, dims) copyto!(dest::SymTridiagonal, src::SymTridiagonal) = (copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest) @@ -215,56 +215,6 @@ end \(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev) ==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (_evview(A)==_evview(B)) -@inline mul!(A::AbstractVector, B::SymTridiagonal, C::AbstractVector, - alpha::Number, beta::Number) = - _mul!(A, B, C, MulAddMul(alpha, beta)) -@inline mul!(A::AbstractMatrix, B::SymTridiagonal, C::AbstractVecOrMat, - alpha::Number, beta::Number) = - _mul!(A, B, C, MulAddMul(alpha, beta)) -# disambiguation -@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::Transpose{<:Any,<:AbstractVecOrMat}, - alpha::Number, beta::Number) = - _mul!(C, A, B, MulAddMul(alpha, beta)) -@inline mul!(C::AbstractMatrix, A::SymTridiagonal, B::Adjoint{<:Any,<:AbstractVecOrMat}, - alpha::Number, beta::Number) = - _mul!(C, A, B, MulAddMul(alpha, beta)) - -@inline function _mul!(C::AbstractVecOrMat, S::SymTridiagonal, B::AbstractVecOrMat, - _add::MulAddMul) - m, n = size(B, 1), size(B, 2) - if !(m == size(S, 1) == size(C, 1)) - throw(DimensionMismatch("A has first dimension $(size(S,1)), B has $(size(B,1)), C has $(size(C,1)) but all must match")) - end - if n != size(C, 2) - throw(DimensionMismatch("second dimension of B, $n, doesn't match second dimension of C, $(size(C,2))")) - end - - if m == 0 - return C - elseif iszero(_add.alpha) - return _rmul_or_fill!(C, _add.beta) - end - - α = S.dv - β = S.ev - @inbounds begin - for j = 1:n - x₊ = B[1, j] - x₀ = zero(x₊) - # If m == 1 then β[1] is out of bounds - β₀ = m > 1 ? zero(β[1]) : zero(eltype(β)) - for i = 1:m - 1 - x₋, x₀, x₊ = x₀, x₊, B[i + 1, j] - β₋, β₀ = β₀, β[i] - _modify!(_add, β₋*x₋ + α[i]*x₀ + β₀*x₊, C, (i, j)) - end - _modify!(_add, β₀*x₀ + α[m]*x₊, C, (m, j)) - end - end - - return C -end - function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector) require_one_based_indexing(x, y) nx, ny = length(x), length(y) @@ -605,7 +555,7 @@ Matrix(M::Tridiagonal{T}) where {T} = Matrix{promote_type(T, typeof(zero(T)))}(M Array(M::Tridiagonal) = Matrix(M) similar(M::Tridiagonal, ::Type{T}) where {T} = Tridiagonal(similar(M.dl, T), similar(M.d, T), similar(M.du, T)) -similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = zeros(T, dims...) +similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = similar(M.d, T, dims) # Operations on Tridiagonal matrices copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)