Skip to content

Commit

Permalink
Sort out frule API (#129)
Browse files Browse the repository at this point in the history
* Change frule implementation

* Add API regression tests

* Core._apply

* Remove unnecessary tail

* Fix inference issue

* Some tweaks

* Update test/rules.jl

Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au>

* tweak style

* Tweak docs further

* Update Project.toml

* Require tests pass on 1.3

* Update .travis.yml

Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au>

* Style

Co-authored-by: Lyndon White <oxinabox@ucc.asn.au>
  • Loading branch information
willtebbutt and oxinabox authored Feb 21, 2020
1 parent 30f9852 commit 2ef3f20
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ julia:
- 1.1
- 1.2
- 1.3
- 1.4
- nightly
jobs:
allow_failures:
- julia: 1.3
- julia: nightly
include:
- stage: "Documentation"
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.6.1"
version = "0.7.0"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
20 changes: 10 additions & 10 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Almost always the _pullback_ will be declared locally within the `rrule`, and wi

The `frule` is written:
```julia
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
...
return y, ∂Y
end
Expand Down Expand Up @@ -175,15 +175,15 @@ end
```
But because it is fused into frule we see it as part of:
```julia
function frule(::typeof(foo), args..., Δself, Δargs...; kwargs...)
function frule((Δself, Δargs...), ::typeof(foo), args...; kwargs...)
...
return y, ∂y
end
```


The input to the pushforward is often called the _perturbation_.
If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule(f, x, ṡelf, ẋ))`.
If the function is `y = f(x)` often the pushforward will be written `ẏ = last(frule((ṡelf, ẋ), f, x))`.
`` is commonly used to represent the perturbation for `y`.

!!! note
Expand Down Expand Up @@ -238,14 +238,14 @@ If we would like to know the the directional derivative of `f` for an input chan

```julia
direction = (1.5, 0.4, -1) # (ȧ, ḃ, ċ)
y, ẏ = frule(f, a, b, c, Zero(), direction)
y, ẏ = frule((Zero(), direction...), f, a, b, c)
```

On the basis directions one gets the partial derivatives of `y`:
```julia
y, ∂y_∂a = frule(f, a, b, c, Zero(), 1, 0, 0)
y, ∂y_∂b = frule(f, a, b, c, Zero(), 0, 1, 0)
y, ∂y_∂c = frule(f, a, b, c, Zero(), 0, 0, 1)
y, ∂y_∂a = frule((Zero(), 1, 0, 0), f, a, b, c)
y, ∂y_∂b = frule((Zero(), 0, 1, 0), f, a, b, c)
y, ∂y_∂c = frule((Zero(), 0, 0, 1), f, a, b, c)
```

Similarly, the most trivial use of `rrule` and returned `pullback` is to calculate the [Gradient](https://en.wikipedia.org/wiki/Gradient):
Expand Down Expand Up @@ -320,10 +320,10 @@ x = 3;
ẋ = 1; # ∂x/∂x
nofields = Zero(); # ∂self/∂self
a, ȧ = frule(sin, x, nofields, ẋ); # ∂a/∂x
b, ḃ = frule(*, 2, a, nofields, Zero(), unthunk(ȧ)); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
a, ȧ = frule((nofields, ẋ), sin, x); # ∂a/∂x
b, ḃ = frule((nofields, Zero(), unthunk(ȧ)), *, 2, a); # ∂b/∂x = ∂b/∂a⋅∂a/∂x
c, ċ = frule(asin, b, nofields, unthunk(ḃ)); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
c, ċ = frule((nofields, unthunk(ḃ)), asin, b); # ∂c/∂x = ∂c/∂b⋅∂b/∂x = ∂f/∂x
unthunk(ċ)
# output
-2.0638950738662625
Expand Down
4 changes: 2 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is push-forward there is one per input to the function
Δs = [Symbol(string(, i)) for i in 1:n_inputs]
Δs = [esc(Symbol(, i)) for i in 1:n_inputs]
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(Δs, ∂s)
Expand All @@ -163,7 +163,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
return quote
# _ is the input derivative w.r.t. function internals. since we do not
# allow closures/functors with @scalar_rule, it is always ignored
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...))
function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...))
$(esc()) = $call
$(setup_stmts...)
return $(esc()), $pushforward_returns
Expand Down
28 changes: 11 additions & 17 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,48 @@
##### `frule`/`rrule`
#####

# TODO: remember to update the examples
"""
frule(f, x..., ṡelf, Δx...)
frule((Δf, Δx...), f, x...)
Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂,
...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple:
Expressing the output of `f(x...)` as `Ω`, return the tuple:
(Ω, (Ω̇₁, Ω̇₂, ...))
(Ω, ΔΩ)
The second return value is the propagation rule, or the pushforward.
It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
and `ṡelf` the internal values of the function (for closures).
The second return value is the differential w.r.t. the output.
If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then
return `nothing`.
If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`.
Examples:
unary input, unary output scalar function:
```
```jldoctest
julia> dself = Zero()
Zero()
julia> x = rand();
julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
julia> sinx, Δsinx = frule(sin, x, dself, 1)
(0.35696518021277485, 0.9341176907197836)
julia> sinx == sin(x)
true
julia> sin_pushforward == cos(x)
julia> Δsinx == cos(x)
true
```
unary input, binary output scalar function:
```
```jldoctest
julia> x = rand();
julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1);
julia> sincosx, Δsincosx = frule(sincos, x, dself, 1);
julia> sincosx == sincos(x)
true
julia> sincos_pushforward == (cos(x), -sin(x))
julia> Δsincosx == (cos(x), -sin(x))
true
```
Expand Down
83 changes: 75 additions & 8 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,56 @@ nice(x) = 1
very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))


