From 1cfda3f9b1a88c8f6069b2cec03fbc957f3ccd3f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 1 Oct 2024 18:10:42 +0530 Subject: [PATCH] Strong zero in Diagonal triple multiplication (#55927) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Currently, triple multiplication with a `LinearAlgebra.BandedMatrix` sandwiched between two `Diagonal`s isn't associative, as this is implemented using broadcasting, which doesn't assume a strong zero, whereas the two-term matrix multiplication does. ```julia julia> D = Diagonal(StepRangeLen(NaN, 0, 3)); julia> B = Bidiagonal(1:3, 1:2, :U); julia> D * B * D 3×3 Matrix{Float64}: NaN NaN NaN NaN NaN NaN NaN NaN NaN julia> (D * B) * D 3×3 Bidiagonal{Float64, Vector{Float64}}: NaN NaN ⋅ ⋅ NaN NaN ⋅ ⋅ NaN julia> D * (B * D) 3×3 Bidiagonal{Float64, Vector{Float64}}: NaN NaN ⋅ ⋅ NaN NaN ⋅ ⋅ NaN ``` This PR ensures that the 3-term multiplication is evaluated as a sequence of two-term multiplications, which fixes this issue. This also improves performance, as only the bands need to be evaluated now. ```julia julia> D = Diagonal(1:1000); B = Bidiagonal(1:1000, 1:999, :U); julia> @btime $D * $B * $D; 656.364 μs (11 allocations: 7.63 MiB) # v"1.12.0-DEV.1262" 2.483 μs (12 allocations: 31.50 KiB) # This PR ``` --- stdlib/LinearAlgebra/src/special.jl | 2 ++ stdlib/LinearAlgebra/test/diagonal.jl | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 5a7c98cfdf32c..32a5476842933 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -112,6 +112,8 @@ for op in (:+, :-) end end +(*)(Da::Diagonal, A::BandedMatrix, Db::Diagonal) = _tri_matmul(Da, A, Db) + # disambiguation between triangular and banded matrices, banded ones "dominate" _mul!(C::AbstractMatrix, A::AbstractTriangular, B::BandedMatrix, alpha::Number, beta::Number) = @stable_muladdmul _mul!(C, A, B, MulAddMul(alpha, beta)) diff --git a/stdlib/LinearAlgebra/test/diagonal.jl b/stdlib/LinearAlgebra/test/diagonal.jl index dfb901908ba69..98f5498c71033 100644 --- a/stdlib/LinearAlgebra/test/diagonal.jl +++ b/stdlib/LinearAlgebra/test/diagonal.jl @@ -1265,6 +1265,17 @@ end @test *(Diagonal(ones(n)), Diagonal(1:n), Diagonal(ones(n)), Diagonal(1:n)) isa Diagonal end +@testset "triple multiplication with a sandwiched BandedMatrix" begin + D = Diagonal(StepRangeLen(NaN, 0, 4)); + B = Bidiagonal(1:4, 1:3, :U) + C = D * B * D + @test iszero(diag(C, 2)) + # test associativity + C1 = (D * B) * D + C2 = D * (B * D) + @test diag(C,2) == diag(C1,2) == diag(C2,2) +end + @testset "diagind" begin D = Diagonal(1:4) M = Matrix(D)