Skip to content

Commit

Permalink
wip parametric tree solve
Browse files Browse the repository at this point in the history
  • Loading branch information
Affie committed Oct 13, 2024
1 parent b4a6b50 commit dce4994
Show file tree
Hide file tree
Showing 7 changed files with 336 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/CliqueStateMachine/services/CliqueStateMachine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ function updateFromSubgraph_StateMachine(csmc::CliqStateMachineContainer)
logCSM(
csmc,
"CSM-5 Clique $(csmc.cliq.id) finished, solveKey=$(csmc.solveKey)";
loglevel = Logging.Info,
loglevel = Logging.Debug,
)
return IncrementalInference.exitStateMachine
end
Expand Down
76 changes: 76 additions & 0 deletions src/Factors/GenericFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,82 @@ function (cf::CalcFactor{<:ManifoldFactor})(X, p, q)
return distanceTangent2Point(cf.factor.M, X, p, q)
end


## ======================================================================================
## adjoint factor - adjoint action applied to the measurement
## ======================================================================================
function Ad(::Union{typeof(SpecialEuclidean(2)), typeof(SpecialEuclidean(3))}, p, X)
t = p.x[1]
R = p.x[2]
v = X.x[1]
Ω = X.x[2]
ArrayPartition(-R*Ω*R'*t + R*v, R*Ω*R')
end

function Ad(::typeof(SpecialEuclidean(3)), p)
t = p.x[1]
R = p.x[2]
vcat(
hcat(R, skew(t)*R),
hcat(zero(SMatrix{3,3,Float64}), R)
)
end

function Ad(::typeof(SpecialEuclidean(2)), p)
t = p.x[1]
R = p.x[2]
vcat(
hcat(R, -SA[0 -1; 1 0]*t),
SA[0 0 1]
)
end

struct AdFactor{F <: AbstractManifoldMinimize} <: AbstractManifoldMinimize
factor::F
end

function (cf::CalcFactor{<:AdFactor})(Xϵ, p, q)
# M = getManifold(cf.factor)
# p,q ∈ M
# Xϵ ∈ TϵM
# ϵ = identity_element(M)
# transform measurement from TϵM to TpM (global to local coordinates)
# Adₚ⁻¹ = AdjointMatrix(M, p)⁻¹ = AdjointMatrix(M, p⁻¹)
# Xp = Adₚ⁻¹ * Xϵᵛ
# ad = Ad(M, inv(M, p))
# Xp = Ad(M, inv(M, p), Xϵ)
# Xp = adjoint_action(M, inv(M, p), Xϵ)
#TODO is vector transport supposed to be the same?
# Xp = vector_transport_to(M, ϵ, Xϵ, p)

# Transform measurement covariance
# ᵉΣₚ = Adₚ ᵖΣₚ Adₚᵀ
#TODO test if transforming sqrt_iΣ is the same as Σ
# Σ = ad * inv(cf.sqrt_iΣ^2) * ad'
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), sqrt(inv(Σ)))
# sqrt_iΣ = convert(typeof(cf.sqrt_iΣ), ad * cf.sqrt_iΣ * ad')
Xp =

child_cf = CalcFactorResidual(
cf.faclbl,
cf.factor.factor,
cf.varOrder,
cf.varOrderIdxs,
cf.meas,
cf.sqrt_iΣ,
cf.cache,
)
return child_cf(Xp, p, q)
end

getMeasurementParametric(f::AdFactor) = getMeasurementParametric(f.factor)

getManifold(f::AdFactor) = getManifold(f.factor)
function getSample(cf::CalcFactor{<:AdFactor})
M = getManifold(cf)
return sampleTangent(M, cf.factor.factor.Z)
end

## ======================================================================================
## ManifoldPrior
## ======================================================================================
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/services/ManifoldSampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ function getSample(cf::CalcFactor{<:AbstractPrior})
end

