From c89c5d036a145480d667a21c3d4c6c297f74b2eb Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 7 Oct 2019 14:35:18 +0100 Subject: [PATCH] Add ChainRules fallback Add Manifest do it in the right place and add blacklisting Update blacklist to include all higher order functions blacklist broken sum Update src/compiler/interface2.jl --- Manifest.toml | 33 ++++++++++------ Project.toml | 2 + src/Zygote.jl | 1 + src/compiler/interface.jl | 1 - src/compiler/interface2.jl | 81 +++++++++++++++++++++++++++++++++++++- 5 files changed, 105 insertions(+), 13 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index 59bc5ab6c..4a3d1bc83 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -27,6 +27,17 @@ git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b" uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" version = "0.6.2" +[[ChainRules]] +deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"] +git-tree-sha1 = "0d6f9017442dc7a00f53dcc1174e4e0c2a2c7751" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "0.2.1" + +[[ChainRulesCore]] +git-tree-sha1 = "a493cc9352df2d99790f9f1225dfd9fbc52cd13e" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.3.0" + [[CommonSubexpressions]] deps = ["Test"] git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" @@ -53,9 +64,9 @@ version = "4.0.0" [[DataStructures]] deps = ["InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a" +git-tree-sha1 = "f94423c68f2e47db0d6f626a26d4872266e0ec3d" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.17.0" +version = "0.17.2" [[Dates]] deps = ["Printf"] @@ -83,15 +94,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[FFTW]] deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"] -git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515" +git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "0.3.0" +version = "1.0.1" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad" +git-tree-sha1 = "16974065d5bfa867446d3228bc63f05a440e910b" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.6.4" +version = "0.7.2" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"] @@ -101,7 +112,7 @@ version = "0.10.3" [[IRTools]] deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "a66befa9ebc63e465212281ac027c1526382bc4e" +git-tree-sha1 = "09e4091acb2c4aac29a673fab16e0f1aa8672b2a" repo-rev = "master" repo-url = "https://github.com/MikeInnes/IRTools.jl.git" uuid = "7869d1d1-7146-5819-86e3-90919afe41df" @@ -168,7 +179,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" version = "0.3.7" [[Pkg]] -deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] @@ -213,10 +224,10 @@ deps = ["LinearAlgebra", "Random"] uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[SpecialFunctions]] -deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"] -git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea" +deps = ["BinDeps", "BinaryProvider", "Libdl"] +git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.7.2" +version = "0.8.0" [[StaticArrays]] deps = ["LinearAlgebra", "Random", "Statistics"] diff --git a/Project.toml b/Project.toml index ba5ab3748..ca05c5000 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" version = "0.3.4" [deps] +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -20,6 +21,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] +ChainRules = "0.2.1" IRTools = "0.2.3" NNlib = "0.6" ZygoteRules = "0.2" diff --git a/src/Zygote.jl b/src/Zygote.jl index d2fcb199a..bb27729ce 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -5,6 +5,7 @@ using LinearAlgebra: copytri!, AbstractTriangular import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty +using ChainRules: ChainRules using IRTools using MacroTools, Requires using MacroTools: @forward diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 8182152f9..2639009e9 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -27,7 +27,6 @@ end # interface2.jl # Wrappers - _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index b5b4e0816..509e3a48a 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -3,7 +3,86 @@ using IRTools.Inner: argnames!, update! ignore(T) = all(T -> T <: Type, T.parameters) -@generated function _pullback(ctx::AContext, f, args...) + +function _pullback(ctx::AContext, f, args...) + if chainrules_blacklist(f, args...) + # then don't even consider using ChainRules + return _pullback_via_source2source(ctx, f, args...) + end + + res = ChainRules.rrule(f, args...) + if res === nothing + # No ChainRule defined, time to do the source tranform + return _pullback_via_source2source(ctx, f, args...) + else + # Can just use ChainRule answer + y, pb = res + return y, _pullback_via_chainrules(pb) + end +end + +#==""" + chainrules_blacklist(f, args...,) + +This is used to disable the use of ChainRule's definitions +for particular functions/methods. + +It is not required if a Zygote rule has already been defined directly. +"""==# +chainrules_blacklist(f, args...) = false + +# ChainRules does higher-order functions badly +# see https://github.com/JuliaDiff/ChainRules.jl/issues/122 +chainrules_blacklist(::typeof(map), args...) = true +chainrules_blacklist(::typeof(broadcast), args...) = true +chainrules_blacklist(::typeof(mapreduce), args...) = true +chainrules_blacklist(::typeof(mapfoldl), args...) = true +chainrules_blacklist(::typeof(mapfoldr), args...) = true +chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true +# Except for sum(abs2, xs), that is fine +chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false + +# ChainRules current Wirtinger deriviative is not compatible +# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29 +chainrules_blacklist(::typeof(abs), ::Complex) = true +chainrules_blacklist(::typeof(abs2), ::Complex) = true +chainrules_blacklist(::typeof(conj), ::Complex) = true +chainrules_blacklist(::typeof(adjoint), ::Complex) = true +chainrules_blacklist(::typeof(hypot), ::Complex) = true +chainrules_blacklist(::typeof(angle), ::Complex) = true +chainrules_blacklist(::typeof(imag), ::Complex) = true +chainrules_blacklist(::typeof(real), ::Complex) = true + +# Sum of nonarrays doesn't really work +# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124 +chainrules_blacklist(::typeof(sum), x) = true +chainrules_blacklist(::typeof(sum), x::AbstractArray{<:Real}) = false + + +#==""" + _pullback_via_chainrules(pb) + +Converts a ChainRules pullback into a Zygote pullback. +`pb` should be a ChainRules pullback, as returned from the second return value of `rrule` +"""==# +function _pullback_via_chainrules(pb) + # The less optimized version of this code is + # cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...)) + function zback(Δs...) + ∂s = pb(Δs...) + ntuple(length(∂s)) do ii + ∂ = ∂s[ii] + zextern(∂) + end + end +end + +zextern(x) = ChainRules.extern(x) +zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing +zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing + + +@generated function _pullback_via_source2source(ctx::AContext, f, args...) T = Tuple{f,args...} ignore(T) && return :(f(args...), Pullback{$T}(())) g = try _lookup_grad(T) catch e e end