Skip to content

Commit

Permalink
Merge pull request #297 from JuliaRobotics/24Q3/enh/liegauprod2
Browse files Browse the repository at this point in the history
Gaussian fusion with 2nd transport (Lie)
  • Loading branch information
dehann authored Jul 29, 2024
2 parents 3908b76 + 2155af3 commit cb26bea
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 49 deletions.
119 changes: 74 additions & 45 deletions src/CommonUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,33 @@ end
# return var(mkd.manifold, getPoints(mkd); kwargs...)
# end

_makevec(w::AbstractVector) = w
_makevec(w::Tuple) = [w...]


function calcProductGaussians_flat(
M::AbstractManifold,
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}}, # point type commonly known as P (actually on-manifold)
Σ_::Union{<:AbstractVector{S},<:NTuple{N,S}};
μ0 = mean(M, _makevec(μ_)), # Tangent space reference around the evenly weighted mean of incoming points
Λ_ = inv.(Σ_),
weight::Real = 1.0,
do_transport_correction::Bool = true
) where {N,P<:AbstractArray,S<:AbstractMatrix{<:Real}}
# calc sum of covariances
Λ = +(Λ_...)

# calc the covariance weighted delta means of incoming points and covariances
ΛΔμc = mapreduce(+, zip(Λ_, μ_)) do (s,u)
Δuvee = vee(M, μ0, log(M, μ0, u))
s*Δuvee
end

# calculate the delta mean
Δμc = Λ \ ΛΔμc

return Δμc, inv(Λ)
end

