Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ChainRulesCore a weak dependency #445

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Roots"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "2.1.8"
version = "2.2.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand All @@ -9,12 +9,14 @@ CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c"

[extensions]
RootsChainRulesCoreExt = "ChainRulesCore"
RootsForwardDiffExt = "ForwardDiff"
RootsIntervalRootFindingExt = "IntervalRootFinding"
RootsSymPyExt = "SymPy"
Expand Down
22 changes: 14 additions & 8 deletions src/chain_rules.jl → ext/RootsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module RootsChainRulesCoreExt

using Roots
import ChainRulesCore

# View find_zero as solving `f(x, p) = 0` for `xᵅ(p)`.
# This is implicitly defined. By the implicit function theorem, we have:
# ∇f = 0 => ∂/∂ₓ f(xᵅ, p) ⋅ ∂xᵅ/∂ₚ + ∂/∂ₚf(x\^α, p) ⋅ I = 0
Expand All @@ -15,7 +20,6 @@
# that is fixable.)

# this assumes a function and a parameter `p` passed in
import ChainRulesCore: Tangent, NoTangent, frule, rrule
function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
(_, _, _, Δp),
Expand All @@ -42,17 +46,17 @@
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
xdots,
::typeof(solve),
ZP::Roots.ZeroProblem,
ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod,
::Nothing;
kwargs...,
) = frule(config, xdots, solve, ZP, M; kwargs...)
) = ChainRulesCore.frule(config, xdots, solve, ZP, M; kwargs...)

Check warning on line 53 in ext/RootsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RootsChainRulesCoreExt.jl#L53

Added line #L53 was not covered by tests

function ChainRulesCore.frule(
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
(_, Δq, _),
::typeof(solve),
ZP::Roots.ZeroProblem,
ZP::ZeroProblem,
M::Roots.AbstractUnivariateZeroMethod;
kwargs...,
)
Expand All @@ -61,12 +65,12 @@
zprob2 = ZeroProblem(|>, ZP.x₀)
nms = fieldnames(typeof(foo))
nt = NamedTuple{nms}(getfield(foo, n) for n in nms)
dfoo = Tangent{typeof(foo)}(; nt...)
dfoo = ChainRulesCore.Tangent{typeof(foo)}(; nt...)

Check warning on line 68 in ext/RootsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RootsChainRulesCoreExt.jl#L68

Added line #L68 was not covered by tests

return frule(
return ChainRulesCore.frule(

Check warning on line 70 in ext/RootsChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RootsChainRulesCoreExt.jl#L70

Added line #L70 was not covered by tests
config,
(NoTangent(), NoTangent(), NoTangent(), dfoo),
Roots.solve,
(ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), dfoo),
solve,
zprob2,
M,
foo,
Expand Down Expand Up @@ -146,3 +150,5 @@

return xᵅ, pullback_solve_ZeroProblem
end

end # module
6 changes: 4 additions & 2 deletions src/Roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using Printf
import CommonSolve
import CommonSolve: solve, solve!, init
using Accessors
import ChainRulesCore

export fzero, fzeros, secant_method

Expand Down Expand Up @@ -53,7 +52,6 @@ include("functions.jl")
include("trace.jl")
include("find_zero.jl")
include("hybrid.jl")
include("chain_rules.jl")

include("Bracketing/bracketing.jl")
include("Bracketing/bisection.jl")
Expand Down Expand Up @@ -83,4 +81,8 @@ include("find_zeros.jl")
include("simple.jl")
include("alternative_interfaces.jl")

if !isdefined(Base, :get_extension)
include("../ext/RootsChainRulesCoreExt.jl")
end

end
Loading