diff --git a/Project.toml b/Project.toml index 5ab083d91..01a294bc2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.5.4" +version = "0.6.0" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 39f9d6673..8af37bc61 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -160,14 +160,21 @@ function scalar_frule_expr(f, call, setup_stmts, inputs, partials) pushforward_returns = pushforward_returns[1] end - return quote + pushforward = quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored + function $(propagator_name(f, :pushforward))(_, $(Δs...)) + $pushforward_returns + end + + end + + return quote function ChainRulesCore.frule(::typeof($f), $(inputs...), _, $(Δs...)) $(esc(:Ω)) = $call $(setup_stmts...) - return $(esc(:Ω)), $pushforward_returns - end + return $(esc(:Ω)), $pushforward + end end end diff --git a/src/rules.jl b/src/rules.jl index 0f0c84e69..dd479b269 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -35,7 +35,7 @@ julia> sinx, sin_pushforward = frule(sin, x, dself, 1) julia> sinx == sin(x) true -julia> sin_pushforward == cos(x) +julia> sin_pushforward(dself, 1) == cos(x) true ``` @@ -49,7 +49,7 @@ julia> sincosx, sincos_pushforward = frule(sincos, x, dself, 1); julia> sincosx == sincos(x) true -julia> sincos_pushforward == (cos(x), -sin(x)) +julia> sincos_pushforward(dself, 1) == (cos(x), -sin(x)) true ``` diff --git a/test/rules.jl b/test/rules.jl index f6b770b5a..5d680c360 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -39,7 +39,7 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) frx, cool_pushforward = frule(cool, 1, dself, 1) @test frx === 2 - @test cool_pushforward === 1 + @test cool_pushforward(dself, 1) === 1 rrx, cool_pullback = rrule(cool, 1) self, rr1 = cool_pullback(1) @test self === NO_FIELDS @@ -47,12 +47,13 @@ _second(t) = Base.tuple_type_head(Base.tuple_type_tail(t)) @test rr1 === 1 frx, nice_pushforward = frule(nice, 1, dself, 1) - @test nice_pushforward === Zero() + @test nice_pushforward(dself, 1) === Zero() rrx, nice_pullback = rrule(nice, 1) - @test (NO_FIELDS, Zero()) === nice_pullback(1) + @test (dself, Zero()) === nice_pullback(1) sx = @SVector [1, 2] sy = @SVector [3, 4] + _, fun = frule(very_nice, 1, 2, dself, sx, sy) # This actually is testing that @scalar_rule and `One()` play nice together, w.r.t broadcasting - @inferred frule(very_nice, 1, 2, Zero(), sx, sy) + @inferred fun(dself, sx, sy) end