Skip to content

Commit

Permalink
=Make frule wrt self and rrule wrt self different [WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Aug 28, 2019
1 parent 1d9165e commit 061197f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broad
export AbstractRule, Rule, frule, rrule
export @scalar_rule, @thunk
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
export NO_FIELDS_RULE
export NO_FIELDS_RULE, ZERO_RULE

include("differentials.jl")
include("rules.jl")
Expand Down
7 changes: 4 additions & 3 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ methods for `frule` and `rrule`:
function ChainRulesCore.frule(::typeof(f), x₁::Number, x₂::Number, ...)
Ω = f(x₁, x₂, ...)
\$(statement₁, statement₂, ...)
return Ω, (NO_FIELDS_RULE,
return Ω, (ZERO_RULE,
Rule((Δx₁, Δx₂, ...) -> ∂f₁_∂x₁ * Δx₁ + ∂f₁_∂x₂ * Δx₂ + ...),
Rule((Δx₁, Δx₂, ...) -> ∂f₂_∂x₁ * Δx₁ + ∂f₂_∂x₂ * Δx₂ + ...),
...)
Expand All @@ -38,7 +38,8 @@ e.g. `f(x₁::Complex, x₂)`, which will constrain `x₁` to `Complex` and `x
At present this does not support defining rules for closures/functors.
This the first returned rule, representing the derivative with respect to the
function itself, is always the `NO_FIELDS_RULE`.
function itself, is always the `NO_FIELDS_RULE` (reverse-mode),
or `ZERO_RULE` (forward-mode).
The result of `f(x₁, x₂, ...)` is automatically bound to `Ω`. This
allows the primal result to be conveniently referenced (as `Ω`) within the
Expand Down Expand Up @@ -100,7 +101,7 @@ macro scalar_rule(call, maybe_setup, partials...)

# First pseudo-partial is derivative WRT function itself. Since this macro does not
# support closures, it is just the empty NamedTuple
forward_rules = Expr(:tuple, NO_FIELDS_RULE, forward_rules...)
forward_rules = Expr(:tuple, ZERO_RULE, forward_rules...)
reverse_rules = Expr(:tuple, NO_FIELDS_RULE, reverse_rules...)
return quote
if fieldcount(typeof($f)) > 0
Expand Down
15 changes: 13 additions & 2 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,23 @@ accumulate!(Δ, rule::Rule{F,U}, args...) where {F,U<:Function} = rule.u(Δ, arg
NO_FIELDS_RULE
Constant for the rule for the derivative with respect to structure that has no fields.
The most notable use for this is for the dertiviate with respect to the function itself,
when that function is not a closure.
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
The rule returns an empty `NamedTuple` for all inputs.
"""
const NO_FIELDS_RULE = Rule((args...)->NamedTuple())

"""
ZERO_RULE
This is a rule that returns `Zero()` regardless of input.
The most notable use for this is for the forward-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const ZERO_RULE = Rule((args...)->Zero())



#####
##### `DNERule`
#####
Expand Down

0 comments on commit 061197f

Please sign in to comment.