Skip to content

Commit

Permalink
Merge #366
Browse files Browse the repository at this point in the history
366: Add ChainRules r=oxinabox a=oxinabox

This replaces #291

The bits from that OP that still matter

Step 1) Change Zygote to check for chainrules before doing its normal stuff,
and adapt the stuff it gets back from chainrules to play nice with Zygote's expectations

Step 2) adapt Zygote more deeply, so it can take full advantage of thunks etc.


This PR is Step 1.

<s> TODO: workout why this seems to segfault for me. </s>


Co-authored-by: Lyndon White <lyndon.white@invenialabs.co.uk>
Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
4 people authored May 28, 2020
2 parents 69b2f2f + 41f4c17 commit 64c02dc
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 96 deletions.
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ version = "0.4.20"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand All @@ -14,25 +14,21 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3"
DiffRules = "0.0, 0.1, 1"
ArrayLayouts = "0.1, 0.2"
ChainRules = "0.6.0"
FillArrays = "0.8"
ForwardDiff = "0"
IRTools = "0.3"
IRTools = "=0.3.1"
MacroTools = "0.5"
NNlib = "0.6.5"
NaNMath = "0"
Requires = "0.5, 1.0"
SpecialFunctions = "0"
ZygoteRules = "0.2"
julia = "1"

Expand Down
10 changes: 10 additions & 0 deletions docs/src/adjoints.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Custom Adjoints

!!! note "Prefer to use ChainRules to define custom adjoints"
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base functions they can be added to ChainRules.jl.

This documentation exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works.
It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD).


The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.

## Pullbacks
Expand Down
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ArrayLayouts: MemoryLayout, AbstractColumnMajor

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty

using ChainRules: ChainRules, rrule, unthunk
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand All @@ -17,6 +18,7 @@ include("tools/buffer.jl")

include("compiler/reverse.jl")
include("compiler/emit.jl")
include("compiler/chainrules.jl")
include("compiler/interface.jl")
include("compiler/show.jl")

Expand Down
111 changes: 111 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
const chainrules_fallback = which(rrule, Tuple{Any})

"""
has_chain_rrule(T)
For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it.
Excluding the generic fallback.
The first return value is `true` if the `rrule` exists, `false` otherwise.
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function,
such that if a suitable rule is defined later, the generated function will recompile.
"""
function has_chain_rrule(T)
m = meta(Tuple{typeof(rrule),T.parameters...})
if m.method !== chainrules_fallback
# found a rrule, no need to add any edges
return true, nothing
end

# did not find anything, will have to attach edges so it recompiles if one is added
@static if VERSION >= v"1.3"
@assert m.code.edges !== nothing
return false, m.code.edges
else
# pre-julia 1.3 there are no edges
return false, tuple()
end
end

"""
is_kwfunc(sigt...)
Determines if `sigt` is the type signature of a kwfunction.
Each element of `sigt` should be a type.
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type,
or the first argument is the base function type and it is not a kwfunction.
the remaining types in `sigt` are the types of the argument.
"""
is_kwfunc(::Vararg{Any}) = false
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)


"""
wrap_chainrules_output(x)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally
(including conjugating complex gradients).
"""
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, x)
convert($T_outer, xp)
end
end

"""
wrap_chainrules_input(x)
Convert `x` from the format Zygote uses internally (including conjugated complex gradients)
to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = conj(x)
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Composite{Any, typeof(xp)}(xp)
end

"""
ZBack{F}(back) <: Function
Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions.
(A functor here is used rather than a closure to avoid boxing issues);
"""
struct ZBack{F} <: Function
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = nothing

"""
chain_rrule(f, args...)
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
@inline function chain_rrule(f, args...)
y, back = rrule(f, args...)
return y, ZBack(back)
end


