Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix size-1 StructuredMatrix's broadcast. #54190

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ find_uplo(bc::Broadcasted) = mapfoldl(find_uplo, merge_uplos, Broadcast.cat_nest
function structured_broadcast_alloc(bc, ::Type{Bidiagonal}, ::Type{ElType}, n) where {ElType}
uplo = n > 0 ? find_uplo(bc) : 'U'
n1 = max(n - 1, 0)
if uplo == 'T'
if count_structedmatrix(Bidiagonal, bc) > 1 && uplo == 'T'
return Tridiagonal(Array{ElType}(undef, n1), Array{ElType}(undef, n), Array{ElType}(undef, n1))
end
return Bidiagonal(Array{ElType}(undef, n),Array{ElType}(undef, n1), uplo)
Expand Down Expand Up @@ -135,24 +135,36 @@ iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)

fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
count_structedmatrix(T, bc::Broadcasted) = sum(Base.Fix2(isa, T), Broadcast.cat_nested(bc); init = 0)

function fzeropreserving(bc)
n = count_structedmatrix(StructuredMatrix, bc)
v = fzero(bc, Val(n==1))
!ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0)
end
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
# expression is stable. We can test the zero-preservability by applying the function
# in cases where all other arguments are known scalars against a zero from the structured
# matrix. If any non-structured matrix argument is not a known scalar, we give up.
fzero(x::Number) = x
fzero(::Type{T}) where T = T
fzero(r::Ref) = r[]
fzero(t::Tuple{Any}) = t[1]
fzero(S::StructuredMatrix) = zero(eltype(S))
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing
fzero(x) = missing
function fzero(bc::Broadcast.Broadcasted)
args = map(fzero, bc.args)
return any(ismissing, args) ? missing : bc.f(args...)
fzero(x::Number, ::Val) = x
fzero(::Type{T}, ::Val) where T = T
fzero(r::Union{Ref,AbstractArray{<:Any,0}}, ::Val) = r[]
fzero(t::Tuple{Any}, ::Val) = t[1]
# The check below is tricky as size-1 `StructuredMatrix`s behave like scalar during broadcast.
# So we have to check their size if there are more than 1 broadcasted arguments which <: StructuredMatrix.
fzero(S::StructuredMatrix, ::Val{O}) where {O} = !O && isone(size(S, 1)) ? S[1, 1] : zero(eltype(S))
fzero(S::StructuredMatrix{<:AbstractMatrix{T}}, ::Val{O}) where {T<:Number,O} = !O && isone(size(S, 1)) ? S[1, 1] : haszero(T) ? zero(T)*I : missing
fzero(x, ::Val) = missing
@inline function fzero(bc::Broadcast.Broadcasted, v::Val)
args = map(Base.Fix2(fzero, v), bc.args)
return anymissing(args) ? missing : bc.f(args...)
end
# force unroll to keep stability
anymissing(x::Tuple{Any,Vararg}) = anymissing(Base.tail(x))
anymissing(::Tuple{Missing,Vararg}) = true
anymissing(::Tuple{}) = false

function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
Base.@constprop :aggressive function Base.similar(bc::Broadcasted{StructuredMatrixStyle{T}}, ::Type{ElType}) where {T,ElType}
inds = axes(bc)
fzerobc = fzeropreserving(bc)
if isstructurepreserving(bc) || (fzerobc && !(T <: Union{SymTridiagonal,UnitLowerTriangular,UnitUpperTriangular}))
Expand Down
63 changes: 38 additions & 25 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@ using Test, LinearAlgebra
M = Matrix(rand(N,N))
structuredarrays = (D, B, T, U, L, M)
fstructuredarrays = map(Array, structuredarrays)

# help functions used to ensure simple structured broadcast is stable
mul2(X) = (g(x) = X .* 2.0; @inferred(g(X)))
mult2(X) = (g(X) = X .* (2.0,); @inferred(g(X)))
mulinf(X) = (g(x) = x .* Inf; @inferred(g(X)))
lpow2(X) = (g(X) = X .^ 2; @inferred(g(X)))
lpow0(X) = (g(X) = X .^ 0; @inferred(g(X)))
pow2(X) = (g(x) = (two = 2; x.^two); @inferred(g(X)))
lpow_1(X) = (g(X) = X .^ -1; @inferred(g(X)))
powt2(X) = (g(X) = X .^ (2,); @inferred(g(X)))