# Tests that aim to ensure that the API for frules doesn't regress and make these things
# hard to implement.

varargs_function(x...) = sum(x)
function ChainRulesCore.frule(dargs, ::typeof(varargs_function), x...)
Δx = Base.tail(dargs)
return sum(x), sum(Δx)
end

mixed_vararg(x, y, z...) = x + y + sum(z)
function ChainRulesCore.frule(
dargs::Tuple{Any, Any, Any, Vararg},
::typeof(mixed_vararg), x, y, z...,
)
Δx = dargs[2]
Δy = dargs[3]
Δz = dargs[4:end]
return mixed_vararg(x, y, z...), Δx + Δy + sum(Δz)
end

type_constraints(x::Int, y::Float64) = x + y
function ChainRulesCore.frule(
(_, Δx, Δy)::Tuple{Any, Int, Float64},
::typeof(type_constraints), x::Int, y::Float64,
)
return type_constraints(x, y), Δx + Δy
end

mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z)
function ChainRulesCore.frule(
dargs::Tuple{Any, Float64, Real, Vararg{Float64}},
::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64},
)
Δx = dargs[2]
Δy = dargs[3]
Δz = dargs[4:end]
return x + y + sum(z), Δx + Δy + sum(Δz)
end

ChainRulesCore.frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...)

#######

_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))

@testset "frule and rrule" begin
dself = Zero()
@test frule(cool, 1, dself, 1) === nothing
@test frule(cool, 1, dself, 1; iscool=true) === nothing
@test frule((dself, 1), cool, 1) === nothing
@test frule((dself, 1), cool, 1; iscool=true) === nothing
@test rrule(cool, 1) === nothing
@test rrule(cool, 1; iscool=true) === nothing

Expand All @@ -37,7 +79,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Tuple{typeof(rrule),typeof(cool),String}])
@test cool_methods == only_methods

frx, cool_pushforward = frule(cool, 1, dself, 1)
frx, cool_pushforward = frule((dself, 1), cool, 1)
@test frx === 2
@test cool_pushforward === 1
rrx, cool_pullback = rrule(cool, 1)
Expand All @@ -46,13 +88,38 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule(nice, 1, dself, 1)
frx, nice_pushforward = frule((dself, 1), nice, 1)
@test nice_pushforward === Zero()
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, Zero()) === nice_pullback(1)

sx = @SVector [1, 2]
sy = @SVector [3, 4]
# This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)

# Test that these run. Do not care about numerical correctness.
@test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0)

@test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0)

@test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0)
@test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing

@test(frule(
(nothing, 3.0, 2.0, 1.0, 0.0),
mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0,
) == (6.0, 6.0))

# violates type constraints, thus an frule should not be found.
@test frule(
(nothing, 3, 2.0, 1.0, 5.0),
mixed_vararg_type_constaint, 3, 2.0, 1.0, 0,
) == nothing

@test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0)

@testset "broadcasting One" begin
sx = @SVector [1, 2]
sy = @SVector [3, 4]

# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule((Zero(), sx, sy), very_nice, 1, 2)
end
end

2 comments on commit 2ef3f20

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/9875

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.0 -m "<description of version>" 2ef3f20b1e506b608ea0f97f4908747015441377
git push origin v0.7.0

Please sign in to comment.