Skip to content

Commit

Permalink
unfused reverse broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Mar 29, 2019
1 parent 7d930ab commit 6b1b80f
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 188 deletions.
15 changes: 10 additions & 5 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ function forward_stacks!(adj, F)
isconcretetype(T) || (T = Any)
rec = insert_node!(adj.forw, length(adj.forw.stmts), T,
xtuple(recs...))
if usetyped
if usetyped && length(adj.perm) > 1
rec = insert_node!(adj.forw, length(adj.forw.stmts), Pullback{F,T},
Expr(:call, Pullback{F,T}, rec))
else
Expand Down Expand Up @@ -119,13 +119,18 @@ end

varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing

function getmeta(T)
m = meta(T)
(usetyped && m != nothing) || return m
any(x -> isexpr(x, :goto, :gotoifnot), m.code.code) || return IRTools.meta(T)
return m
end

function _lookup_grad(T)
(m = meta(T)) == nothing && return
usetyped && m.ret == Union{} && return
(m = getmeta(T)) == nothing && return
m isa IRTools.TypedMeta && m.ret == Union{} && return
va = varargs(m.method, length(T.parameters))
forw, back = stacks!(Adjoint(IRCode(m), varargs = va), T)
# verify_ir(forw)
# verify_ir(back)
m, forw, back
end

Expand Down
210 changes: 33 additions & 177 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# `--' `" `--' `" `'-'

using Base.Broadcast
using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize,
instantiate, flatten, combine_eltypes
using Base.Broadcast: AbstractArrayStyle, broadcasted, materialize

# There's a saying that debugging code is about twice as hard as writing it in
# the first place. So if you're as clever as you can be when writing code, how
Expand All @@ -28,218 +27,75 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize,
# Base's broadcasting is very cleverly written, and this makes differentiating
# it... somewhat tricky.

# Structural utilities
# ====================

using Base: tail

tcat(x) = x
tcat(x, y, z...) = tcat((x..., y...), z...)

broadcast_args(x) = (x,)
broadcast_args(bc::Broadcasted) = tcat(map(broadcast_args, bc.args)...)

_unflatten(x, xs) = first(xs), tail(xs)

_unflatten(x::Tuple{}, xs) = (), xs

function _unflatten(x::Tuple, xs)
t1, xs1 = _unflatten(first(x), xs)
t2, xs2 = _unflatten(tail(x), xs1)
(t1, t2...), xs2
end

function _unflatten(bc::Broadcasted, xs)
t, xs′ = _unflatten(bc.args, xs)
(args=t,f=nothing,axes=nothing), xs′
end

unflatten(x, xs) = _unflatten(x, xs)[1]

unflatten(x, xs::Nothing) = nothing
# Utilities
# =========

accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)

# Work around reducedim_init issue
accum_sum(xs::AbstractArray{Nothing}; dims = :) = nothing
accum_sum(xs::AbstractArray{<:Real}; dims = :) = sum(xs, dims = dims)
accum_sum(xs::Real; dims = :) = xs

trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))

unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, accum_sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))

unbroadcast(x::Union{Number,Ref}, Δ) = accum_sum(Δ)

# Trivial Mode
# ============
unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ?:
length(x) == length(x̄) ? trim(x, x̄) :
trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))

# In some cases, such as `exp.(a .+ b)`, we can see the the gradient only depends
# on the output. Handling these specially is great for performance and memory
# usage, though of course relatively limited. It happens that the set of cases
# lines up nicely with activation functions commonly used in neural nets, though.
unbroadcast(x::Union{Number,Ref}, x̄) = accum_sum(x̄)

# TODO fix this up and use it
# Split Reverse Mode
# ==================

Jtrivial(f, a...) = nothing
Jtrivial(::typeof(+), a...) = a
Jtrivial(::typeof(-), a, b) = (a..., .-b...)

trivia(_) = (1,)
function trivia(bc::Broadcasted)
t = map(trivia, bc.args)
any(t -> t === nothing, t) && return
Jtrivial(bc.f, t...)
end

