Skip to content

Commit

Permalink
simplify Broadcast object computations (#49395)
Browse files Browse the repository at this point in the history
Code should normally preserve values, not the types of values. This
ensures the user can define styles with metadata, and requires less
type-parameter-based programming, but rather can focus on the values.
  • Loading branch information
vtjnash authored Apr 20, 2023
1 parent c237c0a commit 02b7b04
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 30 deletions.
70 changes: 41 additions & 29 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,28 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})

struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
style::Style
f::F
args::Args
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
end

Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} =
Broadcasted{typeof(combine_styles(args...))}(f, args, axes)
function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple}
# using Core.Typeof rather than F preserves inferrability when f is a type
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
Broadcasted(style::Union{Nothing,BroadcastStyle}, f::Tuple, args::Tuple) = error() # disambiguation: tuple is not callable
function Broadcasted(style::Union{Nothing,BroadcastStyle}, f::F, args::Tuple, axes=nothing) where {F}
# using Core.Typeof rather than F preserves inferrability when f is a type
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(style, f, args, axes)
end

function Broadcasted(f::F, args::Tuple, axes=nothing) where {F}
Broadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
end

function Broadcasted{Style}(f::F, args, axes=nothing) where {Style, F}
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(Style()::Style, f, args, axes)
end

function Broadcasted{Style,Axes,F,Args}(f, args, axes) where {Style,Axes,F,Args}
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
end
end

struct AndAnd end
Expand All @@ -194,16 +206,16 @@ function broadcasted(::OrOr, a, bc::Broadcasted)
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
end

Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{<:Any,Axes,F,Args}) where {NewStyle,Axes,F,Args} =
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)::Broadcasted{NewStyle,Axes,F,Args}

function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
print(io, Broadcasted)
# Only show the style parameter if we have a set of axes — representing an instantiated
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
# computation that is not relevant for dispatch, confusing, and just extra line noise.
bc.axes isa Tuple && print(io, '{', Style, '}')
print(io, '(', bc.f, ", ", bc.args, ')')
bc.axes isa Tuple && print(io, "{", Style, "}")
print(io, "(", bc.f, ", ", bc.args, ")")
nothing
end

Expand Down Expand Up @@ -231,7 +243,7 @@ BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))

argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args
argtype(::Type{BC}) where {BC<:Broadcasted} = fieldtype(BC, :args)
argtype(bc::Broadcasted) = argtype(typeof(bc))

@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
Expand Down Expand Up @@ -262,7 +274,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
end

Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N

_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
Expand All @@ -289,14 +301,14 @@ Custom [`BroadcastStyle`](@ref)s may override this default in cases where it is
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
of the `Broadcasted` object empty (populated with [`nothing`](@ref)).
"""
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
@inline function instantiate(bc::Broadcasted)
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
axes = combine_axes(bc.args...)
else
axes = bc.axes
check_broadcast_axes(axes, bc.args...)
end
return Broadcasted{Style}(bc.f, bc.args, axes)
return Broadcasted(bc.style, bc.f, bc.args, axes)
end
instantiate(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc
# Tuples don't need axes, but when they have axes (for .= assignment), we need to check them (#33020)
Expand Down Expand Up @@ -325,7 +337,7 @@ becomes
This is an optional operation that may make custom implementation of broadcasting easier in
some cases.
"""
function flatten(bc::Broadcasted{Style}) where {Style}
function flatten(bc::Broadcasted)
isflat(bc) && return bc
# concatenate the nested arguments into {a, b, c, d}
args = cat_nested(bc)
Expand All @@ -341,7 +353,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
newf = @inline function(args::Vararg{Any,N}) where N
f(makeargs(args...)...)
end
return Broadcasted{Style}(newf, args, bc.axes)
return Broadcasted(bc.style, newf, args, bc.axes)
end
end

Expand Down Expand Up @@ -895,11 +907,11 @@ materialize(x) = x
return materialize!(dest, instantiate(Broadcasted(identity, (x,), axes(dest))))
end

@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style}
@inline function materialize!(dest, bc::Broadcasted{<:Any})
return materialize!(combine_styles(dest, bc), dest, bc)
end
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{Style}) where {Style}
return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{<:Any})
return copyto!(dest, instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest))))
end

## general `copy` methods
Expand All @@ -909,7 +921,7 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) =

const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}

@inline function copy(bc::Broadcasted{Style}) where {Style}
@inline function copy(bc::Broadcasted)
ElType = combine_eltypes(bc.f, bc.args)
if Base.isconcretetype(ElType)
# We can trust it and defer to the simpler `copyto!`
Expand Down Expand Up @@ -968,7 +980,7 @@ broadcast_unalias(::Nothing, src) = src
# Preprocessing a `Broadcasted` does two things:
# * unaliases any arguments from `dest`
# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices
@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes)
@inline preprocess(dest, bc::Broadcasted) = Broadcasted(bc.style, bc.f, preprocess_args(dest, bc.args), bc.axes)
preprocess(dest, x) = extrude(broadcast_unalias(dest, x))

@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...)
Expand Down Expand Up @@ -1038,11 +1050,11 @@ ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischu
ischunkedbroadcast(R, args::Tuple{}) = true

# Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps
liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)(xor), map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,<:Any}) = Broadcasted(bc.style, bc.f, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(sign)}) = Broadcasted(bc.style, identity, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(!)}) = Broadcasted(bc.style, ~, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(*)}) = Broadcasted(bc.style, &, map(liftfuncs, bc.args), bc.axes)
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(==)}) = Broadcasted(bc.style, (~)(xor), map(liftfuncs, bc.args), bc.axes)
liftfuncs(x) = x

liftchunks(::Tuple{}) = ()
Expand Down Expand Up @@ -1315,26 +1327,26 @@ end
return broadcasted((args...) -> f(args...; kwargs...), args...)
end
end
@inline function broadcasted(f, args...)
@inline function broadcasted(f::F, args...) where {F}
args′ = map(broadcastable, args)
broadcasted(combine_styles(args′...), f, args′...)
end
# Due to the current Type{T}/DataType specialization heuristics within Tuples,
# the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in
# mapping broadcastable across the args. These additional methods with explicit
# arguments ensure we preserve Type{T}s in the first or second argument position.
@inline function broadcasted(f, arg1, args...)
@inline function broadcasted(f::F, arg1, args...) where {F}
arg1′ = broadcastable(arg1)
args′ = map(broadcastable, args)
broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...)
end
@inline function broadcasted(f, arg1, arg2, args...)
@inline function broadcasted(f::F, arg1, arg2, args...) where {F}
arg1′ = broadcastable(arg1)
arg2′ = broadcastable(arg2)
args′ = map(broadcastable, args)
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
end
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)
@inline broadcasted(style::BroadcastStyle, f::F, args...) where {F} = Broadcasted(style, f, args)

"""
BroadcastFunction{F} <: Function
Expand Down
2 changes: 1 addition & 1 deletion test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ let
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()
@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()

# inference on nested
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
Expand Down

0 comments on commit 02b7b04

Please sign in to comment.