From bbe8168ecd7d0817f9169faa516f3ab4b7f01b2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20BRIOL?= Date: Thu, 15 Feb 2024 22:06:03 +0100 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20spline=20interp?= =?UTF-8?q?olation=20in=20Spline2D=20class?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../detail/interpolation/factory_1d.hpp | 41 +++++++++++++++++ .../include/pyinterp/detail/math/spline2d.hpp | 46 ++++--------------- 2 files changed, 51 insertions(+), 36 deletions(-) create mode 100644 src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp new file mode 100644 index 00000000..90da9f52 --- /dev/null +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp @@ -0,0 +1,41 @@ +#pragma once +#include + +#include "pyinterp/detail/interpolation/akima.hpp" +#include "pyinterp/detail/interpolation/akima_periodic.hpp" +#include "pyinterp/detail/interpolation/cspline.hpp" +#include "pyinterp/detail/interpolation/cspline_periodic.hpp" +#include "pyinterp/detail/interpolation/linear.hpp" +#include "pyinterp/detail/interpolation/polynomial.hpp" +#include "pyinterp/detail/interpolation/steffen.hpp" + +namespace pyinterp::detail::interpolation { + +template +static inline auto factory_1d(const std::string &kind) + -> std::unique_ptr> { + if (kind == "linear") { + return std::make_unique>(); + } + if (kind == "polynomial") { + return std::make_unique>(); + } + if (kind == "c_spline") { + return std::make_unique>(); + } + if (kind == "c_spline_periodic") { + return std::make_unique>(); + } + if (kind == "akima") { + return std::make_unique>(); + } + if (kind == "akima_periodic") { + return std::make_unique>(); + } + if (kind == "steffen") { + return std::make_unique>(); + } + throw std::invalid_argument("Invalid interpolation type: " + kind); +} + +} // namespace pyinterp::detail::interpolation diff --git a/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp b/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp index af1adf54..a2ea1eb9 100644 --- a/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp @@ -8,7 +8,7 @@ #include #include -#include "pyinterp/detail/gsl/interpolate1d.hpp" +#include "pyinterp/detail/interpolation/factory_1d.hpp" #include "pyinterp/detail/math/frame.hpp" namespace pyinterp::detail::math { @@ -22,53 +22,27 @@ class Spline2D { /// @param type method of calculation explicit Spline2D(const Frame2D &xr, const std::string &kind) : column_(xr.y()->size()), - x_interpolator_(xr.x()->size(), - gsl::Interpolate1D::parse_interp_type(kind), - gsl::Accelerator()), - y_interpolator_(xr.y()->size(), - gsl::Interpolate1D::parse_interp_type(kind), - gsl::Accelerator()) {} + x_interpolator_(interpolation::factory_1d(kind)), + y_interpolator_(interpolation::factory_1d(kind)) {} /// Return the interpolated value of y for a given point x auto interpolate(const double x, const double y, const Frame2D &xr) -> double { - return evaluate(&gsl::Interpolate1D::interpolate, x, y, xr); - } - - /// Return the derivative for a given point x - auto derivative(const double x, const double y, const Frame2D &xr) -> double { - return evaluate(&gsl::Interpolate1D::derivative, x, y, xr); - } - - /// Return the second derivative for a given point x - auto second_derivative(const double x, const double y, const Frame2D &xr) - -> double { - return evaluate(&gsl::Interpolate1D::second_derivative, x, y, xr); + // Spline interpolation as function of X-coordinate + for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) { + column_(ix) = (*x_interpolator_)(*(xr.x()), xr.q()->col(ix), x); + } + return (*y_interpolator_)(*(xr.y()), column_, y); } private: - using InterpolateFunction = double (gsl::Interpolate1D::*)( - const Eigen::VectorXd &, const Eigen::VectorXd &, const double); /// Column of the interpolation window (interpolation according to Y /// coordinates) Eigen::VectorXd column_; /// GSL interpolators - gsl::Interpolate1D x_interpolator_; - gsl::Interpolate1D y_interpolator_; - - /// Evaluation of the GSL function performing the calculation. - auto evaluate( - const std::function - &function, - const double x, const double y, const Frame2D &xr) -> double { - // Spline interpolation as function of X-coordinate - for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) { - column_(ix) = function(x_interpolator_, *(xr.x()), xr.q()->col(ix), x); - } - return function(y_interpolator_, *(xr.y()), column_, y); - } + std::unique_ptr> x_interpolator_; + std::unique_ptr> y_interpolator_; }; } // namespace pyinterp::detail::math