Skip to content

Commit

Permalink
close #446; address type instability in init_options
Browse files Browse the repository at this point in the history
  • Loading branch information
jverzani committed Sep 16, 2024
1 parent a2e58ee commit 6064aaf
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Roots"
uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
version = "2.2.0"
version = "2.2.1"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
16 changes: 16 additions & 0 deletions src/Bracketing/alefeld_potra_shi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ function init_state(::AbstractAlefeldPotraShi, F, x₀, x₁, fx₀, fx₁; c=no
AbstractAlefeldPotraShiState(promote(b, a, d, ee)..., promote(fb, fa, fd, fe)...)
end

# avoid type-stability issue due to dynamic dispatch based on kwargs
function init_options(
M::AbstractAlefeldPotraShi,
state::AbstractUnivariateZeroState{T,S};
kwargs...,
) where {T,S}
d = kwargs
defs = default_tolerances(M, T, S)
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
strict = get(d, :strict, defs[6])
Roots.FExactOptions(δₐ, δᵣ, maxiters, strict)
end


# fn calls w/in calculateΔ
# 1 is default, but this should be adjusted for different methods
fncalls_per_step(::AbstractAlefeldPotraShi) = 1
Expand Down
4 changes: 2 additions & 2 deletions src/Bracketing/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ end
function default_tolerances(::AbstractBracketingMethod, ::Type{T}, ::Type{S}) where {T,S}
xatol = eps(real(T))^3 * oneunit(real(T))
xrtol = eps(real(T)) # unitless
atol = 0 * oneunit(real(S))
rtol = 0 * one(real(S))
atol = zero(oneunit(real(S)))
rtol = zero(one(real(S)))
maxevals = 60
strict = true
(xatol, xrtol, atol, rtol, maxevals, strict)
Expand Down
19 changes: 19 additions & 0 deletions src/Bracketing/itp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,25 @@ function init_state(M::ITP, F, x₀, x₁, fx₀, fx₁)
ITPState(promote(b, a)..., promote(fb, fa)..., 0, ϵ2n₁₂, a)
end

function init_options(
M::ITP,
state::AbstractUnivariateZeroState{T,S};
kwargs...,
) where {T,S}

d = kwargs
defs = default_tolerances(M, T, S)
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
strict = get(d, :strict, defs[6])

return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
end


function update_state(M::ITP, F, o::ITPState{T,S,R}, options, l=NullTracks()) where {T,S,R}
a, b = o.xn0, o.xn1
fa, fb = o.fxn0, o.fxn1
Expand Down
33 changes: 27 additions & 6 deletions src/convergence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,45 @@ init_options(
kwargs...,
) where {T,S} = init_options(M, T, S; kwargs...)

# this function is an issue (#446) it is type unstable.
# this is a fall back now, but in #446 more
# specific choices based on M are made.
function init_options(M, T=Float64, S=Float64; kwargs...)
d = kwargs

defs = default_tolerances(M, T, S)
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
M = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
strict = get(d, :strict, defs[6])

iszero(δₐ) && iszero(δᵣ) && iszero(ϵₐ) && iszero(ϵᵣ) && return ExactOptions(M, strict)
iszero(δₐ) && iszero(δᵣ) && return XExactOptions(ϵₐ, ϵᵣ, M, strict)
iszero(ϵₐ) && iszero(ϵᵣ) && return FExactOptions(δₐ, δᵣ, M, strict)
iszero(δₐ) && iszero(δᵣ) && iszero(ϵₐ) && iszero(ϵᵣ) && return ExactOptions(maxiters, strict)
iszero(δₐ) && iszero(δᵣ) && return XExactOptions(ϵₐ, ϵᵣ, maxiters, strict)
iszero(ϵₐ) && iszero(ϵᵣ) && return FExactOptions(δₐ, δᵣ, maxiters, strict)

return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, M, strict)
return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
end

function init_options(
M::AbstractNonBracketingMethod,
state::AbstractUnivariateZeroState{T,S};
kwargs...,
) where {T,S}

d = kwargs
defs = default_tolerances(M, T, S)
δₐ = get(d, :xatol, get(d, :xabstol, defs[1]))
δᵣ = get(d, :xrtol, get(d, :xreltol, defs[2]))
ϵₐ = get(d, :atol, get(d, :abstol, defs[3]))
ϵᵣ = get(d, :rtol, get(d, :reltol, defs[4]))
maxiters = get(d, :maxiters, get(d, :maxevals, get(d, :maxsteps, defs[5])))
strict = get(d, :strict, defs[6])

return UnivariateZeroOptions(δₐ, δᵣ, ϵₐ, ϵᵣ, maxiters, strict)
end


## --------------------------------------------------

"""
Expand Down
3 changes: 3 additions & 0 deletions test/test_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import BenchmarkTools
@testset "solve: zero allocations" begin
fs = (sin, cos, x -> -sin(x))
x0 = (3, 4)
x0′ = big.(x0)
Ms = (
Order0(),
Order1(),
Expand All @@ -23,9 +24,11 @@ import BenchmarkTools
Ns = (Roots.Newton(), Roots.Halley(), Roots.Schroder())
for M in Ms
@test BenchmarkTools.@ballocated(solve(ZeroProblem($fs, $x0), $M)) == 0
@inferred solve(ZeroProblem(fs, x0′), M)
end
for M in Ns
@test BenchmarkTools.@ballocated(solve(ZeroProblem($fs, $x0), $M)) == 0
@inferred solve(ZeroProblem(fs, x0′), M)
end

# Allocations in Lith
Expand Down

0 comments on commit 6064aaf

Please sign in to comment.