diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 6d03c82e2..94c7a7fa3 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -64,7 +64,7 @@ function forward_stacks!(adj, F) isconcretetype(T) || (T = Any) rec = insert_node!(adj.forw, length(adj.forw.stmts), T, xtuple(recs...)) - if usetyped + if usetyped && length(adj.perm) > 1 rec = insert_node!(adj.forw, length(adj.forw.stmts), Pullback{F,T}, Expr(:call, Pullback{F,T}, rec)) else @@ -119,13 +119,18 @@ end varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing +function getmeta(T) + m = meta(T) + (usetyped && m != nothing) || return m + any(x -> isexpr(x, :goto, :gotoifnot), m.code.code) || return IRTools.meta(T) + return m +end + function _lookup_grad(T) - (m = meta(T)) == nothing && return - usetyped && m.ret == Union{} && return + (m = getmeta(T)) == nothing && return + m isa IRTools.TypedMeta && m.ret == Union{} && return va = varargs(m.method, length(T.parameters)) forw, back = stacks!(Adjoint(IRCode(m), varargs = va), T) - # verify_ir(forw) - # verify_ir(back) m, forw, back end diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index ad1f876c9..b627c0059 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -13,8 +13,7 @@ # `--' `" `--' `" `'-' using Base.Broadcast -using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize, - instantiate, flatten, combine_eltypes +using Base.Broadcast: AbstractArrayStyle, broadcasted, materialize # There's a saying that debugging code is about twice as hard as writing it in # the first place. So if you're as clever as you can be when writing code, how @@ -28,218 +27,75 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize, # Base's broadcasting is very cleverly written, and this makes differentiating # it... somewhat tricky. -# Structural utilities -# ==================== - -using Base: tail - -tcat(x) = x -tcat(x, y, z...) = tcat((x..., y...), z...) - -broadcast_args(x) = (x,) -broadcast_args(bc::Broadcasted) = tcat(map(broadcast_args, bc.args)...) - -_unflatten(x, xs) = first(xs), tail(xs) - -_unflatten(x::Tuple{}, xs) = (), xs - -function _unflatten(x::Tuple, xs) - t1, xs1 = _unflatten(first(x), xs) - t2, xs2 = _unflatten(tail(x), xs1) - (t1, t2...), xs2 -end - -function _unflatten(bc::Broadcasted, xs) - t, xs′ = _unflatten(bc.args, xs) - (args=t,f=nothing,axes=nothing), xs′ -end - -unflatten(x, xs) = _unflatten(x, xs)[1] - -unflatten(x, xs::Nothing) = nothing +# Utilities +# ========= accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims) # Work around reducedim_init issue accum_sum(xs::AbstractArray{Nothing}; dims = :) = nothing accum_sum(xs::AbstractArray{<:Real}; dims = :) = sum(xs, dims = dims) +accum_sum(xs::Real; dims = :) = xs trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, accum_sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) - -unbroadcast(x::Union{Number,Ref}, Δ) = accum_sum(Δ) - -# Trivial Mode -# ============ +unbroadcast(x::AbstractArray, x̄) = + size(x) == size(x̄) ? x̄ : + length(x) == length(x̄) ? trim(x, x̄) : + trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄))))) -# In some cases, such as `exp.(a .+ b)`, we can see the the gradient only depends -# on the output. Handling these specially is great for performance and memory -# usage, though of course relatively limited. It happens that the set of cases -# lines up nicely with activation functions commonly used in neural nets, though. +unbroadcast(x::Union{Number,Ref}, x̄) = accum_sum(x̄) -# TODO fix this up and use it +# Split Reverse Mode +# ================== -Jtrivial(f, a...) = nothing -Jtrivial(::typeof(+), a...) = a -Jtrivial(::typeof(-), a, b) = (a..., .-b...) - -trivia(_) = (1,) -function trivia(bc::Broadcasted) - t = map(trivia, bc.args) - any(t -> t === nothing, t) && return - Jtrivial(bc.f, t...) -end - -Joutput(f, a...) = nothing -Joutput(::typeof(exp), x) = map(t -> y -> y*t, x) - -function Jbroadcast(bc::Broadcasted) - t = map(trivia, bc.args) - any(t -> t === nothing, t) && return - Joutput(bc.f, t...) -end - -@inline function unbroadcast_t(x, y, ȳ, j::J) where J - trim(x, j.(y).*ȳ) -end - -@inline function unbroadcast_t(x::Number, y, ȳ, j::J) where J - x̄ = zero(float(x)) - @simd for I in eachindex(y) - @inbounds x̄ += j(y[I])*ȳ[I] - end - return x̄ -end - -function ∇broadcast_t(bc::Broadcasted, J) - y = copy(instantiate(bc)) - back(ȳ) = map(unbroadcast_t, broadcast_args(bc), map(_ -> y, J), map(_ -> ȳ, J), J) - return y, back -end +# TODO: use DiffRules here. It's complicated a little by the fact that we need +# to do CSE, then broadcast-ify the expression so that the closure captures the +# right arrays. -# Forward Mode -# ============ +@adjoint broadcasted(::typeof(+), xs...) = + broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...) -# Forward mode has many limitations – mainly in that it only supports reals / -# arrays of reals and small numbers of inputs – but in those cases it works very -# generally across broadcasted functions, and handles loops particularly well. -# Most importantly it's easy on the compiler, so until we figure out reverse -# mode we're maintaining this implementation for common cases. +@adjoint broadcasted(::typeof(*), x, y) = x.*y, + z̄ -> (nothing, unbroadcast(x, z̄ .* y), unbroadcast(y, z̄ .* x)) -import ForwardDiff -using ForwardDiff: Dual - -dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T -dualtype(T) = T - -function dual_function(f::F) where F - function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,Val(N))) do x, i - Dual(x, ntuple(j -> i==j, Val(N))) - end - return f(ds...) - end -end - -dualify(bc::Broadcasted{S}) where S = Broadcasted{S}(dual_function(bc.f), bc.args, bc.axes) - -@inline function broadcast_gradient!(bc::Broadcasted, dest::AbstractArray, grads...) - @simd for I in eachindex(bc) - @inbounds begin - out = bc[I] - dest[I] = ForwardDiff.value(out) - Δs = out isa Dual ? out.partials.values : map(_ -> false, grads) - map((g, p) -> g[I] = p, grads, Δs) - end - end -end - -function broadcast_gradient(bc::Broadcasted, ::Type{T}) where T - dest = similar(bc, T) - grads = map(_ -> similar(bc, promote_type(T,Bool)), bc.args) - broadcast_gradient!(bc, dest, grads...) - return dest, grads -end - -@inline function ∇broadcast_f(bc′::Broadcasted) - bc = dualify(instantiate(flatten(bc′))) - T = combine_eltypes(bc.f, bc.args) - T <: Bool && return copy(bc′), _ -> nothing - y, gs = broadcast_gradient(bc, dualtype(T)) - back(Δ) = (unflatten(bc′, map((x, d) -> unbroadcast(x, Δ.*d), bc.args, gs)),) - return y, back -end - -function ∇broadcast_f(bc::Broadcasted{<:AbstractArrayStyle{0}}) - out = dualify(instantiate(flatten(bc)))[] - return out.value, Δ -> (unflatten(bc, map(x -> x*Δ, out.partials.values)),) -end - -# Compatibility test - -isrealinput(x) = x isa Union{Real,AbstractArray{<:Real}} -isrealinput(bc::Broadcasted) = all(isrealinput, bc.args) - -# Reverse Mode -# ============ +# General Fallback +# ================ # The fused reverse mode implementation is the most general but currently has # poor performance. It works by flattening the broadcast and mapping the call to # `_forward` over the input. # However, the core call -# broadcast(_forward, (cx,), bc′.f, bc′.args...) +# broadcast(_forward, (cx,), f, args...) # is already 10x slower than a simple broadcast (presumably due to inlining # issues, or something similar) and the other operations needed take it to about # 100x overhead. -# One thing to experiment with would be a non-fused reverse mode, which is the -# more typical option for this kind of AD. While less efficient than fusing -# in principle, it's likely that this can be much easier on the compiler. - @generated inclen(::NTuple{N,Any}) where N = Val(N+1) -function ∇broadcast_r(cx, bc::Broadcasted) - bc′, unflatten = _forward(cx, Broadcast.flatten, bc) - len = inclen(bc′.args) - y∂b = broadcast(_forward, (cx,), bc′.f, bc′.args...) +# Avoid hitting special cases for `Adjoint` etc. +_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...)) + +@adjoint function broadcasted(::AbstractArrayStyle, f, args...) + len = inclen(args) + y∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...) y = map(x -> x[1], y∂b) ∂b = map(x -> x[2], y∂b) y, function (ȳ) dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ) dxs = ntuple(i -> map(x -> x[i], dxs_zip), len) - (f = accum_sum(dxs[1]), - args = map(unbroadcast, bc′.args, Base.tail(dxs)), - axes = nothing) |> unflatten |> Base.tail + (nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) end end -function ∇broadcast_r(bc::Broadcasted{<:AbstractArrayStyle{0}}) - bc′, unflatten = _forward(Broadcast.flatten, bc) - len = Val(length(bc′.args)+1) - y, ∂b = broadcast(_forward, bc′.f, bc′.args...) +@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...) + len = inclen(args) + y, ∂b = _broadcast((x...) -> _forward(__context__, f, x...), args...) y, function (ȳ) dxs = ∂b(ȳ) - (f = dxs[1], - args = Base.tail(dxs), - axes = nothing) |> unflatten |> Base.tail + (nothing, dxs...) end end -∇broadcast(cx, bc::Broadcasted, J) = ∇broadcast_t(bc, J) - -∇broadcast(cx, bc::Broadcasted, ::Nothing) = - isrealinput(bc) ? ∇broadcast_f(bc) : ∇broadcast_r(cx, bc) - -∇broadcast(cx, bc::Broadcasted) = ∇broadcast(cx, bc, Jbroadcast(bc)) - -@adjoint function broadcasted(f, args...) - broadcasted(f, args...), Δ -> (nothing, Δ.args...) -end - -@adjoint materialize(bc::Broadcasted{<:AbstractArrayStyle}) = - ∇broadcast(__context__, bc, nothing) +@adjoint! (b::typeof(broadcast))(f, args...) = _forward(__context__, broadcasted, f, args...) diff --git a/src/lib/grad.jl b/src/lib/grad.jl index 6cadd4b34..5dbda6265 100644 --- a/src/lib/grad.jl +++ b/src/lib/grad.jl @@ -4,6 +4,7 @@ using MacroTools: combinedef named(arg) = isexpr(arg, :(::)) && length(arg.args) == 1 ? :($(gensym())::$(arg.args[1])) : arg typeless(x) = MacroTools.prewalk(x -> isexpr(x, :(::)) ? x.args[1] : x, x) +isvararg(x) = isexpr(x, :(::)) && namify(x.args[2]) == :Vararg for n = 0:3 gradtuple = Symbol(:gradtuple, n) @@ -26,8 +27,11 @@ function gradm(ex, mut = false) (esc(gensym()), :(Core.Typeof($(esc(name))))) kT = :(Core.kwftype($T)) Ts == nothing && (Ts = []) - args = esc.(named.(args)) - argnames = typeless.(args) + args = named.(args) + argnames = Any[typeless(arg) for arg in args] + isvararg(args[end]) && (argnames[end] = :($(argnames[end])...)) + args = esc.(args) + argnames = esc.(argnames) Ts = esc.(Ts) cx = :($(esc(:__context__))::Context) fargs = kw == nothing ? [cx, :($f::$T), args...] : [kw, cx, :($f::$T), args...] diff --git a/src/lib/lib.jl b/src/lib/lib.jl index 674376bf6..6b335bb68 100644 --- a/src/lib/lib.jl +++ b/src/lib/lib.jl @@ -60,6 +60,8 @@ end # Tuples +using Base: tail + @adjoint tuple(xs...) = xs, identity @adjoint getindex(xs::NTuple{N,Any}, i::Integer) where N = diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 5230d4ff5..ebf4107f6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1,8 +1,6 @@ -using Zygote, NNlib, Test, Random, LinearAlgebra +using Zygote, NNlib, Test, Random, LinearAlgebra, Statistics using Zygote: gradient using NNlib: conv -using Statistics -import Random function ngradient(f, xs::AbstractArray...) grads = zero.(xs) @@ -355,6 +353,6 @@ end @testset "broadcast" begin if !Zygote.usetyped - @test_broken gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1 + @test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1 end end