-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Overhaul Rules #30
Overhaul Rules #30
Changes from all commits
0fabda9
7a72d84
7789a67
0b232a4
de2bb62
eb3c292
f6979ac
55dcefe
53d5f9e
5b6c0d8
5c3fdaa
f14e045
ef748c9
48a0391
752732e
765ecfc
62d3be3
3ae6449
03cb994
cb01743
3656389
1869be1
e51ff80
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
module ChainRulesCore | ||
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable | ||
|
||
export AbstractRule, Rule, frule, rrule | ||
export frule, rrule | ||
export wirtinger_conjugate, wirtinger_primal, refine_differential | ||
export @scalar_rule, @thunk | ||
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule | ||
export extern, cast, store! | ||
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk | ||
export NO_FIELDS | ||
|
||
include("differentials.jl") | ||
include("differential_arithmetic.jl") | ||
include("rule_types.jl") | ||
include("operations.jl") | ||
include("rules.jl") | ||
include("rule_definition_tools.jl") | ||
end # module |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing) | |||||
Base.iterate(::One, ::Any) = nothing | ||||||
|
||||||
|
||||||
##### | ||||||
##### `AbstractThunk | ||||||
##### | ||||||
abstract type AbstractThunk <: AbstractDifferential end | ||||||
|
||||||
Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x)) | ||||||
|
||||||
@inline function Base.iterate(x::AbstractThunk) | ||||||
externed = extern(x) | ||||||
element, state = iterate(externed) | ||||||
return element, (externed, state) | ||||||
end | ||||||
|
||||||
@inline function Base.iterate(::AbstractThunk, (externed, state)) | ||||||
element, new_state = iterate(externed, state) | ||||||
return element, (externed, new_state) | ||||||
end | ||||||
|
||||||
##### | ||||||
##### `Thunk` | ||||||
##### | ||||||
|
@@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing | |||||
Thunk(()->v) | ||||||
A thunk is a deferred computation. | ||||||
It wraps a zero argument closure that when invoked returns a differential. | ||||||
`@thunk(v)` is a macro that expands into `Thunk(()->v)`. | ||||||
|
||||||
Calling that thunk, calls the wrapped closure. | ||||||
Calling a thunk, calls the wrapped closure. | ||||||
`extern`ing thunks applies recursively, it also externs the differial that the closure returns. | ||||||
If you do not want that, then simply call the thunk | ||||||
|
||||||
|
@@ -199,31 +218,87 @@ Thunk(var"##8#10"()) | |||||
julia> t()() | ||||||
3 | ||||||
``` | ||||||
|
||||||
### When to `@thunk`? | ||||||
When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk` | ||||||
appropriately. | ||||||
Propagation rule's that return multiple derivatives are not able to do all the computing themselves. | ||||||
By `@thunk`ing the work required for each, they then compute only what is needed. | ||||||
|
||||||
#### So why not thunk everything? | ||||||
`@thunk` creates a closure over the expression, which (effectively) creates a `struct` | ||||||
with a field for each variable used in the expression, and call overloaded. | ||||||
|
||||||
Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being: | ||||||
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)` | ||||||
- The expression being a constant | ||||||
- The expression being itself a `thunk` | ||||||
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already) | ||||||
""" | ||||||
struct Thunk{F} <: AbstractDifferential | ||||||
struct Thunk{F} <: AbstractThunk | ||||||
f::F | ||||||
end | ||||||
|
||||||
macro thunk(body) | ||||||
return :(Thunk(() -> $(esc(body)))) | ||||||
end | ||||||
|
||||||
# have to define this here after `@thunk` and `Thunk` is defined | ||||||
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x))) | ||||||
|
||||||
|
||||||
(x::Thunk)() = x.f() | ||||||
@inline extern(x::Thunk) = extern(x()) | ||||||
|
||||||
Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x)) | ||||||
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") | ||||||
|
||||||
@inline function Base.iterate(x::Thunk) | ||||||
externed = extern(x) | ||||||
element, state = iterate(externed) | ||||||
return element, (externed, state) | ||||||
""" | ||||||
InplaceableThunk(val::Thunk, add!::Function) | ||||||
nickrobinson251 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function, | ||||||
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`. | ||||||
|
||||||
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val` | ||||||
but it should do this more efficently than simply doing this directly. | ||||||
(Otherwise one can just use a normal `Thunk`). | ||||||
|
||||||
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`; | ||||||
and destroy its inplacability. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
""" | ||||||
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk | ||||||
val::T | ||||||
add!::F | ||||||
end | ||||||
|
||||||
@inline function Base.iterate(::Thunk, (externed, state)) | ||||||
element, new_state = iterate(externed, state) | ||||||
return element, (externed, new_state) | ||||||
(x::InplaceableThunk)() = x.val() | ||||||
@inline extern(x::InplaceableThunk) = extern(x.val) | ||||||
|
||||||
function Base.show(io::IO, x::InplaceableThunk) | ||||||
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))") | ||||||
end | ||||||
|
||||||
Base.conj(x::Thunk) = @thunk(conj(extern(x))) | ||||||
# The real reason we have this: | ||||||
accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ) | ||||||
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))") | ||||||
""" | ||||||
NO_FIELDS | ||||||
|
||||||
Constant for the reverse-mode derivative with respect to a structure that has no fields. | ||||||
The most notable use for this is for the reverse-mode derivative with respect to the | ||||||
function itself, when that function is not a closure. | ||||||
""" | ||||||
const NO_FIELDS = DNE() | ||||||
|
||||||
""" | ||||||
refine_differential(𝒟::Type, der) | ||||||
|
||||||
oxinabox marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
Converts, if required, a differential object `der` | ||||||
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.), | ||||||
to another differential that is more suited for the domain given by the type 𝒟. | ||||||
Often this will behave as the identity function on `der`. | ||||||
""" | ||||||
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger) | ||||||
return wirtinger_primal(w) + wirtinger_conjugate(w) | ||||||
end | ||||||
refine_differential(::Any, der) = der # most of the time leave it alone. |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,45 @@ | ||||||||
# TODO: This all needs a fair bit of rethinking | ||||||||
|
||||||||
""" | ||||||||
accumulate(Δ, ∂) | ||||||||
|
||||||||
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's | ||||||||
various `AbstractDifferential` types. | ||||||||
|
||||||||
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) | ||||||||
""" | ||||||||
accumulate(Δ, ∂) = Δ .+ ∂ | ||||||||
|
||||||||
""" | ||||||||
accumulate!(Δ, ∂) | ||||||||
|
||||||||
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, | ||||||||
storing the result in `Δ`. | ||||||||
|
||||||||
Note: this function may not actually store the result in `Δ` if `Δ` is immutable, | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. | ||||||||
|
||||||||
This function is overloadable by using a [`InplaceThunk`](@ref). | ||||||||
See also: [`accumulate`](@ref), [`store!`](@ref). | ||||||||
""" | ||||||||
function accumulate!(Δ, ∂) | ||||||||
return materialize!(Δ, broadcastable(cast(Δ) + ∂)) | ||||||||
end | ||||||||
|
||||||||
accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) | ||||||||
|
||||||||
|
||||||||
|
||||||||
""" | ||||||||
store!(Δ, ∂) | ||||||||
|
||||||||
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
potentially avoiding intermediate temporary allocations that might be | ||||||||
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`) | ||||||||
|
||||||||
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended | ||||||||
to be customizable for specific rules/input types. | ||||||||
|
||||||||
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) | ||||||||
""" | ||||||||
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This entire section is great
Shall we move it to a page in the docs?