Skip to content

Commit

Permalink
Restore _slack in mumerical calcs (#1790)
Browse files Browse the repository at this point in the history
  • Loading branch information
Affie authored Oct 19, 2023
1 parent 0757330 commit 4bcc7b4
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/Factors/Mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ function sampleFactor(cf::CalcFactor{<:Mixture}, N::Int = 1)
cf.fullvariables,
cf.solvefor,
cf.manifold,
cf.measurement
cf.measurement,
nothing,
)
smpls = [getSample(cf_) for _ = 1:N]
# smpls = Array{Float64,2}(undef,s.dims,N)
Expand Down
4 changes: 3 additions & 1 deletion src/entities/CalcFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ struct CalcFactorNormSq{
C,
VT <: Tuple,
M <: AbstractManifold,
MEAS
MEAS,
S
} <: CalcFactor{FT}
""" the interface compliant user object functor containing the data and logic """
factor::FT
Expand All @@ -58,6 +59,7 @@ struct CalcFactorNormSq{
solvefor::Int
manifold::M
measurement::MEAS
slack::S
end

#TODO deprecate after CalcFactor is updated to CalcFactorNormSq
Expand Down
6 changes: 5 additions & 1 deletion src/services/CalcFactor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ function CalcFactorNormSq(
cache = ccwl.dummyCache,
fullvariables = ccwl.fullvariables,
solvefor = ccwl.varidx[],
manifold = getManifold(ccwl)
manifold = getManifold(ccwl),
slack=nothing,
)
#
# FIXME using ccwl.dummyCache is not thread-safe
Expand All @@ -45,6 +46,7 @@ function CalcFactorNormSq(
solvefor,
manifold,
ccwl.measurement,
slack
)
end

Expand Down Expand Up @@ -407,6 +409,7 @@ function _createCCW(
solvefor,
manifold,
nothing,
nothing,
)

# get a measurement sample
Expand All @@ -426,6 +429,7 @@ function _createCCW(
solvefor,
manifold,
measurement,
nothing,
)


Expand Down
10 changes: 7 additions & 3 deletions src/services/NumericalCalculations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ function _buildCalcFactor(
smpid,
varParams,
activehypo,
_slack = nothing,
)
#
# FIXME, make thread safe (cache)
Expand All @@ -210,6 +211,7 @@ function _buildCalcFactor(
solveforidx, #solvefor
getManifold(ccwl), #manifold
ccwl.measurement,
_slack,
)
end

Expand Down Expand Up @@ -391,10 +393,11 @@ function (cf::CalcFactorNormSq)(::Type{CalcConv}, x)
varValsHypo[cf.solvefor][sampleIdx] = x

res = cf(cf.measurement[sampleIdx], map(vvh -> _getindex_anyn(vvh, sampleIdx), varValsHypo)...)
res = isnothing(cf.slack) ? res : res .- cf.slack
return sum(x->x^2, res)
end

function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer)
function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer, _slack)
# build a view to the decision variable memory
varValsHypo = ccwl.varValsAll[][ccwl.hyporecipe.activehypo]
# create calc factor selected hypo and samples
Expand All @@ -403,7 +406,8 @@ function _buildHypoCalcFactor(ccwl::CommonConvWrapper, smpid::Integer)
ccwl, #
smpid, # ends in _sampleIdx
varValsHypo, # ends in _legacyParams
ccwl.hyporecipe.activehypo # ends in solvefor::Int
ccwl.hyporecipe.activehypo, # ends in solvefor::Int
_slack,
)
return cf
end
Expand All @@ -427,7 +431,7 @@ function _solveCCWNumeric!(
# target = view(ccwl.varValsAll[][ccwl.varidx[]], smpid)

# SUPER IMPORTANT ON PARTIALS, RESIDUAL FUNCTION MUST DEAL WITH PARTIAL AND WILL GET FULL VARIABLE POINTS REGARDLESS
_hypoCalcFactor = _buildHypoCalcFactor(ccwl, smpid)
_hypoCalcFactor = _buildHypoCalcFactor(ccwl, smpid, _slack)

# do the parameter search over defined decision variables using Minimization
sfidx = ccwl.varidx[]
Expand Down

0 comments on commit 4bcc7b4

Please sign in to comment.