Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

basic manifold diff sandbox test #1697

Merged
merged 14 commits into from
Mar 27, 2023
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ DistributedFactorGraphs = "b5cc3c7e-6572-11e9-2517-99fb8daf2f04"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FunctionalStateMachine = "3e9e306e-7e3c-11e9-12d2-8f8f67a2f951"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
KernelDensityEstimate = "2472808a-b354-52ea-a80e-1658a3c6056d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Manifolds = "1cead3c2-87b3-11e9-0ccd-23c62b72b94e"
ManifoldsBase = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
ManifoldDiff = "af67fdf4-a580-4b9f-bbec-742ef357defd"
MetaGraphs = "626554b9-1ddb-594c-aa3c-2596fe9399a5"
NLSolversBase = "d41bc354-129a-5804-8e4c-c37616107c6c"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand Down Expand Up @@ -82,9 +84,10 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Manopt = "0fc0a36d-df90-57f3-8f93-d78a9fc72bb5"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["DifferentialEquations", "Flux", "Graphs", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
test = ["DifferentialEquations", "Flux", "Graphs", "Manopt", "InteractiveUtils", "Interpolations", "Pkg", "Rotations", "Test"]
2 changes: 2 additions & 0 deletions src/IncrementalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Reexport
using Manifolds
using RecursiveArrayTools: ArrayPartition
export ArrayPartition
using ManifoldDiff
using FiniteDifferences

using OrderedCollections: OrderedDict

Expand Down
47 changes: 47 additions & 0 deletions src/ManifoldsExtentions.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,50 @@

## ================================================================================================
## Manifold and ManifoldDiff use with Optim
## ================================================================================================

# Modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09
"""
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold

Adapts Manifolds.jl manifolds for use in Optim.jl
"""
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
M::TM
end

function Optim.retract!(M::ManifoldWrapper, x)
ManifoldsBase.embed_project!(M.M, x, x)
return x
end

function Optim.project_tangent!(M::ManifoldWrapper, g, x)
ManifoldsBase.embed_project!(M.M, g, x, g)
return g
end

# experimental
function optimizeManifold_FD(
M::AbstractManifold,
cost::Function,
x0::AbstractArray;
algorithm = Optim.ConjugateGradient(; manifold=ManifoldWrapper(M))
)
# finitediff setup
r_backend = ManifoldDiff.TangentDiffBackend(
ManifoldDiff.FiniteDifferencesBackend()
)

## finitediff gradient (non-manual)
function costgrad_FD!(X,p)
X .= ManifoldDiff.gradient(M, cost, p, r_backend)
X
end

Optim.optimize(cost, costgrad_FD!, x0, algorithm)
end


## ================================================================================================
## AbstractPowerManifold with N as field to avoid excessive compiling time.
## ================================================================================================
Expand Down
237 changes: 237 additions & 0 deletions test/manifolds/manifolddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@

# using Revise
using Test
using LinearAlgebra
using IncrementalInference
using ManifoldsBase
using Manifolds, Manopt
import Optim
using FiniteDifferences, ManifoldDiff
import Rotations as _Rot

##

# finitediff setup
r_backend = ManifoldDiff.TangentDiffBackend(
ManifoldDiff.FiniteDifferencesBackend()
)

##
@testset "ManifoldDiff, Basic test" begin
##

# problem setup
n = 100
σ = π / 8
M = Manifolds.Sphere(2)
p = 1 / sqrt(2) * [1.0, 0.0, 1.0]
data = [exp(M, p, σ * rand(M; vector_at=p)) for i in 1:n];

# objective function
f(M, p) = sum(1 / (2 * n) * distance.(Ref(M), Ref(p), data) .^ 2)
# f_(p) = f(M,p)

# non-manual: intrinsic finite differences gradient
function grad_f_FD(M,p)
f_(p_) = f(M,p_)
ManifoldDiff.gradient(M, f_, p, r_backend)
end
# manual gradient
# grad_f(M, p) = sum(1 / n * grad_distance.(Ref(M), data, Ref(p)));


# and solve
@time m1 = gradient_descent(M, f, grad_f_FD, data[1])

@info "Basic Manopt test" string(m1')
@test isapprox(p, m1; atol=0.15)

##
end

##

"""
ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold

Adapts Manifolds.jl manifolds for use in Optim.jl
"""
struct ManifoldWrapper{TM<:AbstractManifold} <: Optim.Manifold
M::TM
end

function Optim.retract!(M::ManifoldWrapper, x)
ManifoldsBase.embed_project!(M.M, x, x)
return x
end

function Optim.project_tangent!(M::ManifoldWrapper, g, x)
ManifoldsBase.embed_project!(M.M, g, x, g)
return g
end

##
@testset "Optim.jl ManifoldWrapper example from mateuszbaran (copied to catch issues on future changes)" begin
##
# Example modified from: https://gist.github.com/mateuszbaran/0354c0edfb9cdf25e084a2b915816a09

# example usage of Manifolds.jl manifolds in Optim.jl
M = Manifolds.Sphere(2)
x0 = [1.0, 0.0, 0.0]
q = [0.0, 1.0, 0.0]

f(p) = 0.5 * distance(M, p, q)^2

# manual gradient
function g!(X, p)
log!(M, X, p, q)
X .*= -1
println(p, X)
end

##

sol = Optim.optimize(f, g!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)


## finitediff gradient (non-manual)

function g_FD!(X,p)
X .= ManifoldDiff.gradient(M, f, p, r_backend)
X
end

#
x0 = [1.0, 0.0, 0.0]

sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
@test isapprox([0,1,0.], sol.minimizer; atol=1e-8)

##

# x0 = [1.0, 0.0, 0.0]
# # internal ForwardDfif doesnt work out the box on Manifolds
# sol = Optim.optimize(f, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)); autodiff=:forward )
# @test isapprox([0,1,0.], sol.minimizer; atol=1e-8)

##
end


@testset "Modified Manifolds.jl ManifoldWrapper <: Optim.Manifold for SpecialEuclidean(2)" begin
##

M = Manifolds.SpecialEuclidean(2)
e0 = ArrayPartition([0,0.], [1 0; 0 1.])

x0 = deepcopy(e0)
Cq = 9*ones(3)
while 1.5 < abs(Cq[3])
@show Cq .= randn(3)
# Cq[3] = 1.5 # breaks ConjugateGradient
end
q = exp(M,e0,hat(M,e0,Cq))

f(p) = distance(M, p, q)^2

## finitediff gradient (non-manual)
function g_FD!(X,p)
X .= ManifoldDiff.gradient(M, f, p, r_backend)
X
end

## sanity check gradients

X = hat(M, e0, zeros(3))
g_FD!(X, q)
# gradient at the optimal point should be zero
@show X_ = [X.x[1][:]; X.x[2][:]]
@test isapprox(0, sum(abs.(X_)); atol=1e-8 )

# gradient not the optimal point should be non-zero
g_FD!(X, e0)
@show X_ = [X.x[1][:]; X.x[2][:]]
@test 0.01 < sum(abs.(X_))

## do optimization
x0 = deepcopy(e0)
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
Cq .= randn(3)
# Cq[
@show sol.minimizer
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)

##
end


@testset "Modified ManifoldsWrapper for Optim.Manifolds, SpecialEuclidean(3)" begin
##


M = Manifolds.SpecialEuclidean(3)
e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.)))