Joutput(f, a...) = nothing
Joutput(::typeof(exp), x) = map(t -> y -> y*t, x)

function Jbroadcast(bc::Broadcasted)
t = map(trivia, bc.args)
any(t -> t === nothing, t) && return
Joutput(bc.f, t...)
end

@inline function unbroadcast_t(x, y, ȳ, j::J) where J
trim(x, j.(y).*ȳ)
end

@inline function unbroadcast_t(x::Number, y, ȳ, j::J) where J
= zero(float(x))
@simd for I in eachindex(y)
@inbounds+= j(y[I])*ȳ[I]
end
return
end

function ∇broadcast_t(bc::Broadcasted, J)
y = copy(instantiate(bc))
back(ȳ) = map(unbroadcast_t, broadcast_args(bc), map(_ -> y, J), map(_ -> ȳ, J), J)
return y, back
end
# TODO: use DiffRules here. It's complicated a little by the fact that we need
# to do CSE, then broadcast-ify the expression so that the closure captures the
# right arrays.

# Forward Mode
# ============
@adjoint broadcasted(::typeof(+), xs...) =
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)

# Forward mode has many limitations – mainly in that it only supports reals /
# arrays of reals and small numbers of inputs – but in those cases it works very
# generally across broadcasted functions, and handles loops particularly well.
# Most importantly it's easy on the compiler, so until we figure out reverse
# mode we're maintaining this implementation for common cases.
@adjoint broadcasted(::typeof(*), x, y) = x.*y,
-> (nothing, unbroadcast(x, z̄ .* y), unbroadcast(y, z̄ .* x))

import ForwardDiff
using ForwardDiff: Dual

dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T
dualtype(T) = T

function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
Dual(x, ntuple(j -> i==j, Val(N)))
end
return f(ds...)
end
end

dualify(bc::Broadcasted{S}) where S = Broadcasted{S}(dual_function(bc.f), bc.args, bc.axes)

@inline function broadcast_gradient!(bc::Broadcasted, dest::AbstractArray, grads...)
@simd for I in eachindex(bc)
@inbounds begin
out = bc[I]
dest[I] = ForwardDiff.value(out)
Δs = out isa Dual ? out.partials.values : map(_ -> false, grads)
map((g, p) -> g[I] = p, grads, Δs)
end
end
end

function broadcast_gradient(bc::Broadcasted, ::Type{T}) where T
dest = similar(bc, T)
grads = map(_ -> similar(bc, promote_type(T,Bool)), bc.args)
broadcast_gradient!(bc, dest, grads...)
return dest, grads
end

@inline function ∇broadcast_f(bc′::Broadcasted)
bc = dualify(instantiate(flatten(bc′)))
T = combine_eltypes(bc.f, bc.args)
T <: Bool && return copy(bc′), _ -> nothing
y, gs = broadcast_gradient(bc, dualtype(T))
back(Δ) = (unflatten(bc′, map((x, d) -> unbroadcast(x, Δ.*d), bc.args, gs)),)
return y, back
end

function ∇broadcast_f(bc::Broadcasted{<:AbstractArrayStyle{0}})
out = dualify(instantiate(flatten(bc)))[]
return out.value, Δ -> (unflatten(bc, map(x -> x*Δ, out.partials.values)),)
end

# Compatibility test

isrealinput(x) = x isa Union{Real,AbstractArray{<:Real}}
isrealinput(bc::Broadcasted) = all(isrealinput, bc.args)

# Reverse Mode
# ============
# General Fallback
# ================

# The fused reverse mode implementation is the most general but currently has
# poor performance. It works by flattening the broadcast and mapping the call to
# `_forward` over the input.

# However, the core call
# broadcast(_forward, (cx,), bc′.f, bc′.args...)
# broadcast(_forward, (cx,), f, args...)
# is already 10x slower than a simple broadcast (presumably due to inlining
# issues, or something similar) and the other operations needed take it to about
# 100x overhead.

# One thing to experiment with would be a non-fused reverse mode, which is the
# more typical option for this kind of AD. While less efficient than fusing
# in principle, it's likely that this can be much easier on the compiler.

