From a4b17393c658860882ceffbf347d491201020e06 Mon Sep 17 00:00:00 2001 From: longemen3000 Date: Sat, 10 Jul 2021 01:33:16 -0400 Subject: [PATCH] support for more number types --- Project.toml | 2 +- src/main.jl | 122 +++++++++++++++++++++++++++-------------------- test/runtests.jl | 16 ++++++- 3 files changed, 85 insertions(+), 55 deletions(-) diff --git a/Project.toml b/Project.toml index 338ddd1..352c460 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SpeedMapping" uuid = "f1835b91-879b-4a3f-a438-e4baacf14412" authors = ["Nicolas Lepage-Saucier <42039487+nicolasLepageSaucier@users.noreply.github.com> and contributors"] -version = "0.2.0" +version = "0.2.1" [deps] AccurateArithmetic = "22286c92-06ac-501d-9306-4abd417d9753" diff --git a/src/main.jl b/src/main.jl index 226fc45..46da5c2 100644 --- a/src/main.jl +++ b/src/main.jl @@ -2,29 +2,41 @@ using AccurateArithmetic using ForwardDiff using LinearAlgebra + +#this allows to bail out on other numeric types, but +#keep using the accurate arithmetic package when available + +function accurate_dot(x::AbstractArray{T1},y::AbstractArray{T2}) where {T1<:Union{Float32,Float64},T2<:Union{Float32,Float64}} + return dot_oro(x,y) +end + +function accurate_dot(x,y) + return dot(x,y) +end + + _isbad(x) = isnan(x) || isinf(x) _mod1(x, m) = (x - 1) % m + 1 -Base.@kwdef mutable struct State - +Base.@kwdef mutable struct State{T<:Real} check_obj :: Bool autodiff :: Bool - Lp :: Float64 - maps_limit :: Float64 - time_limit :: Float64 - buffer :: Float64 - t0 :: Float64 - tol :: Float64 + Lp :: T + maps_limit :: T + time_limit :: T + buffer :: T + t0 :: T + tol :: T go_on :: Bool = true converged :: Bool = false maps :: Int64 = 0 # Or the # of g! calls if optim == true. k :: Int64 = 0 f_calls :: Int64 = 0 - σ :: Float64 = 0.0 - α :: Float64 = 1.0 - obj_now :: Float64 = Inf - norm_∇ :: Float64 = Inf + σ :: T = zero(T) + α :: T = one(T) + obj_now :: T = T(Inf) + norm_∇ :: T = T(Inf) ix₀ :: Int64 = 1 ix :: Int64 = 1 @@ -32,19 +44,24 @@ Base.@kwdef mutable struct State ix_new :: Int64 = 1 ord_best :: Int64 = 0 i_ord :: Int64 = 0 - α_best :: Float64 = 1.0 - σ_mult_fail :: Float64 = 1.0 - σ_mult_loop :: Float64 = 1.0 - α_mult :: Float64 = 1.0 - norm_best :: Float64 = Inf - obj_best :: Float64 = Inf - α_boost :: Float64 = 1 - - σs :: Array{Float64} = zeros(10) - norm_∇s :: Array{Float64} = zeros(10) + α_best :: T = one(T) + σ_mult_fail :: T = one(T) + σ_mult_loop :: T = one(T) + α_mult :: T = one(T) + norm_best :: T = T(Inf) + obj_best :: T = T(Inf) + α_boost :: T = one(T) + + σs :: Vector{T} = zeros(T,10) + norm_∇s :: Vector{T} = zeros(T,10) σs_i :: Int64 = 1 end + + +function Base.eltype(::State{F}) where F + return F +end ##### ##### Initialization funtions ##### @@ -64,7 +81,8 @@ function check_arguments( (lower ≠ nothing && maximum(lower .- x_in) > 0) throw(DomainError(x_in, "infeasible starting point")) end - if !(eltype(x_in) <: AbstractFloat) + + if !((eltype(x_in) <: Real) | (eltype(x_in) <: Integer)) throw(ArgumentError("starting point must be of type Float")) end if s.autodiff @@ -74,13 +92,13 @@ function check_arguments( end function α_too_large!( - f, g_auto, g!, s :: State, ∇temp :: T, ∇ :: T, x_in :: T, ∇∇ :: Float64 -) :: Bool where {T<:AbstractArray} + f, g_auto, g!, s :: State{F}, ∇temp :: T, ∇ :: T, x_in :: T, ∇∇ :: F +) :: Bool where {F,T<:AbstractArray} if f ≠ nothing obj_new = f(x_in - s.α * ∇) s.f_calls += 1 - return _isbad(obj_new) || (obj_new > s.obj_now - 0.25s.α * ∇∇) + return _isbad(obj_new) || (obj_new > s.obj_now - F(0.25)*s.α * ∇∇) else s.autodiff ? ∇temp .= g_auto(x_in .- s.α * ∇) : g!(∇temp, x_in .- s.α * ∇) s.maps += 1 @@ -90,11 +108,10 @@ function α_too_large!( end function initialize_α!( - f, g_auto, g!, s :: State, ∇ :: T, x_in :: T, temp :: T -) where {T<:AbstractArray} - - ∇∇ = Float64(∇ ⋅ ∇) - if ∇∇ == 0 + f, g_auto, g!, s :: State{F}, ∇ :: T, x_in :: T, temp :: T +) where {F,T<:AbstractArray} + ∇∇ = F(∇ ⋅ ∇) + if iszero(∇∇) throw(DomainError(x_in, "∇f(x_in) = 0 (extremum or saddle point)")) end @@ -120,7 +137,7 @@ end # Makes sure x_try[i] stays within boundaries with a gap buffer * (bound[i] - x_old[i]). function bound!( - extr, x_try :: T, x_old :: T, bound :: T, buffer :: Float64 + extr, x_try :: T, x_old :: T, bound :: T, buffer ) where {T<:AbstractArray} for i ∈ eachindex(x_try) @@ -196,12 +213,12 @@ end function prodΔ(∇1 :: T, ∇2 :: T, temp :: T) where {T<:AbstractArray} # if p == 2 temp .= ∇2 .- ∇1 - return dot_oro(temp, ∇1), dot_oro(temp, temp) + return accurate_dot(temp, ∇1), accurate_dot(temp, temp) end function prodΔ(∇1 :: T, ∇2 :: T, ∇3 :: T, temp :: T) where {T<:AbstractArray} # if p == 3 temp .= ∇3 .- 2∇2 .+ ∇1 - return dot_oro(temp, (∇2 .- ∇1)), dot_oro(temp, temp) + return accurate_dot(temp, (∇2 .- ∇1)), accurate_dot(temp, temp) end function extrapolate!( # Not a real extrapolation, just a stabilizing step @@ -229,15 +246,15 @@ function extrapolate!( return nothing end -function update_α!(s :: State, σ_new :: Float64, ΔᵇΔᵇ :: Float64) +function update_α!(s :: State{F}, σ_new :: F, ΔᵇΔᵇ :: F) where F # Increasing / decreasing α to maintain σ ∈ [1, 2] as much as possible. - s.α *= 1.5^((σ_new > 2) - (σ_new < 1)) + s.α *= F(1.5)^((σ_new > 2) - (σ_new < 1)) # Boosting α if ΔᵇΔᵇ gets really small if ΔᵇΔᵇ < 1e-100 s.α_boost *= 4 # Increasing more aggressively each time - s.α = min(s.α * s.α_boost, 1.0) + s.α = min(s.α * s.α_boost, one(F)) end return nothing end @@ -249,7 +266,7 @@ end # update_progress! is used to know which iterate to fall back on when # backtracking and, optionally, which x was the smallest minimizer in case the # problem has multiple minima. -function update_progress!(f, s :: State, x :: AbstractArray) +function update_progress!(f, s :: State{F}, x :: AbstractArray) where F if s.go_on && s.check_obj s.obj_now = f(x) @@ -262,7 +279,7 @@ function update_progress!(f, s :: State, x :: AbstractArray) s.ix_best = s.ix₀ s.ix_new = _mod1(s.ix₀ + 1, 2) s.ord_best, s.α_best = (s.i_ord, s.α) - s.σ_mult_fail = s.α_mult = 1.0 + s.σ_mult_fail = s.α_mult = one(F) s.check_obj ? s.obj_best = s.obj_now : s.norm_best = s.norm_∇ end return nothing @@ -452,21 +469,21 @@ julia> speedmapping(x₀; f, g!, upper) ``` """ function speedmapping( - x_in :: AbstractArray; f = nothing, g! = nothing, m! = nothing, - orders :: Array{Int64} = [3,3,2], σ_min :: Real = 0.0, stabilize :: Bool = false, - check_obj :: Bool = false, tol :: Float64 = 1e-8, Lp :: Real = 2, + x_in :: AbstractArray{F}; f = nothing, g! = nothing, m! = nothing, + orders :: Array{Int64} = [3,3,2], σ_min :: Real = zero(F), stabilize :: Bool = false, + check_obj :: Bool = false, tol = F(1e-8), Lp :: Real = 2, maps_limit :: Real = 1e6, time_limit :: Real = 1000, lower :: Union{AbstractArray, Nothing} = nothing, - upper :: Union{AbstractArray, Nothing} = nothing, buffer :: Float64 = 0.01, + upper :: Union{AbstractArray, Nothing} = nothing, buffer = F(0.01), store_info :: Bool = false -) +) where F <: Real - s = State(; autodiff = f ≠ nothing && m! === nothing && g! === nothing, tol, + + s = State{F}(; autodiff = f ≠ nothing && m! === nothing && g! === nothing, tol, buffer, Lp, maps_limit, time_limit, t0 = time(), check_obj) - g_auto = s.autodiff ? x -> ForwardDiff.gradient(f, x) : nothing - type_x = eltype(x_in) + type_x = F if lower !== nothing && eltype(lower) ≠ type_x; lower = type_x.(lower) end if upper !== nothing && eltype(upper) ≠ type_x; lower = type_x.(upper) end @@ -476,10 +493,11 @@ function speedmapping( if stabilize; orders = Int.(vec(hcat(ones(length(orders)),orders)')) end # Two x₀s to avoid copying at each improvement (maybe this is excessive optimization?) + #using copy instead of similar because BigFloats have problems x₀ = [copy(x_in), similar(x_in)] - xs = [similar(x₀[1]) for i ∈ 1:maximum(orders)] - ∇s = [similar(x₀[1]) for i ∈ 1:maximum(orders)] # Storing ∇s is equivalent to Δs - temp = similar(x₀[1]) # temp storage + xs = [copy(x₀[1]) for i ∈ 1:maximum(orders)] + ∇s = [copy(x₀[1]) for i ∈ 1:maximum(orders)] # Storing ∇s is equivalent to Δs + temp = copy(x₀[1]) # temp storage if f !== nothing && (m! === nothing || check_obj) s.obj_now = s.obj_best = f(x_in) # Useful for initialize_α and tracking progress @@ -514,7 +532,7 @@ function speedmapping( if !s.converged && s.go_on if p > 1 ΔᵃΔᵇ, ΔᵇΔᵇ = prodΔ(∇s[1:p]..., temp) - σ_new = Float64(abs(ΔᵃΔᵇ) > 1e-100 && ΔᵇΔᵇ > 1e-100 ? abs(ΔᵃΔᵇ) / ΔᵇΔᵇ : 1.0) + σ_new = F(abs(ΔᵃΔᵇ) > 1e-100 && ΔᵇΔᵇ > 1e-100 ? abs(ΔᵃΔᵇ) / ΔᵇΔᵇ : 1.0) s.σ = max(σ_min, σ_new) * s.σ_mult_fail * s.σ_mult_loop end @@ -530,7 +548,7 @@ function speedmapping( s.ix₀ = s.ix_new if p > 1 - if m! === nothing; update_α!(s, σ_new, Float64(ΔᵇΔᵇ)) end + if m! === nothing; update_α!(s, σ_new, F(ΔᵇΔᵇ)) end check_∞_loop!(s, io == length(orders)) end elseif !s.converged && !s.go_on diff --git a/test/runtests.jl b/test/runtests.jl index ed10b98..d60e037 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using SpeedMapping, Test, LinearAlgebra +using SpeedMapping, Test, LinearAlgebra, ForwardDiff function f(x) # Rosenbrock objective f_out = 0.0 @@ -28,6 +28,7 @@ end C = [1 2 3; 4 5 6; 7 8 9] A = C + C' + function m!(x_out, x_in) # map for the power method mul!(x_out, A, x_in) x_out ./= norm(x_out, Inf) @@ -46,6 +47,7 @@ function exception(expr, error) catch e goodexception = isa(e, error) end + return goodexception end @@ -71,9 +73,19 @@ end # Can't provide both g! and m! @test exception(:(speedmapping([0.0, 0.0]; g!, m!)), ArgumentError) # eltype(x_in) is Int - @test exception(:(speedmapping([0, 0]; g!)), ArgumentError) + #using parametric types,it fails before check_arguments + @test_broken exception(:(speedmapping([0, 0]; g!)), ArgumentError) # check_obj without providing f @test exception(:(speedmapping([0.0, 0.0]; g!, check_obj = true)), ArgumentError) # The gradient is zero @test exception(:(speedmapping([1.0, 1.0]; g!)), DomainError) end + + +@testset "other number types" begin + #support for bigfloats + @test speedmapping(zeros(BigFloat,2); g!).minimizer isa Vector{BigFloat} + #support for ForwardDiff.Dual + @test ForwardDiff.jacobian(x -> speedmapping(x; g!).minimizer + ,[0.0,0.0]) isa Matrix{Float64} +end