Skip to content

Commit

Permalink
Integrate lazy broadcast representation into new broadcast machinery
Browse files Browse the repository at this point in the history
Among other things, this supports returning AbstractRanges for appropriate inputs.

Fixes #21094, fixes #22053
  • Loading branch information
timholy committed Jan 3, 2018
1 parent 367a41f commit a049a70
Show file tree
Hide file tree
Showing 7 changed files with 759 additions and 512 deletions.
961 changes: 546 additions & 415 deletions base/broadcast.jl

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions base/mpfr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ promote_rule(::Type{BigFloat}, ::Type{<:AbstractFloat}) = BigFloat

big(::Type{<:AbstractFloat}) = BigFloat

# Support conversion of AbstractRanges to high precision
Base.Broadcast.maybe_range_safe_f(::typeof(big)) = true

function (::Type{Rational{BigInt}})(x::AbstractFloat)
isnan(x) && return zero(BigInt) // zero(BigInt)
isinf(x) && return copysign(one(BigInt),x) // zero(BigInt)
Expand Down
1 change: 1 addition & 0 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ end
show_comma_array(io::IO, itr, o, c) = show_delim_array(io, itr, o, ',', c, false)
show(io::IO, t::Tuple) = show_delim_array(io, t, '(', ',', ')', true)
show(io::IO, v::SimpleVector) = show_delim_array(io, v, "svec(", ',', ')', false)
show(io::IO, t::TupleLL) = show_delim_array(io, t, '{', ',', '}', true)

show(io::IO, s::Symbol) = show_unquoted_quote_expr(io, s, 0, 0)

Expand Down
210 changes: 116 additions & 94 deletions base/sparse/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ module HigherOrderFns

# This module provides higher order functions specialized for sparse arrays,
# particularly map[!]/broadcast[!] for SparseVectors and SparseMatrixCSCs at present.
import Base: map, map!, broadcast, broadcast!
import Base: map, map!, broadcast, copy, copyto!

using Base: front, tail, to_shape
using Base: TupleLL, front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
using Base.Broadcast: BroadcastStyle
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten

# This module is organized as follows:
# (0) Define BroadcastStyle rules and convenience types for dispatch
# (1) Define a common interface to SparseVectors and SparseMatrixCSCs sufficient for
# map[!]/broadcast[!]'s purposes. The methods below are written against this interface.
# (2) Define entry points for map[!] (short children of _map_[not]zeropres!).
Expand All @@ -28,11 +29,70 @@ using Base.Broadcast: BroadcastStyle
# (12) Define map[!] methods handling combinations of sparse and structured matrices.


# (0) BroadcastStyle rules and convenience types for dispatch

SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}

# broadcast container type promotion for combinations of sparse arrays and other types
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
const SPVM = Union{SparseVecStyle,SparseMatStyle}

# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
# Fall back to DefaultArrayStyle for higher dimensionality.
SparseVecStyle(::Val{0}) = SparseVecStyle()
SparseVecStyle(::Val{1}) = SparseVecStyle()
SparseVecStyle(::Val{2}) = SparseMatStyle()
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
SparseMatStyle(::Val{0}) = SparseMatStyle()
SparseMatStyle(::Val{1}) = SparseMatStyle()
SparseMatStyle(::Val{2}) = SparseMatStyle()
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()

PromoteToSparse(::Val{0}) = PromoteToSparse()
PromoteToSparse(::Val{1}) = PromoteToSparse()
PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()

# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
BroadcastStyle(::Type{<:Base.Adjoint{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded
BroadcastStyle(::Type{<:Base.Transpose{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1)))
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2)))
# end FIXME

# Tuples promote to dense
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# Dispatch on broadcast operations by number of arguments
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},ElType,Axes,Indexing<:Union{Nothing,TupleLL{Nothing,Nothing}},F} =
Broadcasted{Style,ElType,Axes,Indexing,F,TupleLL{Nothing,Nothing}}
const SpBroadcasted1{Style<:SPVM,ElType,Axes,Indexing<:Union{Nothing,TupleLL},F,Args<:TupleLL{<:SparseVecOrMat,Nothing}} =
Broadcasted{Style,ElType,Axes,Indexing,F,Args}
const SpBroadcasted2{Style<:SPVM,ElType,Axes,Indexing<:Union{Nothing,TupleLL},F,Args<:TupleLL{<:SparseVecOrMat,TupleLL{<:SparseVecOrMat,Nothing}}} =
Broadcasted{Style,ElType,Axes,Indexing,F,Args}

