From 09b7da2f0ef4829a689ba7120f55bc7e0da2d676 Mon Sep 17 00:00:00 2001 From: dehann Date: Mon, 20 May 2024 00:32:35 -0700 Subject: [PATCH] merge multidim multiscale Gibbs sampling --- src/CommonUtils.jl | 8 +++- src/services/ManellicTree.jl | 77 ++++++++++++++++++++++++++++--- test/manellic/testManellicTree.jl | 48 +++++++++++++++++-- 3 files changed, 121 insertions(+), 12 deletions(-) diff --git a/src/CommonUtils.jl b/src/CommonUtils.jl index a0f7574..e7cca5e 100644 --- a/src/CommonUtils.jl +++ b/src/CommonUtils.jl @@ -127,6 +127,7 @@ function calcProductGaussians( Σ_::Union{Nothing,<:AbstractVector{S},<:NTuple{N,S}}; dim::Integer=manifold_dimension(M), Λ_ = inv.(Σ_), # TODO these probably need to be transported to common tangent space `u0` -- FYI @Affie 24Q2 + weight::Real = 1.0 ) where {N,P,S<:AbstractMatrix{<:Real}} # # calc sum of covariances @@ -159,6 +160,7 @@ calcProductGaussians( Σ_::Union{<:AbstractVector{S},<:NTuple{N,S}}; dim::Integer=manifold_dimension(M), Λ_ = map(s->diagm( 1.0 ./ s), Σ_), + weight::Real = 1.0 ) where {N,P,S<:AbstractVector} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ ) # @@ -167,6 +169,7 @@ calcProductGaussians( μ_::Union{<:AbstractVector{P},<:NTuple{N,P}}; dim::Integer=manifold_dimension(M), Λ_ = diagm.( (1.0 ./ μ_) ), + weight::Real = 1.0, ) where {N,P} = calcProductGaussians(M, μ_, nothing; dim, Λ_=Λ_ ) @@ -182,7 +185,8 @@ DevNotes """ function calcProductGaussians( M::AbstractManifold, - comps::AbstractVector{<:MvNormalKernel}, + comps::AbstractVector{<:MvNormalKernel}; + weight::Real = 1.0 ) # CHECK this should be on-manifold for points μ_ = mean.(comps) # This is a ArrayPartition which IS DEFINITELY ON MANIFOLD (we dispatch on mean) @@ -192,7 +196,7 @@ function calcProductGaussians( _μ, _Σ = calcProductGaussians(M, μ_, Σ_) - return MvNormalKernel(_μ, _Σ) + return MvNormalKernel(_μ, _Σ, weight) end diff --git a/src/services/ManellicTree.jl b/src/services/ManellicTree.jl index aa2ba8e..3fd8994 100644 --- a/src/services/ManellicTree.jl +++ b/src/services/ManellicTree.jl @@ -516,6 +516,66 @@ function buildTree_Manellic!( return tosort_leaves end +function buildTree_Manellic!( + M::AbstractManifold, + r_ker::AbstractVector{KL}; # vector of points referenced to the r_frame + N = length(r_ker), + weights::AbstractVector{<:Real} = ones(N).*(1/N), + kernel = KL, + kernel_bw = nothing, # TODO +) where {KL <: MvNormalKernel} + # + _μT() = typeof(r_ker[1].μ) + D = manifold_dimension(M) + CV = SMatrix{D,D,Float64,D*D}(collect(cov(r_ker[1]))) + KLT = getfield(ApproxManifoldProducts,kernel.name.name) + KT = KLT( + r_ker[1].μ, + CV + ) |> typeof + + r_PP = SizedVector{N,_μT()}(undef) + + # leaf kernels + lkern = SizedVector{N,KL}(undef) + _workaround_isdef_leafkernel = Set{Int}() + for i in 1:N + r_PP[i] = r_ker[i].μ + lkern[i] = if isnothing(kernel_bw) + r_ker[i] + else + updateKernelBW(r_ker[i], kernel_bw) # TODO handle vector of kernel_bws + end + push!(_workaround_isdef_leafkernel, i + N) + end + + mtree = ManellicTree( + M, + r_PP, + MVector{N,Float64}(weights), + MVector{N,Int}(1:N), + lkern, + SizedVector{N,KT}(undef), + SizedVector{N,Set{Int}}(undef), + _workaround_isdef_leafkernel, + Set{Int}(), + ) + + # + tosort_leaves = buildTree_Manellic!( + mtree, + 1, # start at root + 1, # spanning all data + N; # to end of data + kernel=KLT, + kernel_bw + ) + + # manual reset leaves in the order discovered + permute!(tosort_leaves.leaf_kernels, tosort_leaves.permute) + + return tosort_leaves +end function updateBandwidths( mtr::ManellicTree{M,D,N,HL}, @@ -719,7 +779,8 @@ function calcProductKernelBTLabels( labels_sampled, LOOidx::Union{Int, Nothing} = nothing, gibbsSeq = 1:length(proposals); - permute::Bool = true # true because signature is BTLabels + permute::Bool = true, # true because signature is BTLabels + weight::Real = 1.0 ) # select a density label from the other proposals prop_and_label = Tuple{Int,Int}[] @@ -732,7 +793,7 @@ function calcProductKernelBTLabels( components = map(pr_lb->getKernelTree(proposals[pr_lb[1]], pr_lb[2], permute, true), prop_and_label) # TODO upgrade to tuples - return calcProductGaussians(M, [components...]) + return calcProductGaussians(M, [components...]; weight) end @@ -740,14 +801,16 @@ function calcProductKernelsBTLabels( M::AbstractManifold, proposals::AbstractVector, N_lbl_sets::AbstractVector{<:NTuple}, - permute::Bool = true # true because signature is BTLabels + permute::Bool = true; # true because signature is BTLabels + weights = 1/length(N_lbl_sets) .* ones(length(N_lbl_sets)) ) # T = typeof(getKernelTree(proposals[1],1)) - post = Vector{T}(undef, length(N_lbl_sets)) + N = length(N_lbl_sets) + post = Vector{T}(undef, N) for (i,lbs) in enumerate(N_lbl_sets) - post[i] = calcProductKernelBTLabels(M, proposals, lbs; permute) + post[i] = calcProductKernelBTLabels(M, proposals, lbs; permute, weight=weights[i]) end return post @@ -852,12 +915,14 @@ function sampleProductSeqGibbsBTLabel( return labels_sampled end +Base.length(mkd::ManifoldKernelDensity) = length(mkd.belief) + function sampleProductSeqGibbsBTLabels( M::AbstractManifold, proposals::AbstractVector, MC = 3, - N::Int = round(Int, mean(length.(getPoints.(proposals)))), # FIXME use getLength or length of proposal (not getPoints) + N::Int = round(Int, mean(length.(proposals))), # FIXME use getLength or length of proposal (not getPoints) label_pools=[[1:1;] for _ in proposals] ) # diff --git a/test/manellic/testManellicTree.jl b/test/manellic/testManellicTree.jl index 3675ae2..7db68fd 100755 --- a/test/manellic/testManellicTree.jl +++ b/test/manellic/testManellicTree.jl @@ -999,7 +999,7 @@ end # end -@testset "Product of two Manellic beliefs, Sequential Gibbs" begin +@testset "Product of two Manellic beliefs, Sequential Gibbs, TranslationGroup(1)" begin ## M = TranslationGroup(1) @@ -1034,7 +1034,7 @@ bt_label_pool = [ ] # leaves only version -@info "Leaves only label sampling version (Gibbs)" +@info "Leaves only label sampling version (Gibbs), TranslationGroup(1)" ApproxManifoldProducts.sampleProductSeqGibbsBTLabel(M, [p1; p2], 3, bt_label_pool) @@ -1047,10 +1047,10 @@ mtr = ApproxManifoldProducts.buildTree_Manellic!(M, pts; kernel_bw, kernel=Appro @test isapprox( 0, mean(ApproxManifoldProducts.getKernelTree(mtr,1))[1]; atol=0.75) +@test all( (s->isapprox(1/N, s.weight;atol=1e-6)).(post) ) - -@info "Multi-scale label sampling version (Gibbs)" +@info "Multi-scale label sampling version (Gibbs), TranslationGroup(1)" # test label pool creation child_label_pools, all_leaves = ApproxManifoldProducts.generateLabelPoolRecursive([p1;p2], [1; 1]) @@ -1102,4 +1102,44 @@ end # lines!((s->s[1]).(XX),YY, color=:red) + +@testset "Multi-scale label sampling version (Gibbs), TranslationGroup(2)" begin +## + +M = TranslationGroup(2) +N = 64 + +pts1 = [1*randn(2) for _ in 1:N] +p1 = ApproxManifoldProducts.manikde!_manellic(M,pts1) + +pts2 = [1*randn(2) for _ in 1:N] +p2 = ApproxManifoldProducts.manikde!_manellic(M,pts2) + + +# test sampling +lbls = ApproxManifoldProducts.sampleProductSeqGibbsBTLabels(M, [p1.belief; p2.belief]) +lbls_ = unique(lbls) +N_ = length(lbls_) +weights = 1/N .* ones(N_) +# increase weight of duplicates +if N_ < N + for (i,lb_) in enumerate(lbls_) + idxs = findall(==(lb_),lbls) + weights[i] = weights[i]*length(idxs) + end +end +post = ApproxManifoldProducts.calcProductKernelsBTLabels(M, [p1.belief; p2.belief], lbls_, false; weights) # ?? was permute=false? +# check that any duplicates resulted in a height weight +@test isapprox( weights, (s->s.weight).(post); atol=1e-6 ) + +# NOTE, resulting tree might not have N number of data points +mtr12 = ApproxManifoldProducts.buildTree_Manellic!(M,post) + +## +end + + + + + #