Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul Rules #30

Merged
merged 23 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.2.1-DEV"
version = "0.3.0"

[compat]
julia = "^1.0"
Expand Down
9 changes: 6 additions & 3 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module ChainRulesCore
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable

export AbstractRule, Rule, frule, rrule
export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
include("differential_arithmetic.jl")
include("rule_types.jl")
include("operations.jl")
include("rules.jl")
include("rule_definition_tools.jl")
end # module
26 changes: 13 additions & 13 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.

Notice:
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any)
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
Expand Down Expand Up @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any)
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

Expand All @@ -47,7 +47,7 @@ end

Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
for T in (:Zero, :DNE, :One, :Thunk, :Any)
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))

Expand All @@ -58,7 +58,7 @@ end

Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DNE, :One, :Thunk, :Any)
for T in (:DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

Expand All @@ -69,7 +69,7 @@ end

Base.:+(::DNE, ::DNE) = DNE()
Base.:*(::DNE, ::DNE) = DNE()
for T in (:One, :Thunk, :Any)
for T in (:One, :AbstractThunk, :Any)
@eval Base.:+(::DNE, b::$T) = b
@eval Base.:+(a::$T, ::DNE) = a

Expand All @@ -80,7 +80,7 @@ end

Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
for T in (:Thunk, :Any)
for T in (:AbstractThunk, :Any)
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)

Expand All @@ -89,12 +89,12 @@ for T in (:Thunk, :Any)
end


Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
for T in (:Any,) #This loop is redundant but for consistency...
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
for T in (:Any,)
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)

@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
end
99 changes: 87 additions & 12 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
Base.iterate(::One, ::Any) = nothing


#####
##### `AbstractThunk
#####
abstract type AbstractThunk <: AbstractDifferential end

Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))

@inline function Base.iterate(x::AbstractThunk)
externed = extern(x)
element, state = iterate(externed)
return element, (externed, state)
end

@inline function Base.iterate(::AbstractThunk, (externed, state))
element, new_state = iterate(externed, state)
return element, (externed, new_state)
end

#####
##### `Thunk`
#####
Expand All @@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing
Thunk(()->v)
A thunk is a deferred computation.
It wraps a zero argument closure that when invoked returns a differential.
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.

Calling that thunk, calls the wrapped closure.
Calling a thunk, calls the wrapped closure.
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
If you do not want that, then simply call the thunk

Expand All @@ -199,31 +218,87 @@ Thunk(var"##8#10"())
julia> t()()
3
```

### When to `@thunk`?
When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
appropriately.
Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
By `@thunk`ing the work required for each, they then compute only what is needed.

#### So why not thunk everything?
`@thunk` creates a closure over the expression, which (effectively) creates a `struct`
with a field for each variable used in the expression, and call overloaded.

Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
- The expression being a constant
- The expression being itself a `thunk`
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This entire section is great

Shall we move it to a page in the docs?

"""
struct Thunk{F} <: AbstractDifferential
struct Thunk{F} <: AbstractThunk
f::F
end

macro thunk(body)
return :(Thunk(() -> $(esc(body))))
end

# have to define this here after `@thunk` and `Thunk` is defined
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))


(x::Thunk)() = x.f()
@inline extern(x::Thunk) = extern(x())

Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x))
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")

@inline function Base.iterate(x::Thunk)
externed = extern(x)
element, state = iterate(externed)
return element, (externed, state)
"""
InplaceableThunk(val::Thunk, add!::Function)
nickrobinson251 marked this conversation as resolved.
Show resolved Hide resolved

A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.

`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
but it should do this more efficently than simply doing this directly.
(Otherwise one can just use a normal `Thunk`).

Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
and destroy its inplacability.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
and destroy its inplacability.
and destroy its ability to work inplace.

"""
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
val::T
add!::F
end

@inline function Base.iterate(::Thunk, (externed, state))
element, new_state = iterate(externed, state)
return element, (externed, new_state)
(x::InplaceableThunk)() = x.val()
@inline extern(x::InplaceableThunk) = extern(x.val)

function Base.show(io::IO, x::InplaceableThunk)
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
end

Base.conj(x::Thunk) = @thunk(conj(extern(x)))
# The real reason we have this:
accumulate!(Δ, ∂::InplaceableThunk) = ∂.add!(Δ)
store!(Δ, ∂::InplaceableThunk) = ∂.add!((Δ.*=false)) # zero it, then add to it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.
store!(Δ, ∂::InplaceableThunk) =.add!((Δ .*= false)) # zero it, then add to it.


Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
"""
NO_FIELDS

Constant for the reverse-mode derivative with respect to a structure that has no fields.
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const NO_FIELDS = DNE()

"""
refine_differential(𝒟::Type, der)

oxinabox marked this conversation as resolved.
Show resolved Hide resolved
Converts, if required, a differential object `der`
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
to another differential that is more suited for the domain given by the type 𝒟.
Often this will behave as the identity function on `der`.
"""
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
return wirtinger_primal(w) + wirtinger_conjugate(w)
end
refine_differential(::Any, der) = der # most of the time leave it alone.
45 changes: 45 additions & 0 deletions src/operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TODO: This all needs a fair bit of rethinking

"""
accumulate(Δ, ∂)

Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's
various `AbstractDifferential` types.

See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
"""
accumulate(Δ, ∂) = Δ .+ ∂

"""
accumulate!(Δ, ∂)

Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place,
storing the result in `Δ`.

Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
!!! note
this function may not actually store the result in `Δ` if `Δ` is immutable,

so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.

This function is overloadable by using a [`InplaceThunk`](@ref).
See also: [`accumulate`](@ref), [`store!`](@ref).
"""
function accumulate!(Δ, ∂)
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
end

accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂)



"""
store!(Δ, ∂)

Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Stores ``, in `Δ`, overwriting what ever was in `Δ` before.
Stores `` in `Δ` overwriting whatever was in `Δ` before.

potentially avoiding intermediate temporary allocations that might be
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`)

Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
to be customizable for specific rules/input types.

See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
"""
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂))
Loading