Skip to content

Commit

Permalink
♻️ Refactor spline interpolation in Spline2D class
Browse files Browse the repository at this point in the history
  • Loading branch information
fbriol committed Feb 15, 2024
1 parent a33fce0 commit bbe8168
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 36 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#pragma once
#include <stdexcept>

#include "pyinterp/detail/interpolation/akima.hpp"
#include "pyinterp/detail/interpolation/akima_periodic.hpp"
#include "pyinterp/detail/interpolation/cspline.hpp"
#include "pyinterp/detail/interpolation/cspline_periodic.hpp"
#include "pyinterp/detail/interpolation/linear.hpp"
#include "pyinterp/detail/interpolation/polynomial.hpp"
#include "pyinterp/detail/interpolation/steffen.hpp"

namespace pyinterp::detail::interpolation {

template <typename T>
static inline auto factory_1d(const std::string &kind)
-> std::unique_ptr<Interpolator1D<T>> {
if (kind == "linear") {
return std::make_unique<Linear<T>>();
}
if (kind == "polynomial") {
return std::make_unique<Polynomial<T>>();
}
if (kind == "c_spline") {
return std::make_unique<CSpline<T>>();
}
if (kind == "c_spline_periodic") {
return std::make_unique<CSplinePeriodic<T>>();
}
if (kind == "akima") {
return std::make_unique<Akima<T>>();
}
if (kind == "akima_periodic") {
return std::make_unique<AkimaPeriodic<T>>();
}
if (kind == "steffen") {
return std::make_unique<Steffen<T>>();
}
throw std::invalid_argument("Invalid interpolation type: " + kind);
}

} // namespace pyinterp::detail::interpolation
46 changes: 10 additions & 36 deletions src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include <Eigen/Core>
#include <string>

#include "pyinterp/detail/gsl/interpolate1d.hpp"
#include "pyinterp/detail/interpolation/factory_1d.hpp"
#include "pyinterp/detail/math/frame.hpp"

namespace pyinterp::detail::math {
Expand All @@ -22,53 +22,27 @@ class Spline2D {
/// @param type method of calculation
explicit Spline2D(const Frame2D &xr, const std::string &kind)
: column_(xr.y()->size()),
x_interpolator_(xr.x()->size(),
gsl::Interpolate1D::parse_interp_type(kind),
gsl::Accelerator()),
y_interpolator_(xr.y()->size(),
gsl::Interpolate1D::parse_interp_type(kind),
gsl::Accelerator()) {}
x_interpolator_(interpolation::factory_1d<double>(kind)),
y_interpolator_(interpolation::factory_1d<double>(kind)) {}

/// Return the interpolated value of y for a given point x
auto interpolate(const double x, const double y, const Frame2D &xr)
-> double {
return evaluate(&gsl::Interpolate1D::interpolate, x, y, xr);
}

/// Return the derivative for a given point x
auto derivative(const double x, const double y, const Frame2D &xr) -> double {
return evaluate(&gsl::Interpolate1D::derivative, x, y, xr);
}

/// Return the second derivative for a given point x
auto second_derivative(const double x, const double y, const Frame2D &xr)
-> double {
return evaluate(&gsl::Interpolate1D::second_derivative, x, y, xr);
// Spline interpolation as function of X-coordinate
for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) {
column_(ix) = (*x_interpolator_)(*(xr.x()), xr.q()->col(ix), x);
}
return (*y_interpolator_)(*(xr.y()), column_, y);
}

private:
using InterpolateFunction = double (gsl::Interpolate1D::*)(
const Eigen::VectorXd &, const Eigen::VectorXd &, const double);
/// Column of the interpolation window (interpolation according to Y
/// coordinates)
Eigen::VectorXd column_;

/// GSL interpolators
gsl::Interpolate1D x_interpolator_;
gsl::Interpolate1D y_interpolator_;

/// Evaluation of the GSL function performing the calculation.
auto evaluate(
const std::function<double(gsl::Interpolate1D &, const Eigen::VectorXd &,
const Eigen::VectorXd &, const double)>
&function,
const double x, const double y, const Frame2D &xr) -> double {
// Spline interpolation as function of X-coordinate
for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) {
column_(ix) = function(x_interpolator_, *(xr.x()), xr.q()->col(ix), x);
}
return function(y_interpolator_, *(xr.y()), column_, y);
}
std::unique_ptr<interpolation::Interpolator1D<double>> x_interpolator_;
std::unique_ptr<interpolation::Interpolator1D<double>> y_interpolator_;
};

} // namespace pyinterp::detail::math

0 comments on commit bbe8168

Please sign in to comment.