Skip to content

Commit

Permalink
Faster mapreduce for Broadcasted (#31020)
Browse files Browse the repository at this point in the history
* Better mapreduce for Broadcasted

* Use Axes in IndexStyle for Broadcasted

* Apply suggestions from code review

Co-Authored-By: tkf <29282+tkf@users.noreply.github.com>

* Update base/broadcast.jl

Co-Authored-By: tkf <29282+tkf@users.noreply.github.com>

* Fix IndexStyle for IndexLinear case

* Fix LinearIndices for Broadcasted

* Test that pairwise mapreduce is used

* Test count(::Broadcasted)

* Support Broadcasted in mapreducedim!

Co-authored-by: Matt Bauman <mbauman@gmail.com>
  • Loading branch information
tkf and mbauman authored Apr 30, 2020
1 parent ae2063f commit 2f90dde
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 33 deletions.
36 changes: 26 additions & 10 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
# methods that instead specialize on `BroadcastStyle`,
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})

struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple}
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
Expand All @@ -193,21 +193,25 @@ function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
end

## Allocating the output container
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}) where {N,ElType} =
similar(Array{ElType}, axes(bc))
Base.similar(bc::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}) where N =
similar(BitArray, axes(bc))
Base.similar(bc::Broadcasted, ::Type{T}) where {T} = similar(bc, T, axes(bc))
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{ElType}, dims) where {N,ElType} =
similar(Array{ElType}, dims)
Base.similar(::Broadcasted{DefaultArrayStyle{N}}, ::Type{Bool}, dims) where N =
similar(BitArray, dims)
# In cases of conflict we fall back on Array
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{ElType}) where ElType =
similar(Array{ElType}, axes(bc))
Base.similar(bc::Broadcasted{ArrayConflict}, ::Type{Bool}) =
similar(BitArray, axes(bc))
Base.similar(::Broadcasted{ArrayConflict}, ::Type{ElType}, dims) where ElType =
similar(Array{ElType}, dims)
Base.similar(::Broadcasted{ArrayConflict}, ::Type{Bool}, dims) =
similar(BitArray, dims)

@inline Base.axes(bc::Broadcasted) = _axes(bc, bc.axes)
_axes(::Broadcasted, axes::Tuple) = axes
@inline _axes(bc::Broadcasted, ::Nothing) = combine_axes(bc.args...)
_axes(bc::Broadcasted{<:AbstractArrayStyle{0}}, ::Nothing) = ()

@inline Base.axes(bc::Broadcasted{<:Any, <:NTuple{N}}, d::Integer) where N =
d <= N ? axes(bc)[d] : OneTo(1)

BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))
Expand All @@ -219,6 +223,12 @@ argtype(bc::Broadcasted) = argtype(typeof(bc))
_eachindex(t::Tuple{Any}) = t[1]
_eachindex(t::Tuple) = CartesianIndices(t)

Base.IndexStyle(bc::Broadcasted) = IndexStyle(typeof(bc))
Base.IndexStyle(::Type{<:Broadcasted{<:Any,<:Tuple{Any}}}) = IndexLinear()
Base.IndexStyle(::Type{<:Broadcasted{<:Any}}) = IndexCartesian()

Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = axes(bc)[1]

Base.ndims(::Broadcasted{<:Any,<:NTuple{N,Any}}) where {N} = N
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N

Expand Down Expand Up @@ -564,7 +574,13 @@ end
@boundscheck checkbounds(bc, I)
@inbounds _broadcast_getindex(bc, I)
end
Base.@propagate_inbounds Base.getindex(bc::Broadcasted, i1::Integer, i2::Integer, I::Integer...) = bc[CartesianIndex((i1, i2, I...))]
Base.@propagate_inbounds Base.getindex(
bc::Broadcasted,
i1::Union{Integer,CartesianIndex},
i2::Union{Integer,CartesianIndex},
I::Union{Integer,CartesianIndex}...,
) =
bc[CartesianIndex((i1, i2, I...))]
Base.@propagate_inbounds Base.getindex(bc::Broadcasted) = bc[CartesianIndex(())]

