Skip to content

Commit

Permalink
Remove special handling of multiple inputs to pullbacks from ChainRul…
Browse files Browse the repository at this point in the history
…es (#1)

* ChainRules pullbacks always have 1 input JuliaDiff/ChainRulesCore.jl#152

* swap to version of chainrules that don't use multiarg pullbacks

* update tests

* make so don't need custom rule anymore

* add comment

* Update src/compiler/chainrules.jl

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
  • Loading branch information
oxinabox and willtebbutt authored May 20, 2020
1 parent 378f767 commit bf913a2
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 41 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2"
ChainRules = "0.5.1"
ChainRules = "0.6.0"
FillArrays = "0.8"
ForwardDiff = "0"
IRTools = "=0.3.1"
Expand Down
39 changes: 9 additions & 30 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote u
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
@inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T}
T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types
xp = map(wrap_chainrules_output, x)
convert(T_outer, xp)
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, x)
convert($T_outer, xp)
end
end


"""
wrap_chainrules_input(x)
Expand All @@ -69,24 +72,6 @@ to differentials types ChainRules uses.
ChainRules.Composite{Any, typeof(xp)}(xp)
end

"""
wrap_chainrules_pullback(f, args...)
Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`),
and the outputs.
"""
@inline function wrap_chainrules_pullback(pb, args...)
return wrap_chainrules_output(pb(wrap_chainrules_input(args)...))
end

# Note we hand-expess the single arg version of this to remove splatting
# because splatting breaks constant folding
# This can be removed after https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
@inline function wrap_chainrules_pullback(pb, a)
return wrap_chainrules_output(pb(wrap_chainrules_input(a)))
end


"""
ZBack{F}(back) <: Function
Expand All @@ -96,10 +81,7 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven
struct ZBack{F} <: Function
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy)
# Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple.
# TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152
@inline (s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...)
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = nothing
Expand Down Expand Up @@ -127,6 +109,3 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs
return y, kw_zpullback
end

# Required for nested AD
@adjoint ChainRules.Composite{Any, T}(x::T) where T = ChainRules.Composite{Any, T}(x), x->(x,)
18 changes: 8 additions & 10 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ using Zygote, Test, ChainRules
simo(x) = (5x, 7x)
function ChainRules.rrule(::typeof(simo), x)
simo_rrule_hitcount[] += 1
function simo_pullback(Δa, Δb)
function simo_pullback((Δa, Δb))
simo_pullback_hitcount[] += 1
return ChainRules.NO_FIELDS, 5*Δa + 7*Δb
end
Expand Down Expand Up @@ -101,7 +101,7 @@ using Zygote, Test, ChainRules
mimo(a, b) = (5a + 7b, 100a, 10b)
function ChainRules.rrule(::typeof(mimo), a, b)
mimo_rrule_hitcount[] += 1
function mimo_pullback(Δx, Δy, Δz)
function mimo_pullback((Δx, Δy, Δz))
mimo_pullback_hitcount[] += 1
return ChainRules.NO_FIELDS, 5Δx + 100Δy , 7Δx + 10Δz
end
Expand All @@ -126,9 +126,7 @@ using Zygote, Test, ChainRules

@testset "nested AD hitting identity(::Tuple) pullback" begin
# This is is a particularly fiddly case.
# the adjoint of `tuple` is `identity`
# and `identity(::Tuple)`s pullback has multiple inputs
# (since the primal had multiple outputs)
# Its kind of a simplified version of `sin'''(0.5)` but different in some places.

f(x) = tuple(x, 2x, 3x)

Expand Down Expand Up @@ -177,8 +175,8 @@ using Zygote, Test, ChainRules
end
end

# ChainRules doesn't have support for FastMath yet, so this fails
# https://github.com/JuliaDiff/ChainRules.jl/issues/174
@test_broken gradient(2.0) do x
@fastmath x^2.0
end == (4.0,)
@testset "FastMath support" begin
@test gradient(2.0) do x
@fastmath x^2.0
end == (4.0,)
end

0 comments on commit bf913a2

Please sign in to comment.