Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deconv using CalcFactor dispatch #1789

Merged
merged 7 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions src/services/NumericalCalculations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@
return hypoCalcFactor(CalcConv, p)
end

#TODO untested and unused
# for deconv with the measurement a tangent vector
# function (hypoCalcFactor::CalcFactorNormSq)(M::AbstractManifold, Xc::AbstractVector)
# # M = hypoCalcFactor.manifold # calc factor has factor manifold in not variable that is needed here
# ϵ = getPointIdentity(M)
# X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
# return hypoCalcFactor(CalcDeconv, X)
# end

function _solveLambdaNumeric(
fcttype::Union{F, <:Mixture{N_, F, S, T}},
hypoCalcFactor,
Expand All @@ -109,7 +100,6 @@
# the variable is a manifold point, we are working on the tangent plane in optim for now.
#
#TODO this is not general to all manifolds, should work for lie groups.
# ϵ = identity_element(M, u0)
ϵ = getPointIdentity(variableType)

X0c = zero(MVector{getDimension(M),Float64})
Expand Down Expand Up @@ -177,6 +167,105 @@
return hat(M, ϵ, r.minimizer)
end

## deconvolution with calcfactor wip
struct CalcDeconv end

function (cf::CalcFactorNormSq)(::Type{CalcDeconv}, meas)
res = cf(meas, map(vvh -> _getindex_anyn(vvh, cf._sampleIdx), cf._legacyParams)...)
return sum(x->x^2, res)

Check warning on line 175 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L173-L175

Added lines #L173 - L175 were not covered by tests
end

# for deconv with the measurement a tangent vector, can dispatch for other measurement types.
function (hypoCalcFactor::CalcFactorNormSq)(::Type{CalcDeconv}, M::AbstractManifold, Xc::AbstractVector)
ϵ = getPointIdentity(M)
X = get_vector(M, ϵ, Xc, DefaultOrthogonalBasis())
return hypoCalcFactor(CalcDeconv, X)

Check warning on line 182 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L179-L182

Added lines #L179 - L182 were not covered by tests
end

#NOTE Optim.jl version that assumes measurement is on the tangent
function _solveLambdaNumericMeas_v2(

Check warning on line 186 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L186

Added line #L186 was not covered by tests
Affie marked this conversation as resolved.
Show resolved Hide resolved
fcttype::Union{F, <:Mixture{N_, F, S, T}},
hypoCalcFactor,
X0,#::AbstractVector{<:Real},
islen1::Bool = false,
) where {N_, F <: AbstractManifoldMinimize, S, T}
#
M = getManifold(fcttype)
ϵ = getPointIdentity(M)
X0c = zeros(manifold_dimension(M))
X0c .= vee(M, ϵ, X0)

Check warning on line 196 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L193-L196

Added lines #L193 - L196 were not covered by tests

alg = islen1 ? Optim.BFGS() : Optim.NelderMead()

Check warning on line 198 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L198

Added line #L198 was not covered by tests

r = Optim.optimize(
x->hypoCalcFactor(CalcDeconv, M, x),

Check warning on line 201 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L200-L201

Added lines #L200 - L201 were not covered by tests
X0c,
alg
)
if !Optim.converged(r)
@debug "Optim did not converge:" r

Check warning on line 206 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L205-L206

Added lines #L205 - L206 were not covered by tests
end

return hat(M, ϵ, r.minimizer)

Check warning on line 209 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L209

Added line #L209 was not covered by tests
end

function approxDeconv_v2(

Check warning on line 212 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L212

Added line #L212 was not covered by tests
fcto::DFGFactor,
ccw::CommonConvWrapper = _getCCW(fcto);
N::Int = 100,
measurement::AbstractVector = sampleFactor(ccw, N),
)
# NOTE comments kept from original...
# FIXME needs xDim for all variables at once? xDim = 0 likely to break?
Affie marked this conversation as resolved.
Show resolved Hide resolved

# but what if this is a partial factor -- is that important for general cases in deconv?
_setCCWDecisionDimsConv!(ccw, 0) # ccwl.xDim used to hold the last forward solve getDimension(getVariableType(Xi[sfidx]))

Check warning on line 222 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L222

Added line #L222 was not covered by tests
Affie marked this conversation as resolved.
Show resolved Hide resolved

# FIXME This does not incorporate multihypo??
Affie marked this conversation as resolved.
Show resolved Hide resolved
varsyms = getVariableOrder(fcto)

Check warning on line 225 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L225

Added line #L225 was not covered by tests
# vars = getPoints.(getBelief.(dfg, varsyms, solveKey) )

# TODO, consolidate fmd with getSample/sampleFactor and _buildLambda
# TODO assuming vector on only first container in measurement::Tuple

# NOTE
# build a lambda that incorporates the multihypo selections
# set these first
# ccw.cpt[].activehypo / .p / .params # params should already be set from construction
hyporecipe = _prepareHypoRecipe!(nothing, N, 0, length(varsyms))

Check warning on line 235 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L235

Added line #L235 was not covered by tests
# only doing the current active hypo
@assert hyporecipe.activehypo[2][1] == 1 "deconv was expecting hypothesis nr == (1, 1:d)"

Check warning on line 237 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L237

Added line #L237 was not covered by tests

# get measurement dimension
zDim = _getZDim(fcto)
islen1 = zDim == 1

Check warning on line 241 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L240-L241

Added lines #L240 - L241 were not covered by tests

#make a copy of the original measurement before mutating it
sampled_meas = deepcopy(measurement)

Check warning on line 244 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L244

Added line #L244 was not covered by tests

fcttype = getFactorType(fcto)
if !(fcttype isa AbstractManifoldMinimize)
error("Only AbstractManifoldMinimize is currently supported in approxDeconv_v2")

Check warning on line 248 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L246-L248

Added lines #L246 - L248 were not covered by tests
end

for idx = 1:N

Check warning on line 251 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L251

Added line #L251 was not covered by tests

# TODO must first resolve hypothesis selection before unrolling them -- deferred #1096
resize!(ccw.hyporecipe.activehypo, length(hyporecipe.activehypo[2][2]))
ccw.hyporecipe.activehypo[:] = hyporecipe.activehypo[2][2]

Check warning on line 255 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L254-L255

Added lines #L254 - L255 were not covered by tests
#TODO why is this resize in the loop?

# Create a CalcFactor functor of the correct hypo,. TODO don't know what setup above is still needed
_hypoCalcFactor = _buildHypoCalcFactor(ccw, idx)

Check warning on line 259 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L259

Added line #L259 was not covered by tests

ts = _solveLambdaNumericMeas_v2(fcttype, _hypoCalcFactor, measurement[idx], islen1)
measurement[idx] = ts

Check warning on line 262 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L261-L262

Added lines #L261 - L262 were not covered by tests

end

Check warning on line 264 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L264

Added line #L264 was not covered by tests

return measurement, sampled_meas

Check warning on line 266 in src/services/NumericalCalculations.jl

View check run for this annotation

Codecov / codecov/patch

src/services/NumericalCalculations.jl#L266

Added line #L266 was not covered by tests
end

## ================================================================================================
## Heavy dispatch for all AbstractFactor / Mixture cases below
## ================================================================================================
Expand Down Expand Up @@ -374,7 +463,6 @@
#

struct CalcConv end
struct CalcDeconv end

_getindex_anyn(vec, n) = begin
len = length(vec)
Expand Down
19 changes: 19 additions & 0 deletions test/testSpecialEuclidean2Mani.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,25 @@ m_θ = map(x->x.x[2][2], meas)
@test isapprox(std(p_t), std(m_t), atol=0.3)

##
pred, meas = IIF.approxDeconv_v2(fg[:x0x1f1])

p_t = map(x->x.x[1], pred)
m_t = map(x->x.x[1], meas)
p_θ = map(x->x.x[2][2], pred)
m_θ = map(x->x.x[2][2], meas)

@test isapprox(mean(p_θ), 0.1, atol=0.02)
@test isapprox(std(p_θ), 0.05, atol=0.02)

@test isapprox(mean(p_t), [10,0], atol=0.3)
@test isapprox(std(p_t), [0.5,0.5], atol=0.3)

@test isapprox(mean(p_θ), mean(m_θ), atol=0.03)
@test isapprox(std(p_θ), std(m_θ), atol=0.03)

@test isapprox(mean(p_t), mean(m_t), atol=0.3)
@test isapprox(std(p_t), std(m_t), atol=0.3)

end


Expand Down
Loading