@inline Base.checkbounds(bc::Broadcasted, I::Union{Integer,CartesianIndex}) =
Expand Down
20 changes: 12 additions & 8 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ else
const SmallUnsigned = Union{UInt8,UInt16,UInt32}
end

abstract type AbstractBroadcasted end
const AbstractArrayOrBroadcasted = Union{AbstractArray, AbstractBroadcasted}

"""
Base.add_sum(x, y)
Expand Down Expand Up @@ -227,7 +230,8 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)

# This is a generic implementation of `mapreduce_impl()`,
# certain `op` (e.g. `min` and `max`) may have their own specialized versions.
@noinline function mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer, blksize::Int)
@noinline function mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted,
ifirst::Integer, ilast::Integer, blksize::Int)
if ifirst == ilast
@inbounds a1 = A[ifirst]
return mapreduce_first(f, op, a1)
Expand All @@ -250,7 +254,7 @@ foldr(op, itr; kw...) = mapfoldr(identity, op, itr; kw...)
end
end

mapreduce_impl(f, op, A::AbstractArray, ifirst::Integer, ilast::Integer) =
mapreduce_impl(f, op, A::AbstractArrayOrBroadcasted, ifirst::Integer, ilast::Integer) =
mapreduce_impl(f, op, A, ifirst, ilast, pairwise_blocksize(f, op))

"""
Expand Down Expand Up @@ -383,13 +387,13 @@ The default is `reduce_first(op, f(x))`.
"""
mapreduce_first(f, op, x) = reduce_first(op, f(x))

_mapreduce(f, op, A::AbstractArray) = _mapreduce(f, op, IndexStyle(A), A)
_mapreduce(f, op, A::AbstractArrayOrBroadcasted) = _mapreduce(f, op, IndexStyle(A), A)

function _mapreduce(f, op, ::IndexLinear, A::AbstractArray{T}) where T
function _mapreduce(f, op, ::IndexLinear, A::AbstractArrayOrBroadcasted)
inds = LinearIndices(A)
n = length(inds)
if n == 0
return mapreduce_empty(f, op, T)
return mapreduce_empty_iter(f, op, A, IteratorEltype(A))
elseif n == 1
@inbounds a1 = A[first(inds)]
return mapreduce_first(f, op, a1)
Expand All @@ -410,7 +414,7 @@ end

mapreduce(f, op, a::Number) = mapreduce_first(f, op, a)

_mapreduce(f, op, ::IndexCartesian, A::AbstractArray) = mapfoldl(f, op, A)
_mapreduce(f, op, ::IndexCartesian, A::AbstractArrayOrBroadcasted) = mapfoldl(f, op, A)

"""
reduce(op, itr; [init])
Expand Down Expand Up @@ -560,7 +564,7 @@ isgoodzero(::typeof(max), x) = isbadzero(min, x)
isgoodzero(::typeof(min), x) = isbadzero(max, x)

