Skip to content

Commit

Permalink
merge multidim multiscale Gibbs sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
dehann committed May 20, 2024
1 parent 357b989 commit 09b7da2
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 12 deletions.
8 changes: 6 additions & 2 deletions src/CommonUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ function calcProductGaussians(
Σ_::Union{Nothing,<:AbstractVector{S},<:NTuple{N,S}};
dim::Integer=manifold_dimension(M),
Λ_ = inv.(Σ_), # TODO these probably need to be transported to common tangent space `u0` -- FYI @Affie 24Q2
weight::Real = 1.0
) where {N,P,S<:AbstractMatrix{<:Real}}
#
# calc sum of covariances
Expand Down Expand Up @@ -159,6 +160,7 @@ calcProductGaussians(
Σ_::Union{<:AbstractVector{S},<:NTuple{N,S}};
dim::Integer=manifold_dimension(M),
Λ_ = map(s->diagm( 1.0 ./ s), Σ_),
weight::Real = 1.0
) where {N,P,S<:AbstractVector} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ )
#

Expand All @@ -167,6 +169,7 @@ calcProductGaussians(
μ_::Union{<:AbstractVector{P},<:NTuple{N,P}};
dim::Integer=manifold_dimension(M),
Λ_ = diagm.( (1.0 ./ μ_) ),
weight::Real = 1.0,
) where {N,P} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ )


