Skip to content

Commit

Permalink
Merge pull request #30 from JuliaDiff/ox/wrtfunction
Browse files Browse the repository at this point in the history
Overhaul Rules
  • Loading branch information
oxinabox authored Sep 17, 2019
2 parents ad1a7a4 + e51ff80 commit a133468
Show file tree
Hide file tree
Showing 12 changed files with 509 additions and 396 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "0.2.1-DEV"
version = "0.3.0"

[compat]
julia = "^1.0"
Expand Down
9 changes: 6 additions & 3 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
module ChainRulesCore
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable

export AbstractRule, Rule, frule, rrule
export frule, rrule
export wirtinger_conjugate, wirtinger_primal, refine_differential
export @scalar_rule, @thunk
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export NO_FIELDS

include("differentials.jl")
include("differential_arithmetic.jl")
include("rule_types.jl")
include("operations.jl")
include("rules.jl")
include("rule_definition_tools.jl")
end # module
26 changes: 13 additions & 13 deletions src/differential_arithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ subtypes, as we know the full set that might be encountered.
Thus we can avoid any ambiguities.
Notice:
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :Thunk, :Any)
The precidence goes: (:Wirtinger, :Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
Thus each of the @eval loops creating definitions of + and *
defines the combination this type with all types of lower precidence.
This means each eval loops is 1 item smaller than the previous.
Expand Down Expand Up @@ -36,7 +36,7 @@ function Base.:+(a::Wirtinger, b::Wirtinger)
return Wirtinger(+(a.primal, b.primal), a.conjugate + b.conjugate)
end

for T in (:Casted, :Zero, :DNE, :One, :Thunk, :Any)
for T in (:Casted, :Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Wirtinger, b::$T) = a + Wirtinger(b, Zero())
@eval Base.:+(a::$T, b::Wirtinger) = Wirtinger(a, Zero()) + b

Expand All @@ -47,7 +47,7 @@ end

Base.:+(a::Casted, b::Casted) = Casted(broadcasted(+, a.value, b.value))
Base.:*(a::Casted, b::Casted) = Casted(broadcasted(*, a.value, b.value))
for T in (:Zero, :DNE, :One, :Thunk, :Any)
for T in (:Zero, :DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(a::Casted, b::$T) = Casted(broadcasted(+, a.value, b))
@eval Base.:+(a::$T, b::Casted) = Casted(broadcasted(+, a, b.value))

Expand All @@ -58,7 +58,7 @@ end

Base.:+(::Zero, b::Zero) = Zero()
Base.:*(::Zero, ::Zero) = Zero()
for T in (:DNE, :One, :Thunk, :Any)
for T in (:DNE, :One, :AbstractThunk, :Any)
@eval Base.:+(::Zero, b::$T) = b
@eval Base.:+(a::$T, ::Zero) = a

Expand All @@ -69,7 +69,7 @@ end

Base.:+(::DNE, ::DNE) = DNE()
Base.:*(::DNE, ::DNE) = DNE()
for T in (:One, :Thunk, :Any)
for T in (:One, :AbstractThunk, :Any)
@eval Base.:+(::DNE, b::$T) = b
@eval Base.:+(a::$T, ::DNE) = a

Expand All @@ -80,7 +80,7 @@ end

Base.:+(a::One, b::One) = extern(a) + extern(b)
Base.:*(::One, ::One) = One()
for T in (:Thunk, :Any)
for T in (:AbstractThunk, :Any)
@eval Base.:+(a::One, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::One) = a + extern(b)

Expand All @@ -89,12 +89,12 @@ for T in (:Thunk, :Any)
end


Base.:+(a::Thunk, b::Thunk) = extern(a) + extern(b)
Base.:*(a::Thunk, b::Thunk) = extern(a) * extern(b)
for T in (:Any,) #This loop is redundant but for consistency...
@eval Base.:+(a::Thunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::Thunk) = a + extern(b)
Base.:+(a::AbstractThunk, b::AbstractThunk) = extern(a) + extern(b)
Base.:*(a::AbstractThunk, b::AbstractThunk) = extern(a) * extern(b)
for T in (:Any,)
@eval Base.:+(a::AbstractThunk, b::$T) = extern(a) + b
@eval Base.:+(a::$T, b::AbstractThunk) = a + extern(b)

@eval Base.:*(a::Thunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::Thunk) = a * extern(b)
@eval Base.:*(a::AbstractThunk, b::$T) = extern(a) * b
@eval Base.:*(a::$T, b::AbstractThunk) = a * extern(b)
end
99 changes: 87 additions & 12 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
Base.iterate(::One, ::Any) = nothing


#####
##### `AbstractThunk
#####
abstract type AbstractThunk <: AbstractDifferential end

Base.Broadcast.broadcastable(x::AbstractThunk) = broadcastable(extern(x))

@inline function Base.iterate(x::AbstractThunk)
externed = extern(x)
element, state = iterate(externed)
return element, (externed, state)
end

@inline function Base.iterate(::AbstractThunk, (externed, state))
element, new_state = iterate(externed, state)
return element, (externed, new_state)
end

#####
##### `Thunk`
#####
Expand All @@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing
Thunk(()->v)
A thunk is a deferred computation.
It wraps a zero argument closure that when invoked returns a differential.
`@thunk(v)` is a macro that expands into `Thunk(()->v)`.
Calling that thunk, calls the wrapped closure.
Calling a thunk, calls the wrapped closure.
`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
If you do not want that, then simply call the thunk
Expand All @@ -199,31 +218,87 @@ Thunk(var"##8#10"())
julia> t()()
3
```
### When to `@thunk`?
When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
appropriately.
Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
By `@thunk`ing the work required for each, they then compute only what is needed.
#### So why not thunk everything?
`@thunk` creates a closure over the expression, which (effectively) creates a `struct`
with a field for each variable used in the expression, and call overloaded.
Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
- The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
- The expression being a constant
- The expression being itself a `thunk`
- The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
"""
struct Thunk{F} <: AbstractDifferential
struct Thunk{F} <: AbstractThunk
f::F
end

macro thunk(body)
return :(Thunk(() -> $(esc(body))))
end

# have to define this here after `@thunk` and `Thunk` is defined
Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))


(x::Thunk)() = x.f()
@inline extern(x::Thunk) = extern(x())

Base.Broadcast.broadcastable(x::Thunk) = broadcastable(extern(x))
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")

@inline function Base.iterate(x::Thunk)
externed = extern(x)
element, state = iterate(externed)
return element, (externed, state)
"""
InplaceableThunk(val::Thunk, add!::Function)
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
but it should do this more efficently than simply doing this directly.
(Otherwise one can just use a normal `Thunk`).
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
and destroy its inplacability.
"""
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
val::T
add!::F
end

@inline function Base.iterate(::Thunk, (externed, state))
element, new_state = iterate(externed, state)
return element, (externed, new_state)
(x::InplaceableThunk)() = x.val()
@inline extern(x::InplaceableThunk) = extern(x.val)

function Base.show(io::IO, x::InplaceableThunk)
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
end

Base.conj(x::Thunk) = @thunk(conj(extern(x)))
# The real reason we have this:
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.

Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")
"""
NO_FIELDS
Constant for the reverse-mode derivative with respect to a structure that has no fields.
The most notable use for this is for the reverse-mode derivative with respect to the
function itself, when that function is not a closure.
"""
const NO_FIELDS = DNE()

"""
refine_differential(𝒟::Type, der)
Converts, if required, a differential object `der`
(e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
to another differential that is more suited for the domain given by the type 𝒟.
Often this will behave as the identity function on `der`.
"""
function refine_differential(::Type{<:Union{<:Real, AbstractArray{<:Real}}}, w::Wirtinger)
return wirtinger_primal(w) + wirtinger_conjugate(w)
end
refine_differential(::Any, der) = der # most of the time leave it alone.
45 changes: 45 additions & 0 deletions src/operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# TODO: This all needs a fair bit of rethinking

"""
accumulate(Δ, ∂)
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's
various `AbstractDifferential` types.
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref)
"""
accumulate(Δ, ∂) = Δ .+

"""
accumulate!(Δ, ∂)
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place,
storing the result in `Δ`.
Note: this function may not actually store the result in `Δ` if `Δ` is immutable,
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case.
This function is overloadable by using a [`InplaceThunk`](@ref).
See also: [`accumulate`](@ref), [`store!`](@ref).
"""
function accumulate!(Δ, ∂)
return materialize!(Δ, broadcastable(cast(Δ) + ∂))
end

accumulate!::Number, ∂) = accumulate(Δ, ∂)



"""
store!(Δ, ∂)
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before.
potentially avoiding intermediate temporary allocations that might be
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`)
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended
to be customizable for specific rules/input types.
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref)
"""
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂))
Loading

2 comments on commit a133468

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/3606

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" a13346864df4382108d9f8294ea08afff91edc75
git push origin v0.3.0

Please sign in to comment.