"""
chain_rrule_kw(kwf, kwargs, f, args...)
As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
@inline function chain_rrule_kw(kwf, kwargs, f, args...)
y, back = rrule(f, args...; kwargs...)
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs
return y, kw_zpullback
end
1 change: 0 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ end
# interface2.jl

# Wrappers

_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
Expand Down
11 changes: 11 additions & 0 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters)
@generated function _pullback(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

iskw = is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
base_T = iskw ? Tuple{args[2:end]...} : T
hascr, cr_edges = has_chain_rrule(base_T)
chain_rrule_f = iskw ? :chain_rrule_kw : :chain_rrule
hascr && return :($chain_rrule_f(f, args...))

g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
forw = slots!(pis!(inlineable!(forw)))
@static if VERSION >= v"1.3" # no edges pre-1.3
append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared
end
return update!(meta.code, forw)
end

Expand Down
16 changes: 2 additions & 14 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,6 @@ end
@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)

@adjoint dot(x::AbstractArray, y::AbstractArray) = dot(x, y), Δ->.* y, Δ .* x)

function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
Expand All @@ -361,18 +359,6 @@ end

@adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b)

@adjoint function Diagonal(d::AbstractVector)
back::NamedTuple) =.diag,)
back::AbstractMatrix) = (diag(Δ),)
return Diagonal(d), back
end

@adjoint diag(A::AbstractMatrix) = diag(A), Δ->(Diagonal(Δ),)

@adjoint det(xs::Union{Number, AbstractMatrix}) = det(xs), Δ ->* det(xs) * inv(xs)',)

@adjoint logdet(xs::Union{Number, AbstractMatrix}) = logdet(xs), Δ ->* inv(xs)',)

@adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',)

@adjoint function inv(A::Union{Number, AbstractMatrix})
Expand Down Expand Up @@ -737,6 +723,8 @@ end
end
end

# ChainRules has this also but does not use FillArrays, so we have our own definition
# for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46
Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
# x is a squre matrix checked by tr,
# so we could just use Eye(size(x, 1))
Expand Down
78 changes: 5 additions & 73 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
@@ -1,37 +1,3 @@
using DiffRules, SpecialFunctions, NaNMath
using Base.FastMath: fast_op, make_fastmath

@nograd isinf, isnan, isfinite, div

# TODO use CSE here

for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue
Δ =
dx = DiffRules.diffrule(M, f, :x)
if f in [:abs, :abs2]
Δ = :(real($Δ))
else
dx = :(conj($dx))
end
@eval begin
@adjoint $M.$f(x::Number) = $M.$f(x),
Δ -> ($Δ * $dx,)
end
end

for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue
f == :^ && continue
da, db = DiffRules.diffrule(M, f, :a, :b)
@eval begin
@adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b),
Δ ->* conj($da), Δ * conj($db))
end
end

@adjoint Base.:^(x::Number, p::Number) = x^p,
Δ ->* conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x))))
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
Expand All @@ -45,20 +11,6 @@ end

@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)

@adjoint Base.muladd(x::Number, y::Number, z::Number) =
Base.muladd(x, y, z), ō -> (y'ō, x'ō, ō)

@adjoint Base.fma(x::Number, y::Number, z::Number) =
Base.fma(x, y, z), ō -> (y'ō, x'ō, ō)

@adjoint function sincos(x)
s, c = sincos(x)
(s, c), ((s̄, c̄),) -> (s̄*c -*s,)
end

@adjoint acosh(x::Complex) =
acosh(x), Δ ->* conj(inv(sqrt(x - 1) * sqrt(x + 1))),)

@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))

@nograd floor, ceil, trunc, round, hash
Expand All @@ -71,28 +23,8 @@ end
@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)

DiffRules._abs_deriv(x::Complex) = x/abs(x)

# adjoint for Fastmath operations
for (f, fastf) in fast_op
if DiffRules.hasdiffrule(:Base, f, 1)
dx = DiffRules.diffrule(:Base, f, :x)
Δ =
if f in [:abs, :abs2]
Δ = :(real($Δ))
else
dx = :(conj($dx))
end
@eval begin
@adjoint Base.FastMath.$fastf(x::Number) =
Base.FastMath.$fastf(x), Δ -> ($Δ * make_fastmath($dx),)
end
elseif DiffRules.hasdiffrule(:Base, f, 2)
dx, dy = DiffRules.diffrule(:Base, f, :x, :y)
@eval begin
@adjoint Base.FastMath.$fastf(x::Number, y::Number) =
Base.FastMath.$fastf(x, y),
Δ ->* make_fastmath(conj($dx)), Δ * make_fastmath(conj($dy)))
end
end
end
# we intentionally define these here rather than falling back on ChainRules.jl
# because ChainRules doesn't really handle nonanalytic complex functions
@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
Loading

0 comments on commit 64c02dc

Please sign in to comment.