From e8b9b7feecec0bd547c4563741a1d479b77e4c64 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 26 May 2021 19:46:13 +0100 Subject: [PATCH 1/3] Remove all the rule overload generation stuff --- docs/make.jl | 5 +- docs/src/autodiff/operator_overloading.md | 80 ---------- .../overview.md => use_in_ad_system.md} | 3 +- src/ChainRulesCore.jl | 3 - src/precompile.jl | 1 - src/ruleset_loading.jl | 142 ------------------ test/demos/forwarddiffzero.jl | 91 ----------- test/demos/reversediffzero.jl | 141 ----------------- test/ruleset_loading.jl | 72 --------- test/runtests.jl | 6 - 10 files changed, 3 insertions(+), 541 deletions(-) delete mode 100644 docs/src/autodiff/operator_overloading.md rename docs/src/{autodiff/overview.md => use_in_ad_system.md} (89%) delete mode 100644 src/precompile.jl delete mode 100644 src/ruleset_loading.jl delete mode 100644 test/demos/forwarddiffzero.jl delete mode 100644 test/demos/reversediffzero.jl delete mode 100644 test/ruleset_loading.jl diff --git a/docs/make.jl b/docs/make.jl index 011502584..ea4876bb4 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -52,10 +52,7 @@ makedocs( "Deriving Array Rules" => "arrays.md", "Debug Mode" => "debug_mode.md", "Gradient Accumulation" => "gradient_accumulation.md", - "Usage in AD" => [ - "Overview" => "autodiff/overview.md", - "Operator Overloading" => "autodiff/operator_overloading.md", - ], + "Usage in AD" => "use_in_ad_system.md", "Design" => [ "Changing the Primal" => "design/changing_the_primal.md", "Many Differential Types" => "design/many_differentials.md", diff --git a/docs/src/autodiff/operator_overloading.md b/docs/src/autodiff/operator_overloading.md deleted file mode 100644 index 56f72e10c..000000000 --- a/docs/src/autodiff/operator_overloading.md +++ /dev/null @@ -1,80 +0,0 @@ -# Operator Overloading - -The principal interface for using the operator overload generation method is [`on_new_rule`](@ref). -This function allows one to register a hook to be run every time a new rule is defined. -The hook receives a signature type-type as input, and generally will use `eval` to define -an overload of an AD system's overloaded type. -For example, using the signature type `Tuple{typeof(+), Real, Real}` to make -`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. -A signature type tuple always has the form: -`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}, ...}`, where `pos_arg1` is the -first positional argument. -One can dispatch on the signature type to make rules with argument types your AD does not support not call `eval`; -or more simply you can just use conditions for this. -For example if your AD only supports `AbstractMatrix{Float64}` and `Float64` inputs you might write: -```julia -const ACCEPT_TYPE = Union{Float64, AbstractMatrix{Float64}} -function define_overload(sig::Type{<:Tuple{F, Vararg{ACCEPT_TYPE}}) where F - @eval quote - # ... - end -end -define_overload(::Any) = nothing # don't do anything for any other signature - -on_new_rule(define_overload, frule) -``` - -or you might write: -```julia -const ACCEPT_TYPES = (Float64, AbstractMatrix{Float64}) -function define_overload(sig) - sig = Base.unwrap_unionall(sig) # not really handling most UnionAll, - opT, argTs = Iterators.peel(sig.parameters) - all(any(acceptT<: argT for acceptT in ACCEPT_TYPES) for argT in argTs) || return - @eval quote - # ... - end -end - -on_new_rule(define_overload, frule) -``` - -The generation of overloaded code is the responsibility of the AD implementor. -Packages like [ExprTools.jl](https://github.com/invenia/ExprTools.jl) can be helpful for this. -Its generally fairly simple, though can become complex if you need to handle complicated type-constraints. -Examples are shown below. - -The hook is automatically triggered whenever a package is loaded. -It can also be triggers manually using `refresh_rules`(@ref). -This is useful for example if new rules are define in the REPL, or if a package defining rules is modified. -(Revise.jl will not automatically trigger). -When the rules are refreshed (automatically or manually), the hooks are only triggered on new/modified rules; not ones that have already had the hooks triggered on. - -`clear_new_rule_hooks!`(@ref) clears all registered hooks. -It is useful to undo [`on_new_rule`] hook registration if you are iteratively developing your overload generation function. - -## Examples - -### ForwardDiffZero -The overload generation hook in this example is: `define_dual_overload`. - -````@eval -using Markdown -Markdown.parse(""" -```julia -$(read(joinpath(@__DIR__,"../../../test/demos/forwarddiffzero.jl"), String)) -``` -""") -```` - -### ReverseDiffZero -The overload generation hook in this example is: `define_tracked_overload`. - -````@eval -using Markdown -Markdown.parse(""" -```julia -$(read(joinpath(@__DIR__,"../../../test/demos/reversediffzero.jl"), String)) -``` -""") -```` diff --git a/docs/src/autodiff/overview.md b/docs/src/use_in_ad_system.md similarity index 89% rename from docs/src/autodiff/overview.md rename to docs/src/use_in_ad_system.md index 2638d6382..d0c4d0228 100644 --- a/docs/src/autodiff/overview.md +++ b/docs/src/use_in_ad_system.md @@ -7,7 +7,8 @@ to avoid having to code all your own AD primitives / custom sensitives. There are 3 main ways to access ChainRules rule sets in your AutoDiff system. -1. [Operation Overloading Generation](operator_overloading.html) +1. [Operator Overloading Generation](https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable) + - using [ChainRulesOverloadGenerration.jl](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/) - This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all. diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 04f9ad4d4..070673458 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -4,7 +4,6 @@ using LinearAlgebra: LinearAlgebra using SparseArrays: SparseVector, SparseMatrixCSC using Compat: hasfield -export on_new_rule, refresh_rules # generation tools export frule, rrule # core function export @non_differentiable, @scalar_rule, @thunk, @not_implemented # definition helper macros export canonicalize, extern, unthunk # differential operations @@ -27,9 +26,7 @@ include("accumulation.jl") include("rules.jl") include("rule_definition_tools.jl") -include("ruleset_loading.jl") include("deprecated.jl") -include("precompile.jl") end # module diff --git a/src/precompile.jl b/src/precompile.jl deleted file mode 100644 index ba4e72c8d..000000000 --- a/src/precompile.jl +++ /dev/null @@ -1 +0,0 @@ -precompile(_package_hook, (Base.PkgId,)) diff --git a/src/ruleset_loading.jl b/src/ruleset_loading.jl deleted file mode 100644 index aa40fae26..000000000 --- a/src/ruleset_loading.jl +++ /dev/null @@ -1,142 +0,0 @@ -# Infastructure to support generating overloads from rules. -_package_hook(::Base.PkgId) = refresh_rules() -function __init__() - # Need to refresh rules when a package is loaded - push!(Base.package_callbacks, _package_hook) -end - -# Holds all the hook functions that are invokes when a new rule is defined -const RRULE_DEFINITION_HOOKS = Function[] -const FRULE_DEFINITION_HOOKS = Function[] -_hook_list(::typeof(rrule)) = RRULE_DEFINITION_HOOKS -_hook_list(::typeof(frule)) = FRULE_DEFINITION_HOOKS - -""" - on_new_rule(hook, frule | rrule) - -Register a `hook` function to run when new rules are defined. -The hook receives a signature type-type as input, and generally will use `eval` to define -an overload of an AD system's overloaded type -For example, using the signature type `Tuple{typeof(+), Real, Real}` to make -`+(::DualNumber, ::DualNumber)` call the `frule` for `+`. -A signature type tuple always has the form: -`Tuple{typeof(operation), typeof{pos_arg1}, typeof{pos_arg2}...}`, where `pos_arg1` is the -first positional argument. - -The hooks are automatically run on new rules whenever a package is loaded. -They can be manually triggered by [`refresh_rules`](@ref). -When a hook is first registered with `on_new_rule` it is run on all existing rules. -""" -function on_new_rule(hook_fun, rule_kind) - # apply the hook to the existing rules - ret = map(_rule_list(rule_kind)) do method - sig = _primal_sig(rule_kind, method) - _safe_hook_fun(hook_fun, sig) - end - - # register hook for new rules -- so all new rules get this function applied - push!(_hook_list(rule_kind), hook_fun) - return ret -end - -""" - clear_new_rule_hooks!(frule|rrule) - -Clears all hooks that were registered with corresponding [`on_new_rule`](@ref). -This is useful for while working interactively to define your rule generating hooks. -If you previously wrong an incorrect hook, you can use this to get rid of the old one. - -!!! warning - This absolutely should not be used in a package, as it will break any other AD system - using the rule hooks that might happen to be loaded. -""" -clear_new_rule_hooks!(rule_kind) = empty!(_hook_list(rule_kind)) - -""" - _rule_list(frule | rrule) - -Returns a list of all the methods of the currently defined rules of the given kind. -Excluding the fallback rule that returns `nothing` for every input. -""" -function _rule_list end -# The fallback rules are the only rules defined in ChainRulesCore & that is how we skip them -_rule_list(rule_kind) = (m for m in methods(rule_kind) if m.module != @__MODULE__) - - -const LAST_REFRESH_RRULE = Ref(0) -const LAST_REFRESH_FRULE = Ref(0) -last_refresh(::typeof(frule)) = LAST_REFRESH_FRULE -last_refresh(::typeof(rrule)) = LAST_REFRESH_RRULE - -""" - refresh_rules() - refresh_rules(frule | rrule) - -This triggers all [`on_new_rule`](@ref) hooks to run on any newly defined rules. -It is *automatically* run when ever a package is loaded. -It can also be manually called to run it directly, for example if a rule was defined -in the REPL or within the same file as the AD function. -""" -function refresh_rules() - refresh_rules(frule); - refresh_rules(rrule) -end - -function refresh_rules(rule_kind) - isempty(_rule_list(rule_kind)) && return # if no hooks, exit early, nothing to run - already_done_world_age = last_refresh(rule_kind)[] - for method in _rule_list(rule_kind) - _defined_world(method) < already_done_world_age && continue - sig = _primal_sig(rule_kind, method) - _trigger_new_rule_hooks(rule_kind, sig) - end - - last_refresh(rule_kind)[] = _current_world() - return nothing -end - -@static if VERSION >= v"1.2" - _current_world() = Base.get_world_counter() - _defined_world(method) = method.primary_world -else - _current_world() = ccall(:jl_get_world_counter, UInt, ()) - _defined_world(method) = method.min_world -end - -""" - _primal_sig(frule|rule, rule_method | rule_sig) - -Returns the signature as a `Tuple{function_type, arg1_type, arg2_type,...}`. -""" -_primal_sig(rule_kind, method::Method) = _primal_sig(rule_kind, method.sig) -function _primal_sig(::typeof(frule), rule_sig::DataType) - @assert rule_sig.parameters[1] == typeof(frule) - # need to skip frule and the deriviative info, so starting from the 3rd - return Tuple{rule_sig.parameters[3:end]...} -end -function _primal_sig(::typeof(rrule), rule_sig::DataType) - @assert rule_sig.parameters[1] == typeof(rrule) - # need to skip rrule so starting from the 2rd - return Tuple{rule_sig.parameters[2:end]...} -end -function _primal_sig(rule_kind, rule_sig::UnionAll) - # This looks a lot like Base.unwrap_unionall and Base.rewrap_unionall, but using those - # seems not to work - p_sig = _primal_sig(rule_kind, rule_sig.body) - return UnionAll(rule_sig.var, p_sig) -end - - -function _trigger_new_rule_hooks(rule_kind, sig) - for hook_fun in _hook_list(rule_kind) - _safe_hook_fun(hook_fun, sig) - end -end - -function _safe_hook_fun(hook_fun, sig) - try - hook_fun(sig) - catch err - @error "Error triggering hook" hook_fun sig exception=(err, catch_backtrace()) - end -end diff --git a/test/demos/forwarddiffzero.jl b/test/demos/forwarddiffzero.jl deleted file mode 100644 index 6ff22da25..000000000 --- a/test/demos/forwarddiffzero.jl +++ /dev/null @@ -1,91 +0,0 @@ -"The simplest viable forward mode a AD, only supports `Float64`" -module ForwardDiffZero -using ChainRulesCore -using Test - -######################################### -# Initial rule setup -@scalar_rule x + y (1, 1) -@scalar_rule x - y (1, -1) -########################## -# Define the AD - -# Note that we never directly define Dual Number Arithmetic on Dual numbers -# instead it is automatically defined from the `frules` -struct Dual <: Real - primal::Float64 - partial::Float64 -end - -primal(d::Dual) = d.primal -partial(d::Dual) = d.partial - -primal(d::Real) = d -partial(d::Real) = 0.0 - -# needed for `^` to work from having `*` defined -Base.to_power_type(x::Dual) = x - - -function define_dual_overload(sig) - sig = Base.unwrap_unionall(sig) # Not really handling most UnionAlls - opT, argTs = Iterators.peel(sig.parameters) - opT isa Type{<:Type} && return # not handling constructors - fieldcount(opT) == 0 || return # not handling functors - all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. - - N = length(sig.parameters) - 1 # skip the op - fdef = quote - # we use the function call overloading form as it lets us avoid namespacing issues - # as we can directly interpolate the function type into to the AST. - function (op::$opT)(dual_args::Vararg{Union{Dual, Float64}, $N}; kwargs...) - ȧrgs = (NoTangent(), partial.(dual_args)...) - args = (op, primal.(dual_args)...) - y, ẏ = frule(ȧrgs, args...; kwargs...) - return Dual(y, ẏ) # if y, ẏ are not `Float64` this will error. - end - end - eval(fdef) -end - -# !Important!: Attach the define function to the `on_new_rule` hook -on_new_rule(define_dual_overload, frule) - -"Do a calculus. `f` should have a single input." -function derv(f, arg) - duals = Dual(arg, one(arg)) - return partial(f(duals...)) -end - -# End AD definition -################################ - -# add a rule later also -function ChainRulesCore.frule((_, Δx, Δy), ::typeof(*), x::Number, y::Number) - return (x * y, Δx * y + x * Δy) -end - -# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call -refresh_rules(); - -@testset "ForwardDiffZero" begin - foo(x) = x + x - @test derv(foo, 1.6) == 2 - - bar(x) = x + 2.1 * x - @test derv(bar, 1.2) == 3.1 - - baz(x) = 2.0 * x^2 + 3.0*x + 1.2 - @test derv(baz, 1.7) == 2*2.0*1.7 + 3.0 - - qux(x) = foo(x) + bar(x) + baz(x) - @test derv(qux, 1.7) == (2*2.0*1.7 + 3.0) + 3.1 + 2 - - function quux(x) - y = 2.0*x + 3.0*x - return 4.0*y + 5.0*y - end - @test derv(quux, 11.1) == 4*(2+3) + 5*(2+3) -end - -end # module diff --git a/test/demos/reversediffzero.jl b/test/demos/reversediffzero.jl deleted file mode 100644 index 17f33cf5a..000000000 --- a/test/demos/reversediffzero.jl +++ /dev/null @@ -1,141 +0,0 @@ -"The simplest viable reverse mode a AD, only supports `Float64`" -module ReverseDiffZero -using ChainRulesCore -using Test - -######################################### -# Initial rule setup -@scalar_rule x + y (1, 1) -@scalar_rule x - y (1, -1) -########################## -#Define the AD - -struct Tracked{F} <: Real - propagate::F - primal::Float64 - tape::Vector{Tracked} # a reference to a shared tape - partial::Base.RefValue{Float64} # current accumulated sensitivity -end - -"An intermediate value, a Branch in Nabla terms." -function Tracked(propagate, primal, tape) - v = Tracked(propagate, primal, tape, Ref(zero(primal))) - push!(tape, v) - return v -end - -"Marker for inputs (leaves) that don't need to propagate." -struct NoPropagate end - -"An input, a Leaf in Nabla terms. No inputs of its own to propagate to." -function Tracked(primal, tape) - # don't actually need to put these on the tape, since they don't need to propagate - return Tracked(NoPropagate(), primal, tape, Ref(zero(primal))) -end - -primal(d::Tracked) = d.primal -primal(d) = d - -partial(d::Tracked) = d.partial[] -partial(d) = nothing - -tape(d::Tracked) = d.tape -tape(d) = nothing - -"we have many inputs grab the tape from the first one that is tracked" -get_tape(ds) = something(tape.(ds)...) - -"propagate the currently stored partial back to my inputs." -propagate!(d::Tracked) = d.propagate(d.partial[]) - -"Accumulate the sensitivity, if the value is being tracked." -accum!(d::Tracked, x̄) = d.partial[] += x̄ -accum!(d, x̄) = nothing - -# needed for `^` to work from having `*` defined -Base.to_power_type(x::Tracked) = x - -"What to do when a new rrule is declared" -function define_tracked_overload(sig) - sig = Base.unwrap_unionall(sig) # not really handling most UnionAll - opT, argTs = Iterators.peel(sig.parameters) - opT isa Type{<:Type} && return # not handling constructors - fieldcount(opT) == 0 || return # not handling functors - all(argT isa Type && Float64 <: argT for argT in argTs) || return # only handling purely Float64 ops. - - N = length(sig.parameters) - 1 # skip the op - fdef = quote - # we use the function call overloading form as it lets us avoid namespacing issues - # as we can directly interpolate the function type into to the AST. - function (op::$opT)(tracked_args::Vararg{Union{Tracked, Float64}, $N}; kwargs...) - args = (op, primal.(tracked_args)...) - y, y_pullback = rrule(args...; kwargs...) - the_tape = get_tape(tracked_args) - y_tracked = Tracked(y, the_tape) do ȳ - # pull this partial back and propagate it to the input's partial store - _, ārgs = Iterators.peel(y_pullback(ȳ)) - accum!.(tracked_args, ārgs) - end - return y_tracked - end - end - eval(fdef) -end - -# !Important!: Attach the define function to the `on_new_rule` hook -on_new_rule(define_tracked_overload, rrule) - -"Do a calculus. `f` should have a single output." -function derv(f, args::Vararg; kwargs...) - the_tape = Vector{Tracked}() - tracked_inputs = Tracked.(args, Ref(the_tape)) - tracked_output = f(tracked_inputs...; kwargs...) - @assert tape(tracked_output) === the_tape - - # Now the backward pass - out = primal(tracked_output) - ōut = one(out) - accum!(tracked_output, ōut) - # By going down the tape backwards we know we will have fully accumulated partials - # before propagating them onwards - for op in reverse(the_tape) - propagate!(op) - end - return partial.(tracked_inputs) -end - -# End AD definition -################################ - -# add a rule later also -function ChainRulesCore.rrule(::typeof(*), x::Number, y::Number) - function times_pullback(ΔΩ) - # we will use thunks here to show we handle them fine. - return (NoTangent(), @thunk(ΔΩ * y'), @thunk(x' * ΔΩ)) - end - return x * y, times_pullback -end - -# Manual refresh needed as new rule added in same file as AD after the `on_new_rule` call -refresh_rules(); - -@testset "ReversedDiffZero" begin - foo(x) = x + x - @test derv(foo, 1.6) == (2.0,) - - bar(x) = x + 2.1 * x - @test derv(bar, 1.2) == (3.1,) - - baz(x) = 2.0 * x^2 + 3.0*x + 1.2 - @test derv(baz, 1.7) == (2 * 2.0 * 1.7 + 3.0,) - - qux(x) = foo(x) + bar(x) + baz(x) - @test derv(qux, 1.7) == ((2 * 2.0 * 1.7 + 3.0) + 3.1 + 2,) - - function quux(x) - y = 2.0*x + 3.0*x - return 4.0*y + 5.0*y - end - @test derv(quux, 11.1) == (4*(2+3) + 5*(2+3),) -end -end # module diff --git a/test/ruleset_loading.jl b/test/ruleset_loading.jl deleted file mode 100644 index e1743f0f9..000000000 --- a/test/ruleset_loading.jl +++ /dev/null @@ -1,72 +0,0 @@ -@testset "ruleset_loading.jl" begin - @testset "on_new_rule" begin - frule_history = [] - rrule_history = [] - on_new_rule(frule) do sig - op = sig.parameters[1] - push!(frule_history, op) - end - on_new_rule(rrule) do sig - op = sig.parameters[1] - push!(rrule_history, op) - end - - @testset "new rules hit the hooks" begin - # Now define some rules - @scalar_rule x + y (1, 1) - @scalar_rule x - y (1, -1) - refresh_rules() - - @test Set(frule_history[end-1:end]) == Set((typeof(+), typeof(-))) - @test Set(rrule_history[end-1:end]) == Set((typeof(+), typeof(-))) - end - - @testset "# Make sure nothing happens anymore once we clear the hooks" begin - ChainRulesCore.clear_new_rule_hooks!(frule) - ChainRulesCore.clear_new_rule_hooks!(rrule) - - old_frule_history = copy(frule_history) - old_rrule_history = copy(rrule_history) - - @scalar_rule sin(x) cos(x) - refresh_rules() - - @test old_rrule_history == rrule_history - @test old_frule_history == frule_history - end - - end - - @testset "_primal_sig" begin - _primal_sig = ChainRulesCore._primal_sig - @testset "frule" begin - @test isequal( # DataType without shared type but with constraint - _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), Int, Vector{Int}}), - Tuple{typeof(*), Int, Vector{Int}} - ) - @test isequal( # UnionAall without shared type but with constraint - _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Int} where T<:Real), - Tuple{typeof(*), T, Int} where T<:Real - ) - @test isequal( # UnionAall with share type - _primal_sig(frule, Tuple{typeof(frule), Any, typeof(*), T, Vector{T}} where T), - Tuple{typeof(*), T, Vector{T}} where T - ) - end - - @testset "rrule" begin - @test isequal( # DataType without shared type but with constraint - _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), Int, Vector{Int}}), - Tuple{typeof(*), Int, Vector{Int}} - ) - @test isequal( # UnionAall without shared type but with constraint - _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Int} where T<:Real), - Tuple{typeof(*), T, Int} where T<:Real - ) - @test isequal( # UnionAall with share type - _primal_sig(rrule, Tuple{typeof(rrule), typeof(*), T, Vector{T}} where T), - Tuple{typeof(*), T, Vector{T}} where T - ) - end - end -end diff --git a/test/runtests.jl b/test/runtests.jl index 080258c46..612e57f9f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,12 +16,6 @@ using Test include("accumulation.jl") - include("ruleset_loading.jl") include("rules.jl") include("rule_definition_tools.jl") - - @testset "demos" begin - include("demos/forwarddiffzero.jl") - include("demos/reversediffzero.jl") - end end From d46f0f785d196504d965198d70902a8e520229a3 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 26 May 2021 20:36:08 +0100 Subject: [PATCH 2/3] fix docs, and reverse dep test overload generation package --- .github/workflows/IntegrationTest.yml | 1 + docs/src/api.md | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/.github/workflows/IntegrationTest.yml b/.github/workflows/IntegrationTest.yml index 8cf6bb1b7..d3c46b81c 100644 --- a/.github/workflows/IntegrationTest.yml +++ b/.github/workflows/IntegrationTest.yml @@ -16,6 +16,7 @@ jobs: os: [ubuntu-latest] package: - {user: JuliaDiff, repo: ChainRules.jl} + - {user: JuliaDiff, repo: ChainRulesOverloadGeneration.jl} - {user: JuliaMath, repo: SpecialFunctions.jl} - {user: invenia, repo: BlockDiagonals.jl} - {user: invenia, repo: PDMatsExtras.jl} diff --git a/docs/src/api.md b/docs/src/api.md index d6db4941b..1fffafd8f 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -34,16 +34,8 @@ add!! ChainRulesCore.is_inplaceable_destination ``` -## Ruleset Loading -```@autodocs -Modules = [ChainRulesCore] -Pages = ["ruleset_loading.jl"] -Private = false -``` - ## Internal ```@docs ChainRulesCore.AbstractTangent ChainRulesCore.debug_mode -ChainRulesCore.clear_new_rule_hooks! ``` From b70ead73eb657de3a2a78d6c4c1c1dc901d86a37 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 26 May 2021 22:41:04 +0100 Subject: [PATCH 3/3] Update docs/src/use_in_ad_system.md Co-authored-by: Nick Robinson --- docs/src/use_in_ad_system.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/use_in_ad_system.md b/docs/src/use_in_ad_system.md index d0c4d0228..7ad24ea32 100644 --- a/docs/src/use_in_ad_system.md +++ b/docs/src/use_in_ad_system.md @@ -8,7 +8,7 @@ to avoid having to code all your own AD primitives / custom sensitives. There are 3 main ways to access ChainRules rule sets in your AutoDiff system. 1. [Operator Overloading Generation](https://juliadiff.org/ChainRulesOverloadGeneration.jl/stable) - - using [ChainRulesOverloadGenerration.jl](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/) + - Use [ChainRulesOverloadGeneration.jl](https://github.com/JuliaDiff/ChainRulesOverloadGeneration.jl/). - This is primarily intended for operator overloading based AD systems which will generate overloads for primal functions based for their overloaded types based on the existence of an `rrule`/`frule`. - A source code generation based AD can also use this by overloading their transform generating function directly so as not to recursively generate a transform but to just return the rule. - This does not play nice with Revise.jl, adding or modifying rules in loaded files will not be reflected until a manual refresh, and deleting rules will not be reflected at all.