function mapreduce_impl(f, op::Union{typeof(max), typeof(min)},
A::AbstractArray, first::Int, last::Int)
A::AbstractArrayOrBroadcasted, first::Int, last::Int)
a1 = @inbounds A[first]
v1 = mapreduce_first(f, op, a1)
v2 = v3 = v4 = v1
Expand Down Expand Up @@ -856,7 +860,7 @@ function count(pred, itr)
end
return n
end
function count(pred, a::AbstractArray)
function count(pred, a::AbstractArrayOrBroadcasted)
n = 0
for i in eachindex(a)
@inbounds n += pred(a[i])::Bool
Expand Down
37 changes: 22 additions & 15 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ No method is implemented for reducing index range of type $(typeof(i)). Please i
reduced_index for this index type or report this as an issue.
"""
))
reduced_indices(a::AbstractArray, region) = reduced_indices(axes(a), region)
reduced_indices(a::AbstractArrayOrBroadcasted, region) = reduced_indices(axes(a), region)

# for reductions that keep 0 dims as 0
reduced_indices0(a::AbstractArray, region) = reduced_indices0(axes(a), region)
Expand Down Expand Up @@ -89,8 +89,8 @@ for (Op, initval) in ((:(typeof(&)), true), (:(typeof(|)), false))
end

# reducedim_initarray is called by
reducedim_initarray(A::AbstractArray, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
reducedim_initarray(A::AbstractArray, region, init::T) where {T} = reducedim_initarray(A, region, init, T)
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init, ::Type{R}) where {R} = fill!(similar(A,R,reduced_indices(A,region)), init)
reducedim_initarray(A::AbstractArrayOrBroadcasted, region, init::T) where {T} = reducedim_initarray(A, region, init, T)

# TODO: better way to handle reducedim initialization
#
Expand Down Expand Up @@ -156,8 +156,8 @@ end
reducedim_init(f::Union{typeof(abs),typeof(abs2)}, op::typeof(max), A::AbstractArray{T}, region) where {T} =
reducedim_initarray(A, region, zero(f(zero(T))), _realtype(f, T))

reducedim_init(f, op::typeof(&), A::AbstractArray, region) = reducedim_initarray(A, region, true)
reducedim_init(f, op::typeof(|), A::AbstractArray, region) = reducedim_initarray(A, region, false)
reducedim_init(f, op::typeof(&), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, true)
reducedim_init(f, op::typeof(|), A::AbstractArrayOrBroadcasted, region) = reducedim_initarray(A, region, false)

# specialize to make initialization more efficient for common cases

Expand All @@ -179,8 +179,11 @@ end

## generic (map)reduction

has_fast_linear_indexing(a::AbstractArray) = false
has_fast_linear_indexing(a::AbstractArrayOrBroadcasted) = false
has_fast_linear_indexing(a::Array) = true
has_fast_linear_indexing(::Number) = true # for Broadcasted
has_fast_linear_indexing(bc::Broadcast.Broadcasted) =
all(has_fast_linear_indexing, bc.args)

function check_reducedims(R, A)
# Check whether R has compatible dimensions w.r.t. A for reduction
Expand Down Expand Up @@ -233,7 +236,7 @@ _firstslice(i::OneTo) = OneTo(1)
_firstslice(i::Slice) = Slice(_firstslice(i.indices))
_firstslice(i) = i[firstindex(i):firstindex(i)]

function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray)
function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArrayOrBroadcasted)
lsiz = check_reducedims(R,A)
isempty(A) && return R

Expand Down Expand Up @@ -271,10 +274,10 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray)
return R
end

mapreducedim!(f, op, R::AbstractArray, A::AbstractArray) =
mapreducedim!(f, op, R::AbstractArray, A::AbstractArrayOrBroadcasted) =
(_mapreducedim!(f, op, R, A); R)

reducedim!(op, R::AbstractArray{RT}, A::AbstractArray) where {RT} =
reducedim!(op, R::AbstractArray{RT}, A::AbstractArrayOrBroadcasted) where {RT} =
mapreducedim!(identity, op, R, A)

"""
Expand Down Expand Up @@ -304,17 +307,21 @@ julia> mapreduce(isodd, |, a, dims=1)
1 1 1 1
```
"""
mapreduce(f, op, A::AbstractArray; dims=:, kw...) = _mapreduce_dim(f, op, kw.data, A, dims)
mapreduce(f, op, A::AbstractArray...; kw...) = reduce(op, map(f, A...); kw...)
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims=:, kw...) =
_mapreduce_dim(f, op, kw.data, A, dims)
mapreduce(f, op, A::AbstractArrayOrBroadcasted...; kw...) =
reduce(op, map(f, A...); kw...)

_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, ::Colon) = mapfoldl(f, op, A; nt...)
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, ::Colon) =
mapfoldl(f, op, A; nt...)

_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, ::Colon) = _mapreduce(f, op, IndexStyle(A), A)
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, ::Colon) =
_mapreduce(f, op, IndexStyle(A), A)

_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArray, dims) =
_mapreduce_dim(f, op, nt::NamedTuple{(:init,)}, A::AbstractArrayOrBroadcasted, dims) =
mapreducedim!(f, op, reducedim_initarray(A, dims, nt.init), A)

_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArray, dims) =
_mapreduce_dim(f, op, ::NamedTuple{()}, A::AbstractArrayOrBroadcasted, dims) =
mapreducedim!(f, op, reducedim_init(f, op, A, dims), A)

"""
Expand Down
56 changes: 56 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ end
# Broadcasted iterable/indexable APIs
let
bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5))
@test IndexStyle(bc) == IndexLinear()
@test eachindex(bc) === Base.OneTo(5)
@test length(bc) === 5
@test ndims(bc) === 1
Expand All @@ -831,6 +832,7 @@ let
@test ndims(copy(bc)) == ndims([v for v in bc]) == ndims(collect(bc)) == ndims(bc)

bc = Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4)))
@test IndexStyle(bc) == IndexCartesian()
@test eachindex(bc) === CartesianIndices((Base.OneTo(5), Base.OneTo(4)))
@test length(bc) === 20
@test ndims(bc) === 2
Expand All @@ -851,6 +853,60 @@ let a = rand(5), b = rand(5), c = copy(a)
@test x == [2]
end

@testset "broadcasted mapreduce" begin
xs = 1:10
ys = 1:2:20
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
@test IndexStyle(bc) == IndexLinear()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))

xs2 = reshape(xs, 1, :)
ys2 = reshape(ys, 1, :)
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs2, ys2))
@test IndexStyle(bc) == IndexCartesian()
@test sum(bc) == mapreduce(Base.splat(*), +, zip(xs, ys))

xs = 1:5:3*5
ys = 1:4:3*4
bc = Broadcast.instantiate(
Broadcast.broadcasted(iseven, Broadcast.broadcasted(-, xs, ys)))
@test count(bc) == count(iseven, map(-, xs, ys))

xs = reshape(1:6, (2, 3))
ys = 1:2
bc = Broadcast.instantiate(Broadcast.broadcasted(*, xs, ys))
@test reduce(+, bc; dims=1, init=0) == [5 11 17]

# Let's test that `Broadcasted` actually hits the efficient
# `mapreduce` method as intended. We are going to invoke `reduce`
# with this *NON-ASSOCIATIVE* binary operator to see what
# associativity is chosen by the implementation:
paren = (x, y) -> "($x,$y)"
# Next, we construct data `xs` such that `length(xs)` is greater
# than short array cutoff of `_mapreduce`:
alphabets = 'a':'z'
blksize = Base.pairwise_blocksize(identity, paren) ÷ length(alphabets)
xs = repeat(alphabets, 2 * blksize)
@test length(xs) > blksize
# So far we constructed the data `xs` and reducing function
# `paren` such that `reduce` and `foldl` results are different.
# That is to say, this `reduce` does not hit the fall-back `foldl`
# branch:
@test foldl(paren, xs) != reduce(paren, xs)

# Now let's try it with `Broadcasted`:
bcraw = Broadcast.broadcasted(identity, xs)
bc = Broadcast.instantiate(bcraw)
# If `Broadcasted` has `IndexLinear` style, it should hit the
# `reduce` branch:
@test IndexStyle(bc) == IndexLinear()
@test reduce(paren, bc) == reduce(paren, xs)
# If `Broadcasted` does not have `IndexLinear` style, it should
# hit the `foldl` branch:
@test IndexStyle(bcraw) == IndexCartesian()
@test reduce(paren, bcraw) == foldl(paren, xs)
end

# treat Pair as scalar:
@test replace.(split("The quick brown fox jumps over the lazy dog"), r"[aeiou]"i => "_") ==
["Th_", "q__ck", "br_wn", "f_x", "j_mps", "_v_r", "th_", "l_zy", "d_g"]
Expand Down

0 comments on commit 2f90dde

Please sign in to comment.