Skip to content

Commit

Permalink
Make frule return a closure which returns the derivative.
Browse files Browse the repository at this point in the history
Co-authored-by: "Yingbo Ma" <mayingbo5@gmail.com>
Co-authored-by: "Shashi Gowda" <gowda@mit.edu>
  • Loading branch information
3 people committed Jan 16, 2020
1 parent 6454db2 commit 2a468ca
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.5.4"
version = "0.6.0"

[deps]
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
Expand Down
15 changes: 11 additions & 4 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,21 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials)
pushforward_returns = pushforward_returns[1]
end

return quote
pushforward = quote
# _ is the input derivative w.r.t. function internals. since we do not
# allow closures/functors with @scalar_rule, it is always ignored
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...))
function $(propagator_name(f, :pushforward))(_, $(Δs...))
$pushforward_returns
end

end

return quote
function ChainRulesCore.frule(::typeof($f), $(inputs...), _, _...)
$(esc()) = $call
$(setup_stmts...)
return $(esc()), $pushforward_returns
end
return $(esc()), $pushforward
end
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ julia> sinx, sin_pushforward = frule(sin, x, dself, 1)
julia> sinx == sin(x)
true
julia> sin_pushforward == cos(x)
julia> sin_pushforward(dself, 1) == cos(x)
true
```
Expand All @@ -49,7 +49,7 @@ julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1);
julia> sincosx == sincos(x)
true
julia> sincos_pushforward == (cos(x), -sin(x))
julia> sincos_pushforward(dself, 1) == (cos(x), -sin(x))
true
```
Expand Down
9 changes: 5 additions & 4 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,21 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t))

frx, cool_pushforward = frule(cool, 1, dself, 1)
@test frx === 2
@test cool_pushforward === 1
@test cool_pushforward(dself, 1) === 1
rrx, cool_pullback = rrule(cool, 1)
self, rr1 = cool_pullback(1)
@test self === NO_FIELDS
@test rrx === 2
@test rr1 === 1

frx, nice_pushforward = frule(nice, 1, dself, 1)
@test nice_pushforward === Zero()
@test nice_pushforward(dself, 1) === Zero()
rrx, nice_pullback = rrule(nice, 1)
@test (NO_FIELDS, Zero()) === nice_pullback(1)
@test (dself, Zero()) === nice_pullback(1)

sx = @SVector [1, 2]
sy = @SVector [3, 4]
_, fun = frule(very_nice, 1, 2, dself, sx, sy)
# This actually is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting
@inferred frule(very_nice, 1, 2, Zero(), sx, sy)
@inferred fun(dself, sx, sy)
end

0 comments on commit 2a468ca

Please sign in to comment.