x0 = deepcopy(e0)
Cq = 0.5*randn(6)
q = exp(M,e0,hat(M,e0,Cq))

f(p) = distance(M, p, q)^2

## finitediff gradient (non-manual)
function g_FD!(X,p)
X .= ManifoldDiff.gradient(M, f, p, r_backend)
X
end

## sanity check gradients

X = hat(M, e0, zeros(6))
g_FD!(X, q)

@show X_ = [X.x[1][:]; X.x[2][:]]
# gradient at the optimal point should be zero
@test isapprox(0, sum(abs.(X_)); atol=1e-8 )

# gradient not the optimal point should be non-zero
g_FD!(X, e0)
@show X_ = [X.x[1][:]; X.x[2][:]]
@test 0.01 < sum(abs.(X_))

## do optimization
x0 = deepcopy(e0)
sol = Optim.optimize(f, g_FD!, x0, Optim.ConjugateGradient(; manifold=ManifoldWrapper(M)))
# Cq .= 0.5*randn(6)
# Cq[
@show sol.minimizer
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-3)


##
end


@testset "Optim.Manifolds, SpecialEuclidean(3), using IIF.optimizeManifold_FD" begin
##

M = Manifolds.SpecialEuclidean(3)
e0 = ArrayPartition([0,0,0.], Matrix(_Rot.RotXYZ(0,0,0.)))

x0 = deepcopy(e0)
Cq = 0.5*randn(6)
q = exp(M,e0,hat(M,e0,Cq))

f(p) = distance(M, p, q)^2

sol = IncrementalInference.optimizeManifold_FD(M,f,x0)

@show sol.minimizer
@test isapprox( f(sol.minimizer), 0; atol=1e-3 )
@test isapprox( 0, sum(abs.(log(M, e0, compose(M, inv(M,q), sol.minimizer)))); atol=1e-5)


##
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ end

if TEST_GROUP in ["all", "basic_functional_group"]
# more frequent stochasic failures from numerics
include("manifolds/manifolddiff.jl")
include("testSpecialEuclidean2Mani.jl")
include("testEuclidDistance.jl")

Expand Down