Skip to content

Commit

Permalink
Evaluation of residuals and Jacobian from pyceres.Problem (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
B1ueber2y authored Jun 26, 2024
1 parent c3b9f42 commit 4fce758
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 4 deletions.
2 changes: 2 additions & 0 deletions _pyceres/core/bindings.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "_pyceres/core/callbacks.h"
#include "_pyceres/core/cost_functions.h"
#include "_pyceres/core/covariance.h"
#include "_pyceres/core/crs_matrix.h"
#include "_pyceres/core/loss_functions.h"
#include "_pyceres/core/manifold.h"
#include "_pyceres/core/problem.h"
Expand All @@ -17,6 +18,7 @@ void BindCore(py::module& m) {
BindTypes(m);
BindCallbacks(m);
BindCovariance(m);
BindCRSMatrix(m);
BindSolver(m);
BindLossFunctions(m);
BindCostFunctions(m);
Expand Down
49 changes: 49 additions & 0 deletions _pyceres/core/crs_matrix.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once

#include "_pyceres/core/wrappers.h"
#include "_pyceres/helpers.h"
#include "_pyceres/logging.h"

#include <Eigen/Sparse>
#include <ceres/ceres.h>
#include <ceres/crs_matrix.h>
#include <pybind11/pybind11.h>

namespace py = pybind11;

namespace {
py::tuple ConvertCRSToPyTuple(const ceres::CRSMatrix& crsMatrix) {
const size_t n_values = crsMatrix.values.size();
py::array_t<int> rows(n_values), cols(n_values);
py::array_t<double> values(n_values);

int* const rows_data = static_cast<int*>(rows.request().ptr);
int* const cols_data = static_cast<int*>(cols.request().ptr);
double* const values_data = static_cast<double*>(values.request().ptr);

int counter = 0;
for (int row = 0; row < crsMatrix.num_rows; ++row) {
for (int k = crsMatrix.rows[row]; k < crsMatrix.rows[row + 1]; ++k) {
rows_data[counter] = row;
cols_data[counter] = crsMatrix.cols[k];
values_data[counter] = crsMatrix.values[k];
counter++;
}
}

// return as a tuple
return py::make_tuple(rows, cols, values);
}
} // namespace

void BindCRSMatrix(py::module& m) {
using CRSMatrix = ceres::CRSMatrix;
py::class_<CRSMatrix> PyCRSMatrix(m, "CRSMatrix");
PyCRSMatrix.def(py::init<>())
.def_readonly("num_rows", &CRSMatrix::num_rows)
.def_readonly("num_cols", &CRSMatrix::num_cols)
.def_readonly("rows", &CRSMatrix::rows)
.def_readonly("cols", &CRSMatrix::cols)
.def_readonly("values", &CRSMatrix::values)
.def("to_tuple", &ConvertCRSToPyTuple);
}
32 changes: 28 additions & 4 deletions _pyceres/core/problem.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,17 @@ void BindProblem(py::module& m) {
.def_readwrite("disable_all_safety_checks",
&options::disable_all_safety_checks);

// TODO: bind Problem::Evaluate
py::class_<ceres::Problem::EvaluateOptions>(m, "EvaluateOptions")
.def(py::init<>())
// Doesn't make sense to wrap this as you can't see the pointers in python
//.def_readwrite("parameter_blocks",&ceres::Problem::EvaluateOptions)
.def("set_parameter_blocks",
[](ceres::Problem::EvaluateOptions& self,
std::vector<py::array_t<double>>& blocks) {
self.parameter_blocks.clear();
for (auto it = blocks.begin(); it != blocks.end(); ++it) {
py::buffer_info info = it->request();
self.parameter_blocks.push_back(static_cast<double*>(info.ptr));
}
})
.def_readwrite("apply_loss_function",
&ceres::Problem::EvaluateOptions::apply_loss_function)
.def_readwrite("num_threads",
Expand Down Expand Up @@ -233,5 +239,23 @@ void BindProblem(py::module& m) {
.def("remove_residual_block",
[](ceres::Problem& self, ResidualBlockIDWrapper& residual_block_id) {
self.RemoveResidualBlock(residual_block_id.id);
});
})
.def(
"evaluate_residuals",
[](ceres::Problem& self,
const ceres::Problem::EvaluateOptions& options) {
std::vector<double> residuals;
self.Evaluate(options, nullptr, &residuals, nullptr, nullptr);
return residuals;
},
py::arg("options") = ceres::Problem::EvaluateOptions())
.def(
"evaluate_jacobian",
[](ceres::Problem& self,
const ceres::Problem::EvaluateOptions& options) {
ceres::CRSMatrix jacobian;
self.Evaluate(options, nullptr, nullptr, nullptr, &jacobian);
return jacobian;
},
py::arg("options") = ceres::Problem::EvaluateOptions());
}

0 comments on commit 4fce758

Please sign in to comment.