Skip to content

Commit

Permalink
rename InplaceableThunk InplaceThunk
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed Sep 17, 2019
1 parent 959b869 commit 85b5bf9
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ export frule, rrule
export wirtinger_conjugate, wirtinger_primal, differential
export @scalar_rule, @thunk
export extern, cast, store!
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceThunk
export NO_FIELDS

include("differentials.jl")
Expand Down
20 changes: 10 additions & 10 deletions src/differentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,33 +254,33 @@ Base.conj(x::AbstractThunk) = @thunk(conj(extern(x)))
Base.show(io::IO, x::Thunk) = println(io, "Thunk($(repr(x.f)))")

"""
InplaceableThunk(val::Thunk, add!::Function)
InplaceThunk(val::Thunk, add!::Function)
A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
which is used internally in `accumulate!(Δ, ::InplaceThunk)`.
`add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
but it should do this more efficently than simply doing this directly.
(Otherwise one can just use a normal `Thunk`).
Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
Most operations on an `InplaceThunk` treat it just like a normal `Thunk`;
and destroy its inplacability.
"""
struct InplaceableThunk{T<:Thunk, F} <: AbstractThunk
struct InplaceThunk{T<:Thunk, F} <: AbstractThunk
val::T
add!::F
end

(x::InplaceableThunk)() = x.val()
@inline extern(x::InplaceableThunk) = extern(x.val)
(x::InplaceThunk)() = x.val()
@inline extern(x::InplaceThunk) = extern(x.val)

function Base.show(io::IO, x::InplaceableThunk)
println(io, "InplaceableThunk($(repr(x.val)), $(repr(x.add!)))")
function Base.show(io::IO, x::InplaceThunk)
println(io, "InplaceThunk($(repr(x.val)), $(repr(x.add!)))")
end

# The real reason we have this:
accumulate!(Δ, ∂::InplaceableThunk) =.add!(Δ)
store!(Δ, ∂::InplaceableThunk) =.add!((Δ.*=false)) # zero it, then add to it.
accumulate!(Δ, ∂::InplaceThunk) =.add!(Δ)
store!(Δ, ∂::InplaceThunk) =.add!((Δ.*=false)) # zero it, then add to it.

"""
NO_FIELDS
Expand Down

0 comments on commit 85b5bf9

Please sign in to comment.