Skip to content

Commit

Permalink
Make Broadcasted iterable and more indexable (#26987)
Browse files Browse the repository at this point in the history
Defines iteration and some related traits as well as a few functions to make indexing a bit friendlier.
  • Loading branch information
mbauman authored May 9, 2018
1 parent 04fe0f4 commit 0db674d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 13 deletions.
55 changes: 42 additions & 13 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,15 @@ end
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)

Base.show(io::IO, bc::Broadcasted{Style}) where {Style} = print(io, Broadcasted, '{', Style, "}(", bc.f, ", ", bc.args, ')')
function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
print(io, Broadcasted)
# Only show the style parameter if we have a set of axes — representing an instantiated
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
# computation that is not relevant for dispatch, confusing, and just extra line noise.
bc.axes isa Tuple && print(io, '{', Style, '}')
print(io, '(', bc.f, ", ", bc.args, ')')
nothing
end

## Allocating the output container
"""
Expand Down Expand Up @@ -218,8 +226,6 @@ This should only be specialized for objects that do not define axes but want to
"""
broadcast_axes

### End of methods that users will typically have to specialize ###

@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
Expand All @@ -239,19 +245,39 @@ _not_nested(t::Tuple) = _not_nested(tail(t))
_not_nested(::NestedTuple) = false
_not_nested(::Tuple{}) = true

@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
_eachindex(t::Tuple{Any}) = t[1]
_eachindex(t::Tuple) = CartesianIndices(t)

Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N

Base.length(bc::Broadcasted) = prod(map(length, axes(bc)))
Base.size(bc::Broadcasted) = _size(axes(bc))
_size(::Tuple{Vararg{Base.OneTo}}) = map(length, axes(bc))

Base.start(bc::Broadcasted) = (iter = eachindex(bc); (iter, start(iter)))
Base.@propagate_inbounds function Base.next(bc::Broadcasted, s)
iter, state = s
i, newstate = next(iter, state)
return (bc[i], (iter, newstate))
end
Base.done(bc::Broadcasted, s) = done(s[1], s[2])

Base.IteratorSize(::Type{<:Broadcasted{<:Any,<:NTuple{N,Base.OneTo}}}) where {N} = Base.HasShape{N}()
Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown()

## Instantiation fills in the "missing" fields in Broadcasted.
instantiate(x) = x

"""
Broadcast.instantiate(bc::Broadcasted)
Construct the axes and indexing helpers for the lazy Broadcasted object `bc`.
Construct and check the axes for the lazy Broadcasted object `bc`.
Custom `BroadcastStyle`s may override this default in cases where it is fast and easy
to compute the resulting `axes` and indexing helpers on-demand, leaving those fields
of the `Broadcasted` object empty (populated with `nothing`). If they do so, however,
they must provide their own `Base.axes(::Broadcasted{Style})` and
`Base.getindex(::Broadcasted{Style}, I::Union{Int,CartesianIndex})` methods as appropriate.
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
of the `Broadcasted` object empty (populated with `nothing`).
"""
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
Expand Down Expand Up @@ -481,6 +507,7 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
# If dot-broadcasting were already defined, this would be `ifelse.(keep, I, Idefault)`.
@inline newindex(I::CartesianIndex, keep, Idefault) = CartesianIndex(_newindex(I.I, keep, Idefault))
@inline newindex(i::Int, keep::Tuple{Bool}, idefault) = ifelse(keep[1], i, idefault[1])
@inline newindex(i::Int, keep::Tuple{}, idefault) = CartesianIndex(())
@inline _newindex(I, keep, Idefault) =
(ifelse(keep[1], I[1], Idefault[1]), _newindex(tail(I), tail(keep), tail(Idefault))...)
@inline _newindex(I, keep::Tuple{}, Idefault) = () # truncate if keep is shorter than I
Expand All @@ -496,12 +523,14 @@ Base.@propagate_inbounds _newindex(ax::Tuple{}, I::Tuple{}) = ()
(length(ind1)!=1, keep...), (first(ind1), Idefault...)
end

@inline function Base.getindex(bc::Broadcasted, I)
@inline function Base.getindex(bc::Broadcasted, I::Union{Int,CartesianIndex})
@boundscheck checkbounds(bc, I)
@inbounds _broadcast_getindex(bc, I)
end
Base.@propagate_inbounds Base.getindex(bc::Broadcasted, i1::Int, i2::Int, I::Int...) = bc[CartesianIndex((i1, i2, I...))]
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]

@inline Base.checkbounds(bc::Broadcasted, I) =
@inline Base.checkbounds(bc::Broadcasted, I::Union{Int,CartesianIndex}) =
Base.checkbounds_indices(Bool, axes(bc), (I,)) || Base.throw_boundserror(bc, (I,))


Expand Down Expand Up @@ -739,7 +768,7 @@ const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
# value to determine the starting output eltype; copyto_nonleaf!
# will widen `dest` as needed to accommodate later values.
bc′ = preprocess(nothing, bc)
iter = CartesianIndices(axes(bc′))
iter = eachindex(bc′)
state = start(iter)
if done(iter, state)
# if empty, take the ElType at face value
Expand Down Expand Up @@ -807,7 +836,7 @@ preprocess_args(dest, args::Tuple{}) = ()
end
end
bc′ = preprocess(dest, bc)
@simd for I in CartesianIndices(axes(bc′))
@simd for I in eachindex(bc′)
@inbounds dest[I] = bc′[I]
end
return dest
Expand All @@ -822,7 +851,7 @@ end
destc = dest.chunks
ind = cind = 1
bc′ = preprocess(dest, bc)
@simd for I in CartesianIndices(axes(bc′))
@simd for I in eachindex(bc′)
@inbounds tmp[ind] = bc′[I]
ind += 1
if ind > bitcache_size
Expand Down
23 changes: 23 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -726,3 +726,26 @@ let f(args...) = *(args...)
@test f.(x..., y, z...) == broadcast(f, x..., y, z...) == 120
@test f.(x..., f.(x..., y, z...), y, z...) == broadcast(f, x..., broadcast(f, x..., y, z...), y, z...) == 120*120
end

# Broadcasted iterable/indexable APIs
let
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
@test eachindex(bc) === Base.OneTo(5)
@test length(bc) === 5
@test ndims(bc) === 1
@test ndims(typeof(bc)) === 1
@test bc[1] === bc[CartesianIndex((1,))] === 5.0
@test copy(bc) == [v for v in bc] == collect(bc)
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)

bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4)))
@test eachindex(bc) === CartesianIndices((Base.OneTo(5), Base.OneTo(4)))
@test length(bc) === 20
@test ndims(bc) === 2
@test ndims(typeof(bc)) === 2
@test bc[1,1] == bc[CartesianIndex((1,1))] === 5.0
@test copy(bc) == [v for v in bc] == collect(bc)
@test eltype(copy(bc)) == eltype([v for v in bc]) == eltype(collect(bc))
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)
end

0 comments on commit 0db674d

Please sign in to comment.