Skip to content

Commit

Permalink
Centralize broadcast support for structured matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Jan 7, 2018
1 parent 4c02b07 commit e4d1962
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 36 deletions.
3 changes: 3 additions & 0 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ BroadcastStyle(::Type{<:Ref}) = DefaultArrayStyle{0}()
# 3 or more arguments still return an `ArrayConflict`.
struct ArrayConflict <: AbstractArrayStyle{Any} end

# This will be used for Diagonal, Bidiagonal, Tridiagonal, and SymTridiagonal
struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end

### Binary BroadcastStyle rules
"""
BroadcastStyle(::Style1, ::Style2) = Style3()
Expand Down
27 changes: 17 additions & 10 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,23 @@ Bidiagonal{T}(A::Bidiagonal) where {T} =
# When asked to convert Bidiagonal to AbstractMatrix{T}, preserve structure by converting to Bidiagonal{T} <: AbstractMatrix{T}
AbstractMatrix{T}(A::Bidiagonal) where {T} = convert(Bidiagonal{T}, A)

broadcast(::typeof(big), B::Bidiagonal) = Bidiagonal(big.(B.dv), big.(B.ev), B.uplo)
function copyto!(dest::Bidiagonal, bc::Broadcasted{PromoteToSparse})
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
if dest.uplo == 'U'
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
end
else
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i))
end
end
dest
end

# For B<:Bidiagonal, similar(B[, neweltype]) should yield a Bidiagonal matrix.
# On the other hand, similar(B, [neweltype,] shape...) should yield a sparse matrix.
Expand Down Expand Up @@ -234,18 +250,9 @@ function size(M::Bidiagonal, d::Integer)
end

#Elementary operations
broadcast(::typeof(abs), M::Bidiagonal) = Bidiagonal(abs.(M.dv), abs.(M.ev), M.uplo)
broadcast(::typeof(round), M::Bidiagonal) = Bidiagonal(round.(M.dv), round.(M.ev), M.uplo)
broadcast(::typeof(trunc), M::Bidiagonal) = Bidiagonal(trunc.(M.dv), trunc.(M.ev), M.uplo)
broadcast(::typeof(floor), M::Bidiagonal) = Bidiagonal(floor.(M.dv), floor.(M.ev), M.uplo)
broadcast(::typeof(ceil), M::Bidiagonal) = Bidiagonal(ceil.(M.dv), ceil.(M.ev), M.uplo)
for func in (:conj, :copy, :real, :imag)
@eval ($func)(M::Bidiagonal) = Bidiagonal(($func)(M.dv), ($func)(M.ev), M.uplo)
end
broadcast(::typeof(round), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(round.(T, M.dv), round.(T, M.ev), M.uplo)
broadcast(::typeof(trunc), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(trunc.(T, M.dv), trunc.(T, M.ev), M.uplo)
broadcast(::typeof(floor), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(floor.(T, M.dv), floor.(T, M.ev), M.uplo)
broadcast(::typeof(ceil), ::Type{T}, M::Bidiagonal) where {T<:Integer} = Bidiagonal(ceil.(T, M.dv), ceil.(T, M.ev), M.uplo)

transpose(M::Bidiagonal) = Bidiagonal(M.dv, M.ev, M.uplo == 'U' ? :L : :U)
adjoint(M::Bidiagonal) = Bidiagonal(conj(M.dv), conj(M.ev), M.uplo == 'U' ? :L : :U)
Expand Down
10 changes: 9 additions & 1 deletion base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,18 @@ isposdef(D::Diagonal) = all(x -> x > 0, D.diag)

factorize(D::Diagonal) = D

broadcast(::typeof(abs), D::Diagonal) = Diagonal(abs.(D.diag))
real(D::Diagonal) = Diagonal(real(D.diag))
imag(D::Diagonal) = Diagonal(imag(D.diag))

function copyto!(dest::Diagonal, bc::Broadcasted{PromoteToSparse})
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.diag[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
dest
end

istriu(D::Diagonal) = true
istril(D::Diagonal) = true
function triu!(D::Diagonal,k::Integer=0)
Expand Down
2 changes: 2 additions & 0 deletions base/linalg/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
StridedReshapedArray, strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
using Base: hvcat_fill, iszero, IndexLinear, _length, promote_op, promote_typeof,
@propagate_inbounds, @pure, reduce, typed_vcat
using Base.Broadcast: Broadcasted, PromoteToSparse

# We use `_length` because of non-1 indices; releases after julia 0.5
# can go back to `length`. `_length(A)` is equivalent to `length(linearindices(A))`.

Expand Down
47 changes: 25 additions & 22 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,22 @@ end
similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T), similar(S.ev, T))
similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...)

function copyto!(dest::SymTridiagonal, bc::Broadcasted{PromoteToSparse})
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.dv[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
for i = 1:size(dest, 1)-1
dest.ev[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
end
dest
end

#Elementary operations
broadcast(::typeof(abs), M::SymTridiagonal) = SymTridiagonal(abs.(M.dv), abs.(M.ev))
broadcast(::typeof(round), M::SymTridiagonal) = SymTridiagonal(round.(M.dv), round.(M.ev))
broadcast(::typeof(trunc), M::SymTridiagonal) = SymTridiagonal(trunc.(M.dv), trunc.(M.ev))
broadcast(::typeof(floor), M::SymTridiagonal) = SymTridiagonal(floor.(M.dv), floor.(M.ev))
broadcast(::typeof(ceil), M::SymTridiagonal) = SymTridiagonal(ceil.(M.dv), ceil.(M.ev))
for func in (:conj, :copy, :real, :imag)
@eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev))
end
broadcast(::typeof(round), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(round.(T, M.dv), round.(T, M.ev))
broadcast(::typeof(trunc), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(trunc.(T, M.dv), trunc.(T, M.ev))
broadcast(::typeof(floor), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(floor.(T, M.dv), floor.(T, M.ev))
broadcast(::typeof(ceil), ::Type{T}, M::SymTridiagonal) where {T<:Integer} = SymTridiagonal(ceil.(T, M.dv), ceil.(T, M.ev))

transpose(M::SymTridiagonal) = M #Identity operation
adjoint(M::SymTridiagonal) = conj(M)
Expand Down Expand Up @@ -500,24 +503,11 @@ similar(M::Tridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spz
copyto!(dest::Tridiagonal, src::Tridiagonal) = (copyto!(dest.dl, src.dl); copyto!(dest.d, src.d); copyto!(dest.du, src.du); dest)

#Elementary operations
broadcast(::typeof(abs), M::Tridiagonal) = Tridiagonal(abs.(M.dl), abs.(M.d), abs.(M.du))
broadcast(::typeof(round), M::Tridiagonal) = Tridiagonal(round.(M.dl), round.(M.d), round.(M.du))
broadcast(::typeof(trunc), M::Tridiagonal) = Tridiagonal(trunc.(M.dl), trunc.(M.d), trunc.(M.du))
broadcast(::typeof(floor), M::Tridiagonal) = Tridiagonal(floor.(M.dl), floor.(M.d), floor.(M.du))
broadcast(::typeof(ceil), M::Tridiagonal) = Tridiagonal(ceil.(M.dl), ceil.(M.d), ceil.(M.du))
for func in (:conj, :copy, :real, :imag)
@eval function ($func)(M::Tridiagonal)
Tridiagonal(($func)(M.dl), ($func)(M.d), ($func)(M.du))
end
end
broadcast(::typeof(round), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
Tridiagonal(round.(T, M.dl), round.(T, M.d), round.(T, M.du))
broadcast(::typeof(trunc), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
Tridiagonal(trunc.(T, M.dl), trunc.(T, M.d), trunc.(T, M.du))
broadcast(::typeof(floor), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
Tridiagonal(floor.(T, M.dl), floor.(T, M.d), floor.(T, M.du))
broadcast(::typeof(ceil), ::Type{T}, M::Tridiagonal) where {T<:Integer} =
Tridiagonal(ceil.(T, M.dl), ceil.(T, M.d), ceil.(T, M.du))

transpose(M::Tridiagonal) = Tridiagonal(M.du, M.d, M.dl)
adjoint(M::Tridiagonal) = conj(transpose(M))
Expand Down Expand Up @@ -576,6 +566,19 @@ function Base.replace_in_print_matrix(A::Tridiagonal,i::Integer,j::Integer,s::Ab
i==j-1||i==j||i==j+1 ? s : Base.replace_with_centered_mark(s)
end

function copyto!(dest::Tridiagonal, bc::Broadcasted{PromoteToSparse})
axs = axes(dest)
axes(bc) == axs || Broadcast.throwdm(axes(bc), axs)
for i in axs[1]
dest.d[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i))
end
for i = 1:size(dest, 1)-1
dest.du[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i, i+1))
dest.dl[i] = Broadcast._broadcast_getindex(bc, CartesianIndex(i+1, i))
end
dest
end

#tril and triu

istriu(M::Tridiagonal) = iszero(M.dl)
Expand Down
16 changes: 14 additions & 2 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import Base: map, map!, broadcast, copy, copyto!
using Base: TupleLL, TupleLLEnd, front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using Base.Broadcast: BroadcastStyle, Broadcasted, PromoteToSparse, Args1, Args2, flatten

# This module is organized as follows:
# (0) Define BroadcastStyle rules and convenience types for dispatch
Expand Down Expand Up @@ -54,7 +54,6 @@ SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()

Expand Down Expand Up @@ -969,6 +968,7 @@ function _copy(::Any, bc::Broadcasted{<:SPVM})
parevalf, passedargstup = capturescalars(bcf.f, args)
return broadcast(parevalf, passedargstup...)
end

function _shapecheckbc(bc::Broadcasted)
args = Tuple(bc.args)
_aresameshape(bc.args) ? _noshapecheck_map(bc.f, args...) : _diffshape_broadcast(bc.f, args...)
Expand Down Expand Up @@ -1044,10 +1044,22 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.

function copy(bc::Broadcasted{PromoteToSparse})
if bc.args isa Args1{<:StructuredMatrix} || bc.args isa Args2{<:Type,<:StructuredMatrix}
if _iszero(fzero(bc.f, bc.args))
T = Broadcast.combine_eltypes(bc.f, bc.args)
M = get_matrix(bc.args)
dest = similar(M, T)
return copyto!(dest, bc)
end
end
bcf = flatten(bc)
As = Tuple(bcf.args)
broadcast(bcf.f, map(_sparsifystructured, As)...)
end
get_matrix(args::Args1{<:StructuredMatrix}) = args.head
get_matrix(args::Args2{<:Type,<:StructuredMatrix}) = args.rest.head
fzero(f::Tf, args::Args1{<:StructuredMatrix}) where Tf = f(zero(eltype(get_matrix(args))))
fzero(f::Tf, args::Args2{<:Type, <:StructuredMatrix}) where Tf = f(args.head, zero(eltype(get_matrix(args))))

function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
Expand Down
2 changes: 1 addition & 1 deletion test/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ end
structuredarrays = (D, B, T, S)
fstructuredarrays = map(Array, structuredarrays)
for (X, fX) in zip(structuredarrays, fstructuredarrays)
@test (Q = broadcast(sin, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(sin, fX)))
@test (Q = broadcast(sin, X); typeof(Q) == typeof(X) && Q == sparse(broadcast(sin, fX)))
@test broadcast!(sin, Z, X) == sparse(broadcast(sin, fX))
@test (Q = broadcast(cos, X); Q isa SparseMatrixCSC && Q == sparse(broadcast(cos, fX)))
@test broadcast!(cos, Z, X) == sparse(broadcast(cos, fX))
Expand Down

0 comments on commit e4d1962

Please sign in to comment.