Skip to content

Commit

Permalink
Support residual evaluation from ceres::Problem (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
B1ueber2y authored Jun 28, 2024
1 parent 4fce758 commit f56d0fb
Showing 1 changed file with 43 additions and 9 deletions.
52 changes: 43 additions & 9 deletions _pyceres/core/problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@

namespace py = pybind11;

namespace {

// Set residual blocks for Ceres::Problem::EvaluateOptions
void SetResidualBlocks(
ceres::Problem::EvaluateOptions& self,
std::vector<ResidualBlockIDWrapper>& residual_block_ids) {
self.residual_blocks.clear();
self.residual_blocks.reserve(residual_block_ids.size());
for (auto it = residual_block_ids.begin(); it != residual_block_ids.end();
++it) {
self.residual_blocks.push_back(it->id);
}
}

} // namespace

// Function to create Problem::Options with DO_NOT_TAKE_OWNERSHIP
// This is cause we want Python to manage our memory not Ceres
ceres::Problem::Options CreateProblemOptions() {
Expand Down Expand Up @@ -42,15 +58,21 @@ void BindProblem(py::module& m) {

py::class_<ceres::Problem::EvaluateOptions>(m, "EvaluateOptions")
.def(py::init<>())
.def("set_parameter_blocks",
[](ceres::Problem::EvaluateOptions& self,
std::vector<py::array_t<double>>& blocks) {
self.parameter_blocks.clear();
for (auto it = blocks.begin(); it != blocks.end(); ++it) {
py::buffer_info info = it->request();
self.parameter_blocks.push_back(static_cast<double*>(info.ptr));
}
})
.def(
"set_parameter_blocks",
[](ceres::Problem::EvaluateOptions& self,
std::vector<py::array_t<double>>& blocks) {
self.parameter_blocks.clear();
self.parameter_blocks.reserve(blocks.size());
for (auto it = blocks.begin(); it != blocks.end(); ++it) {
py::buffer_info info = it->request();
self.parameter_blocks.push_back(static_cast<double*>(info.ptr));
}
},
py::arg("parameter_blocks"))
.def("set_residual_blocks",
&SetResidualBlocks,
py::arg("residual_block_ids"))
.def_readwrite("apply_loss_function",
&ceres::Problem::EvaluateOptions::apply_loss_function)
.def_readwrite("num_threads",
Expand Down Expand Up @@ -249,6 +271,18 @@ void BindProblem(py::module& m) {
return residuals;
},
py::arg("options") = ceres::Problem::EvaluateOptions())
.def(
"evaluate_residuals",
[](ceres::Problem& self,
std::vector<ResidualBlockIDWrapper>& residual_block_ids) {
ceres::Problem::EvaluateOptions eval_options =
ceres::Problem::EvaluateOptions();
SetResidualBlocks(eval_options, residual_block_ids);
std::vector<double> residuals;
self.Evaluate(eval_options, nullptr, &residuals, nullptr, nullptr);
return residuals;
},
py::arg("residual_block_ids"))
.def(
"evaluate_jacobian",
[](ceres::Problem& self,
Expand Down

0 comments on commit f56d0fb

Please sign in to comment.