# (1) The definitions below provide a common interface to sparse vectors and matrices
# sufficient for the purposes of map[!]/broadcast[!]. This interface treats sparse vectors
# as n-by-one sparse matrices which, though technically incorrect, is how broacast[!] views
# sparse vectors in practice.
SparseVecOrMat = Union{SparseVector,SparseMatrixCSC}
@inline numrows(A::SparseVector) = A.n
@inline numrows(A::SparseMatrixCSC) = A.m
@inline numcols(A::SparseVector) = 1
Expand Down Expand Up @@ -91,11 +151,11 @@ function _noshapecheck_map(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N
_map_notzeropres!(f, fofzeros, C, A, Bs...)
end
# (3) broadcast[!] entry points
broadcast(f::Tf, A::SparseVector) where {Tf} = _noshapecheck_map(f, A)
broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
copy(bc::SpBroadcasted1) = _noshapecheck_map(bc.f, bc.args.head)

@inline function broadcast!(f::Tf, C::SparseVecOrMat, ::Nothing) where Tf
@inline function copyto!(C::SparseVecOrMat, bc::Broadcasted0{Nothing})
isempty(C) && return _finishempty!(C)
f = bc.f
fofnoargs = f()
if _iszero(fofnoargs) # f() is zero, so empty C
trimstorage!(C, 0)
Expand All @@ -108,13 +168,6 @@ broadcast(f::Tf, A::SparseMatrixCSC) where {Tf} = _noshapecheck_map(f, A)
return C
end

# the following three similar defs are necessary for type stability in the mixed vector/matrix case
broadcast(f::Tf, A::SparseVector, Bs::Vararg{SparseVector,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
broadcast(f::Tf, A::SparseMatrixCSC, Bs::Vararg{SparseMatrixCSC,N}) where {Tf,N} =
_aresameshape(A, Bs...) ? _noshapecheck_map(f, A, Bs...) : _diffshape_broadcast(f, A, Bs...)
broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N} =
_diffshape_broadcast(f, A, Bs...)
function _diffshape_broadcast(f::Tf, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
Expand All @@ -139,7 +192,14 @@ end
@inline _aresameshape(A) = true
@inline _aresameshape(A, B) = size(A) == size(B)
@inline _aresameshape(A, B, Cs...) = _aresameshape(A, B) ? _aresameshape(B, Cs...) : false
@inline _aresameshape(t::TupleLL{<:Any,Nothing}) = true
@inline _aresameshape(t::TupleLL{<:Any,<:TupleLL}) =
_aresameshape(t.head, t.rest.head) ? _aresameshape(t.rest) : false
@inline _checksameshape(As...) = _aresameshape(As...) || throw(DimensionMismatch("argument shapes must match"))
@inline _all_args_isa(t::TupleLL{<:Any,Nothing}, ::Type{T}) where T = isa(t.head, T)
@inline _all_args_isa(t::TupleLL, ::Type{T}) where T = isa(t.head, T) & _all_args_isa(t.rest, T)
@inline _all_args_isa(t::TupleLL{<:Broadcasted,Nothing}, ::Type{T}) where T = _all_args_isa(t.head.args, T)
@inline _all_args_isa(t::TupleLL{<:Broadcasted}, ::Type{T}) where T = _all_args_isa(t.head.args, T) & _all_args_isa(t.rest, T)
@inline _densennz(shape::NTuple{1}) = shape[1]
@inline _densennz(shape::NTuple{2}) = shape[1] * shape[2]
_maxnnzfrom(shape::NTuple{1}, A) = nnz(A) * div(shape[1], A.n)
Expand Down Expand Up @@ -892,37 +952,42 @@ end

# (10) broadcast over combinations of broadcast scalars and sparse vectors/matrices

# broadcast container type promotion for combinations of sparse arrays and other types
struct SparseVecStyle <: Broadcast.AbstractArrayStyle{1} end
struct SparseMatStyle <: Broadcast.AbstractArrayStyle{2} end
Broadcast.BroadcastStyle(::Type{<:SparseVector}) = SparseVecStyle()
Broadcast.BroadcastStyle(::Type{<:SparseMatrixCSC}) = SparseMatStyle()
const SPVM = Union{SparseVecStyle,SparseMatStyle}

# SparseVecStyle handles 0-1 dimensions, SparseMatStyle 0-2 dimensions.
# SparseVecStyle promotes to SparseMatStyle for 2 dimensions.
# Fall back to DefaultArrayStyle for higher dimensionality.
SparseVecStyle(::Val{0}) = SparseVecStyle()
SparseVecStyle(::Val{1}) = SparseVecStyle()
SparseVecStyle(::Val{2}) = SparseMatStyle()
SparseVecStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()
SparseMatStyle(::Val{0}) = SparseMatStyle()
SparseMatStyle(::Val{1}) = SparseMatStyle()
SparseMatStyle(::Val{2}) = SparseMatStyle()
SparseMatStyle(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::SparseMatStyle, ::SparseVecStyle) = SparseMatStyle()

# Tuples promote to dense
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{1}()
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# broadcast entry points for combinations of sparse arrays and other (scalar) types
function broadcast(f, ::SPVM, ::Nothing, ::Nothing, mixedargs::Vararg{Any,N}) where N
parevalf, passedargstup = capturescalars(f, mixedargs)
function copy(bc::Broadcasted{<:SPVM})
bcf = flatten(bc)
_all_args_isa(bcf.args, SparseVector) && return _shapecheckbc(bcf)
_all_args_isa(bcf.args, SparseMatrixCSC) && return _shapecheckbc(bcf)
args = Tuple(bcf.args)
_all_args_isa(bcf.args, SparseVecOrMat) && return _diffshape_broadcast(bcf.f, args...)
parevalf, passedargstup = capturescalars(bcf.f, args)
return broadcast(parevalf, passedargstup...)
end
# for broadcast! see (11)
function _shapecheckbc(bc::Broadcasted)
args = Tuple(bc.args)
_aresameshape(bc.args) ? _noshapecheck_map(bc.f, args...) : _diffshape_broadcast(bc.f, args...)
end

function copyto!(dest::SparseVecOrMat, bc::Broadcasted{<:SPVM})
if bc.f === identity && bc isa SpBroadcasted1 && Base.axes(dest) == (A = bc.args.head; Base.axes(A))
return copyto!(dest, A)
end
bcf = flatten(bc)
As = Tuple(bcf.args)
if _all_args_isa(bcf.args, SparseVecOrMat)
_aresameshape(dest, As...) && return _noshapecheck_map!(bcf.f, dest, As...)
Base.Broadcast.check_broadcast_indices(axes(dest), As...)
fofzeros = bcf.f(_zeros_eltypes(As...)...)
fpreszeros = _iszero(fofzeros)
fpreszeros ? _broadcast_zeropres!(bcf.f, dest, As...) :
_broadcast_notzeropres!(bcf.f, fofzeros, dest, As...)
else
# As contains nothing but SparseVecOrMat and scalars
# See below for capturescalars
parevalf, passedsrcargstup = capturescalars(bcf.f, As)
broadcast!(parevalf, dest, passedsrcargstup...)
end
return dest
end

# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
Expand Down Expand Up @@ -971,59 +1036,16 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(
# vectors/matrices, promote all structured matrices and dense vectors/matrices to sparse
# and rebroadcast. otherwise, divert to generic AbstractArray broadcast code.

struct PromoteToSparse <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrix = Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}
Broadcast.BroadcastStyle(::Type{<:StructuredMatrix}) = PromoteToSparse()

PromoteToSparse(::Val{0}) = PromoteToSparse()
PromoteToSparse(::Val{1}) = PromoteToSparse()
PromoteToSparse(::Val{2}) = PromoteToSparse()
PromoteToSparse(::Val{N}) where N = Broadcast.DefaultArrayStyle{N}()

Broadcast.BroadcastStyle(::PromoteToSparse, ::SPVM) = PromoteToSparse()
Broadcast.BroadcastStyle(::PromoteToSparse, ::Broadcast.Style{Tuple}) = Broadcast.DefaultArrayStyle{2}()

# FIXME: switch to DefaultArrayStyle once we can delete VectorStyle and MatrixStyle
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{0}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{1}) = PromoteToSparse()
# Broadcast.BroadcastStyle(::SPVM, ::Broadcast.DefaultArrayStyle{2}) = PromoteToSparse()
BroadcastStyle(::Type{<:Base.Adjoint{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Adjoint not yet defined when broadcast.jl loaded
BroadcastStyle(::Type{<:Base.Transpose{T,<:Vector}}) where T = Broadcast.MatrixStyle() # Transpose not yet defined when broadcast.jl loaded
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.VectorStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SPVM, ::Broadcast.MatrixStyle) = PromoteToSparse()
Broadcast.BroadcastStyle(::SparseVecStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(1)))
Broadcast.BroadcastStyle(::SparseMatStyle, ::Broadcast.DefaultArrayStyle{N}) where N =
Broadcast.DefaultArrayStyle(Broadcast._max(Val(N), Val(2)))
# end FIXME

broadcast(f, ::PromoteToSparse, ::Nothing, ::Nothing, As::Vararg{Any,N}) where {N} =
broadcast(f, map(_sparsifystructured, As)...)

# For broadcast! with ::Any inputs, we need a layer of indirection to determine whether
# the inputs can be promoted to SparseVecOrMat. If it's just SparseVecOrMat and scalars,
# we can handle it here, otherwise see below for the promotion machinery.
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, A::SparseVecOrMat, Bs::Vararg{SparseVecOrMat,N}) where {Tf,N}
if f isa typeof(identity) && N == 0 && Base.axes(dest) == Base.axes(A)
return copyto!(dest, A)
end
_aresameshape(dest, A, Bs...) && return _noshapecheck_map!(f, dest, A, Bs...)
Base.Broadcast.check_broadcast_indices(axes(dest), A, Bs...)
fofzeros = f(_zeros_eltypes(A, Bs...)...)
fpreszeros = _iszero(fofzeros)
fpreszeros ? _broadcast_zeropres!(f, dest, A, Bs...) :
_broadcast_notzeropres!(f, fofzeros, dest, A, Bs...)
return dest
function copy(bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
As = Tuple(bcf.args)
broadcast(bcf.f, map(_sparsifystructured, As)...)
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::SPVM, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
# mixedsrcargs contains nothing but SparseVecOrMat and scalars
parevalf, passedsrcargstup = capturescalars(f, mixedsrcargs)
broadcast!(parevalf, dest, passedsrcargstup...)
return dest
end
function broadcast!(f::Tf, dest::SparseVecOrMat, ::PromoteToSparse, mixedsrcargs::Vararg{Any,N}) where {Tf,N}
broadcast!(f, dest, map(_sparsifystructured, mixedsrcargs)...)
return dest

function copyto!(dest::SparseVecOrMat, bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
As = Tuple(bcf.args)
broadcast!(bcf.f, dest, map(_sparsifystructured, As)...)
end

_sparsifystructured(M::AbstractMatrix) = SparseMatrixCSC(M)
Expand Down
50 changes: 50 additions & 0 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,53 @@ any(x::Tuple{Bool, Bool, Bool}) = x[1]|x[2]|x[3]
Returns an empty tuple, `()`.
"""
empty(x::Tuple) = ()

## Linked-list representation of a tuple. Inferrable even for Type elements.

struct TupleLL{T, Rest}
head::T # car
rest::Rest # cdr
TupleLL(x, rest::TupleLL) where {} = new{Core.Typeof(x), typeof(rest)}(x, rest) # (cons x rest)
TupleLL(x, rest::Nothing) where {} = new{Core.Typeof(x), typeof(rest)}(x, rest) # (cons x nil)
TupleLL(x) where {} = new{Core.Typeof(x), Nothing}(x, nothing) # (list x)
TupleLL() where {} = new{Nothing, Nothing}(nothing, nothing)
end
# (apply list a)
make_TupleLL() = TupleLL()
make_TupleLL(a) = TupleLL(a)
make_TupleLL(a, args...) = TupleLL(a, make_TupleLL(args...))

# (map f tt)
map(f, tt::TupleLL{Nothing, Nothing}) = ()
map(f, tt::TupleLL{<:Any, Nothing}) = (f(tt.head),)
function map(f, tt::TupleLL)
return (f(tt.head), map(f, tt.rest)...)
end

mapTupleLL(f, tt::TupleLL{Nothing, Nothing}) = TupleLL()
mapTupleLL(f, tt::TupleLL{<:Any, Nothing}) = TupleLL(f(tt.head),)
function mapTupleLL(f, tt::TupleLL)
return TupleLL(f(tt.head), mapTupleLL(f, tt.rest))
end

convert(::Type{Tuple}, tt::TupleLL) = map(identity, tt)
(::Type{Tuple})(tt::TupleLL) = convert(Tuple, tt)

any(f::Function, tt::TupleLL{Nothing, Nothing}) = false
any(f::Function, tt::TupleLL{<:Any, Nothing}) = f(tt.head)
any(f::Function, tt::TupleLL) = f(tt.head) || any(f, tt.rest)

all(f::Function, tt::TupleLL{Nothing, Nothing}) = true
all(f::Function, tt::TupleLL{<:Any, Nothing}) = f(tt.head)
all(f::Function, tt::TupleLL) = f(tt.head) && all(f, tt.rest)

start(tt::TupleLL) = tt
next(::TupleLL, tt::TupleLL) = (tt.head, tt.rest)
done(::TupleLL{Nothing, Nothing}, tt::TupleLL{Nothing, Nothing}) = true
done(::TupleLL, tt::Nothing) = true
done(::TupleLL, tt::TupleLL) = false

length(tt::TupleLL{Nothing, Nothing}) = 0
length(tt::TupleLL) = _length(1, tt.rest)
_length(l::Int, tt::TupleLL) = _length(l+1, tt.rest)
_length(l::Int, ::Nothing) = l
Loading

0 comments on commit a049a70

Please sign in to comment.