From 5be3e27e029835cb56dd6934d302680c26f6e21b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 20 Nov 2020 04:48:05 +0100 Subject: [PATCH] Let `muladd` accept a more restricted set of arrays (#38250) This adjusts #37065 to be much more cautious about what arrays it acts on: it calls mul! on StridedArrays, treats a few special types like Diagonal, UpperTriangular, and UniformScaling, and sends anything else to muladd(A,y,z) = A*y .+ z. However this broadcasting restricts the shape of z, mostly such that A*y .= z would work. That ensures you should get the same error from the mul!(::StridedMatrix, ...) method, as from the fallback broadcasting one. Both allow z of lower dimension than the existing muladd(x,y,z) = x*y+z. But x*y+z also allows z to have trailing dimensions, as long as they are of size 1. I made the broadcasting method allow these too, which I think should make this non-breaking. (I presume this is rarely used, and thus not worth sending to the fast method.) Structured matrices such as UpperTriangular should all go to x*y+z. Some combinations could be made more efficient but it gets complicated. Only the case of 3 diagonals is handled. --- stdlib/LinearAlgebra/src/diagonal.jl | 4 + stdlib/LinearAlgebra/src/matmul.jl | 56 ++++++--- stdlib/LinearAlgebra/src/uniformscaling.jl | 9 ++ stdlib/LinearAlgebra/test/matmul.jl | 130 +++++++++++++++------ 4 files changed, 147 insertions(+), 52 deletions(-) diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 52a0b782c2c77..467dea057bd9e 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -752,3 +752,7 @@ function logabsdet(A::Diagonal) mapreduce(x -> (log(abs(x)), sign(x)), ((d1, s1), (d2, s2)) -> (d1 + d2, s1 * s2), A.diag) end + +function Base.muladd(A::Diagonal, B::Diagonal, z::Diagonal) + Diagonal(A.diag .* B.diag .+ z.diag) +end diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 7e5c6f06cdd92..4fed8a577bd63 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -201,34 +201,54 @@ julia> muladd(A, B, z) 107.0 107.0 ``` """ -function Base.muladd(A::AbstractMatrix{TA}, y::AbstractVector{Ty}, z) where {TA, Ty} - T = promote_type(TA, Ty, eltype(z)) +function Base.muladd(A::AbstractMatrix, y::AbstractVecOrMat, z::Union{Number, AbstractArray}) + Ay = A * y + for d in 1:ndims(Ay) + # Same error as Ay .+= z would give, to match StridedMatrix method: + size(z,d) > size(Ay,d) && throw(DimensionMismatch("array could not be broadcast to match destination")) + end + for d in ndims(Ay)+1:ndims(z) + # Similar error to what Ay + z would give, to match (Any,Any,Any) method: + size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ", + axes(z), ", must have singleton at dim ", d))) + end + Ay .+ z +end + +function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z::Union{Number, AbstractArray}) + if size(z,1) > length(u) || size(z,2) > length(v) + # Same error as (u*v) .+= z: + throw(DimensionMismatch("array could not be broadcast to match destination")) + end + for d in 3:ndims(z) + # Similar error to (u*v) + z: + size(z,d) > 1 && throw(DimensionMismatch(string("dimensions must match: z has dims ", + axes(z), ", must have singleton at dim ", d))) + end + (u .* v) .+ z +end + +Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) = + muladd(A', x', z')' +Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z::Union{Number, AbstractVecOrMat}) = + transpose(muladd(transpose(A), transpose(x), transpose(z))) + +StridedMaybeAdjOrTransMat{T} = Union{StridedMatrix{T}, Adjoint{T, <:StridedMatrix}, Transpose{T, <:StridedMatrix}} + +function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, y::AbstractVector{<:Number}, z::Union{Number, AbstractVector}) + T = promote_type(eltype(A), eltype(y), eltype(z)) C = similar(A, T, axes(A,1)) C .= z mul!(C, A, y, true, true) end -function Base.muladd(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}, z) where {TA, TB} - T = promote_type(TA, TB, eltype(z)) +function Base.muladd(A::StridedMaybeAdjOrTransMat{<:Number}, B::StridedMaybeAdjOrTransMat{<:Number}, z::Union{Number, AbstractVecOrMat}) + T = promote_type(eltype(A), eltype(B), eltype(z)) C = similar(A, T, axes(A,1), axes(B,2)) C .= z mul!(C, A, B, true, true) end -Base.muladd(x::AdjointAbsVec, A::AbstractMatrix, z) = muladd(A', x', z')' -Base.muladd(x::TransposeAbsVec, A::AbstractMatrix, z) = transpose(muladd(transpose(A), transpose(x), transpose(z))) - -function Base.muladd(u::AbstractVector, v::AdjOrTransAbsVec, z) - ndims(z) > 2 && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions")) - (u .* v) .+ z -end - -function Base.muladd(u::AdjOrTransAbsVec, v::AbstractVector, z) - uv = _dot_nonrecursive(u, v) - ndims(z) > ndims(uv) && throw(DimensionMismatch("cannot broadcast array to have fewer dimensions")) - uv .+ z -end - """ mul!(Y, A, B) -> Y diff --git a/stdlib/LinearAlgebra/src/uniformscaling.jl b/stdlib/LinearAlgebra/src/uniformscaling.jl index a22b2ca06c8d6..c59871e0641ef 100644 --- a/stdlib/LinearAlgebra/src/uniformscaling.jl +++ b/stdlib/LinearAlgebra/src/uniformscaling.jl @@ -483,3 +483,12 @@ Diagonal(s::UniformScaling, m::Integer) = Diagonal{eltype(s)}(s, m) dot(x::AbstractVector, J::UniformScaling, y::AbstractVector) = dot(x, J.λ, y) dot(x::AbstractVector, a::Number, y::AbstractVector) = sum(t -> dot(t[1], a, t[2]), zip(x, y)) dot(x::AbstractVector, a::Union{Real,Complex}, y::AbstractVector) = a*dot(x, y) + +# muladd +Base.muladd(A::UniformScaling, B::UniformScaling, z::UniformScaling) = + UniformScaling(A.λ * B.λ + z.λ) +Base.muladd(A::Union{Diagonal, UniformScaling}, B::Union{Diagonal, UniformScaling}, z::Union{Diagonal, UniformScaling}) = + Diagonal(_diag_or_value(A) .* _diag_or_value(B) .+ _diag_or_value(z)) + +_diag_or_value(A::Diagonal) = A.diag +_diag_or_value(A::UniformScaling) = A.λ diff --git a/stdlib/LinearAlgebra/test/matmul.jl b/stdlib/LinearAlgebra/test/matmul.jl index 564f5882e65e0..6a81222174e1f 100644 --- a/stdlib/LinearAlgebra/test/matmul.jl +++ b/stdlib/LinearAlgebra/test/matmul.jl @@ -293,45 +293,107 @@ end end @testset "muladd" begin - A23 = reshape(1:6, 2,3) + A23 = reshape(1:6, 2,3) .+ 0 B34 = reshape(1:12, 3,4) .+ im u2 = [10,20] v3 = [3,5,7] .+ im w4 = [11,13,17,19im] - @test muladd(A23, B34, 100) == A23 * B34 .+ 100 - @test muladd(A23, B34, u2) == A23 * B34 .+ u2 - @test muladd(A23, B34, w4') == A23 * B34 .+ w4' - @test_throws DimensionMismatch muladd(B34, A23, 1) - @test_throws DimensionMismatch muladd(A23, B34, ones(2,4,1)) - - @test muladd(A23, v3, 100) == A23 * v3 .+ 100 - @test muladd(A23, v3, u2) == A23 * v3 .+ u2 - @test muladd(A23, v3, im) isa Vector{Complex{Int}} - @test_throws DimensionMismatch muladd(A23, v3, ones(2,2)) - - @test muladd(v3', B34, 0) isa Adjoint - @test muladd(v3', B34, 2im) == v3' * B34 .+ 2im - @test muladd(v3', B34, w4') == v3' * B34 .+ w4' - @test_throws DimensionMismatch muladd(v3', B34, ones(1,4)) - - @test muladd(u2, v3', 0) isa Matrix - @test muladd(u2, v3', 99) == u2 * v3' .+ 99 - @test muladd(u2, v3', A23) == u2 * v3' .+ A23 - @test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4)) - - @test muladd(u2', u2, 0) isa Number - @test muladd(v3', v3, im) == dot(v3,v3) + im - @test_throws DimensionMismatch muladd(v3', v3, [1]) - - vofm = [rand(1:9,2,2) for _ in 1:3] - Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3] - - @test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner - @test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer - @test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat - @test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat - @test_broken muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec + @testset "matrix-matrix" begin + @test muladd(A23, B34, 0) == A23 * B34 + @test muladd(A23, B34, 100) == A23 * B34 .+ 100 + @test muladd(A23, B34, u2) == A23 * B34 .+ u2 + @test muladd(A23, B34, w4') == A23 * B34 .+ w4' + @test_throws DimensionMismatch muladd(B34, A23, 1) + @test muladd(ones(1,3), ones(3,4), ones(1,4)) == fill(4.0,1,4) + @test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4)) + + # broadcasting fallback method allows trailing dims + @test muladd(A23, B34, ones(2,4,1)) == A23 * B34 + ones(2,4,1) + @test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(9,4,1)) + @test_throws DimensionMismatch muladd(ones(1,3), ones(3,4), ones(1,4,9)) + # and catches z::Array{T,0} + @test muladd(A23, B34, fill(0)) == A23 * B34 + end + @testset "matrix-vector" begin + @test muladd(A23, v3, 0) == A23 * v3 + @test muladd(A23, v3, 100) == A23 * v3 .+ 100 + @test muladd(A23, v3, u2) == A23 * v3 .+ u2 + @test muladd(A23, v3, im) isa Vector{Complex{Int}} + @test muladd(ones(1,3), ones(3), ones(1)) == [4] + @test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7)) + + # fallback + @test muladd(A23, v3, ones(2,1,1)) == A23 * v3 + ones(2,1,1) + @test_throws DimensionMismatch muladd(A23, v3, ones(2,2)) + @test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(7,1)) + @test_throws DimensionMismatch muladd(ones(1,3), ones(3), ones(1,7)) + @test muladd(A23, v3, fill(0)) == A23 * v3 + end + @testset "adjoint-matrix" begin + @test muladd(v3', B34, 0) isa Adjoint + @test muladd(v3', B34, 2im) == v3' * B34 .+ 2im + @test muladd(v3', B34, w4') == v3' * B34 .+ w4' + + # via fallback + @test muladd(v3', B34, ones(1,4)) == (B34' * v3 + ones(4,1))' + @test_throws DimensionMismatch muladd(v3', B34, ones(7,4)) + @test_throws DimensionMismatch muladd(v3', B34, ones(1,4,7)) + @test muladd(v3', B34, fill(0)) == v3' * B34 # does not make an Adjoint + end + @testset "vector-adjoint" begin + @test muladd(u2, v3', 0) isa Matrix + @test muladd(u2, v3', 99) == u2 * v3' .+ 99 + @test muladd(u2, v3', A23) == u2 * v3' .+ A23 + + @test muladd(u2, v3', ones(2,3,1)) == u2 * v3' + ones(2,3,1) + @test_throws DimensionMismatch muladd(u2, v3', ones(2,3,4)) + @test_throws DimensionMismatch muladd([1], v3', ones(7,3)) + @test muladd(u2, v3', fill(0)) == u2 * v3' + end + @testset "dot" begin # all use muladd(::Any, ::Any, ::Any) + @test muladd(u2', u2, 0) isa Number + @test muladd(v3', v3, im) == dot(v3,v3) + im + @test muladd(u2', u2, [1]) == [dot(u2,u2) + 1] + @test_throws DimensionMismatch muladd(u2', u2, [1,1]) == [dot(u2,u2) + 1] + @test muladd(u2', u2, fill(0)) == dot(u2,u2) + end + @testset "arrays of arrays" begin + vofm = [rand(1:9,2,2) for _ in 1:3] + Mofm = [rand(1:9,2,2) for _ in 1:3, _ in 1:3] + + @test muladd(vofm', vofm, vofm[1]) == vofm' * vofm .+ vofm[1] # inner + @test muladd(vofm, vofm', Mofm) == vofm * vofm' .+ Mofm # outer + @test muladd(vofm', Mofm, vofm') == vofm' * Mofm .+ vofm' # bra-mat + @test muladd(Mofm, Mofm, vofm) == Mofm * Mofm .+ vofm # mat-mat + @test muladd(Mofm, vofm, vofm) == Mofm * vofm .+ vofm # mat-vec + end +end + +@testset "muladd & structured matrices" begin + A33 = reshape(1:9, 3,3) .+ im + v3 = [3,5,7im] + + # no special treatment + @test muladd(Symmetric(A33), Symmetric(A33), 1) == Symmetric(A33) * Symmetric(A33) .+ 1 + @test muladd(Hermitian(A33), Hermitian(A33), v3) == Hermitian(A33) * Hermitian(A33) .+ v3 + @test muladd(adjoint(A33), transpose(A33), A33) == A33' * transpose(A33) .+ A33 + + u1 = muladd(UpperTriangular(A33), UpperTriangular(A33), Diagonal(v3)) + @test u1 isa UpperTriangular + @test u1 == UpperTriangular(A33) * UpperTriangular(A33) + Diagonal(v3) + + # diagonal + @test muladd(Diagonal(v3), Diagonal(A33), Diagonal(v3)).diag == ([1,5,9] .+ im .+ 1) .* v3 + + # uniformscaling + @test muladd(Diagonal(v3), I, I).diag == v3 .+ 1 + @test muladd(2*I, 3*I, I).λ == 7 + @test muladd(A33, A33', I) == A33 * A33' + I + + # https://github.com/JuliaLang/julia/issues/38426 + @test @evalpoly(A33, 1.0*I, 1.0*I) == I + A33 + @test @evalpoly(A33, 1.0*I, 1.0*I, 1.0*I) == I + A33 + A33^2 end # issue #6450