Skip to content

Commit

Permalink
Add ChainRules fallback
Browse files Browse the repository at this point in the history
Add Manifest

do it in the right place and add blacklisting

Update blacklist to include all higher order functions

blacklist broken sum

Update src/compiler/interface2.jl
  • Loading branch information
oxinabox committed Oct 19, 2019
1 parent d9474cc commit c89c5d0
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 13 deletions.
33 changes: 22 additions & 11 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ git-tree-sha1 = "c69698c3d4a7255bc1b4bc2afc09f59db910243b"
uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f"
version = "0.6.2"

[[ChainRules]]
deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"]
git-tree-sha1 = "0d6f9017442dc7a00f53dcc1174e4e0c2a2c7751"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.2.1"

[[ChainRulesCore]]
git-tree-sha1 = "a493cc9352df2d99790f9f1225dfd9fbc52cd13e"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.3.0"

[[CommonSubexpressions]]
deps = ["Test"]
git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0"
Expand All @@ -53,9 +64,9 @@ version = "4.0.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "0809951a1774dc724da22d26e4289bbaab77809a"
git-tree-sha1 = "f94423c68f2e47db0d6f626a26d4872266e0ec3d"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.0"
version = "0.17.2"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -83,15 +94,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "e1a479d3c972f20c9a70563eec740bbfc786f515"
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
version = "0.3.0"
version = "1.0.1"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "8fba6ddaf66b45dec830233cea0aae43eb1261ad"
git-tree-sha1 = "16974065d5bfa867446d3228bc63f05a440e910b"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.6.4"
version = "0.7.2"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "InteractiveUtils", "LinearAlgebra", "NaNMath", "Random", "SparseArrays", "SpecialFunctions", "StaticArrays", "Test"]
Expand All @@ -101,7 +112,7 @@ version = "0.10.3"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "a66befa9ebc63e465212281ac027c1526382bc4e"
git-tree-sha1 = "09e4091acb2c4aac29a673fab16e0f1aa8672b2a"
repo-rev = "master"
repo-url = "https://github.com/MikeInnes/IRTools.jl.git"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
Expand Down Expand Up @@ -168,7 +179,7 @@ uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.7"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

[[Printf]]
Expand Down Expand Up @@ -213,10 +224,10 @@ deps = ["LinearAlgebra", "Random"]
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[[SpecialFunctions]]
deps = ["BinDeps", "BinaryProvider", "Libdl", "Test"]
git-tree-sha1 = "0b45dc2e45ed77f445617b99ff2adf0f5b0f23ea"
deps = ["BinDeps", "BinaryProvider", "Libdl"]
git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
version = "0.7.2"
version = "0.8.0"

[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.3.4"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -20,6 +21,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
ChainRules = "0.2.1"
IRTools = "0.2.3"
NNlib = "0.6"
ZygoteRules = "0.2"
Expand Down
1 change: 1 addition & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using LinearAlgebra: copytri!, AbstractTriangular

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty

using ChainRules: ChainRules
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand Down
1 change: 0 additions & 1 deletion src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ end
# interface2.jl

# Wrappers

_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
Expand Down
81 changes: 80 additions & 1 deletion src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,86 @@ using IRTools.Inner: argnames!, update!

ignore(T) = all(T -> T <: Type, T.parameters)

@generated function _pullback(ctx::AContext, f, args...)

function _pullback(ctx::AContext, f, args...)
if chainrules_blacklist(f, args...)
# then don't even consider using ChainRules
return _pullback_via_source2source(ctx, f, args...)
end

res = ChainRules.rrule(f, args...)
if res === nothing
# No ChainRule defined, time to do the source tranform
return _pullback_via_source2source(ctx, f, args...)
else
# Can just use ChainRule answer
y, pb = res
return y, _pullback_via_chainrules(pb)
end
end

#=="""
chainrules_blacklist(f, args...,)
This is used to disable the use of ChainRule's definitions
for particular functions/methods.
It is not required if a Zygote rule has already been defined directly.
"""==#
chainrules_blacklist(f, args...) = false

# ChainRules does higher-order functions badly
# see https://github.com/JuliaDiff/ChainRules.jl/issues/122
chainrules_blacklist(::typeof(map), args...) = true
chainrules_blacklist(::typeof(broadcast), args...) = true
chainrules_blacklist(::typeof(mapreduce), args...) = true
chainrules_blacklist(::typeof(mapfoldl), args...) = true
chainrules_blacklist(::typeof(mapfoldr), args...) = true
chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true
# Except for sum(abs2, xs), that is fine
chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false

# ChainRules current Wirtinger deriviative is not compatible
# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29
chainrules_blacklist(::typeof(abs), ::Complex) = true
chainrules_blacklist(::typeof(abs2), ::Complex) = true
chainrules_blacklist(::typeof(conj), ::Complex) = true
chainrules_blacklist(::typeof(adjoint), ::Complex) = true
chainrules_blacklist(::typeof(hypot), ::Complex) = true
chainrules_blacklist(::typeof(angle), ::Complex) = true
chainrules_blacklist(::typeof(imag), ::Complex) = true
chainrules_blacklist(::typeof(real), ::Complex) = true

# Sum of nonarrays doesn't really work
# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124
chainrules_blacklist(::typeof(sum), x) = true
chainrules_blacklist(::typeof(sum), x::AbstractArray{<:Real}) = false


#=="""
_pullback_via_chainrules(pb)
Converts a ChainRules pullback into a Zygote pullback.
`pb` should be a ChainRules pullback, as returned from the second return value of `rrule`
"""==#
function _pullback_via_chainrules(pb)
# The less optimized version of this code is
# cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...))
function zback(Δs...)
∂s = pb(Δs...)
ntuple(length(∂s)) do ii
= ∂s[ii]
zextern(∂)
end
end
end

zextern(x) = ChainRules.extern(x)
zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing
zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing


@generated function _pullback_via_source2source(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore(T) && return :(f(args...), Pullback{$T}(()))
g = try _lookup_grad(T) catch e e end
Expand Down

0 comments on commit c89c5d0

Please sign in to comment.