-
-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
366: Add ChainRules r=oxinabox a=oxinabox This replaces #291 The bits from that OP that still matter Step 1) Change Zygote to check for chainrules before doing its normal stuff, and adapt the stuff it gets back from chainrules to play nice with Zygote's expectations Step 2) adapt Zygote more deeply, so it can take full advantage of thunks etc. This PR is Step 1. <s> TODO: workout why this seems to segfault for me. </s> Co-authored-by: Lyndon White <lyndon.white@invenialabs.co.uk> Co-authored-by: Mike J Innes <mike.j.innes@gmail.com> Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
- Loading branch information
Showing
10 changed files
with
341 additions
and
96 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
const chainrules_fallback = which(rrule, Tuple{Any}) | ||
|
||
""" | ||
has_chain_rrule(T) | ||
For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it. | ||
Excluding the generic fallback. | ||
The first return value is `true` if the `rrule` exists, `false` otherwise. | ||
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function, | ||
such that if a suitable rule is defined later, the generated function will recompile. | ||
""" | ||
function has_chain_rrule(T) | ||
m = meta(Tuple{typeof(rrule),T.parameters...}) | ||
if m.method !== chainrules_fallback | ||
# found a rrule, no need to add any edges | ||
return true, nothing | ||
end | ||
|
||
# did not find anything, will have to attach edges so it recompiles if one is added | ||
@static if VERSION >= v"1.3" | ||
@assert m.code.edges !== nothing | ||
return false, m.code.edges | ||
else | ||
# pre-julia 1.3 there are no edges | ||
return false, tuple() | ||
end | ||
end | ||
|
||
""" | ||
is_kwfunc(sigt...) | ||
Determines if `sigt` is the type signature of a kwfunction. | ||
Each element of `sigt` should be a type. | ||
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type, | ||
or the first argument is the base function type and it is not a kwfunction. | ||
the remaining types in `sigt` are the types of the argument. | ||
""" | ||
is_kwfunc(::Vararg{Any}) = false | ||
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) | ||
|
||
|
||
""" | ||
wrap_chainrules_output(x) | ||
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally | ||
(including conjugating complex gradients). | ||
""" | ||
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks | ||
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) | ||
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing | ||
for T_outer in (:Tuple, :NamedTuple) | ||
# we create separate methods rather than using a `Union` + an `if` so that we avoid a | ||
# branch that changes output type, because nested AD on that kinda thing makes Zygote less | ||
# than happy. | ||
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer} | ||
xp = map(wrap_chainrules_output, x) | ||
convert($T_outer, xp) | ||
end | ||
end | ||
|
||
""" | ||
wrap_chainrules_input(x) | ||
Convert `x` from the format Zygote uses internally (including conjugated complex gradients) | ||
to differentials types ChainRules uses. | ||
""" | ||
@inline wrap_chainrules_input(x) = conj(x) | ||
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero() | ||
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) | ||
xp = map(wrap_chainrules_input, xs) | ||
ChainRules.Composite{Any, typeof(xp)}(xp) | ||
end | ||
|
||
""" | ||
ZBack{F}(back) <: Function | ||
Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. | ||
(A functor here is used rather than a closure to avoid boxing issues); | ||
""" | ||
struct ZBack{F} <: Function | ||
back::F | ||
end | ||
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) | ||
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 | ||
# though it might be worth keeping as a performance optimization (benchmarking pending) | ||
@inline (s::ZBack)(::Nothing) = nothing | ||
|
||
""" | ||
chain_rrule(f, args...) | ||
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`. | ||
The pullback is appropriately wrapped up to follow Zygote conventions. | ||
""" | ||
@inline function chain_rrule(f, args...) | ||
y, back = rrule(f, args...) | ||
return y, ZBack(back) | ||
end | ||
|
||
|
||
""" | ||
chain_rrule_kw(kwf, kwargs, f, args...) | ||
As per [`chain_rrule`](@ref) but with support for kwargs. | ||
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments. | ||
""" | ||
@inline function chain_rrule_kw(kwf, kwargs, f, args...) | ||
y, back = rrule(f, args...; kwargs...) | ||
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs | ||
return y, kw_zpullback | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.