"""
$SIGNATURES
Expand All @@ -126,35 +151,40 @@ DevNotes:
function calcProductGaussians(
M::AbstractManifold,
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}}, # point type commonly known as P (actually on-manifold)
Σ_::Union{Nothing,<:AbstractVector{S},<:NTuple{N,S}};
dim::Integer=manifold_dimension(M),
Λ_ = inv.(Σ_), # TODO these probably need to be transported to common tangent space `u0` -- FYI @Affie 24Q2
weight::Real = 1.0
) where {N,P,S<:AbstractMatrix{<:Real}}
#
# calc sum of covariances
Λ = zeros(MMatrix{dim,dim})
# FIXME first transport (push forward) covariances to common coordinates
Σ_::Union{<:AbstractVector{S},<:NTuple{N,S}};
μ0 = mean(M, _makevec(μ_)), # Tangent space reference around the evenly weighted mean of incoming points
Λ_ = inv.(Σ_),
weight::Real = 1.0,
do_transport_correction::Bool = true
) where {N,P<:AbstractArray,S<:AbstractMatrix{<:Real}}
# step 1, basic/naive Gaussian product (ignoring disjointed covariance coordinates)
Δμn, Σn = calcProductGaussians_flat(M, μ_, Σ_; μ0, Λ_, weight)
Δμ = exp(M, μ0, hat(M, μ0, Δμn))

# for development and testing cases return without doing transport
do_transport_correction ? nothing : (return Δμ, Σn)

# first transport (push forward) covariances to common coordinates
# see [Ge, van Goor, Mahony, 2024]
Λ .= sum(Λ_)
iΔμ = inv(M, Δμ)
μi_ = map(u->Manifolds.compose(M,iΔμ,u), μ_)
μi_̂ = map(u->log(M,μ0,u), μi_)
# μi = map(u->vee(M,μ0,u), μi_̂ )
Ji = ApproxManifoldProducts.parallel_transport_curvature_2nd_lie.(Ref(M), μi_̂ )
iJi = inv.(Ji)
Σi_hat = map((J,S)->J*S*(J'), iJi, Σ_)

# Reset step to absorb extended μ+ coordinates into kernel on-manifold μ
# consider using Δμ in place of μ0
Δμplusc, Σdiam = ApproxManifoldProducts.calcProductGaussians_flat(M, μi_, Σi_hat; μ0, weight)
Δμplus_̂ = hat(M, μ0, Δμplusc)
Δμplus = exp(M, μ0, Δμplus_̂ )
μ_plus = Manifolds.compose(M,Δμ,Δμplus)
= ApproxManifoldProducts.parallel_transport_curvature_2nd_lie(M, Δμplus_̂ )
Σ_plus =*Σdiam*(Jμ')

# naive mean
# Tangent space reference around the evenly weighted mean of incoming points
u0 = mean(M, μ_)

# calc the covariance weighted delta means of incoming points and covariances
ΛΔμ = zeros(MVector{dim})
for (s,u) in zip(Λ_, μ_)
# require vee as per Pennec, Caesar Ref [3.6]
Δuvee = vee(M, u0, log(M, u0, u))
ΛΔμ += s*Δuvee
end
Λ
# calculate the delta mean
Δμ = Λ \ ΛΔμ

# return new mean and covariance
return exp(M, u0, hat(M, u0, Δμ)), inv(Λ)
return μ_plus, Σ_plus
end

# additional support case where covariances are passed as diagonal-only vectors
Expand All @@ -165,17 +195,11 @@ calcProductGaussians(
Σ_::Union{<:AbstractVector{S},<:NTuple{N,S}};
dim::Integer=manifold_dimension(M),
Λ_ = map(s->diagm( 1.0 ./ s), Σ_),
weight::Real = 1.0
) where {N,P,S<:AbstractVector} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ )
weight::Real = 1.0,
do_transport_correction::Bool = true
) where {N,P,S<:AbstractVector} = calcProductGaussians(M, μ_, nothing; dim, Λ_, do_transport_correction )
#

calcProductGaussians(
M::AbstractManifold,
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}};
dim::Integer=manifold_dimension(M),
Λ_ = diagm.( (1.0 ./ μ_) ),
weight::Real = 1.0,
) where {N,P} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ )


"""
Expand All @@ -191,17 +215,22 @@ DevNotes
"""
function calcProductGaussians(
M::AbstractManifold,
comps::AbstractVector{<:MvNormalKernel};
weight::Real = 1.0
)
kernels::Union{<:AbstractVector{K},NTuple{N,K}};
μ0 = nothing,
weight::Real = 1.0,
do_transport_correction::Bool = true
) where {N,K <: MvNormalKernel}
# CHECK this should be on-manifold for points
μ_ = mean.(comps) # This is a ArrayPartition which IS DEFINITELY ON MANIFOLD (we dispatch on mean)
Σ_ = cov.(comps) # on tangent

# FIXME is parallel transport needed here for covariances from different tangent spaces?

_μ, _Σ = calcProductGaussians(M, μ_, Σ_)

μ_ = mean.(kernels) # This is a ArrayPartition which IS DEFINITELY ON MANIFOLD (we dispatch on mean)
Σ_ = cov.(kernels) # on tangent

# parallel transport needed for covariances from different tangent spaces
_μ, _Σ = if isnothing(μ0)
calcProductGaussians(M, μ_, Σ_; do_transport_correction)
else
calcProductGaussians(M, μ_, Σ_; μ0, do_transport_correction)
end

return MvNormalKernel(_μ, _Σ, weight)
end

Expand Down
13 changes: 13 additions & 0 deletions src/services/ManifoldsOverloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@ const _UPSTREAM_MANIFOLDS_ADJOINT_ACTION = false

# local union definition during development -- TODO consolidate upstream
LieGroupManifoldsPirate = Union{
typeof(TranslationGroup(1)),
typeof(TranslationGroup(2)),
typeof(TranslationGroup(3)),
typeof(TranslationGroup(4)),
typeof(TranslationGroup(5)),
typeof(TranslationGroup(6)),
typeof(SpecialOrthogonal(2)),
typeof(SpecialOrthogonal(3)),
typeof(SpecialEuclidean(2)),
Expand Down Expand Up @@ -170,6 +176,13 @@ function ad_lie(
)
end

# basic fallback
# X is tangent vector (Lie algebra element)
ad(
M::LieGroupManifoldsPirate,
X
) = ad_lie(M,X)


Ad(
M::Union{typeof(SpecialOrthogonal(2)), typeof(SpecialOrthogonal(3))},
Expand Down
96 changes: 92 additions & 4 deletions test/testLieFundamentals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
using Test
using Manifolds
using ApproxManifoldProducts
using StaticArrays
using LinearAlgebra
using Distributions


##
Expand All @@ -27,10 +29,10 @@ d = hat(M, p, d̂) # direction in algebra


ApproxManifoldProducts.ad_lie(M, X)
# @test isapprox(
# ApproxManifoldProducts.ad_lie(M, X),
# ApproxManifoldProducts.ad(M, X)
# )
@test isapprox(
ApproxManifoldProducts.ad_lie(M, X),
ApproxManifoldProducts.ad(M, X)
)

# ptcMat = ApproxManifoldProducts.parallel_transport_curvature_2nd_lie(M, d)

Expand Down Expand Up @@ -234,4 +236,90 @@ ptcMat = ApproxManifoldProducts.parallel_transport_curvature_2nd_lie(M, d)
##
end


@testset "(Lie Group) on-manifold Gaussian product SO(3), [Ge, van Goor, Mahony, 2024]" begin
##

γ = 1
ξ = 1

M = SpecialOrthogonal(3)
ε = Identity(M)

X1c = γ/sqrt(3) .* SA[1; 1; -1.0]
X1 = hat(M, ε, X1c) # random algebra element
w_R1 = exp(M, ε, X1)
Σ1 = ξ .* SA[1 0 0; 0 0.75 0; 0 0 0.5]

X2c = γ/sqrt(2) .* SA[1; -1; 0.0]
X2 = hat(M, ε, X2c)
w_R2 = exp(M, ε, X2)
Σ2 = ξ .* SA[0.5 0 0; 0 1 0; 0 0 0.75]

#

p1 = ApproxManifoldProducts.MvNormalKernel(;
μ = w_R1,
p = MvNormal(SA[0; 0; 0.0], Σ1)
)

p2 = ApproxManifoldProducts.MvNormalKernel(;
μ = w_R2,
p = MvNormal(SA[0; 0; 0.0], Σ2)
)


# Naive product (standard linear product of Gaussians) -- reference implementation around group identity
Xcs = (X1c,X2c)
_Σs = map(s->inv(cov(s)), (p1,p2))
_Σn = +(_Σs...)
_Σnμn = mapreduce(+, zip(_Σs, Xcs)) do (s,c)
s*c
end
μn = _Σn\_Σnμn
Σn = inv(_Σn)

# verify calcProductGaussians utility function
= calcProductGaussians(
M,
(p1,p2);
μ0 = ε,
do_transport_correction = false
)

@test isapprox(
μn,
vee(M, ε, log(M, ε, mean(p̂)))
)

# approx match for even-mean-mean rather than naive-identity-mean
= calcProductGaussians(
M,
(p1,p2);
# μ0 = ε,
do_transport_correction = false
)
@test isapprox(
μn,
vee(M, ε, log(M, ε, mean(p̂)));
atol=1e-1 # NOTE looser bound for even-mean-mean case vs naive-identity-mean case
)

##

= calcProductGaussians(
M,
(p1,p2);
# μ0 = ε,
do_transport_correction = true
)


##
end





#

0 comments on commit cb26bea

Please sign in to comment.