function getSample(cf::CalcFactor{<:AbstractRelative})
M =getManifold(cf)
M = getManifold(cf)
if hasfield(typeof(cf.factor), :Z)
X = sampleTangent(M, cf.factor.Z)
else
Expand Down
195 changes: 184 additions & 11 deletions src/parametric/services/ParametricCSMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Notes
- Parametric state machine function nr. 3
"""
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
function solveUp_ParametricStateMachine_Old(csmc::CliqStateMachineContainer)
infocsm(csmc, "Par-3, Solving Up")

setCliqueDrawColor!(csmc.cliq, "red")
Expand Down Expand Up @@ -96,6 +96,145 @@ function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
return waitForDown_StateMachine
end

# solve relatives ignoring any priors keeping `from` at ϵ
# if clique has priors : solve to get a prior on `from`
# send messages as factors or just the beliefs? for now factors
function solveUp_ParametricStateMachine(csmc::CliqStateMachineContainer)
infocsm(csmc, "Par-3, Solving Up")

setCliqueDrawColor!(csmc.cliq, "red")
# csmc.drawtree ? drawTree(csmc.tree, show=false, filepath=joinpath(getSolverParams(csmc.dfg).logpath,"bt.pdf")) : nothing

msgfcts = Symbol[]

for (idx, upmsg) in getMessageBuffer(csmc.cliq).upRx #get cached messages taken from children saved in this clique
child_factors = addMsgFactors_Parametric!(csmc.cliqSubFg, upmsg, UpwardPass)
append!(msgfcts, getLabel.(child_factors)) # addMsgFactors_Parametric!
end
logCSM(csmc, "length mgsfcts=$(length(msgfcts))")
infocsm(csmc, "length mgsfcts=$(length(msgfcts))")

# store the cliqSubFg for later debugging
_dbgCSMSaveSubFG(csmc, "fg_beforeupsolve")

subfg = csmc.cliqSubFg

frontals = getCliqFrontalVarIds(csmc.cliq)
separators = getCliqSeparatorVarIds(csmc.cliq)

# if its a root do full solve
if length(getParent(csmc.tree, csmc.cliq)) == 0
# M, vartypeslist, lm_r, Σ = solve_RLM(subfg; is_sparse=false, finiteDiffCovariance=true)
autoinitParametric!(subfg)
M, vartypeslist, lm_r, Σ = solveGraphParametric!(subfg; is_sparse=false, finiteDiffCovariance=true, damping_term_min=1e-18)

else

# select first seperator as constant reference at the identity element
isempty(separators) && @warn "empty separators solving cliq $(csmc.cliq.id.value)" ls(subfg) lsf(subfg)
from = first(separators)
from_v = getVariable(subfg, from)
getSolverData(from_v, :parametric).val[1] = getPointIdentity(getVariableType(from_v))

#TODO handle priors
# Variables that are free to move
free_vars = [frontals; separators[2:end]]
# Solve for the free variables

@assert !isempty(lsf(subfg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(subfg)) lsf=$(lsf(subfg))"

# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from];)
M, vartypeslist, lm_r, Σ = solve_RLM_conditional(subfg, free_vars, [from]; finiteDiffCovariance=false, damping_term_min=1e-18)

end

# FIXME check solve convergence
if !true
@error "Par-3, clique $(csmc.cliq.id) failed to converge in upsolve" result
# propagate error to cleanly exit all cliques
putErrorUp(csmc)
if length(getParent(csmc.tree, csmc.cliq)) == 0
putErrorDown(csmc)
return IncrementalInference.exitStateMachine
end

return waitForDown_StateMachine
end

logCSM(csmc, "$(csmc.cliq.id): subfg solve converged sending messages")

# Pack results in massage factors

sigmas = extractMarginalsAP(M, vartypeslist, Σ)

# FIXME fix MsgRelativeType
relative_message_factors = MsgRelativeType();
for (i, to) in enumerate(vartypeslist)
if to in separators
#assume full dim factor
factype = selectFactorType(subfg, from, to)
# make S symetrical
# S = sigmas[i] # FIXME for some reason SMatrix is not invertable even though it is!!!!!!!!
S = Matrix(sigmas[i])# FIXME
S = (S + S') / 2
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")


M_to = getManifold(getVariableType(subfg, to))
ϵ = getPointIdentity(M_to)
μ = vee(M_to, ϵ, log(M_to, ϵ, lm_r[i]))

message_factor = AdFactor(factype(MvNormal(μ, S)))


# logCSM(csmc, "$(csmc.cliq.id): Z=$(getMeasurementParametric(message_factor))"; loglevel = Logging.Warn)

push!(relative_message_factors, (variables=[from, to], likelihood=message_factor))
end
end

# Done with solve delete factors
#TODO confirm, maybe don't delete mesage factors on subgraph, maybe delete if its priors, but not conditionals
# deleteMsgFactors!(csmc.cliqSubFg)

# store the cliqSubFg for later debugging
_dbgCSMSaveSubFG(csmc, "fg_afterupsolve")

# cliqueLikelihood = calculateMarginalCliqueLikelihood(vardict, Σ, varIds, cliqSeparatorVarIds)

#Fill in CliqueLikelihood
beliefMsg = LikelihoodMessage(;
sender = (; id = csmc.cliq.id.value, step = csmc._csm_iter),
status = UPSOLVED,
variableOrder = separators,
# cliqueLikelihood,
jointmsg = _MsgJointLikelihood(;relatives=relative_message_factors),
msgType = ParametricMessage(),
)

# @assert length(separators) <= 2 "TODO length(separators) = $(length(separators)) > 2 in clique $(csmc.cliq.id.value)"
@assert isempty(lsfPriors(csmc.cliqSubFg)) || csmc.cliq.id.value == 1 "TODO priors in clique $(csmc.cliq.id.value)"
# if length(lsfPriors(csmc.cliqSubFg)) > 0 || length(separators) > 2
# for si in cliqSeparatorVarIds
# vnd = getSolverData(getVariable(csmc.cliqSubFg, si), :parametric)
# beliefMsg.belief[si] = TreeBelief(deepcopy(vnd))
# end
# end

for e in getEdgesParent(csmc.tree, csmc.cliq)
logCSM(csmc, "$(csmc.cliq.id): put! on edge $(e)")
getMessageBuffer(csmc.cliq).upTx = deepcopy(beliefMsg)
putBeliefMessageUp!(csmc.tree, e, beliefMsg)
end

return waitForDown_StateMachine
end

global g_n = nothing

"""
$SIGNATURES
Expand All @@ -120,6 +259,15 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
logCSM(csmc, "$(csmc.cliq.id): Updating separator $msym from message $(belief.val)")
vnd.val .= belief.val
vnd.bw .= belief.bw