@generated inclen(::NTuple{N,Any}) where N = Val(N+1)

function ∇broadcast_r(cx, bc::Broadcasted)
bc′, unflatten = _forward(cx, Broadcast.flatten, bc)
len = inclen(bc′.args)
y∂b = broadcast(_forward, (cx,), bc′.f, bc′.args...)
# Avoid hitting special cases for `Adjoint` etc.
_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))

@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
len = inclen(args)
y∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...)
y = map(x -> x[1], y∂b)
∂b = map(x -> x[2], y∂b)
y, function (ȳ)
dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ)
dxs = ntuple(i -> map(x -> x[i], dxs_zip), len)
(f = accum_sum(dxs[1]),
args = map(unbroadcast, bc′.args, Base.tail(dxs)),
axes = nothing) |> unflatten |> Base.tail
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
end

function ∇broadcast_r(bc::Broadcasted{<:AbstractArrayStyle{0}})
bc′, unflatten = _forward(Broadcast.flatten, bc)
len = Val(length(bc′.args)+1)
y, ∂b = broadcast(_forward, bc′.f, bc′.args...)
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
len = inclen(args)
y, ∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...)
y, function (ȳ)
dxs = ∂b(ȳ)
(f = dxs[1],
args = Base.tail(dxs),
axes = nothing) |> unflatten |> Base.tail
(nothing, dxs...)
end
end

∇broadcast(cx, bc::Broadcasted, J) = ∇broadcast_t(bc, J)

∇broadcast(cx, bc::Broadcasted, ::Nothing) =
isrealinput(bc) ? ∇broadcast_f(bc) : ∇broadcast_r(cx, bc)

∇broadcast(cx, bc::Broadcasted) = ∇broadcast(cx, bc, Jbroadcast(bc))

@adjoint function broadcasted(f, args...)
broadcasted(f, args...), Δ -> (nothing, Δ.args...)
end

@adjoint materialize(bc::Broadcasted{<:AbstractArrayStyle}) =
∇broadcast(__context__, bc, nothing)
@adjoint! (b::typeof(broadcast))(f, args...) = _forward(__context__, broadcasted, f, args...)
8 changes: 6 additions & 2 deletions src/lib/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using MacroTools: combinedef
named(arg) = isexpr(arg, :(::)) && length(arg.args) == 1 ? :($(gensym())::$(arg.args[1])) : arg

typeless(x) = MacroTools.prewalk(x -> isexpr(x, :(::)) ? x.args[1] : x, x)
isvararg(x) = isexpr(x, :(::)) && namify(x.args[2]) == :Vararg

for n = 0:3
gradtuple = Symbol(:gradtuple, n)
Expand All @@ -26,8 +27,11 @@ function gradm(ex, mut = false)
(esc(gensym()), :(Core.Typeof($(esc(name)))))
kT = :(Core.kwftype($T))
Ts == nothing && (Ts = [])
args = esc.(named.(args))
argnames = typeless.(args)
args = named.(args)
argnames = Any[typeless(arg) for arg in args]
isvararg(args[end]) && (argnames[end] = :($(argnames[end])...))
args = esc.(args)
argnames = esc.(argnames)
Ts = esc.(Ts)
cx = :($(esc(:__context__))::Context)
fargs = kw == nothing ? [cx, :($f::$T), args...] : [kw, cx, :($f::$T), args...]
Expand Down
2 changes: 2 additions & 0 deletions src/lib/lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ end

# Tuples

using Base: tail

@adjoint tuple(xs...) = xs, identity

@adjoint getindex(xs::NTuple{N,Any}, i::Integer) where N =
Expand Down
6 changes: 2 additions & 4 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
using Zygote, NNlib, Test, Random, LinearAlgebra
using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics
using Zygote: gradient
using NNlib: conv
using Statistics
import Random

function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
Expand Down Expand Up @@ -355,6 +353,6 @@ end

@testset "broadcast" begin
if !Zygote.usetyped
@test_broken gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
@test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
end
end

0 comments on commit 6b1b80f

Please sign in to comment.