for (X, fX) in zip(structuredarrays, fstructuredarrays)
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@test (Q = @inferred(broadcast(sin, X)); typeof(Q) == typeof(X) && Q == broadcast(sin, fX))
@test broadcast!(sin, Z, X) == broadcast(sin, fX)
@test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX))
@test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX))
@test broadcast!(cos, Z, X) == broadcast(cos, fX)
@test (Q = broadcast(*, s, X); typeof(Q) == typeof(X) && Q == broadcast(*, s, fX))
@test broadcast!(*, Z, s, X) == broadcast(*, s, fX)
Expand All @@ -29,18 +40,12 @@ using Test, LinearAlgebra
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)

@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
@test X .* 2.0 isa typeof(X)
@test X .* (2.0,) isa typeof(X)
@test isequal(X .* Inf, fX .* Inf)
@test mul2(X)::typeof(X) == mult2(X)::typeof(X) == mult2(fX)
@test isequal(mulinf(X), mulinf(fX))

two = 2
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
@test X .^ 2 isa typeof(X)
@test X .^ (2,) isa typeof(X)
@test X .^ two isa typeof(X)
@test X .^ 0 == fX .^ 0
@test X .^ -1 == fX .^ -1
@test lpow2(X)::typeof(X) == powt2(X)::typeof(X) == pow2(X)::typeof(X) == lpow2(fX)
@test lpow0(X) == lpow0(fX)
@test lpow_1(X) == lpow_1(fX)

for (Y, fY) in zip(structuredarrays, fstructuredarrays)
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
Expand All @@ -65,9 +70,9 @@ using Test, LinearAlgebra
Ttris = typeof.((UpperTriangular(parent(UU)), LowerTriangular(parent(UU))))
funittriangulars = map(Array, unittriangulars)
for (X, fX, Ttri) in zip(unittriangulars, funittriangulars, Ttris)
@test (Q = broadcast(sin, X); typeof(Q) == Ttri && Q == broadcast(sin, fX))
@test (Q = @inferred(broadcast(sin, X)); typeof(Q) == Ttri && Q == broadcast(sin, fX))
@test broadcast!(sin, Z, X) == broadcast(sin, fX)
@test (Q = broadcast(cos, X); Q isa Matrix && Q == broadcast(cos, fX))
@test (Q = @inferred(broadcast(cos, X)); Q isa Matrix && Q == broadcast(cos, fX))
@test broadcast!(cos, Z, X) == broadcast(cos, fX)
@test (Q = broadcast(*, s, X); typeof(Q) == Ttri && Q == broadcast(*, s, fX))
@test broadcast!(*, Z, s, X) == broadcast(*, s, fX)
Expand All @@ -76,18 +81,14 @@ using Test, LinearAlgebra
@test (Q = broadcast(*, s, fV, fA, X); Q isa Matrix && Q == broadcast(*, s, fV, fA, fX))
@test broadcast!(*, Z, s, fV, fA, X) == broadcast(*, s, fV, fA, fX)

@test X .* 2.0 == X .* (2.0,) == fX .* 2.0
@test X .* 2.0 isa Ttri
@test X .* (2.0,) isa Ttri
@test isequal(X .* Inf, fX .* Inf)
@test mul2(X)::Ttri == mult2(X)::Ttri == mul2(fX)
@test isequal(mulinf(X), mulinf(fX))

two = 2
@test X .^ 2 == X .^ (2,) == fX .^ 2 == X .^ two
@test X .^ 2 isa typeof(X) # special cased, as isstructurepreserving
@test X .^ (2,) isa Ttri
@test X .^ two isa Ttri
@test X .^ 0 == fX .^ 0
@test X .^ -1 == fX .^ -1
@test lpow2(X)::typeof(X) == # special cased, as isstructurepreserving
powt2(X)::Ttri == pow2(X)::Ttri == lpow2(fX)
@test lpow0(X) == lpow0(fX)
@test lpow_1(X) == lpow_1(fX)

for (Y, fY) in zip(unittriangulars, funittriangulars)
@test broadcast(+, X, Y) == broadcast(+, fX, fY)
Expand Down Expand Up @@ -338,4 +339,16 @@ end
end
end

@testset "Issue 54087: size-1 structured matrix's broadcast" begin
Ns = 1, 3
D1, D2 = map(N->Diagonal(rand(N)), Ns)
B1, B2 = map(N->Bidiagonal(rand(N), rand(N - 1), :U), Ns)
T1, T2 = map(N->Tridiagonal(rand(N - 1), rand(N), rand(N - 1)), Ns)
Ss = [D1, D2, B1, B2, T1, T2]
MSs = Matrix.(Ss)
for ((S1, M1), (S2, M2)) in Iterators.product(zip(Ss, MSs), zip(Ss, MSs))
@test S1 .+ S2 == M1 .+ M2
end
end

end