Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort out frule API #129

Merged
merged 14 commits into from
Feb 21, 2020
4 changes: 2 additions & 2 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)

# Δs is the input to the propagator rule
# because this is push-forward there is one per input to the function
Δs = [Symbol(string(:Δ, i)) for i in 1:n_inputs]
Δs = [esc(Symbol(:Δ, i)) for i in 1:n_inputs]
pushforward_returns = map(1:n_outputs) do output_i
∂s = partials[output_i].args
propagation_expr(Δs, ∂s)
Expand All @@ -163,7 +163,7 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
return quote
# _ is the input derivative w.r.t. function internals. since we do not
# allow closures/functors with @scalar_rule, it is always ignored
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...))
function ChainRulesCore.frule((_, $(Δs...)), ::typeof($f), $(inputs...))
$(esc(:Ω)) = $call
$(setup_stmts...)
return $(esc(:Ω)), $pushforward_returns
Expand Down
28 changes: 11 additions & 17 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,48 @@
##### `frule`/`rrule`
#####

# TODO: remember to update the examples
"""
frule(f, x..., ṡelf, Δx...)
frule((Δf, Δx...), f, x...)

Expressing `x` as the tuple `(x₁, x₂, ...)`, `Δx` as the tuple `(Δx₁, Δx₂,
...)`, and the output tuple of `f(x...)` as `Ω`, return the tuple:
Expressing the output of `f(x...)` as `Ω`, return the tuple:

(Ω, (Ω̇₁, Ω̇₂, ...))
(Ω, ΔΩ)

The second return value is the propagation rule, or the pushforward.
It takes in differentials corresponding to the inputs (`ẋ₁, ẋ₂, ...`)
and `ṡelf` the internal values of the function (for closures).
The second return value is the differential w.r.t. the output.


If no method matching `frule(f, x..., ṡelf, Δx...)` has been defined, then
return `nothing`.
If no method matching `frule((Δf, Δx...), f, x...)` has been defined, then return `nothing`.

Examples:

unary input, unary output scalar function:

```
```jldoctest
julia> dself = Zero()
Zero()

julia> x = rand();

julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
julia> sinx, Δsinx = frule(sin, x, dself, 1)
(0.35696518021277485, 0.9341176907197836)

julia> sinx == sin(x)
true

julia> sin_pushforward == cos(x)
julia> Δsinx == cos(x)
true
```

unary input, binary output scalar function:

```
```jldoctest
julia> x = rand();

julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1);
julia> sincosx, Δsincosx = frule(sincos, x, dself, 1);

julia> sincosx == sincos(x)
true

julia> sincos_pushforward == (cos(x), -sin(x))
julia> Δsincosx == (cos(x), -sin(x))
true
```

Expand Down
85 changes: 77 additions & 8 deletions test/rules.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#######
# Demo setup
using StaticArrays: @SVector
import ChainRulesCore: frule
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

cool(x) = x + 1
cool(x, y) = x + y + 1
Expand All @@ -15,14 +16,58 @@ nice(x) = 1
very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))


#
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
# Tests that aim to ensure that the API for frules doesn't regress and make these things
# hard to implement.
#
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

varargs_function(x...) = sum(x)
function frule(dargs, ::typeof(varargs_function), x...)
Δx = Base.tail(dargs)
return sum(x), sum(Δx)
end

mixed_vararg(x, y, z...) = x + y + sum(z)
function frule(
dargs::Tuple{Any, Any, Any, Vararg},
::typeof(mixed_vararg), x, y, z...,
)
Δx = dargs[2]
Δy = dargs[3]
Δz = dargs[4:end]
return mixed_vararg(x, y, z...), Δx + Δy + sum(Δz)
end

type_constraints(x::Int, y::Float64) = x + y
function frule(
(_, Δx, Δy)::Tuple{Any, Int, Float64},
::typeof(type_constraints), x::Int, y::Float64,
)
return type_constraints(x, y), Δx + Δy
end

mixed_vararg_type_constaint(x::Float64, y::Real, z::Vararg{Float64}) = x + y + sum(z)
function frule(
dargs::Tuple{Any, Float64, Real, Vararg{Float64}},
::typeof(mixed_vararg_type_constaint), x::Float64, y::Real, z::Vararg{Float64},
)
Δx = dargs[2]
Δy = dargs[3]
Δz = dargs[4:end]
return x + y + sum(z), Δx + Δy + sum(Δz)
end

frule(dargs, ::typeof(Core._apply), f, x...) = frule(dargs[2:end], f, x...)

#######

_second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))

@testset "frule and rrule" begin
dself = Zero()
@test frule(cool, 1, dself, 1) === nothing
@test frule(cool, 1, dself, 1; iscool=true) === nothing
@test frule((dself, 1), cool, 1) === nothing
@test frule((dself, 1), cool, 1; iscool=true) === nothing
@test rrule(cool, 1) === nothing
@test rrule(cool, 1; iscool=true) === nothing

Expand All @@ -37,7 +82,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
Tuple{typeof(rrule),typeof(cool),String}])
@test cool_methods == only_methods

frx, cool_pushforward = frule(cool, 1, dself, 1)
frx, cool_pushforward = frule((dself, 1), cool, 1)
@test frx === 2
@test cool_pushforward === 1
rrx, cool_pullback = rrule(cool, 1)
Expand All @@ -46,13 +91,37 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule(nice, 1, dself, 1)
frx, nice_pushforward = frule((dself, 1), nice, 1)
@test nice_pushforward === Zero()
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, Zero()) === nice_pullback(1)

sx = @SVector [1, 2]
sy = @SVector [3, 4]
# This is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)

# Test that these run. Do not care about numerical correctness.
@test frule((nothing, 1.0, 1.0, 1.0), varargs_function, 0.5, 0.5, 0.5) == (1.5, 3.0)

@test frule((nothing, 1.0, 2.0, 3.0, 4.0), mixed_vararg, 1.0, 2.0, 3.0, 4.0) == (10.0, 10.0)

@test frule((nothing, 3, 2.0), type_constraints, 5, 4.0) == (9.0, 5.0)
@test frule((nothing, 3.0, 2.0im), type_constraints, 5, 4.0) == nothing

@test(frule(
(nothing, 3.0, 2.0, 1.0, 0.0),
mixed_vararg_type_constaint, 3.0, 2.0, 1.0, 0.0,
) == (6.0, 6.0))

@test frule(
(nothing, 3, 2.0, 1.0, 5.0), # violates type constraints
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
mixed_vararg_type_constaint, 3, 2.0, 1.0, 0,
) == nothing

@test frule((nothing, nothing, 5.0), Core._apply, dummy_identity, 4.0) == (4.0, 5.0)

@testset "broadcasting One" begin
sx = @SVector [1, 2]
sy = @SVector [3, 4]

# Test that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule((Zero(), sx, sy), very_nice, 1, 2)
end
end