-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from JuliaDiff/ox/wrtfunction
Overhaul Rules
- Loading branch information
Showing
12 changed files
with
509 additions
and
396 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,16 @@ | ||
module ChainRulesCore | ||
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable | ||
|
||
export AbstractRule, Rule, frule, rrule | ||
export frule, rrule | ||
export wirtinger_conjugate, wirtinger_primal, refine_differential | ||
export @scalar_rule, @thunk | ||
export extern, cast, store!, Wirtinger, Zero, One, Casted, DNE, Thunk, DNERule | ||
export extern, cast, store! | ||
export Wirtinger, Zero, One, Casted, DNE, Thunk, InplaceableThunk | ||
export NO_FIELDS | ||
|
||
include("differentials.jl") | ||
include("differential_arithmetic.jl") | ||
include("rule_types.jl") | ||
include("operations.jl") | ||
include("rules.jl") | ||
include("rule_definition_tools.jl") | ||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# TODO: This all needs a fair bit of rethinking | ||
|
||
""" | ||
accumulate(Δ, ∂) | ||
Return `Δ + ∂` evaluated in a manner that supports ChainRulesCore's | ||
various `AbstractDifferential` types. | ||
See also: [`accumulate!`](@ref), [`store!`](@ref), [`AbstractRule`](@ref) | ||
""" | ||
accumulate(Δ, ∂) = Δ .+ ∂ | ||
|
||
""" | ||
accumulate!(Δ, ∂) | ||
Similar to [`accumulate`](@ref), but attempts to compute `Δ + rule(args...)` in-place, | ||
storing the result in `Δ`. | ||
Note: this function may not actually store the result in `Δ` if `Δ` is immutable, | ||
so it is best to always call this as `Δ = accumulate!(Δ, ∂)` just in-case. | ||
This function is overloadable by using a [`InplaceThunk`](@ref). | ||
See also: [`accumulate`](@ref), [`store!`](@ref). | ||
""" | ||
function accumulate!(Δ, ∂) | ||
return materialize!(Δ, broadcastable(cast(Δ) + ∂)) | ||
end | ||
|
||
accumulate!(Δ::Number, ∂) = accumulate(Δ, ∂) | ||
|
||
|
||
|
||
""" | ||
store!(Δ, ∂) | ||
Stores `∂`, in `Δ`, overwriting what ever was in `Δ` before. | ||
potentially avoiding intermediate temporary allocations that might be | ||
necessary for alternative approaches (e.g. `copyto!(Δ, extern(∂))`) | ||
Like [`accumulate`](@ref) and [`accumulate!`](@ref), this function is intended | ||
to be customizable for specific rules/input types. | ||
See also: [`accumulate`](@ref), [`accumulate!`](@ref), [`AbstractRule`](@ref) | ||
""" | ||
store!(Δ, ∂) = materialize!(Δ, broadcastable(∂)) |
Oops, something went wrong.
a133468
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JuliaRegistrator register()
a133468
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Registration pull request created: JuliaRegistries/General/3606
After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.
This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via: