From 1c3ee6dc587cc74982443cc0d4b11b92a8d48b22 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 7 Oct 2019 14:35:18 +0100 Subject: [PATCH 01/35] Add ChainRules fallback Add Manifest do it in the right place and add blacklisting Update blacklist to include all higher order functions blacklist broken sum Update src/compiler/interface2.jl --- Project.toml | 2 + src/Zygote.jl | 1 + src/compiler/interface.jl | 1 - src/compiler/interface2.jl | 81 +++++++++++++++++++++++++++++++++++++- 4 files changed, 83 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index e0f1316f9..34592cbf3 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.4.20" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -26,6 +27,7 @@ AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2, 0.3" DiffRules = "0.0, 0.1, 1" FillArrays = "0.8" +ChainRules = "0.2.1" ForwardDiff = "0" IRTools = "0.3" MacroTools = "0.5" diff --git a/src/Zygote.jl b/src/Zygote.jl index 550ee0bb0..e5206c65d 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -6,6 +6,7 @@ using ArrayLayouts: MemoryLayout, AbstractColumnMajor import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty +using ChainRules: ChainRules using IRTools using MacroTools, Requires using MacroTools: @forward diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl index 1a785b6ef..d9dd1b2cb 100644 --- a/src/compiler/interface.jl +++ b/src/compiler/interface.jl @@ -35,7 +35,6 @@ end # interface2.jl # Wrappers - _pullback(f, args...) = _pullback(Context(), f, args...) tailmemaybe(::Nothing) = nothing diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index a0484c0d1..3a28351af 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -3,7 +3,86 @@ using IRTools.Inner: argnames!, update! ignore_sig(T) = all(T -> T <: Type, T.parameters) -@generated function _pullback(ctx::AContext, f, args...) + +function _pullback(ctx::AContext, f, args...) + if chainrules_blacklist(f, args...) + # then don't even consider using ChainRules + return _pullback_via_source2source(ctx, f, args...) + end + + res = ChainRules.rrule(f, args...) + if res === nothing + # No ChainRule defined, time to do the source tranform + return _pullback_via_source2source(ctx, f, args...) + else + # Can just use ChainRule answer + y, pb = res + return y, _pullback_via_chainrules(pb) + end +end + +#==""" + chainrules_blacklist(f, args...,) + +This is used to disable the use of ChainRule's definitions +for particular functions/methods. + +It is not required if a Zygote rule has already been defined directly. +"""==# +chainrules_blacklist(f, args...) = false + +# ChainRules does higher-order functions badly +# see https://github.com/JuliaDiff/ChainRules.jl/issues/122 +chainrules_blacklist(::typeof(map), args...) = true +chainrules_blacklist(::typeof(broadcast), args...) = true +chainrules_blacklist(::typeof(mapreduce), args...) = true +chainrules_blacklist(::typeof(mapfoldl), args...) = true +chainrules_blacklist(::typeof(mapfoldr), args...) = true +chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true +# Except for sum(abs2, xs), that is fine +chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false + +# ChainRules current Wirtinger deriviative is not compatible +# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29 +chainrules_blacklist(::typeof(abs), ::Complex) = true +chainrules_blacklist(::typeof(abs2), ::Complex) = true +chainrules_blacklist(::typeof(conj), ::Complex) = true +chainrules_blacklist(::typeof(adjoint), ::Complex) = true +chainrules_blacklist(::typeof(hypot), ::Complex) = true +chainrules_blacklist(::typeof(angle), ::Complex) = true +chainrules_blacklist(::typeof(imag), ::Complex) = true +chainrules_blacklist(::typeof(real), ::Complex) = true + +# Sum of nonarrays doesn't really work +# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124 +chainrules_blacklist(::typeof(sum), x) = true +chainrules_blacklist(::typeof(sum), x::AbstractArray{<:Real}) = false + + +#==""" + _pullback_via_chainrules(pb) + +Converts a ChainRules pullback into a Zygote pullback. +`pb` should be a ChainRules pullback, as returned from the second return value of `rrule` +"""==# +function _pullback_via_chainrules(pb) + # The less optimized version of this code is + # cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...)) + function zback(Δs...) + ∂s = pb(Δs...) + ntuple(length(∂s)) do ii + ∂ = ∂s[ii] + zextern(∂) + end + end +end + +zextern(x) = ChainRules.extern(x) +zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing +zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing + + +@generated function _pullback_via_source2source(ctx::AContext, f, args...) T = Tuple{f,args...} ignore_sig(T) && return :(f(args...), Pullback{$T}(())) g = try _lookup_grad(T) catch e e end From 0280d45b59fa3530b3d3c1774f9446fcad2d1570 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 18 Oct 2019 20:48:26 +0100 Subject: [PATCH 02/35] Test ChainRules integration directly Update test/chainrules.jl Update test/chainrules.jl --- test/chainrules.jl | 49 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 6 ++++++ 2 files changed, 55 insertions(+) create mode 100644 test/chainrules.jl diff --git a/test/chainrules.jl b/test/chainrules.jl new file mode 100644 index 000000000..e804c39b5 --- /dev/null +++ b/test/chainrules.jl @@ -0,0 +1,49 @@ +using Zygote, Test, ChainRules + +const cr_inner_demo_rrule_hitcount = Ref(0) +const cr_inner_demo_pullback_hitcount = Ref(0) +cr_inner_demo(x) = 5x +function ChainRules.rrule(::typeof(cr_inner_demo), x) + cr_inner_demo_rrule_hitcount[] += 1 + function cr_inner_demo_pullback(Δx) + cr_inner_demo_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, 5.0*Δx + end + return cr_inner_demo(x), cr_inner_demo_pullback +end + +function cr_outer_demo(x) + 2 + 10cr_inner_demo(x) +end + +@testset "ChainRules Integration" begin + @testset "gradient inner" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + @test (5.0,) == gradient(cr_inner_demo, 11) + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 1 + end + + @testset "gradient outer" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + @test (50.0,) == gradient(cr_outer_demo, 11) + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 1 + end + + @testset "pullback inner" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + y, pb = pullback(cr_inner_demo, 11) + @test y == 55 + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 0 + @test pb(1)==(5.0,); + @test pb(2)==(10.0,); + @test pb(3)==(15.0,); + @test cr_inner_demo_pullback_hitcount[] == 3 + @test cr_inner_demo_rrule_hitcount[] == 1 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6e02ebd9a..f8c30a958 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -27,6 +27,12 @@ end include("structures.jl") end +@info "Testing ChainRules integration" + +@testset "ChainRules" begin + include("chainrules.jl") +end + @info "Running Gradient Checks" @testset "Gradients" begin From c2b37c583547d07ee7a0017ebb49a3ab90d54841 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Sun, 20 Oct 2019 10:15:26 +0100 Subject: [PATCH 03/35] use metaprogramming in blacklist Update src/compiler/interface2.jl fix missing eval --- src/compiler/interface2.jl | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 3a28351af..b949d2cc0 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -33,25 +33,18 @@ chainrules_blacklist(f, args...) = false # ChainRules does higher-order functions badly # see https://github.com/JuliaDiff/ChainRules.jl/issues/122 -chainrules_blacklist(::typeof(map), args...) = true -chainrules_blacklist(::typeof(broadcast), args...) = true -chainrules_blacklist(::typeof(mapreduce), args...) = true -chainrules_blacklist(::typeof(mapfoldl), args...) = true -chainrules_blacklist(::typeof(mapfoldr), args...) = true +for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr) + @eval chainrules_blacklist(::typeof($f), args...) = true +end chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true # Except for sum(abs2, xs), that is fine chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false # ChainRules current Wirtinger deriviative is not compatible # reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29 -chainrules_blacklist(::typeof(abs), ::Complex) = true -chainrules_blacklist(::typeof(abs2), ::Complex) = true -chainrules_blacklist(::typeof(conj), ::Complex) = true -chainrules_blacklist(::typeof(adjoint), ::Complex) = true -chainrules_blacklist(::typeof(hypot), ::Complex) = true -chainrules_blacklist(::typeof(angle), ::Complex) = true -chainrules_blacklist(::typeof(imag), ::Complex) = true -chainrules_blacklist(::typeof(real), ::Complex) = true +for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real) + @eval chainrules_blacklist(::typeof($f), ::Complex) = true +end # Sum of nonarrays doesn't really work # Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124 @@ -66,8 +59,8 @@ Converts a ChainRules pullback into a Zygote pullback. `pb` should be a ChainRules pullback, as returned from the second return value of `rrule` """==# function _pullback_via_chainrules(pb) - # The less optimized version of this code is - # cback2zback(pb) = (Δs...) -> zextern.(pb(Δs...)) + # This is the optimized version of + # _pullback_via_chainrules(pb) = (Δs...) -> zextern.(pb(Δs...)) function zback(Δs...) ∂s = pb(Δs...) ntuple(length(∂s)) do ii From ec5c65fa08a980d90bfa9bddebf3b12a6d265dc6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 23 Oct 2019 15:56:04 +0100 Subject: [PATCH 04/35] compact choosing pullback mechanism code wip: ChainRules v0.3 Update blacklists --- Project.toml | 2 +- src/compiler/interface2.jl | 86 ++++++++++++++++---------------------- 2 files changed, 36 insertions(+), 52 deletions(-) diff --git a/Project.toml b/Project.toml index 34592cbf3..b40e69cdc 100644 --- a/Project.toml +++ b/Project.toml @@ -27,7 +27,7 @@ AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2, 0.3" DiffRules = "0.0, 0.1, 1" FillArrays = "0.8" -ChainRules = "0.2.1" +ChainRules = "0.3.0" ForwardDiff = "0" IRTools = "0.3" MacroTools = "0.5" diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index b949d2cc0..268651acc 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -3,16 +3,9 @@ using IRTools.Inner: argnames!, update! ignore_sig(T) = all(T -> T <: Type, T.parameters) - function _pullback(ctx::AContext, f, args...) - if chainrules_blacklist(f, args...) - # then don't even consider using ChainRules - return _pullback_via_source2source(ctx, f, args...) - end - - res = ChainRules.rrule(f, args...) - if res === nothing - # No ChainRule defined, time to do the source tranform + if chainrules_blacklist(f, args...) || (res = ChainRules.rrule(f, args...)) === nothing + # Blacklisted or no ChainRule defined, time to do the source tranform return _pullback_via_source2source(ctx, f, args...) else # Can just use ChainRule answer @@ -21,6 +14,39 @@ function _pullback(ctx::AContext, f, args...) end end +@generated function _pullback_via_source2source(ctx::AContext, f, args...) + T = Tuple{f,args...} + ignore(T) && return :(f(args...), Pullback{$T}(())) + g = try _lookup_grad(T) catch e e end + !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,))) + meta, forw, _ = g + argnames!(meta, Symbol("#self#"), :ctx, :f, :args) + forw = varargs!(meta, forw, 3) + # IRTools.verify(forw) + forw = slots!(pis!(inlineable!(forw))) + return update!(meta.code, forw) +end + +@generated function (j::Pullback{T})(Δ) where T + ignore(T) && return :nothing + g = try _lookup_grad(T) + catch e + rethrow(CompileError(T,e)) + end + if g == nothing + Δ == Nothing && return :nothing + return :(error("Non-differentiable function $(repr(j.t[1]))")) + end + meta, _, back = g + argnames!(meta, Symbol("#self#"), :Δ) + # IRTools.verify(back) + back = slots!(inlineable!(back)) + return update!(meta.code, back) +end + + + + #==""" chainrules_blacklist(f, args...,) @@ -40,17 +66,6 @@ chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true # Except for sum(abs2, xs), that is fine chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false -# ChainRules current Wirtinger deriviative is not compatible -# reconsider after https://github.com/JuliaDiff/ChainRulesCore.jl/pull/29 -for f in (abs, abs2, conj, adjoint, hypot, angle, imag, real) - @eval chainrules_blacklist(::typeof($f), ::Complex) = true -end - -# Sum of nonarrays doesn't really work -# Fixed in https://github.com/JuliaDiff/ChainRules.jl/pull/124 -chainrules_blacklist(::typeof(sum), x) = true -chainrules_blacklist(::typeof(sum), x::AbstractArray{<:Real}) = false - #==""" _pullback_via_chainrules(pb) @@ -73,34 +88,3 @@ end zextern(x) = ChainRules.extern(x) zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing - - -@generated function _pullback_via_source2source(ctx::AContext, f, args...) - T = Tuple{f,args...} - ignore_sig(T) && return :(f(args...), Pullback{$T}(())) - g = try _lookup_grad(T) catch e e end - !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,))) - meta, forw, _ = g - argnames!(meta, Symbol("#self#"), :ctx, :f, :args) - forw = varargs!(meta, forw, 3) - # IRTools.verify(forw) - forw = slots!(pis!(inlineable!(forw))) - return update!(meta.code, forw) -end - -@generated function (j::Pullback{T})(Δ) where T - ignore_sig(T) && return :nothing - g = try _lookup_grad(T) - catch e - rethrow(CompileError(T,e)) - end - if g == nothing - Δ == Nothing && return :nothing - return :(error("Non-differentiable function $(repr(j.t[1]))")) - end - meta, _, back = g - argnames!(meta, Symbol("#self#"), :Δ) - # IRTools.verify(back) - back = slots!(inlineable!(back)) - return update!(meta.code, back) -end From 3d9bd6221f6a488685d26e5d09fbce5728e0ace2 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 13 Jan 2020 19:34:34 +0000 Subject: [PATCH 05/35] Simplify conversion Zygote style pullback don't conjugate on way out as ChainRules currently doesn't do complex so no need to fight that fight. --- src/compiler/interface2.jl | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 268651acc..f24f71990 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -74,17 +74,12 @@ Converts a ChainRules pullback into a Zygote pullback. `pb` should be a ChainRules pullback, as returned from the second return value of `rrule` """==# function _pullback_via_chainrules(pb) - # This is the optimized version of - # _pullback_via_chainrules(pb) = (Δs...) -> zextern.(pb(Δs...)) - function zback(Δs...) + function zygote_pullback(Δs...) ∂s = pb(Δs...) - ntuple(length(∂s)) do ii - ∂ = ∂s[ii] - zextern(∂) - end + # TODO: Should not unthunk on the way out of a pullback, but rather on way in since + # that is when we know it is probably going to be used. + ∂s_zy = map(ChainRules.unthunk, ∂s) + @info "Invoking via ChainRules" typeof(pb) typeof(∂s) typeof(∂s_zy) + return ∂s_zy end end - -zextern(x) = ChainRules.extern(x) -zextern(::ChainRules.Zero) = nothing # Zygote loves calling things nothing -zextern(::ChainRules.DNE) = nothing # Zygote loves calling things nothing From 449a6a77885efbfe992bd2f368a9d89662dc1c5b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 20 Apr 2020 12:54:43 +0100 Subject: [PATCH 06/35] add ChainRules, rm DiffRules --- Manifest.toml | 208 +++++++++++++++++++++++++++++++++++++ Project.toml | 8 +- src/Zygote.jl | 1 + src/compiler/interface2.jl | 21 ++-- src/lib/number.jl | 100 +++++++++--------- 5 files changed, 271 insertions(+), 67 deletions(-) create mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml new file mode 100644 index 000000000..ee7525f32 --- /dev/null +++ b/Manifest.toml @@ -0,0 +1,208 @@ +# This file is machine-generated - editing it directly is not advised + +[[AbstractFFTs]] +deps = ["LinearAlgebra"] +git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" +uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" +version = "0.5.0" + +[[ArrayLayouts]] +deps = ["FillArrays", "LinearAlgebra"] +git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c" +uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" +version = "0.2.6" + +[[Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[BinaryProvider]] +deps = ["Libdl", "Logging", "SHA"] +git-tree-sha1 = "428e9106b1ff27593cbd979afac9b45b82372b8c" +uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" +version = "0.5.9" + +[[ChainRules]] +deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"] +git-tree-sha1 = "f7175b1c1275b55e67b926c8d5ba57188b01c679" +uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" +version = "0.5.2" + +[[ChainRulesCore]] +deps = ["MuladdMacro"] +git-tree-sha1 = "e7f1b2b4ba7146575e1a30345e0ae842ae4c74d8" +uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +version = "0.7.5" + +[[CommonSubexpressions]] +deps = ["Test"] +git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" +uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" +version = "0.2.0" + +[[CompilerSupportLibraries_jll]] +deps = ["Libdl", "Pkg"] +git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.3.3+0" + +[[Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[DiffResults]] +deps = ["StaticArrays"] +git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" +uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" +version = "1.0.2" + +[[DiffRules]] +deps = ["NaNMath", "Random", "SpecialFunctions"] +git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" +uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" +version = "1.0.1" + +[[Distributed]] +deps = ["Random", "Serialization", "Sockets"] +uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" + +[[FillArrays]] +deps = ["LinearAlgebra", "Random", "SparseArrays"] +git-tree-sha1 = "5322d34d7600d3429665b37bcf7628dc602a28cc" +uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" +version = "0.8.8" + +[[ForwardDiff]] +deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] +git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac" +uuid = "f6369f11-7733-5829-9624-2563aa707210" +version = "0.10.10" + +[[IRTools]] +deps = ["InteractiveUtils", "MacroTools", "Test"] +git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410" +uuid = "7869d1d1-7146-5819-86e3-90919afe41df" +version = "0.3.1" + +[[InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[LibGit2]] +deps = ["Printf"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[MacroTools]] +deps = ["Markdown", "Random"] +git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" +uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +version = "0.5.5" + +[[Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[MuladdMacro]] +git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" +uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" +version = "0.2.2" + +[[NNlib]] +deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] +git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" +uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +version = "0.6.6" + +[[NaNMath]] +git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" +uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +version = "0.3.3" + +[[OpenSpecFun_jll]] +deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] +git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" +uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" +version = "0.5.3+3" + +[[Pkg]] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" + +[[Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[Random]] +deps = ["Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[Reexport]] +deps = ["Pkg"] +git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" +uuid = "189a3867-3050-52da-a836-e630ba90ab69" +version = "0.2.0" + +[[Requires]] +deps = ["UUIDs"] +git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" +uuid = "ae029012-a4dd-5104-9daa-d747884805df" +version = "1.0.1" + +[[SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" + +[[Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[SpecialFunctions]] +deps = ["OpenSpecFun_jll"] +git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c" +uuid = "276daf66-3868-5448-9aa4-cd146d93841b" +version = "0.10.0" + +[[StaticArrays]] +deps = ["LinearAlgebra", "Random", "Statistics"] +git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0" +uuid = "90137ffa-7385-5640-81b9-e52037218182" +version = "0.12.3" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + +[[Test]] +deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[[UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[ZygoteRules]] +deps = ["MacroTools"] +git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" +uuid = "700de1a5-db45-46bc-99cf-38207098b444" +version = "0.2.0" diff --git a/Project.toml b/Project.toml index b40e69cdc..6fce8246f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.4.20" AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" -DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Future = "9fa8497b-333b-5362-9e8d-4d0656e87820" @@ -15,26 +14,21 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5" -ArrayLayouts = "0.1, 0.2, 0.3" -DiffRules = "0.0, 0.1, 1" +ArrayLayouts = "0.1, 0.2" FillArrays = "0.8" ChainRules = "0.3.0" ForwardDiff = "0" IRTools = "0.3" MacroTools = "0.5" NNlib = "0.6.5" -NaNMath = "0" Requires = "0.5, 1.0" -SpecialFunctions = "0" ZygoteRules = "0.2" julia = "1" diff --git a/src/Zygote.jl b/src/Zygote.jl index e5206c65d..d24ba2fc1 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,6 +3,7 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular using ArrayLayouts: MemoryLayout, AbstractColumnMajor +using ChainRules import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index f24f71990..445b37b22 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -3,27 +3,29 @@ using IRTools.Inner: argnames!, update! ignore_sig(T) = all(T -> T <: Type, T.parameters) -function _pullback(ctx::AContext, f, args...) - if chainrules_blacklist(f, args...) || (res = ChainRules.rrule(f, args...)) === nothing - # Blacklisted or no ChainRule defined, time to do the source tranform - return _pullback_via_source2source(ctx, f, args...) +const chainrules_fallback = which(rrule, Tuple{Any}) + +function has_chainrule(T) + m = meta(Tuple{typeof(rrule),T.parameters...}) + if m.method === chainrules_fallback + return false, m.code.edges else - # Can just use ChainRule answer - y, pb = res - return y, _pullback_via_chainrules(pb) + return true, nothing end end -@generated function _pullback_via_source2source(ctx::AContext, f, args...) +@generated function _pullback(ctx::AContext, f, args...) T = Tuple{f,args...} ignore(T) && return :(f(args...), Pullback{$T}(())) + hascr, cr_edges = has_chainrule(T) + hascr && return :(rrule(f, args...)) g = try _lookup_grad(T) catch e e end !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) - # IRTools.verify(forw) forw = slots!(pis!(inlineable!(forw))) + append!(meta.code.edges, cr_edges) return update!(meta.code, forw) end @@ -39,7 +41,6 @@ end end meta, _, back = g argnames!(meta, Symbol("#self#"), :Δ) - # IRTools.verify(back) back = slots!(inlineable!(back)) return update!(meta.code, back) end diff --git a/src/lib/number.jl b/src/lib/number.jl index e28139a3a..1a4069801 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -1,34 +1,34 @@ using DiffRules, SpecialFunctions, NaNMath using Base.FastMath: fast_op, make_fastmath -@nograd isinf, isnan, isfinite, div +# @nograd isinf, isnan, isfinite, div # TODO use CSE here -for (M, f, arity) in DiffRules.diffrules() - arity == 1 || continue - Δ = :Δ - dx = DiffRules.diffrule(M, f, :x) - if f in [:abs, :abs2] - Δ = :(real($Δ)) - else - dx = :(conj($dx)) - end - @eval begin - @adjoint $M.$f(x::Number) = $M.$f(x), - Δ -> ($Δ * $dx,) - end -end - -for (M, f, arity) in DiffRules.diffrules() - arity == 2 || continue - f == :^ && continue - da, db = DiffRules.diffrule(M, f, :a, :b) - @eval begin - @adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b), - Δ -> (Δ * conj($da), Δ * conj($db)) - end -end +# for (M, f, arity) in DiffRules.diffrules() +# arity == 1 || continue +# Δ = :Δ +# dx = DiffRules.diffrule(M, f, :x) +# if f in [:abs, :abs2] +# Δ = :(real($Δ)) +# else +# dx = :(conj($dx)) +# end +# @eval begin +# @adjoint $M.$f(x::Number) = $M.$f(x), +# Δ -> ($Δ * $dx,) +# end +# end +# +# for (M, f, arity) in DiffRules.diffrules() +# arity == 2 || continue +# f == :^ && continue +# da, db = DiffRules.diffrule(M, f, :a, :b) +# @eval begin +# @adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b), +# Δ -> (Δ * conj($da), Δ * conj($db)) +# end +# end @adjoint Base.:^(x::Number, p::Number) = x^p, Δ -> (Δ * conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x)))) @@ -71,28 +71,28 @@ end @adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) @adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) -DiffRules._abs_deriv(x::Complex) = x/abs(x) - - # adjoint for Fastmath operations -for (f, fastf) in fast_op - if DiffRules.hasdiffrule(:Base, f, 1) - dx = DiffRules.diffrule(:Base, f, :x) - Δ = :Δ - if f in [:abs, :abs2] - Δ = :(real($Δ)) - else - dx = :(conj($dx)) - end - @eval begin - @adjoint Base.FastMath.$fastf(x::Number) = - Base.FastMath.$fastf(x), Δ -> ($Δ * make_fastmath($dx),) - end - elseif DiffRules.hasdiffrule(:Base, f, 2) - dx, dy = DiffRules.diffrule(:Base, f, :x, :y) - @eval begin - @adjoint Base.FastMath.$fastf(x::Number, y::Number) = - Base.FastMath.$fastf(x, y), - Δ -> (Δ * make_fastmath(conj($dx)), Δ * make_fastmath(conj($dy))) - end - end -end +# DiffRules._abs_deriv(x::Complex) = x/abs(x) + +# # adjoint for Fastmath operations +# for (f, fastf) in fast_op +# if DiffRules.hasdiffrule(:Base, f, 1) +# dx = DiffRules.diffrule(:Base, f, :x) +# Δ = :Δ +# if f in [:abs, :abs2] +# Δ = :(real($Δ)) +# else +# dx = :(conj($dx)) +# end +# @eval begin +# @adjoint Base.FastMath.$fastf(x::Number) = +# Base.FastMath.$fastf(x), Δ -> ($Δ * make_fastmath($dx),) +# end +# elseif DiffRules.hasdiffrule(:Base, f, 2) +# dx, dy = DiffRules.diffrule(:Base, f, :x, :y) +# @eval begin +# @adjoint Base.FastMath.$fastf(x::Number, y::Number) = +# Base.FastMath.$fastf(x, y), +# Δ -> (Δ * make_fastmath(conj($dx)), Δ * make_fastmath(conj($dy))) +# end +# end +# end From c9f2dd03aee9a3bcfcd705629d7e8e85896fa02f Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 14 Apr 2020 12:36:53 +0100 Subject: [PATCH 07/35] wrap CR types cleanup after merging in Option 2 code add ChainRules, rm DiffRules cleanup after merging in Option 2 code --- Project.toml | 2 +- src/Zygote.jl | 4 +-- src/compiler/chainrules.jl | 19 +++++++++++++ src/compiler/interface2.jl | 57 ++------------------------------------ src/lib/number.jl | 3 +- 5 files changed, 26 insertions(+), 59 deletions(-) create mode 100644 src/compiler/chainrules.jl diff --git a/Project.toml b/Project.toml index 6fce8246f..0a40ce70d 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2" FillArrays = "0.8" -ChainRules = "0.3.0" +ChainRules = "0.5.0" ForwardDiff = "0" IRTools = "0.3" MacroTools = "0.5" diff --git a/src/Zygote.jl b/src/Zygote.jl index d24ba2fc1..b10f5fbd1 100644 --- a/src/Zygote.jl +++ b/src/Zygote.jl @@ -3,11 +3,10 @@ module Zygote using LinearAlgebra, Statistics using LinearAlgebra: copytri!, AbstractTriangular using ArrayLayouts: MemoryLayout, AbstractColumnMajor -using ChainRules import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty -using ChainRules: ChainRules +using ChainRules: ChainRules, rrule, unthunk using IRTools using MacroTools, Requires using MacroTools: @forward @@ -19,6 +18,7 @@ include("tools/buffer.jl") include("compiler/reverse.jl") include("compiler/emit.jl") +include("compiler/chainrules.jl") include("compiler/interface.jl") include("compiler/show.jl") diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl new file mode 100644 index 000000000..9869cc178 --- /dev/null +++ b/src/compiler/chainrules.jl @@ -0,0 +1,19 @@ +const chainrules_fallback = which(rrule, Tuple{Any}) + +function has_chain_rrule(T) + m = meta(Tuple{typeof(rrule),T.parameters...}) + if m.method === chainrules_fallback + return false, m.code.edges + else + return true, nothing + end +end + +# For now we are just not going to deal with thunks +wrap_chainrules(x) = unthunk(x) +wrap_chainrules(x::Tuple) = map(wrap_chainrules, x) + +function chain_rrule(f, args...) + y, back = rrule(f, args...) + y, dy -> wrap_chainrules(back(dy)) +end diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 445b37b22..6219091bb 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -3,29 +3,18 @@ using IRTools.Inner: argnames!, update! ignore_sig(T) = all(T -> T <: Type, T.parameters) -const chainrules_fallback = which(rrule, Tuple{Any}) - -function has_chainrule(T) - m = meta(Tuple{typeof(rrule),T.parameters...}) - if m.method === chainrules_fallback - return false, m.code.edges - else - return true, nothing - end -end - @generated function _pullback(ctx::AContext, f, args...) T = Tuple{f,args...} ignore(T) && return :(f(args...), Pullback{$T}(())) - hascr, cr_edges = has_chainrule(T) - hascr && return :(rrule(f, args...)) + hascr, cr_edges = has_chain_rrule(T) + hascr && return :(chain_rrule(f, args...)) g = try _lookup_grad(T) catch e e end !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) forw = slots!(pis!(inlineable!(forw))) - append!(meta.code.edges, cr_edges) + append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared return update!(meta.code, forw) end @@ -44,43 +33,3 @@ end back = slots!(inlineable!(back)) return update!(meta.code, back) end - - - - -#==""" - chainrules_blacklist(f, args...,) - -This is used to disable the use of ChainRule's definitions -for particular functions/methods. - -It is not required if a Zygote rule has already been defined directly. -"""==# -chainrules_blacklist(f, args...) = false - -# ChainRules does higher-order functions badly -# see https://github.com/JuliaDiff/ChainRules.jl/issues/122 -for f in (map, broadcast, mapreduce, mapfoldl, mapfoldr) - @eval chainrules_blacklist(::typeof($f), args...) = true -end -chainrules_blacklist(::typeof(sum), f, x::AbstractArray{<:Real}) = true -# Except for sum(abs2, xs), that is fine -chainrules_blacklist(::typeof(sum), ::typeof(abs2), x::AbstractArray{<:Real}) = false - - -#==""" - _pullback_via_chainrules(pb) - -Converts a ChainRules pullback into a Zygote pullback. -`pb` should be a ChainRules pullback, as returned from the second return value of `rrule` -"""==# -function _pullback_via_chainrules(pb) - function zygote_pullback(Δs...) - ∂s = pb(Δs...) - # TODO: Should not unthunk on the way out of a pullback, but rather on way in since - # that is when we know it is probably going to be used. - ∂s_zy = map(ChainRules.unthunk, ∂s) - @info "Invoking via ChainRules" typeof(pb) typeof(∂s) typeof(∂s_zy) - return ∂s_zy - end -end diff --git a/src/lib/number.jl b/src/lib/number.jl index 1a4069801..4c716add2 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -1,5 +1,4 @@ -using DiffRules, SpecialFunctions, NaNMath -using Base.FastMath: fast_op, make_fastmath +#using Base.FastMath: fast_op, make_fastmath # @nograd isinf, isnan, isfinite, div From 56c8402dbdb24ce4c7c711a826da4deeaeff372a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 20 Apr 2020 16:55:35 +0100 Subject: [PATCH 08/35] wrap `nothing` --- src/compiler/chainrules.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 9869cc178..a9d49a7c0 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -14,6 +14,8 @@ wrap_chainrules(x) = unthunk(x) wrap_chainrules(x::Tuple) = map(wrap_chainrules, x) function chain_rrule(f, args...) - y, back = rrule(f, args...) - y, dy -> wrap_chainrules(back(dy)) + y, By = rrule(f, args...) + back(::Nothing) = nothing + back(dy) = wrap_chainrules(By(dy)) + return y, back end From f64062134224414737624ec7871cbf0d27770232 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 20 Apr 2020 16:56:41 +0100 Subject: [PATCH 09/35] record fastmath regression --- test/chainrules.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/chainrules.jl b/test/chainrules.jl index e804c39b5..fcad939ed 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -47,3 +47,7 @@ end @test cr_inner_demo_rrule_hitcount[] == 1 end end + +@test_broken gradient(2.0) do x + @fastmath x^2.0 +end == (4.0,) From 6b354e5a0d420d8411836a39f3e2e7dbe73056ec Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 21 Apr 2020 19:02:17 +0100 Subject: [PATCH 10/35] conjugate appropriately Implement adjoints for abs and abs2 directly also conjugate on way in convert thing on way in as well as out (WIP for sin''') --- src/compiler/chainrules.jl | 33 +++++++++++++++++++++++++++------ src/lib/number.jl | 5 +++++ 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index a9d49a7c0..8bcb44198 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -10,12 +10,33 @@ function has_chain_rrule(T) end # For now we are just not going to deal with thunks -wrap_chainrules(x) = unthunk(x) -wrap_chainrules(x::Tuple) = map(wrap_chainrules, x) +wrap_chainrules_output(x) = conj(unthunk(x)) +wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} + T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types + # Composite supports map as name preserving, and is fast + xp = map(wrap_chainrules_output, x) + convert(T_outer, xp) +end + +wrap_chainrules_input(x) = conj(x) +wrap_chainrules_input(x::Tuple) = map(wrap_chainrules_input, x) +wrap_chainrules_input(::Nothing) = ChainRules.Zero() +function wrap_chainrules_input(xs::NamedTuple) + xs_comp = ChainRules.Composite{Any}(xs...) + # Composite supports map as name preserving, and is fast + xs_comp_p = map(wrap_chainrules_input, xs_comp) +end + function chain_rrule(f, args...) - y, By = rrule(f, args...) - back(::Nothing) = nothing - back(dy) = wrap_chainrules(By(dy)) - return y, back + #@info "Using ChainRule" f, typeof.(args) + y, back = rrule(f, args...) + + zpullback(dy) = wrap_chainrules_output(back(wrap_chainrules_input(dy))) + # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 + # though it might be worth keeping as a performance optimization (benchmarking pending) + zpullback(::Nothing) = nothing + + y, zpullback end diff --git a/src/lib/number.jl b/src/lib/number.jl index 4c716add2..f2bdd4c24 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -70,6 +70,11 @@ end @adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) @adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) +@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),) +@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),) +@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) + + # DiffRules._abs_deriv(x::Complex) = x/abs(x) # # adjoint for Fastmath operations From b86e09ef9c403ab92843b74258bb744977238269 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 23 Apr 2020 12:58:18 +0100 Subject: [PATCH 11/35] multiple input and multiple output functions with chainrules --- src/compiler/chainrules.jl | 7 +- test/chainrules.jl | 161 ++++++++++++++++++++++++++++--------- 2 files changed, 128 insertions(+), 40 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 8bcb44198..6af91158d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -13,7 +13,7 @@ end wrap_chainrules_output(x) = conj(unthunk(x)) wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} - T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types + T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types # Composite supports map as name preserving, and is fast xp = map(wrap_chainrules_output, x) convert(T_outer, xp) @@ -28,12 +28,15 @@ function wrap_chainrules_input(xs::NamedTuple) xs_comp_p = map(wrap_chainrules_input, xs_comp) end +wrap_chainrules(f, args...) = wrap_chainrules_output(f(wrap_chainrules_input(args)...)) + function chain_rrule(f, args...) #@info "Using ChainRule" f, typeof.(args) y, back = rrule(f, args...) - zpullback(dy) = wrap_chainrules_output(back(wrap_chainrules_input(dy))) + zpullback(dy) = wrap_chainrules(back, dy) + zpullback(dy::Tuple) = wrap_chainrules(back, dy...) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing diff --git a/test/chainrules.jl b/test/chainrules.jl index fcad939ed..e39ca6ff8 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,53 +1,138 @@ using Zygote, Test, ChainRules -const cr_inner_demo_rrule_hitcount = Ref(0) -const cr_inner_demo_pullback_hitcount = Ref(0) -cr_inner_demo(x) = 5x -function ChainRules.rrule(::typeof(cr_inner_demo), x) - cr_inner_demo_rrule_hitcount[] += 1 - function cr_inner_demo_pullback(Δx) - cr_inner_demo_pullback_hitcount[] += 1 - return ChainRules.NO_FIELDS, 5.0*Δx + +@testset "ChainRules Integration" begin + @testset "basic" begin + cr_inner_demo_rrule_hitcount = Ref(0) + cr_inner_demo_pullback_hitcount = Ref(0) + cr_inner_demo(x) = 5x + function ChainRules.rrule(::typeof(cr_inner_demo), x) + cr_inner_demo_rrule_hitcount[] += 1 + function cr_inner_demo_pullback(Δx) + cr_inner_demo_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, 5.0*Δx + end + return cr_inner_demo(x), cr_inner_demo_pullback + end + + function cr_outer_demo(x) + 2 + 10cr_inner_demo(x) + end + + + @testset "gradient inner" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + @test (5.0,) == gradient(cr_inner_demo, 11) + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 1 + end + + @testset "gradient outer" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + @test (50.0,) == gradient(cr_outer_demo, 11) + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 1 + end + + @testset "pullback inner" begin + cr_inner_demo_rrule_hitcount[] = 0 + cr_inner_demo_pullback_hitcount[] = 0 + y, pb = pullback(cr_inner_demo, 11) + @test y == 55 + @test cr_inner_demo_rrule_hitcount[] == 1 + @test cr_inner_demo_pullback_hitcount[] == 0 + @test pb(1)==(5.0,); + @test pb(2)==(10.0,); + @test pb(3)==(15.0,); + @test cr_inner_demo_pullback_hitcount[] == 3 + @test cr_inner_demo_rrule_hitcount[] == 1 + end end - return cr_inner_demo(x), cr_inner_demo_pullback -end -function cr_outer_demo(x) - 2 + 10cr_inner_demo(x) -end + @testset "Multiple output single input" begin + simo_rrule_hitcount = Ref(0) + simo_pullback_hitcount = Ref(0) + simo(x) = (5x, 7x) + function ChainRules.rrule(::typeof(simo), x) + simo_rrule_hitcount[] += 1 + function simo_pullback(Δa, Δb) + simo_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, 5*Δa + 7*Δb + end + return simo(x), simo_pullback + end -@testset "ChainRules Integration" begin - @testset "gradient inner" begin - cr_inner_demo_rrule_hitcount[] = 0 - cr_inner_demo_pullback_hitcount[] = 0 - @test (5.0,) == gradient(cr_inner_demo, 11) - @test cr_inner_demo_rrule_hitcount[] == 1 - @test cr_inner_demo_pullback_hitcount[] == 1 + simo_outer(x) = sum(simo(x)) + + @assert simo_rrule_hitcount[] == 0 + @assert simo_pullback_hitcount[] == 0 + @test (12,) == Zygote.gradient(simo_outer, π) + @test simo_rrule_hitcount[] == 1 + @test simo_pullback_hitcount[] == 1 end - @testset "gradient outer" begin - cr_inner_demo_rrule_hitcount[] = 0 - cr_inner_demo_pullback_hitcount[] = 0 - @test (50.0,) == gradient(cr_outer_demo, 11) - @test cr_inner_demo_rrule_hitcount[] == 1 - @test cr_inner_demo_pullback_hitcount[] == 1 + @testset "multiple input, Single output" begin + miso_rrule_hitcount = Ref(0) + miso_pullback_hitcount = Ref(0) + miso(a, b) = 5a + 7b + function ChainRules.rrule(::typeof(miso), a, b) + miso_rrule_hitcount[] += 1 + function miso_pullback(Δy) + miso_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, 5Δy , 7Δy + end + return miso(a, b), miso_pullback + end + + miso_outer(x) = miso(100x, 10x) + + @assert miso_rrule_hitcount[] == 0 + @assert miso_pullback_hitcount[] == 0 + @test (570,) == Zygote.gradient(miso_outer, π) + @test miso_rrule_hitcount[] == 1 + @test miso_pullback_hitcount[] == 1 end + + @testset "multiple input multiple output" begin + mimo_rrule_hitcount = Ref(0) + mimo_pullback_hitcount = Ref(0) + mimo(a, b) = (5a + 7b, 100a, 10b) + function ChainRules.rrule(::typeof(mimo), a, b) + mimo_rrule_hitcount[] += 1 + function mimo_pullback(Δx, Δy, Δz) + mimo_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, 5Δx + 100Δy , 7Δx + 10Δz + end + return mimo(a, b), mimo_pullback + end + + @assert mimo_rrule_hitcount[] == 0 + @assert mimo_pullback_hitcount[] == 0 + _, pb = Zygote.pullback(mimo, π, 2π) + @test (105, 17) == pb((1, 1, 1)) + @test mimo_rrule_hitcount[] == 1 + @test mimo_pullback_hitcount[] == 1 + + mimo_outer(x) = sum(mimo(x, x)) - @testset "pullback inner" begin - cr_inner_demo_rrule_hitcount[] = 0 - cr_inner_demo_pullback_hitcount[] = 0 - y, pb = pullback(cr_inner_demo, 11) - @test y == 55 - @test cr_inner_demo_rrule_hitcount[] == 1 - @test cr_inner_demo_pullback_hitcount[] == 0 - @test pb(1)==(5.0,); - @test pb(2)==(10.0,); - @test pb(3)==(15.0,); - @test cr_inner_demo_pullback_hitcount[] == 3 - @test cr_inner_demo_rrule_hitcount[] == 1 + mimo_rrule_hitcount[] = 0 + mimo_pullback_hitcount[] = 0 + @test (122,) == gradient(mimo_outer, π) + @test mimo_rrule_hitcount[] == 1 + @test mimo_pullback_hitcount[] == 1 end end @test_broken gradient(2.0) do x @fastmath x^2.0 end == (4.0,) + + + + +mimo(a, b) = (5a + 7b, 100a, 10b) +_, pb = Zygote.pullback(mimo, 10, 100) + +pb((1, 1, 1)) From 02a302cfe694af304f9cf2f90d609dafefce9cad Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 23 Apr 2020 14:49:04 +0100 Subject: [PATCH 12/35] delete extra code added to test by mistake --- test/chainrules.jl | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index e39ca6ff8..4e8145e29 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -94,7 +94,7 @@ using Zygote, Test, ChainRules @test miso_rrule_hitcount[] == 1 @test miso_pullback_hitcount[] == 1 end - + @testset "multiple input multiple output" begin mimo_rrule_hitcount = Ref(0) mimo_pullback_hitcount = Ref(0) @@ -128,11 +128,3 @@ end @test_broken gradient(2.0) do x @fastmath x^2.0 end == (4.0,) - - - - -mimo(a, b) = (5a + 7b, 100a, 10b) -_, pb = Zygote.pullback(mimo, 10, 100) - -pb((1, 1, 1)) From 3e1e8e09ff4bb0a9e152c07a917a6373cb7519a1 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 24 Apr 2020 10:21:55 +0100 Subject: [PATCH 13/35] Integration test that we have identity working right --- test/chainrules.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/chainrules.jl b/test/chainrules.jl index 4e8145e29..8254a8791 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -123,6 +123,23 @@ using Zygote, Test, ChainRules @test mimo_rrule_hitcount[] == 1 @test mimo_pullback_hitcount[] == 1 end + + @testset "nested AD hitting identity(::Tuple) pullback" begin + # This is is a particularly fiddly case. + # the adjoint of `tuple` is `identity` + # and `identity(::Tuple)`s pullback has multiple inputs + # (since the primal had multiple outputs) + + function g(y) + f(x) = tuple(x, 2x, 3x) + a1, pb1 = Zygote.pullback(f, π) + + pb1((y,0,0)) + end + + a2, pb2 = Zygote.pullback(g, 1) + @test pb2(1) == (1,) + end end @test_broken gradient(2.0) do x From 5a9ec6a48f46197ba1307559726bb4321eedec2f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 24 Apr 2020 18:40:21 +0100 Subject: [PATCH 14/35] Nested AD working --- src/compiler/chainrules.jl | 28 ++++++++++++++++++++-------- test/chainrules.jl | 17 +++++++++++++---- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 6af91158d..c99e2995b 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -9,30 +9,31 @@ function has_chain_rrule(T) end end -# For now we are just not going to deal with thunks -wrap_chainrules_output(x) = conj(unthunk(x)) + +wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types - # Composite supports map as name preserving, and is fast xp = map(wrap_chainrules_output, x) convert(T_outer, xp) end wrap_chainrules_input(x) = conj(x) -wrap_chainrules_input(x::Tuple) = map(wrap_chainrules_input, x) wrap_chainrules_input(::Nothing) = ChainRules.Zero() -function wrap_chainrules_input(xs::NamedTuple) - xs_comp = ChainRules.Composite{Any}(xs...) - # Composite supports map as name preserving, and is fast - xs_comp_p = map(wrap_chainrules_input, xs_comp) +function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) + xp = map(wrap_chainrules_input, xs) + ChainRules.Composite{Any, typeof(xp)}(xp) end wrap_chainrules(f, args...) = wrap_chainrules_output(f(wrap_chainrules_input(args)...)) + function chain_rrule(f, args...) #@info "Using ChainRule" f, typeof.(args) +# Core.println("Using ChainRule ", f," ", typeof.(args)) + y, back = rrule(f, args...) zpullback(dy) = wrap_chainrules(back, dy) @@ -41,5 +42,16 @@ function chain_rrule(f, args...) # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing + #== + function _zpullback(dy) + Core.print("Using ChainRule f=", f," args=", typeof.(args), "\n\tdy=", typeof(dy)) + dx = zpullback(dy) + Core.println(" dx=", typeof.(dx)) + return dx + end + ==# y, zpullback end + +# Required for nested AD +@adjoint ChainRules.Composite{Any, T}(x::T) where T = ChainRules.Composite{Any, T}(x), x->(x,) diff --git a/test/chainrules.jl b/test/chainrules.jl index 8254a8791..0afa6c953 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -130,15 +130,24 @@ using Zygote, Test, ChainRules # and `identity(::Tuple)`s pullback has multiple inputs # (since the primal had multiple outputs) + f(x) = tuple(x, 2x, 3x) + function g(y) - f(x) = tuple(x, 2x, 3x) a1, pb1 = Zygote.pullback(f, π) - pb1((y,0,0)) end - a2, pb2 = Zygote.pullback(g, 1) - @test pb2(1) == (1,) + @test (1,) == g(1) + + function h(n) + a2, pb2 = Zygote.pullback(g, 1) + pb2(n) + end + + @test (1,) == h(1) + + a3, pb3 = Zygote.pullback(h, 1) + @test ((1,),) == pb3(1) end end From d9db59b53467ae9b5ed7b57f06c86aa387d79c8d Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 28 Apr 2020 17:31:14 +0100 Subject: [PATCH 15/35] delete rules that are in chainrules --- src/lib/array.jl | 14 ++------------ src/lib/number.jl | 49 ++--------------------------------------------- 2 files changed, 4 insertions(+), 59 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index e9c5daaa5..f2634e9ad 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -347,8 +347,6 @@ end @adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),) @adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),) -@adjoint dot(x::AbstractArray, y::AbstractArray) = dot(x, y), Δ->(Δ .* y, Δ .* x) - function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) mat1_rsh = reshape(mat1,(1,m1,1,n1)) @@ -361,18 +359,8 @@ end @adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b) -@adjoint function Diagonal(d::AbstractVector) - back(Δ::NamedTuple) = (Δ.diag,) - back(Δ::AbstractMatrix) = (diag(Δ),) - return Diagonal(d), back -end - @adjoint diag(A::AbstractMatrix) = diag(A), Δ->(Diagonal(Δ),) -@adjoint det(xs::Union{Number, AbstractMatrix}) = det(xs), Δ -> (Δ * det(xs) * inv(xs)',) - -@adjoint logdet(xs::Union{Number, AbstractMatrix}) = logdet(xs), Δ -> (Δ * inv(xs)',) - @adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',) @adjoint function inv(A::Union{Number, AbstractMatrix}) @@ -737,6 +725,8 @@ end end end +# ChainRules has this also but does not use FillArrays, so we have out own defination +# for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46 Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) # x is a squre matrix checked by tr, # so we could just use Eye(size(x, 1)) diff --git a/src/lib/number.jl b/src/lib/number.jl index f2bdd4c24..0425281e5 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -1,36 +1,3 @@ -#using Base.FastMath: fast_op, make_fastmath - -# @nograd isinf, isnan, isfinite, div - -# TODO use CSE here - -# for (M, f, arity) in DiffRules.diffrules() -# arity == 1 || continue -# Δ = :Δ -# dx = DiffRules.diffrule(M, f, :x) -# if f in [:abs, :abs2] -# Δ = :(real($Δ)) -# else -# dx = :(conj($dx)) -# end -# @eval begin -# @adjoint $M.$f(x::Number) = $M.$f(x), -# Δ -> ($Δ * $dx,) -# end -# end -# -# for (M, f, arity) in DiffRules.diffrules() -# arity == 2 || continue -# f == :^ && continue -# da, db = DiffRules.diffrule(M, f, :a, :b) -# @eval begin -# @adjoint $M.$f(a::Number, b::Number) = $M.$f(a, b), -# Δ -> (Δ * conj($da), Δ * conj($db)) -# end -# end - -@adjoint Base.:^(x::Number, p::Number) = x^p, - Δ -> (Δ * conj(p * x^(p-1)), Δ * conj(x^p * log(complex(x)))) @adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} = Base.literal_pow(^,x,Val(p)), Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing) @@ -44,20 +11,6 @@ end @adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs) -@adjoint Base.muladd(x::Number, y::Number, z::Number) = - Base.muladd(x, y, z), ō -> (y'ō, x'ō, ō) - -@adjoint Base.fma(x::Number, y::Number, z::Number) = - Base.fma(x, y, z), ō -> (y'ō, x'ō, ō) - -@adjoint function sincos(x) - s, c = sincos(x) - (s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,) -end - -@adjoint acosh(x::Complex) = - acosh(x), Δ -> (Δ * conj(inv(sqrt(x - 1) * sqrt(x + 1))),) - @adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) @nograd floor, ceil, trunc, round, hash @@ -70,6 +23,8 @@ end @adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),) @adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,) +# we intentionally define these here rather than falling back on ChainRules.jl +# because ChainRules doesn't really handle nonanalytic complex functions @adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),) @adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),) @adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) From cb9ec4ce1e765e6a99b592d383cd88ecc88f938c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 28 Apr 2020 17:32:15 +0100 Subject: [PATCH 16/35] delete diag rule also --- src/lib/array.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index f2634e9ad..7efc6f66f 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -359,8 +359,6 @@ end @adjoint kron(a::AbstractMatrix, b::AbstractMatrix) = pullback(_kron, a, b) -@adjoint diag(A::AbstractMatrix) = diag(A), Δ->(Diagonal(Δ),) - @adjoint logabsdet(xs::AbstractMatrix) = logabsdet(xs), Δ -> (Δ[1] * inv(xs)',) @adjoint function inv(A::Union{Number, AbstractMatrix}) From aa2914430f37fd609c5d5dcb5cf3abdae4a2e488 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 29 Apr 2020 11:26:40 +0100 Subject: [PATCH 17/35] bump verion --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0a40ce70d..5a5ef45f9 100644 --- a/Project.toml +++ b/Project.toml @@ -23,7 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2" FillArrays = "0.8" -ChainRules = "0.5.0" +ChainRules = "0.5.1" ForwardDiff = "0" IRTools = "0.3" MacroTools = "0.5" From 1fd415ab00a981475925beed5b837c01bf1e4619 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 29 Apr 2020 13:33:27 +0100 Subject: [PATCH 18/35] add in docs about prefer to use ChainRules Update docs/src/adjoints.md Co-Authored-By: Nick Robinson fix typo in docs delete debug printing --- docs/src/adjoints.md | 10 ++++++++++ src/compiler/chainrules.jl | 11 ----------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md index 61abed9a3..3fe760b28 100644 --- a/docs/src/adjoints.md +++ b/docs/src/adjoints.md @@ -1,5 +1,15 @@ # Custom Adjoints +!!! note "Prefer to use ChainRules to define custom adjoints" + Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities. + It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote. + These sensitivities can be added in your own package, or for Base functions they can be added to ChainRules.jl. + + This documentation exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote. + Defining adjoints this way does not make them accessable to other AD systems, but does let you do things that directly depend on how Zygote works. + It allows for specific definations of adjoints that are only defined for Zgyote (which might work differently to more generic definations defined for all AD) + + The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are. ## Pullbacks diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index c99e2995b..f23e076f5 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -31,9 +31,6 @@ wrap_chainrules(f, args...) = wrap_chainrules_output(f(wrap_chainrules_input(arg function chain_rrule(f, args...) - #@info "Using ChainRule" f, typeof.(args) -# Core.println("Using ChainRule ", f," ", typeof.(args)) - y, back = rrule(f, args...) zpullback(dy) = wrap_chainrules(back, dy) @@ -42,14 +39,6 @@ function chain_rrule(f, args...) # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing - #== - function _zpullback(dy) - Core.print("Using ChainRule f=", f," args=", typeof.(args), "\n\tdy=", typeof(dy)) - dx = zpullback(dy) - Core.println(" dx=", typeof.(dx)) - return dx - end - ==# y, zpullback end From d9a53a0bb24a2cddab9dbb14ca43c9ccf803b79b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 30 Apr 2020 14:49:23 +0100 Subject: [PATCH 19/35] delete commented out fastmath code linkl top chainrules issue about fastmath --- src/lib/number.jl | 27 --------------------------- test/chainrules.jl | 2 ++ 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/src/lib/number.jl b/src/lib/number.jl index 0425281e5..bb60e46c2 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -28,30 +28,3 @@ end @adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),) @adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),) @adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),) - - -# DiffRules._abs_deriv(x::Complex) = x/abs(x) - -# # adjoint for Fastmath operations -# for (f, fastf) in fast_op -# if DiffRules.hasdiffrule(:Base, f, 1) -# dx = DiffRules.diffrule(:Base, f, :x) -# Δ = :Δ -# if f in [:abs, :abs2] -# Δ = :(real($Δ)) -# else -# dx = :(conj($dx)) -# end -# @eval begin -# @adjoint Base.FastMath.$fastf(x::Number) = -# Base.FastMath.$fastf(x), Δ -> ($Δ * make_fastmath($dx),) -# end -# elseif DiffRules.hasdiffrule(:Base, f, 2) -# dx, dy = DiffRules.diffrule(:Base, f, :x, :y) -# @eval begin -# @adjoint Base.FastMath.$fastf(x::Number, y::Number) = -# Base.FastMath.$fastf(x, y), -# Δ -> (Δ * make_fastmath(conj($dx)), Δ * make_fastmath(conj($dy))) -# end -# end -# end diff --git a/test/chainrules.jl b/test/chainrules.jl index 0afa6c953..a01ae7234 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -151,6 +151,8 @@ using Zygote, Test, ChainRules end end +# ChainRules doesn't have support for FastMath yet, so this fails +# https://github.com/JuliaDiff/ChainRules.jl/issues/174 @test_broken gradient(2.0) do x @fastmath x^2.0 end == (4.0,) From 5cf8b1066dbefd002a82931c87981eecc34651a6 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 30 Apr 2020 16:55:44 +0100 Subject: [PATCH 20/35] comment all about the chainrules interface --- src/compiler/chainrules.jl | 43 +++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index f23e076f5..9fd9e36d3 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -1,5 +1,14 @@ const chainrules_fallback = which(rrule, Tuple{Any}) +""" + has_chain_rrule(T) + +For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it. +Excluding the generic fallback. +The first return value is a Bool is whether or not the `rrule` exists. +If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function, +such that if a suitable rule is defined later, the generated function will recompile. +""" function has_chain_rrule(T) m = meta(Tuple{typeof(rrule),T.parameters...}) if m.method === chainrules_fallback @@ -9,7 +18,12 @@ function has_chain_rrule(T) end end +""" + wrap_chainrules_output(x) +Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally +(including conjugating complex gradients). +""" wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing @@ -19,6 +33,13 @@ function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} convert(T_outer, xp) end + +""" + wrap_chainrules_input(x) + +Convert `x` from the format Zygote uses internally (including conjugated complex gradients) +to differentials types ChainRules uses. +""" wrap_chainrules_input(x) = conj(x) wrap_chainrules_input(::Nothing) = ChainRules.Zero() function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) @@ -26,15 +47,31 @@ function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) ChainRules.Composite{Any, typeof(xp)}(xp) end -wrap_chainrules(f, args...) = wrap_chainrules_output(f(wrap_chainrules_input(args)...)) +""" + wrap_chainrules_pullback(f, args...) + +Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`), +and the outputs. +""" +function wrap_chainrules_pullback(pb, args...) + returun wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) +end +""" + chain_rrule(f, args...) +Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`. +The pullback is appropriately wrapped up to follow Zygote conventions. +""" function chain_rrule(f, args...) y, back = rrule(f, args...) - zpullback(dy) = wrap_chainrules(back, dy) - zpullback(dy::Tuple) = wrap_chainrules(back, dy...) + # Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. + # TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 + zpullback(dy) = wrap_chainrules_pullback(back, dy) + zpullback(dy::Tuple) = wrap_chainrules_pullback(back, dy...) + # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing From 53000503dc7a829d59a54fdd80fbac07fec5ef01 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 30 Apr 2020 16:58:38 +0100 Subject: [PATCH 21/35] Update the Manifest.toml fix typo Pin IRTools to 0.3.2 because https://github.com/MikeInnes/IRTools.jl/issues/58 --- Project.toml | 4 ++-- src/compiler/chainrules.jl | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 5a5ef45f9..3b1778758 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,10 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2" -FillArrays = "0.8" ChainRules = "0.5.1" +FillArrays = "0.8" ForwardDiff = "0" -IRTools = "0.3" +IRTools = "=0.3.1" MacroTools = "0.5" NNlib = "0.6.5" Requires = "0.5, 1.0" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 9fd9e36d3..f0bc83c7b 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -12,6 +12,7 @@ such that if a suitable rule is defined later, the generated function will recom function has_chain_rrule(T) m = meta(Tuple{typeof(rrule),T.parameters...}) if m.method === chainrules_fallback + @assert m.code.edges !== nothing return false, m.code.edges else return true, nothing @@ -54,7 +55,7 @@ Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`), and the outputs. """ function wrap_chainrules_pullback(pb, args...) - returun wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) + return wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) end @@ -76,7 +77,7 @@ function chain_rrule(f, args...) # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing - y, zpullback + return y, zpullback end # Required for nested AD From 329b53063b3daf45a594d824a36268b6f100f59a Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 4 May 2020 11:47:50 +0100 Subject: [PATCH 22/35] Update docs/src/adjoints.md Co-authored-by: Pietro Vertechi Update docs/src/adjoints.md --- docs/src/adjoints.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md index 3fe760b28..82170164c 100644 --- a/docs/src/adjoints.md +++ b/docs/src/adjoints.md @@ -6,8 +6,8 @@ These sensitivities can be added in your own package, or for Base functions they can be added to ChainRules.jl. This documentation exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote. - Defining adjoints this way does not make them accessable to other AD systems, but does let you do things that directly depend on how Zygote works. - It allows for specific definations of adjoints that are only defined for Zgyote (which might work differently to more generic definations defined for all AD) + Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works. + It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD). The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are. From 8cd76f4b84bb0b3ce186d8259395e89c614820fa Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 4 May 2020 17:22:24 +0100 Subject: [PATCH 23/35] Update src/lib/array.jl Co-authored-by: AzamatB --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 7efc6f66f..056043323 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -723,7 +723,7 @@ end end end -# ChainRules has this also but does not use FillArrays, so we have out own defination +# ChainRules has this also but does not use FillArrays, so we have our own definition # for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46 Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) # x is a squre matrix checked by tr, From 497846a10ff5060b7a3a6227b869ff20ad33eee3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Mon, 4 May 2020 19:01:29 +0100 Subject: [PATCH 24/35] WIP: support kwargs make kwargs work Update src/compiler/chainrules.jl Update src/compiler/chainrules.jl and chainrules kwarg tests --- src/compiler/chainrules.jl | 46 +++++++++++++++++++++++++++++++++----- test/chainrules.jl | 26 +++++++++++++++++++++ 2 files changed, 66 insertions(+), 6 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index f0bc83c7b..d04017c9d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -11,14 +11,37 @@ such that if a suitable rule is defined later, the generated function will recom """ function has_chain_rrule(T) m = meta(Tuple{typeof(rrule),T.parameters...}) - if m.method === chainrules_fallback - @assert m.code.edges !== nothing - return false, m.code.edges - else + if m.method !== chainrules_fallback + # found a rrule, no need to add any edges return true, nothing end + + # Could be a kwarg function, handle that case + if is_kwfunc(T.parameters...) + # Need to check for rrule for function not the kwfunction. + base_T = Tuple{T.parameters[3:end]...} + return has_chain_rrule(base_T) + end + + # did not find anything, will have to attach edges so it recompiles if one is added + @assert m.code.edges !== nothing + return false, m.code.edges end +""" + is_kwfunc(sigt...) + +Determines if `sigt` is the type signature of a kwfunction. +Each element of `sigt` should be a type. +Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type, +or the first argument is the base function type and it is not a kwfunction. +the remaining types in `sigt` are the types of the argument. + +""" +is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k==Core.kwftype(f) +is_kwfunc(::Vararg{Any}) = false + + """ wrap_chainrules_output(x) @@ -66,7 +89,7 @@ Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainR The pullback is appropriately wrapped up to follow Zygote conventions. """ function chain_rrule(f, args...) - y, back = rrule(f, args...) + local back # Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. # TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 @@ -77,7 +100,18 @@ function chain_rrule(f, args...) # though it might be worth keeping as a performance optimization (benchmarking pending) zpullback(::Nothing) = nothing - return y, zpullback + if is_kwfunc(typeof(f), typeof.(args)...) + kwargs = args[1] + base_f = args[2] + pos_args = args[3:end] + y, back = rrule(base_f, pos_args...; kwargs...) + + kw_zpullback(dy) = (nothing, nothing, zpullback(dy)...) # first two nothings are for kwfunc and kwargs + return y, kw_zpullback + else + y, back = rrule(f, args...) + return y, zpullback + end end # Required for nested AD diff --git a/test/chainrules.jl b/test/chainrules.jl index a01ae7234..720080825 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -149,6 +149,32 @@ using Zygote, Test, ChainRules a3, pb3 = Zygote.pullback(h, 1) @test ((1,),) == pb3(1) end + + + @testset "kwargs" begin + kwfoo_rrule_hitcount = Ref(0) + kwfoo_pullback_hitcount = Ref(0) + kwfoo(x; k=10) = x + k + function ChainRules.rrule(::typeof(kwfoo), x; k=10) + kwfoo_rrule_hitcount[] += 1 + function kwfoo_pullback(Δy) + kwfoo_pullback_hitcount[] += 1 + return ChainRules.NO_FIELDS, Δy + end + return kwfoo(x; k=k), kwfoo_pullback + end + + kwfoo_outer_unused(x) = kwfoo(x) + kwfoo_outer_used(x) = kwfoo(x; k=-15) + + @testset "$outer" for outer in (kwfoo_outer_used, kwfoo_outer_unused) + kwfoo_rrule_hitcount[] = 0 + kwfoo_pullback_hitcount[] = 0 + @test (1,) == Zygote.gradient(outer, π) + @test kwfoo_rrule_hitcount[] == 1 + @test kwfoo_pullback_hitcount[] == 1 + end + end end # ChainRules doesn't have support for FastMath yet, so this fails From 2ff6cc3ba6d29443e16d4d30a87399abadd619ed Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 5 May 2020 15:20:04 +0100 Subject: [PATCH 25/35] Fix type inference --- src/compiler/chainrules.jl | 39 ++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index d04017c9d..544bbb008 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -38,8 +38,10 @@ or the first argument is the base function type and it is not a kwfunction. the remaining types in `sigt` are the types of the argument. """ -is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k==Core.kwftype(f) is_kwfunc(::Vararg{Any}) = false +# Needs `@pure` because else will not run during type inference. +# This is pure enough, the only generic function it calls is in `Core` overloading `Core.kwftype` will no doubt break many other things also +Base.@pure is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) """ @@ -81,6 +83,22 @@ function wrap_chainrules_pullback(pb, args...) return wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) end +""" + ZBack{F}(back) <: Function + +Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions. +(A functor here is used rather than a closure to avoid boxing issues); +""" +struct ZBack{F} <: Function + back::F +end +(s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy) +# Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. +# TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 +(s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...) +# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 +# though it might be worth keeping as a performance optimization (benchmarking pending) +(s::ZBack)(::Nothing) = nothing """ chain_rrule(f, args...) @@ -89,28 +107,17 @@ Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainR The pullback is appropriately wrapped up to follow Zygote conventions. """ function chain_rrule(f, args...) - local back - - # Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. - # TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 - zpullback(dy) = wrap_chainrules_pullback(back, dy) - zpullback(dy::Tuple) = wrap_chainrules_pullback(back, dy...) - - # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 - # though it might be worth keeping as a performance optimization (benchmarking pending) - zpullback(::Nothing) = nothing - - if is_kwfunc(typeof(f), typeof.(args)...) + if is_kwfunc(typeof(f), map(typeof, args)...) kwargs = args[1] base_f = args[2] pos_args = args[3:end] - y, back = rrule(base_f, pos_args...; kwargs...) + y, base_f_back = rrule(base_f, pos_args...; kwargs...) - kw_zpullback(dy) = (nothing, nothing, zpullback(dy)...) # first two nothings are for kwfunc and kwargs + kw_zpullback(dy) = (nothing, nothing, ZBack(base_f_back)(dy)...) # first two nothings are for kwfunc and kwargs return y, kw_zpullback else y, back = rrule(f, args...) - return y, zpullback + return y, ZBack(back) end end From 513b787b427fade319890dea6ece072ab144c5db Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 5 May 2020 16:20:40 +0100 Subject: [PATCH 26/35] Fix nexting --- src/compiler/chainrules.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 544bbb008..841e2ea70 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -107,7 +107,9 @@ Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainR The pullback is appropriately wrapped up to follow Zygote conventions. """ function chain_rrule(f, args...) - if is_kwfunc(typeof(f), map(typeof, args)...) + # Note we avoid using `map(typeof, args)...` in the condition as it complicates nested AD + # so we just check relevent ones by hand + if length(args) >= 2 && is_kwfunc(typeof(f), typeof(args[1]), typeof(args[2])) kwargs = args[1] base_f = args[2] pos_args = args[3:end] From 29e0244728f12e756111c61f8ecbacdb7bd9c975 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 5 May 2020 21:09:44 +0100 Subject: [PATCH 27/35] Update src/compiler/chainrules.jl Co-authored-by: Carlo Lucibello --- src/compiler/chainrules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 841e2ea70..2a2572f83 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -5,7 +5,7 @@ const chainrules_fallback = which(rrule, Tuple{Any}) For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it. Excluding the generic fallback. -The first return value is a Bool is whether or not the `rrule` exists. +The first return value is `true` if the `rrule` exists, `false` otherwise. If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function, such that if a suitable rule is defined later, the generated function will recompile. """ From 325d3047a247bf2f06026bf471b8dece3fe4fdc7 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Tue, 5 May 2020 21:12:07 +0100 Subject: [PATCH 28/35] Update src/compiler/chainrules.jl --- src/compiler/chainrules.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 2a2572f83..593468b92 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -40,7 +40,8 @@ the remaining types in `sigt` are the types of the argument. """ is_kwfunc(::Vararg{Any}) = false # Needs `@pure` because else will not run during type inference. -# This is pure enough, the only generic function it calls is in `Core` overloading `Core.kwftype` will no doubt break many other things also +# This is pure enough, the only generic function it calls is in `Core` +# overloading `Core.kwftype` will no doubt break many other things also Base.@pure is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) From 1432563be51c6676af4ce4504171ebddd95dbffe Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 6 May 2020 12:02:10 +0100 Subject: [PATCH 29/35] pre1.3 do not worry about edges --- src/compiler/chainrules.jl | 9 +++++++-- src/compiler/interface2.jl | 4 +++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 593468b92..1cad9b112 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -24,8 +24,13 @@ function has_chain_rrule(T) end # did not find anything, will have to attach edges so it recompiles if one is added - @assert m.code.edges !== nothing - return false, m.code.edges + @static if VERSION >= v"1.3" + @assert m.code.edges !== nothing + return false, m.code.edges + else + # pre-julia 1.3 there are no edges + return false, tuple() + end end """ diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 6219091bb..f55db66c3 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -14,7 +14,9 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters) argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) forw = slots!(pis!(inlineable!(forw))) - append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared + @static if VERSION >= v"1.3" # no edges pre-1.3 + append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared + end return update!(meta.code, forw) end From e93eca1e371241f1cb1d05a29044cfbc5c871201 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 6 May 2020 19:02:48 +0100 Subject: [PATCH 30/35] Fix constant folding Decide if is kwfunc at compile time. --- src/compiler/chainrules.jl | 70 ++++++++++++++++++++------------------ src/compiler/interface2.jl | 10 ++++-- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 1cad9b112..b06b3d45d 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -16,13 +16,6 @@ function has_chain_rrule(T) return true, nothing end - # Could be a kwarg function, handle that case - if is_kwfunc(T.parameters...) - # Need to check for rrule for function not the kwfunction. - base_T = Tuple{T.parameters[3:end]...} - return has_chain_rrule(base_T) - end - # did not find anything, will have to attach edges so it recompiles if one is added @static if VERSION >= v"1.3" @assert m.code.edges !== nothing @@ -45,7 +38,7 @@ the remaining types in `sigt` are the types of the argument. """ is_kwfunc(::Vararg{Any}) = false # Needs `@pure` because else will not run during type inference. -# This is pure enough, the only generic function it calls is in `Core` +# This is pure enough, the only generic function it calls is in `Core` # overloading `Core.kwftype` will no doubt break many other things also Base.@pure is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) @@ -56,10 +49,10 @@ Base.@pure is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally (including conjugating complex gradients). """ -wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks -wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) -wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing -function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} +@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks +@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing +@inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types xp = map(wrap_chainrules_output, x) convert(T_outer, xp) @@ -72,9 +65,9 @@ end Convert `x` from the format Zygote uses internally (including conjugated complex gradients) to differentials types ChainRules uses. """ -wrap_chainrules_input(x) = conj(x) -wrap_chainrules_input(::Nothing) = ChainRules.Zero() -function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) +@inline wrap_chainrules_input(x) = conj(x) +@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero() +@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple}) xp = map(wrap_chainrules_input, xs) ChainRules.Composite{Any, typeof(xp)}(xp) end @@ -85,10 +78,18 @@ end Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`), and the outputs. """ -function wrap_chainrules_pullback(pb, args...) +@inline function wrap_chainrules_pullback(pb, args...) return wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) end +# Note we hand-expess the single arg version of this to remove splatting +# because splatting breaks constant folding +# This can be removed after https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 +@inline function wrap_chainrules_pullback(pb, a) + return wrap_chainrules_output(pb(wrap_chainrules_input(a))) +end + + """ ZBack{F}(back) <: Function @@ -98,13 +99,13 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven struct ZBack{F} <: Function back::F end -(s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy) +@inline (s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy) # Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. # TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 -(s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...) +@inline (s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) -(s::ZBack)(::Nothing) = nothing +@inline (s::ZBack)(::Nothing) = nothing """ chain_rrule(f, args...) @@ -112,21 +113,22 @@ end Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`. The pullback is appropriately wrapped up to follow Zygote conventions. """ -function chain_rrule(f, args...) - # Note we avoid using `map(typeof, args)...` in the condition as it complicates nested AD - # so we just check relevent ones by hand - if length(args) >= 2 && is_kwfunc(typeof(f), typeof(args[1]), typeof(args[2])) - kwargs = args[1] - base_f = args[2] - pos_args = args[3:end] - y, base_f_back = rrule(base_f, pos_args...; kwargs...) - - kw_zpullback(dy) = (nothing, nothing, ZBack(base_f_back)(dy)...) # first two nothings are for kwfunc and kwargs - return y, kw_zpullback - else - y, back = rrule(f, args...) - return y, ZBack(back) - end +@inline function chain_rrule(f, args...) + y, back = rrule(f, args...) + return y, ZBack(back) +end + + +""" + chain_rrule_kw(kwf, kwargs, f, args...) + +As per [`chain_rrule`](@ref) but with support for kwargs. +`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments. +""" +@inline function chain_rrule_kw(kwf, kwargs, f, args...) + y, back = rrule(f, args...; kwargs...) + kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs + return y, kw_zpullback end # Required for nested AD diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index f55db66c3..65a07ac8d 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -6,8 +6,14 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters) @generated function _pullback(ctx::AContext, f, args...) T = Tuple{f,args...} ignore(T) && return :(f(args...), Pullback{$T}(())) - hascr, cr_edges = has_chain_rrule(T) - hascr && return :(chain_rrule(f, args...)) + + iskw = is_kwfunc(f, args...) + # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function + base_T = iskw ? Tuple{args[2:end]...} : T + hascr, cr_edges = has_chain_rrule(base_T) + chain_rrule_f = iskw ? :chain_rrule_kw : :chain_rrule + hascr && return :($chain_rrule_f(f, args...)) + g = try _lookup_grad(T) catch e e end !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g From 1b4db9befddf936670f3154999bb6dc0a25db4bc Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 7 May 2020 14:33:20 +0100 Subject: [PATCH 31/35] remove Manifest --- Manifest.toml | 208 -------------------------------------------------- 1 file changed, 208 deletions(-) delete mode 100644 Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index ee7525f32..000000000 --- a/Manifest.toml +++ /dev/null @@ -1,208 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "051c95d6836228d120f5f4b984dd5aba1624f716" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "0.5.0" - -[[ArrayLayouts]] -deps = ["FillArrays", "LinearAlgebra"] -git-tree-sha1 = "a504dca2ac7eda8761c8f7c1ed52427a1be75a3c" -uuid = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" -version = "0.2.6" - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[BinaryProvider]] -deps = ["Libdl", "Logging", "SHA"] -git-tree-sha1 = "428e9106b1ff27593cbd979afac9b45b82372b8c" -uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232" -version = "0.5.9" - -[[ChainRules]] -deps = ["ChainRulesCore", "LinearAlgebra", "Reexport", "Requires", "Statistics"] -git-tree-sha1 = "f7175b1c1275b55e67b926c8d5ba57188b01c679" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.5.2" - -[[ChainRulesCore]] -deps = ["MuladdMacro"] -git-tree-sha1 = "e7f1b2b4ba7146575e1a30345e0ae842ae4c74d8" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.7.5" - -[[CommonSubexpressions]] -deps = ["Test"] -git-tree-sha1 = "efdaf19ab11c7889334ca247ff4c9f7c322817b0" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.2.0" - -[[CompilerSupportLibraries_jll]] -deps = ["Libdl", "Pkg"] -git-tree-sha1 = "7c4f882c41faa72118841185afc58a2eb00ef612" -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "0.3.3+0" - -[[Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[DiffResults]] -deps = ["StaticArrays"] -git-tree-sha1 = "da24935df8e0c6cf28de340b958f6aac88eaa0cc" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.0.2" - -[[DiffRules]] -deps = ["NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "eb0c34204c8410888844ada5359ac8b96292cfd1" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.0.1" - -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "5322d34d7600d3429665b37bcf7628dc602a28cc" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.8.8" - -[[ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "869540e4367122fbffaace383a5bdc34d6e5e5ac" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.10" - -[[IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "1a4355e4b5b50be2311ebb644f34f3306dbd0410" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.3.1" - -[[InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[LibGit2]] -deps = ["Printf"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[LinearAlgebra]] -deps = ["Libdl"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.5" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[MuladdMacro]] -git-tree-sha1 = "c6190f9a7fc5d9d5915ab29f2134421b12d24a68" -uuid = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" -version = "0.2.2" - -[[NNlib]] -deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"] -git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824" -uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -version = "0.6.6" - -[[NaNMath]] -git-tree-sha1 = "928b8ca9b2791081dc71a51c55347c27c618760f" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "0.3.3" - -[[OpenSpecFun_jll]] -deps = ["CompilerSupportLibraries_jll", "Libdl", "Pkg"] -git-tree-sha1 = "d51c416559217d974a1113522d5919235ae67a87" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.3+3" - -[[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" - -[[Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[Reexport]] -deps = ["Pkg"] -git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "0.2.0" - -[[Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "d37400976e98018ee840e0ca4f9d20baa231dc6b" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.0.1" - -[[SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[SparseArrays]] -deps = ["LinearAlgebra", "Random"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[SpecialFunctions]] -deps = ["OpenSpecFun_jll"] -git-tree-sha1 = "e19b98acb182567bcb7b75bb5d9eedf3a3b5ec6c" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "0.10.0" - -[[StaticArrays]] -deps = ["LinearAlgebra", "Random", "Statistics"] -git-tree-sha1 = "5c06c0aeb81bef54aed4b3f446847905eb6cbda0" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "0.12.3" - -[[Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" - -[[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[ZygoteRules]] -deps = ["MacroTools"] -git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.0" From cfeda77d8f4b70a8c9dfedd325241a05c4dfc41f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 7 May 2020 18:49:35 +0100 Subject: [PATCH 32/35] Remove pure as not needed anymore --- src/compiler/chainrules.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index b06b3d45d..2ccd37603 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -37,10 +37,7 @@ the remaining types in `sigt` are the types of the argument. """ is_kwfunc(::Vararg{Any}) = false -# Needs `@pure` because else will not run during type inference. -# This is pure enough, the only generic function it calls is in `Core` -# overloading `Core.kwftype` will no doubt break many other things also -Base.@pure is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) +is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f) """ From cef928c5f4af09cd1f828758fc34c13340cbbb19 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 20 May 2020 18:36:55 +0100 Subject: [PATCH 33/35] Remove special handling of multiple inputs to pullbacks from ChainRules (#1) * ChainRules pullbacks always have 1 input https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 * swap to version of chainrules that don't use multiarg pullbacks * update tests * make so don't need custom rule anymore * add comment * Update src/compiler/chainrules.jl Co-authored-by: willtebbutt Co-authored-by: willtebbutt --- Project.toml | 2 +- src/compiler/chainrules.jl | 39 +++++++++----------------------------- test/chainrules.jl | 18 ++++++++---------- 3 files changed, 18 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index 3b1778758..aded1c503 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5" ArrayLayouts = "0.1, 0.2" -ChainRules = "0.5.1" +ChainRules = "0.6.0" FillArrays = "0.8" ForwardDiff = "0" IRTools = "=0.3.1" diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 2ccd37603..5136c73dd 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -49,13 +49,16 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote u @inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks @inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) @inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing -@inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T} - T_outer = T <: NamedTuple ? NamedTuple : Tuple # must be a Tuple or NamedTuple, don't care about exact parameter types - xp = map(wrap_chainrules_output, x) - convert(T_outer, xp) +for T_outer in (:Tuple, :NamedTuple) + # we create separate methods rather than using a `Union` + an `if` so that we avoid a + # branch that changes output type, because nested AD on that kinda thing makes Zygote less + # than happy. + @eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer} + xp = map(wrap_chainrules_output, x) + convert($T_outer, xp) + end end - """ wrap_chainrules_input(x) @@ -69,24 +72,6 @@ to differentials types ChainRules uses. ChainRules.Composite{Any, typeof(xp)}(xp) end -""" - wrap_chainrules_pullback(f, args...) - -Wrap a chainrule's pullback `f`, converting the format of the inputs (`args`), -and the outputs. -""" -@inline function wrap_chainrules_pullback(pb, args...) - return wrap_chainrules_output(pb(wrap_chainrules_input(args)...)) -end - -# Note we hand-expess the single arg version of this to remove splatting -# because splatting breaks constant folding -# This can be removed after https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 -@inline function wrap_chainrules_pullback(pb, a) - return wrap_chainrules_output(pb(wrap_chainrules_input(a))) -end - - """ ZBack{F}(back) <: Function @@ -96,10 +81,7 @@ Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conven struct ZBack{F} <: Function back::F end -@inline (s::ZBack)(dy) = wrap_chainrules_pullback(s.back, dy) -# Dispatch here handles chainrules considing pullbacks to have multiple input if Tuple. -# TODO: this could be removed if: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/152 -@inline (s::ZBack)(dy::Tuple) = wrap_chainrules_pullback(s.back, dy...) +@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy))) # `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603 # though it might be worth keeping as a performance optimization (benchmarking pending) @inline (s::ZBack)(::Nothing) = nothing @@ -127,6 +109,3 @@ As per [`chain_rrule`](@ref) but with support for kwargs. kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs return y, kw_zpullback end - -# Required for nested AD -@adjoint ChainRules.Composite{Any, T}(x::T) where T = ChainRules.Composite{Any, T}(x), x->(x,) diff --git a/test/chainrules.jl b/test/chainrules.jl index 720080825..5fceb863c 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -57,7 +57,7 @@ using Zygote, Test, ChainRules simo(x) = (5x, 7x) function ChainRules.rrule(::typeof(simo), x) simo_rrule_hitcount[] += 1 - function simo_pullback(Δa, Δb) + function simo_pullback((Δa, Δb)) simo_pullback_hitcount[] += 1 return ChainRules.NO_FIELDS, 5*Δa + 7*Δb end @@ -101,7 +101,7 @@ using Zygote, Test, ChainRules mimo(a, b) = (5a + 7b, 100a, 10b) function ChainRules.rrule(::typeof(mimo), a, b) mimo_rrule_hitcount[] += 1 - function mimo_pullback(Δx, Δy, Δz) + function mimo_pullback((Δx, Δy, Δz)) mimo_pullback_hitcount[] += 1 return ChainRules.NO_FIELDS, 5Δx + 100Δy , 7Δx + 10Δz end @@ -126,9 +126,7 @@ using Zygote, Test, ChainRules @testset "nested AD hitting identity(::Tuple) pullback" begin # This is is a particularly fiddly case. - # the adjoint of `tuple` is `identity` - # and `identity(::Tuple)`s pullback has multiple inputs - # (since the primal had multiple outputs) + # Its kind of a simplified version of `sin'''(0.5)` but different in some places. f(x) = tuple(x, 2x, 3x) @@ -177,8 +175,8 @@ using Zygote, Test, ChainRules end end -# ChainRules doesn't have support for FastMath yet, so this fails -# https://github.com/JuliaDiff/ChainRules.jl/issues/174 -@test_broken gradient(2.0) do x - @fastmath x^2.0 -end == (4.0,) +@testset "FastMath support" begin + @test gradient(2.0) do x + @fastmath x^2.0 + end == (4.0,) +end From e7ecdd64abe0ef4526952c90182a9834da7549cd Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 28 May 2020 11:46:28 +0100 Subject: [PATCH 34/35] Follow up after rebase --- src/compiler/interface2.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index 65a07ac8d..26f7f3370 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -5,7 +5,7 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters) @generated function _pullback(ctx::AContext, f, args...) T = Tuple{f,args...} - ignore(T) && return :(f(args...), Pullback{$T}(())) + ignore_sig(T) && return :(f(args...), Pullback{$T}(())) iskw = is_kwfunc(f, args...) # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function @@ -19,6 +19,7 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters) meta, forw, _ = g argnames!(meta, Symbol("#self#"), :ctx, :f, :args) forw = varargs!(meta, forw, 3) + # IRTools.verify(forw) forw = slots!(pis!(inlineable!(forw))) @static if VERSION >= v"1.3" # no edges pre-1.3 append!(meta.code.edges, cr_edges) # be ready to swap to using chainrule if one is declared @@ -27,7 +28,7 @@ ignore_sig(T) = all(T -> T <: Type, T.parameters) end @generated function (j::Pullback{T})(Δ) where T - ignore(T) && return :nothing + ignore_sig(T) && return :nothing g = try _lookup_grad(T) catch e rethrow(CompileError(T,e)) @@ -38,6 +39,7 @@ end end meta, _, back = g argnames!(meta, Symbol("#self#"), :Δ) + # IRTools.verify(back) back = slots!(inlineable!(back)) return update!(meta.code, back) end From 41f4c1717a41844dc744e71d52f5f4a70e85468b Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 28 May 2020 12:16:02 +0100 Subject: [PATCH 35/35] mark as broken the test that fails on 1.0 --- test/chainrules.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/chainrules.jl b/test/chainrules.jl index 5fceb863c..e87d161ab 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -144,10 +144,18 @@ using Zygote, Test, ChainRules @test (1,) == h(1) - a3, pb3 = Zygote.pullback(h, 1) - @test ((1,),) == pb3(1) - end + if VERSION > v"1" + a3, pb3 = Zygote.pullback(h, 1) + @test ((1,),) == pb3(1) + else + # broken on Julia 1.0 because of https://github.com/FluxML/Zygote.jl/issues/638 + @test_broken begin + a3, pb3 = Zygote.pullback(h, 1); # line that errors + ((1,),) == pb3(1) # line actually being tested + end + end + end @testset "kwargs" begin kwfoo_rrule_hitcount = Ref(0)