diff --git a/Project.toml b/Project.toml index a05a1d36..dd9b0f10 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,7 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d" @@ -24,6 +25,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e" ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" +ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd" MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5" NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" @@ -82,9 +84,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" +Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["DifferentialEquations", "Flux", "Graphs", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"] +test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"] diff --git a/src/IncrementalInference.jl b/src/IncrementalInference.jl index 849e74a6..9673ceec 100644 --- a/src/IncrementalInference.jl +++ b/src/IncrementalInference.jl @@ -15,6 +15,8 @@ using Reexport using Manifolds using RecursiveArrayTools: ArrayPartition export ArrayPartition +using ManifoldDiff +using FiniteDifferences using OrderedCollections: OrderedDict diff --git a/src/ManifoldsExtentions.jl b/src/ManifoldsExtentions.jl index 10bf8fd6..a63216c3 100644 --- a/src/ManifoldsExtentions.jl +++ b/src/ManifoldsExtentions.jl @@ -1,3 +1,50 @@ + +## ================================================================================================ +## Manifold and ManifoldDiff use with Optim +## ================================================================================================ + +# Modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09 +""" + ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold + +Adapts Manifolds.jl manifolds for use in Optim.jl +""" +struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold + M::TM +end + +function Optim.retract!(M::ManifoldWrapper, x) + ManifoldsBase.embed_project!(M.M, x, x) + return x +end + +function Optim.project_tangent!(M::ManifoldWrapper, g, x) + ManifoldsBase.embed_project!(M.M, g, x, g) + return g +end + +# experimental +function optimizeManifold_FD( + M::AbstractManifold, + cost::Function, + x0::AbstractArray; + algorithm = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)) +) + # finitediff setup + r_backend = ManifoldDiff.TangentDiffBackend( + ManifoldDiff.FiniteDifferencesBackend() + ) + + ## finitediff gradient (non-manual) + function costgrad_FD!(X,p) + X .= ManifoldDiff.gradient(M, cost, p, r_backend) + X + end + + Optim.optimize(cost, costgrad_FD!, x0, algorithm) +end + + ## ================================================================================================ ## AbstractPowerManifold with N as field to avoid excessive compiling time. ## ================================================================================================ diff --git a/test/manifolds/manifolddiff.jl b/test/manifolds/manifolddiff.jl new file mode 100644 index 00000000..d5b50b68 --- /dev/null +++ b/test/manifolds/manifolddiff.jl @@ -0,0 +1,237 @@ + +# using Revise +using Test +using LinearAlgebra +using IncrementalInference +using ManifoldsBase +using Manifolds, Manopt +import Optim +using FiniteDifferences, ManifoldDiff +import Rotations as _Rot + +## + +# finitediff setup +r_backend = ManifoldDiff.TangentDiffBackend( + ManifoldDiff.FiniteDifferencesBackend() +) + +## +@testset "ManifoldDiff, Basic test" begin +## + +# problem setup +n = 100 +σ = π / 8 +M = Manifolds.Sphere(2) +p = 1 / sqrt(2) * [1.0, 0.0, 1.0] +data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n]; + +# objective function +f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2) +# f_(p) = f(M,p) + +# non-manual: intrinsic finite differences gradient +function grad_f_FD(M,p) + f_(p_) = f(M,p_) + ManifoldDiff.gradient(M, f_, p, r_backend) +end +# manual gradient +# grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p))); + + +# and solve +@time m1 = gradient_descent(M, f, grad_f_FD, data[1]) + +@info "Basic Manopt test" string(m1') +@test isapprox(p, m1; atol=0.15) + +## +end + +## + +""" + ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold + +Adapts Manifolds.jl manifolds for use in Optim.jl +""" +struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold + M::TM +end + +function Optim.retract!(M::ManifoldWrapper, x) + ManifoldsBase.embed_project!(M.M, x, x) + return x +end + +function Optim.project_tangent!(M::ManifoldWrapper, g, x) + ManifoldsBase.embed_project!(M.M, g, x, g) + return g +end + +## +@testset "Optim.jl ManifoldWrapper example from mateuszbaran (copied to catch issues on future changes)" begin +## +# Example modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09 + +# example usage of Manifolds.jl manifolds in Optim.jl +M = Manifolds.Sphere(2) +x0 = [1.0, 0.0, 0.0] +q = [0.0, 1.0, 0.0] + +f(p) = 0.5 * distance(M, p, q)^2 + +# manual gradient +function g!(X, p) + log!(M, X, p, q) + X .*= -1 + println(p, X) +end + +## + +sol = Optim.optimize(f, g!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))) +@test isapprox([0,1,0.], sol.minimizer; atol=1e-8) + + +## finitediff gradient (non-manual) + +function g_FD!(X,p) + X .= ManifoldDiff.gradient(M, f, p, r_backend) + X +end + +# +x0 = [1.0, 0.0, 0.0] + +sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))) +@test isapprox([0,1,0.], sol.minimizer; atol=1e-8) + +## + +# x0 = [1.0, 0.0, 0.0] +# # internal ForwardDfif doesnt work out the box on Manifolds +# sol = Optim.optimize(f, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)); autodiff=:forward ) +# @test isapprox([0,1,0.], sol.minimizer; atol=1e-8) + +## +end + + +@testset "Modified Manifolds.jl ManifoldWrapper <: Optim.Manifold for SpecialEuclidean(2)" begin +## + +M = Manifolds.SpecialEuclidean(2) +e0 = ArrayPartition([0,0.], [1 0; 0 1.]) + +x0 = deepcopy(e0) +Cq = 9*ones(3) +while 1.5 < abs(Cq[3]) + @show Cq .= randn(3) + # Cq[3] = 1.5 # breaks ConjugateGradient +end +q = exp(M,e0,hat(M,e0,Cq)) + +f(p) = distance(M, p, q)^2 + +## finitediff gradient (non-manual) +function g_FD!(X,p) + X .= ManifoldDiff.gradient(M, f, p, r_backend) + X +end + +## sanity check gradients + +X = hat(M, e0, zeros(3)) +g_FD!(X, q) +# gradient at the optimal point should be zero +@show X_ = [X.x[1][:]; X.x[2][:]] +@test isapprox(0, sum(abs.(X_)); atol=1e-8 ) + +# gradient not the optimal point should be non-zero +g_FD!(X, e0) +@show X_ = [X.x[1][:]; X.x[2][:]] +@test 0.01 < sum(abs.(X_)) + +## do optimization +x0 = deepcopy(e0) +sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))) +Cq .= randn(3) +# Cq[ +@show sol.minimizer +@test isapprox( f(sol.minimizer), 0; atol=1e-3 ) +@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5) + +## +end + + +@testset "Modified ManifoldsWrapper for Optim.Manifolds, SpecialEuclidean(3)" begin +## + + +M = Manifolds.SpecialEuclidean(3) +e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.))) + +x0 = deepcopy(e0) +Cq = 0.5*randn(6) +q = exp(M,e0,hat(M,e0,Cq)) + +f(p) = distance(M, p, q)^2 + +## finitediff gradient (non-manual) +function g_FD!(X,p) + X .= ManifoldDiff.gradient(M, f, p, r_backend) + X +end + +## sanity check gradients + +X = hat(M, e0, zeros(6)) +g_FD!(X, q) + +@show X_ = [X.x[1][:]; X.x[2][:]] +# gradient at the optimal point should be zero +@test isapprox(0, sum(abs.(X_)); atol=1e-8 ) + +# gradient not the optimal point should be non-zero +g_FD!(X, e0) +@show X_ = [X.x[1][:]; X.x[2][:]] +@test 0.01 < sum(abs.(X_)) + +## do optimization +x0 = deepcopy(e0) +sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))) +# Cq .= 0.5*randn(6) +# Cq[ +@show sol.minimizer +@test isapprox( f(sol.minimizer), 0; atol=1e-3 ) +@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-3) + + +## +end + + +@testset "Optim.Manifolds, SpecialEuclidean(3), using IIF.optimizeManifold_FD" begin +## + +M = Manifolds.SpecialEuclidean(3) +e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.))) + +x0 = deepcopy(e0) +Cq = 0.5*randn(6) +q = exp(M,e0,hat(M,e0,Cq)) + +f(p) = distance(M, p, q)^2 + +sol = IncrementalInference.optimizeManifold_FD(M,f,x0) + +@show sol.minimizer +@test isapprox( f(sol.minimizer), 0; atol=1e-3 ) +@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5) + + +## +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 07c1b124..cfaeb6ee 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ end if TEST_GROUP in ["all", "basic_functional_group"] # more frequent stochasic failures from numerics +include("manifolds/manifolddiff.jl") include("testSpecialEuclidean2Mani.jl") include("testEuclidDistance.jl")