p = belief.val[1]

S = belief.bw
S = (S + S') / 2
vnd.bw .= S

nd = MvNormal(getCoordinates(Main.Pose2, p), S)
addFactor!(csmc.cliqSubFg, [msym], Main.PriorPose2(nd))
end
end
end
Expand All @@ -132,23 +280,48 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
#only down solve if its not a root
if length(getParent(csmc.tree, csmc.cliq)) != 0
frontals = getCliqFrontalVarIds(csmc.cliq)
vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
# vardict, result, flatvars, Σ = solveConditionalsParametric(csmc.cliqSubFg, frontals)
#TEMP testing difference
# vardict, result = solveGraphParametric(csmc.cliqSubFg)
# Pack all results in variables
if result.g_converged || result.f_converged
@assert !isempty(lsf(csmc.cliqSubFg)) "No factors in clique $(csmc.cliq.id.value) ls=$(ls(csmc.cliqSubFg)) lsf=$(lsf(csmc.cliqSubFg))"

# M, vartypeslist, lm_r, Σ = solve_RLM_conditional(csmc.cliqSubFg, frontals; finiteDiffCovariance=false, damping_term_min=1e-18)
M, vartypeslist, lm_r, Σ = solve_RLM(csmc.cliqSubFg; finiteDiffCovariance=false, damping_term_min=1e-18)
sigmas = extractMarginalsAP(M, vartypeslist, Σ)

if true # TODO check for convergence result.g_converged || result.f_converged
logCSM(
csmc,
"$(csmc.cliq.id): subfg optim converged updating variables";
loglevel = Logging.Info,
loglevel = Logging.Debug,
)
for (v, val) in vardict
logCSM(csmc, "$(csmc.cliq.id) down: updating $v : $val"; loglevel = Logging.Info)
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)
#Update subfg variables
vnd.val[1] = val.val
vnd.bw .= val.cov
for (i, v) in enumerate(vartypeslist)
if v in frontals
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)

S = Matrix(sigmas[i])# FIXME
S = (S + S') / 2
# @assert all(isapprox.(S, sigmas[i], rtol=1e-3)) "Bad covariance matrix - not symetrical"
!all(isapprox.(S, sigmas[i], rtol=1e-3)) && @error("Bad covariance matrix - not symetrical")
# @assert all(diag(S) .> 0) "Bad covariance matrix - not positive diag"
!all(diag(S) .> 0) && @error("Bad covariance matrix - not positive diag")


#Update subfg variables
vnd.val[1] = lm_r[i]
vnd.bw .= S
end
end
# for (v, val) in vardict
# logCSM(csmc, "$(csmc.cliq.id) down: updating $v"; val, loglevel = Logging.Debug)
# vnd = getSolverData(getVariable(csmc.cliqSubFg, v), :parametric)

# #Update subfg variables
# vnd.val[1] = val.val
# vnd.bw .= val.cov
# end
else
@error "Par-5, clique $(csmc.cliq.id) failed to converge in down solve" result
#propagate error to cleanly exit all cliques
Expand All @@ -169,7 +342,7 @@ function solveDown_ParametricStateMachine(csmc::CliqStateMachineContainer)
for fi in cliqFrontalVarIds
vnd = getSolverData(getVariable(csmc.cliqSubFg, fi), :parametric)
beliefMsg.belief[fi] = TreeBelief(vnd)
logCSM(csmc, "$(csmc.cliq.id): down message $fi : $beliefMsg"; loglevel = Logging.Info)
logCSM(csmc, "$(csmc.cliq.id): down message $fi"; beliefMsg=beliefMsg.belief[fi], loglevel = Logging.Debug)
end

# pass through the frontal variables that were sent from above
Expand Down
Loading

0 comments on commit dce4994

Please sign in to comment.