Expand All @@ -182,7 +185,8 @@ DevNotes
"""
function calcProductGaussians(
M::AbstractManifold,
comps::AbstractVector{<:MvNormalKernel},
comps::AbstractVector{<:MvNormalKernel};
weight::Real = 1.0
)
# CHECK this should be on-manifold for points
μ_ = mean.(comps) # This is a ArrayPartition which IS DEFINITELY ON MANIFOLD (we dispatch on mean)
Expand All @@ -192,7 +196,7 @@ function calcProductGaussians(

_μ, _Σ = calcProductGaussians(M, μ_, Σ_)

return MvNormalKernel(_μ, _Σ)
return MvNormalKernel(_μ, _Σ, weight)
end


Expand Down
77 changes: 71 additions & 6 deletions src/services/ManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,66 @@ function buildTree_Manellic!(
return tosort_leaves
end

function buildTree_Manellic!(
M::AbstractManifold,
r_ker::AbstractVector{KL}; # vector of points referenced to the r_frame
N = length(r_ker),
weights::AbstractVector{<:Real} = ones(N).*(1/N),
kernel = KL,
kernel_bw = nothing, # TODO
) where {KL <: MvNormalKernel}
#
_μT() = typeof(r_ker[1].μ)
D = manifold_dimension(M)
CV = SMatrix{D,D,Float64,D*D}(collect(cov(r_ker[1])))
KLT = getfield(ApproxManifoldProducts,kernel.name.name)
KT = KLT(
r_ker[1].μ,
CV
) |> typeof

r_PP = SizedVector{N,_μT()}(undef)

# leaf kernels
lkern = SizedVector{N,KL}(undef)
_workaround_isdef_leafkernel = Set{Int}()
for i in 1:N
r_PP[i] = r_ker[i].μ
lkern[i] = if isnothing(kernel_bw)
r_ker[i]
else
updateKernelBW(r_ker[i], kernel_bw) # TODO handle vector of kernel_bws
end
push!(_workaround_isdef_leafkernel, i + N)
end

mtree = ManellicTree(
M,
r_PP,
MVector{N,Float64}(weights),
MVector{N,Int}(1:N),
lkern,
SizedVector{N,KT}(undef),
SizedVector{N,Set{Int}}(undef),
_workaround_isdef_leafkernel,
Set{Int}(),
)

#
tosort_leaves = buildTree_Manellic!(
mtree,
1, # start at root
1, # spanning all data
N; # to end of data
kernel=KLT,
kernel_bw
)

# manual reset leaves in the order discovered
permute!(tosort_leaves.leaf_kernels, tosort_leaves.permute)

return tosort_leaves
end

function updateBandwidths(
mtr::ManellicTree{M,D,N,HL},
Expand Down Expand Up @@ -719,7 +779,8 @@ function calcProductKernelBTLabels(
labels_sampled,
LOOidx::Union{Int, Nothing} = nothing,
gibbsSeq = 1:length(proposals);
permute::Bool = true # true because signature is BTLabels
permute::Bool = true, # true because signature is BTLabels
weight::Real = 1.0
)
# select a density label from the other proposals
prop_and_label = Tuple{Int,Int}[]
Expand All @@ -732,22 +793,24 @@ function calcProductKernelBTLabels(
components = map(pr_lb->getKernelTree(proposals[pr_lb[1]], pr_lb[2], permute, true), prop_and_label)

# TODO upgrade to tuples
return calcProductGaussians(M, [components...])
return calcProductGaussians(M, [components...]; weight)
end


function calcProductKernelsBTLabels(
M::AbstractManifold,
proposals::AbstractVector,
N_lbl_sets::AbstractVector{<:NTuple},
permute::Bool = true # true because signature is BTLabels
permute::Bool = true; # true because signature is BTLabels
weights = 1/length(N_lbl_sets) .* ones(length(N_lbl_sets))
)
#
T = typeof(getKernelTree(proposals[1],1))
post = Vector{T}(undef, length(N_lbl_sets))
N = length(N_lbl_sets)
post = Vector{T}(undef, N)

for (i,lbs) in enumerate(N_lbl_sets)
post[i] = calcProductKernelBTLabels(M, proposals, lbs; permute)
post[i] = calcProductKernelBTLabels(M, proposals, lbs; permute, weight=weights[i])
end

return post
Expand Down Expand Up @@ -852,12 +915,14 @@ function sampleProductSeqGibbsBTLabel(
return labels_sampled
end

Base.length(mkd::ManifoldKernelDensity) = length(mkd.belief)


function sampleProductSeqGibbsBTLabels(
M::AbstractManifold,
proposals::AbstractVector,
MC = 3,
N::Int = round(Int, mean(length.(getPoints.(proposals)))), # FIXME use getLength or length of proposal (not getPoints)
N::Int = round(Int, mean(length.(proposals))), # FIXME use getLength or length of proposal (not getPoints)
label_pools=[[1:1;] for _ in proposals]
)
#
Expand Down
48 changes: 44 additions & 4 deletions test/manellic/testManellicTree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ end
# end


@testset "Product of two Manellic beliefs, Sequential Gibbs" begin
@testset "Product of two Manellic beliefs, Sequential Gibbs, TranslationGroup(1)" begin
##

M = TranslationGroup(1)
Expand Down Expand Up @@ -1034,7 +1034,7 @@ bt_label_pool = [
]

# leaves only version
@info "Leaves only label sampling version (Gibbs)"
@info "Leaves only label sampling version (Gibbs), TranslationGroup(1)"

ApproxManifoldProducts.sampleProductSeqGibbsBTLabel(M, [p1; p2], 3, bt_label_pool)

Expand All @@ -1047,10 +1047,10 @@ mtr = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw, kernel=Appro

@test isapprox( 0, mean(ApproxManifoldProducts.getKernelTree(mtr,1))[1]; atol=0.75)

@test all( (s->isapprox(1/N, s.weight;atol=1e-6)).(post) )



@info "Multi-scale label sampling version (Gibbs)"
@info "Multi-scale label sampling version (Gibbs), TranslationGroup(1)"

# test label pool creation
child_label_pools, all_leaves = ApproxManifoldProducts.generateLabelPoolRecursive([p1;p2], [1; 1])
Expand Down Expand Up @@ -1102,4 +1102,44 @@ end
# lines!((s->s[1]).(XX),YY, color=:red)



@testset "Multi-scale label sampling version (Gibbs), TranslationGroup(2)" begin
##

M = TranslationGroup(2)
N = 64

pts1 = [1*randn(2) for _ in 1:N]
p1 = ApproxManifoldProducts.manikde!_manellic(M,pts1)

pts2 = [1*randn(2) for _ in 1:N]
p2 = ApproxManifoldProducts.manikde!_manellic(M,pts2)


# test sampling
lbls = ApproxManifoldProducts.sampleProductSeqGibbsBTLabels(M, [p1.belief; p2.belief])
lbls_ = unique(lbls)
N_ = length(lbls_)
weights = 1/N .* ones(N_)
# increase weight of duplicates
if N_ < N
for (i,lb_) in enumerate(lbls_)
idxs = findall(==(lb_),lbls)
weights[i] = weights[i]*length(idxs)
end
end
post = ApproxManifoldProducts.calcProductKernelsBTLabels(M, [p1.belief; p2.belief], lbls_, false; weights) # ?? was permute=false?
# check that any duplicates resulted in a height weight
@test isapprox( weights, (s->s.weight).(post); atol=1e-6 )

# NOTE, resulting tree might not have N number of data points
mtr12 = ApproxManifoldProducts.buildTree_Manellic!(M,post)

##
end





#

0 comments on commit 09b7da2

Please sign in to comment.