From 75f4c310ffeb4186e47453fac074ae99c529cc5d Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 16 Feb 2020 11:26:21 +0000 Subject: [PATCH 01/13] Change frule implementation --- src/rule_definition_tools.jl | 9 ++++++++- src/rules.jl | 28 +++++++++++----------------- test/rules.jl | 11 ++++++----- 3 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 29046ecd7..2da9a4d4b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,6 +1,13 @@ # These are some macros (and supporting functions) to make it easier to define rules. using MuladdMacro: @muladd +""" + tail(x::Tuple) + +Produce a new `Tuple` containing all but the first element of `x`. +""" +tail(x::Tuple) = x[2:end] + """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), @@ -163,7 +170,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) return quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...)) + function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...)) $(esc(:Ω)) = $call $(setup_stmts...) return $(esc(:Ω)), $pushforward_returns diff --git a/src/rules.jl b/src/rules.jl index 0f0c84e69..0fb73ece2 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -2,54 +2,48 @@ ##### `frule`/`rrule` ##### -# TODO: remember to update the examples """ - frule(f, x..., ṡelf, Δx...) + frule((Δf, Δx...), f, x...) -Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂, -...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple: +Expressing the output of `f(x...)` as `Ω`, return the tuple: - (Ω, (Ω̇₁, Ω̇₂, ...)) + (Ω, ΔΩ) -The second return value is the propagation rule, or the pushforward. -It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`) -and `ṡelf` the internal values of the function (for closures). +The second return value is the differential w.r.t. the output. - -If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then -return `nothing`. +If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`. Examples: unary input, unary output scalar function: -``` +```jldoctest julia> dself = Zero() Zero() julia> x = rand(); -julia> sinx, sin_pushforward = frule(sin, x, dself, 1) +julia> sinx, Δsinx = frule(sin, x, dself, 1) (0.35696518021277485, 0.9341176907197836) julia> sinx == sin(x) true -julia> sin_pushforward == cos(x) +julia> Δsinx == cos(x) true ``` unary input, binary output scalar function: -``` +```jldoctest julia> x = rand(); -julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1); +julia> sincosx, Δsincosx = frule(sincos, x, dself, 1); julia> sincosx == sincos(x) true -julia> sincos_pushforward == (cos(x), -sin(x)) +julia> Δsincosx == (cos(x), -sin(x)) true ``` diff --git a/test/rules.jl b/test/rules.jl index bd502d9f5..2910c9015 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -21,8 +21,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @testset "frule and rrule" begin dself = Zero() - @test frule(cool, 1, dself, 1) === nothing - @test frule(cool, 1, dself, 1; iscool=true) === nothing + @test frule((dself, 1), cool, 1) === nothing + @test frule((dself, 1), cool, 1; iscool=true) === nothing @test rrule(cool, 1) === nothing @test rrule(cool, 1; iscool=true) === nothing @@ -37,7 +37,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) Tuple{typeof(rrule),typeof(cool),String}]) @test cool_methods == only_methods - frx, cool_pushforward = frule(cool, 1, dself, 1) + frx, cool_pushforward = frule((dself, 1), cool, 1) @test frx === 2 @test cool_pushforward === 1 rrx, cool_pullback = rrule(cool, 1) @@ -46,13 +46,14 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rrx === 2 @test rr1 === 1 - frx, nice_pushforward = frule(nice, 1, dself, 1) + frx, nice_pushforward = frule((dself, 1), nice, 1) @test nice_pushforward === Zero() rrx, nice_pullback = rrule(nice, 1) @test (NO_FIELDS, Zero()) === nice_pullback(1) sx = @SVector [1, 2] sy = @SVector [3, 4] + # This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting - @inferred frule(very_nice, 1, 2, Zero(), sx, sy) + @inferred frule((Zero(), sx, sy), very_nice, 1, 2) end From d09d5f89d59c98a9670d7d1453d2b244c7b269c0 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 16 Feb 2020 11:54:32 +0000 Subject: [PATCH 02/13] Add API regression tests --- test/rules.jl | 72 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 68 insertions(+), 4 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index 2910c9015..87f4f5834 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -1,6 +1,7 @@ ####### # Demo setup using StaticArrays: @SVector +import ChainRulesCore: frule, tail cool(x) = x + 1 cool(x, y) = x + y + 1 @@ -15,6 +16,48 @@ nice(x) = 1 very_nice(x, y) = x + y @scalar_rule(very_nice(x, y), (One(), One())) + +# +# Tests that aim to ensure that the API for frules doesn't regress and make these things +# hard to implement. +# + +varargs_function(x...) = sum(x) +function frule(dargs, ::typeof(varargs_function), x...) + Δx = tail(dargs) + return sum(x), sum(Δx) +end + +mixed_vararg(x, y, z...) = x + y + sum(z) +function frule( + dargs::Tuple{Any, Any, Any, Vararg}, + ::typeof(mixed_vararg), x, y, z..., +) + Δx = dargs[2] + Δy = dargs[3] + Δz = dargs[4:end] + return mixed_vararg(x, y, z...), Δx + Δy + sum(Δz) +end + +type_constraints(x::Int, y::Float64) = x + y +function frule( + (_, Δx, Δy)::Tuple{Any, Int, Float64}, + ::typeof(type_constraints), x::Int, y::Float64, +) + return type_constraints(x, y), Δx + Δy +end + +mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) +function frule( + dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, + ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, +) + Δx = dargs[2] + Δy = dargs[3] + Δz = dargs[4:end] + return x + y + sum(z), Δx + Δy + sum(Δz) +end + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @@ -51,9 +94,30 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) rrx, nice_pullback = rrule(nice, 1) @test (NO_FIELDS, Zero()) === nice_pullback(1) - sx = @SVector [1, 2] - sy = @SVector [3, 4] - # This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting - @inferred frule((Zero(), sx, sy), very_nice, 1, 2) + # Test that these run. Do not care about numerical correctness. + @test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0) + + @test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0) + + @test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0) + @test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing + + @test(frule( + (nothing, 3.0, 2.0, 1.0, 0.0), + mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, + ) == (6.0, 6.0)) + + @test frule( + (nothing, 3, 2.0, 1.0, 5.0), # violates type constraints + mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, + ) == nothing + + @testset "broadcasting One" begin + sx = @SVector [1, 2] + sy = @SVector [3, 4] + + # Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting + @inferred frule((Zero(), sx, sy), very_nice, 1, 2) + end end From 785fdf5e562511b0090eba4d552c178516db013e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 16 Feb 2020 12:02:43 +0000 Subject: [PATCH 03/13] Core._apply --- test/rules.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/rules.jl b/test/rules.jl index 87f4f5834..481d103c6 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -58,6 +58,8 @@ function frule( return x + y + sum(z), Δx + Δy + sum(Δz) end +frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) + ####### _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @@ -113,6 +115,8 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, ) == nothing + @test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0) + @testset "broadcasting One" begin sx = @SVector [1, 2] sy = @SVector [3, 4] From 06358fc1282bd15435200aa721c8df38542795ad Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Sun, 16 Feb 2020 15:16:00 +0000 Subject: [PATCH 04/13] Remove unnecessary tail --- src/rule_definition_tools.jl | 7 ------- test/rules.jl | 4 ++-- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 2da9a4d4b..e414a5aed 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,13 +1,6 @@ # These are some macros (and supporting functions) to make it easier to define rules. using MuladdMacro: @muladd -""" - tail(x::Tuple) - -Produce a new `Tuple` containing all but the first element of `x`. -""" -tail(x::Tuple) = x[2:end] - """ @scalar_rule(f(x₁, x₂, ...), @setup(statement₁, statement₂, ...), diff --git a/test/rules.jl b/test/rules.jl index 481d103c6..f93b08295 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -1,7 +1,7 @@ ####### # Demo setup using StaticArrays: @SVector -import ChainRulesCore: frule, tail +import ChainRulesCore: frule cool(x) = x + 1 cool(x, y) = x + y + 1 @@ -24,7 +24,7 @@ very_nice(x, y) = x + y varargs_function(x...) = sum(x) function frule(dargs, ::typeof(varargs_function), x...) - Δx = tail(dargs) + Δx = Base.tail(dargs) return sum(x), sum(Δx) end From 7ed97637259f3977be28ef9e11a6f782f789e84b Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 17 Feb 2020 17:59:03 +0000 Subject: [PATCH 05/13] Fix inference issue --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index e414a5aed..4edc8c88f 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -148,7 +148,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) # Δs is the input to the propagator rule # because this is push-forward there is one per input to the function - Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs] + Δs = [esc(Symbol(:Δ, i)) for i in 1:n_inputs] pushforward_returns = map(1:n_outputs) do output_i ∂s = partials[output_i].args propagation_expr(Δs, ∂s) From f0e706fa6dd535602b16202ab2aac3e2d17158ae Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 17 Feb 2020 18:35:23 +0000 Subject: [PATCH 06/13] Some tweaks --- Project.toml | 2 +- docs/src/index.md | 2 +- test/rules.jl | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index a8d04495e..2f5deb032 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.6.1" +version = "0.6.1-DEV" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/docs/src/index.md b/docs/src/index.md index 9c0b70b9f..795fd0d5a 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -59,7 +59,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi The `frule` is written: ```julia -function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂Y end diff --git a/test/rules.jl b/test/rules.jl index f93b08295..32be45d23 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -110,8 +110,9 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0, ) == (6.0, 6.0)) + # violates type constraints, thus an frule should not be found. @test frule( - (nothing, 3, 2.0, 1.0, 5.0), # violates type constraints + (nothing, 3, 2.0, 1.0, 5.0), mixed_vararg_type_constaint, 3, 2.0, 1.0, 0, ) == nothing From d3655b5e79e0c208cd9c9fbb6373e3d7b0286e85 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 17 Feb 2020 18:35:31 +0000 Subject: [PATCH 07/13] Update test/rules.jl Co-Authored-By: Lyndon White --- test/rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index f93b08295..d473d04a9 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -17,7 +17,6 @@ very_nice(x, y) = x + y @scalar_rule(very_nice(x, y), (One(), One())) -# # Tests that aim to ensure that the API for frules doesn't regress and make these things # hard to implement. # From c69df6391b798112c5a0882fdd2eff0d944bc791 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 17 Feb 2020 18:36:15 +0000 Subject: [PATCH 08/13] tweak style --- test/rules.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rules.jl b/test/rules.jl index b074003a0..b04c290db 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -19,7 +19,6 @@ very_nice(x, y) = x + y # Tests that aim to ensure that the API for frules doesn't regress and make these things # hard to implement. -# varargs_function(x...) = sum(x) function frule(dargs, ::typeof(varargs_function), x...) From 0908e3f5c3b66e1e60e267b3d425e22e3d3b51b8 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Mon, 17 Feb 2020 18:56:16 +0000 Subject: [PATCH 09/13] Tweak docs further --- docs/src/index.md | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 795fd0d5a..38c5c4cfa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -175,7 +175,7 @@ end ``` But because it is fused into frule we see it as part of: ```julia -function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...) +function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...) ... return y, ∂y end @@ -183,7 +183,7 @@ end The input to the pushforward is often called the _perturbation_. -If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`. +If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule((ṡelf, ẋ), f, x))`. `ẏ` is commonly used to represent the perturbation for `y`. !!! note @@ -238,14 +238,14 @@ If we would like to know the the directional derivative of `f` for an input chan ```julia direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ) -y, ẏ = frule(f, a, b, c, Zero(), direction) +y, ẏ = frule((Zero(), direction...), f, a, b, c) ``` On the basis directions one gets the partial derivatives of `y`: ```julia -y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0) -y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0) -y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1) +y, ∂y_∂a = frule((Zero(), 1, 0, 0), f, a, b, c) +y, ∂y_∂b = frule((Zero(), 0, 1, 0), f, a, b, c) +y, ∂y_∂c = frule((Zero(), 0, 0, 1), f, a, b, c) ``` Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient): @@ -320,10 +320,10 @@ x = 3; ẋ = 1; # ∂x/∂x nofields = Zero(); # ∂self/∂self -a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x -b, ḃ = frule(*, 2, a, nofields, Zero(), unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x +a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x +b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), *, 2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x -c, ċ = frule(asin, b, nofields, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x +c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x unthunk(ċ) # output -2.0638950738662625 From 408ef2875d19b98e1b703fe510614008785bc43e Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 20 Feb 2020 18:49:32 +0000 Subject: [PATCH 10/13] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2f5deb032..c032e4fc4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.6.1-DEV" +version = "0.7.0" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" From fd88489761bec3768834a82cf5b316a20b765537 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 20 Feb 2020 18:52:23 +0000 Subject: [PATCH 11/13] Require tests pass on 1.3 --- .travis.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 8cc900185..fd0526ccf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -10,7 +10,6 @@ julia: - nightly jobs: allow_failures: - - julia: 1.3 - julia: nightly include: - stage: "Documentation" From 390b3fa216d45f27cefb5156491ccf1166452e50 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 21 Feb 2020 17:09:58 +0000 Subject: [PATCH 12/13] Update .travis.yml Co-Authored-By: Lyndon White --- .travis.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.travis.yml b/.travis.yml index fd0526ccf..3c0a64eb9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ julia: - 1.1 - 1.2 - 1.3 + - 1.4 - nightly jobs: allow_failures: From d6b668175964e28ab61576365e75cd9c01bb4cbd Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 21 Feb 2020 17:21:36 +0000 Subject: [PATCH 13/13] Style --- test/rules.jl | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index b04c290db..a80000376 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -1,7 +1,6 @@ ####### # Demo setup using StaticArrays: @SVector -import ChainRulesCore: frule cool(x) = x + 1 cool(x, y) = x + y + 1 @@ -21,13 +20,13 @@ very_nice(x, y) = x + y # hard to implement. varargs_function(x...) = sum(x) -function frule(dargs, ::typeof(varargs_function), x...) +function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...) Δx = Base.tail(dargs) return sum(x), sum(Δx) end mixed_vararg(x, y, z...) = x + y + sum(z) -function frule( +function ChainRulesCore.frule( dargs::Tuple{Any, Any, Any, Vararg}, ::typeof(mixed_vararg), x, y, z..., ) @@ -38,7 +37,7 @@ function frule( end type_constraints(x::Int, y::Float64) = x + y -function frule( +function ChainRulesCore.frule( (_, Δx, Δy)::Tuple{Any, Int, Float64}, ::typeof(type_constraints), x::Int, y::Float64, ) @@ -46,7 +45,7 @@ function frule( end mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z) -function frule( +function ChainRulesCore.frule( dargs::Tuple{Any, Float64, Real, Vararg{Float64}}, ::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64}, ) @@ -56,7 +55,7 @@ function frule( return x + y + sum(z), Δx + Δy + sum(Δz) end -frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) +ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...) #######