diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp new file mode 100644 index 00000000..90da9f52 --- /dev/null +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/factory_1d.hpp @@ -0,0 +1,41 @@ +#pragma once +#include + +#include "pyinterp/detail/interpolation/akima.hpp" +#include "pyinterp/detail/interpolation/akima_periodic.hpp" +#include "pyinterp/detail/interpolation/cspline.hpp" +#include "pyinterp/detail/interpolation/cspline_periodic.hpp" +#include "pyinterp/detail/interpolation/linear.hpp" +#include "pyinterp/detail/interpolation/polynomial.hpp" +#include "pyinterp/detail/interpolation/steffen.hpp" + +namespace pyinterp::detail::interpolation { + +template +static inline auto factory_1d(const std::string &kind) + -> std::unique_ptr> { + if (kind == "linear") { + return std::make_unique>(); + } + if (kind == "polynomial") { + return std::make_unique>(); + } + if (kind == "c_spline") { + return std::make_unique>(); + } + if (kind == "c_spline_periodic") { + return std::make_unique>(); + } + if (kind == "akima") { + return std::make_unique>(); + } + if (kind == "akima_periodic") { + return std::make_unique>(); + } + if (kind == "steffen") { + return std::make_unique>(); + } + throw std::invalid_argument("Invalid interpolation type: " + kind); +} + +} // namespace pyinterp::detail::interpolation diff --git a/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp b/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp index af1adf54..a2ea1eb9 100644 --- a/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/math/spline2d.hpp @@ -8,7 +8,7 @@ #include #include -#include "pyinterp/detail/gsl/interpolate1d.hpp" +#include "pyinterp/detail/interpolation/factory_1d.hpp" #include "pyinterp/detail/math/frame.hpp" namespace pyinterp::detail::math { @@ -22,53 +22,27 @@ class Spline2D { /// @param type method of calculation explicit Spline2D(const Frame2D &xr, const std::string &kind) : column_(xr.y()->size()), - x_interpolator_(xr.x()->size(), - gsl::Interpolate1D::parse_interp_type(kind), - gsl::Accelerator()), - y_interpolator_(xr.y()->size(), - gsl::Interpolate1D::parse_interp_type(kind), - gsl::Accelerator()) {} + x_interpolator_(interpolation::factory_1d(kind)), + y_interpolator_(interpolation::factory_1d(kind)) {} /// Return the interpolated value of y for a given point x auto interpolate(const double x, const double y, const Frame2D &xr) -> double { - return evaluate(&gsl::Interpolate1D::interpolate, x, y, xr); - } - - /// Return the derivative for a given point x - auto derivative(const double x, const double y, const Frame2D &xr) -> double { - return evaluate(&gsl::Interpolate1D::derivative, x, y, xr); - } - - /// Return the second derivative for a given point x - auto second_derivative(const double x, const double y, const Frame2D &xr) - -> double { - return evaluate(&gsl::Interpolate1D::second_derivative, x, y, xr); + // Spline interpolation as function of X-coordinate + for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) { + column_(ix) = (*x_interpolator_)(*(xr.x()), xr.q()->col(ix), x); + } + return (*y_interpolator_)(*(xr.y()), column_, y); } private: - using InterpolateFunction = double (gsl::Interpolate1D::*)( - const Eigen::VectorXd &, const Eigen::VectorXd &, const double); /// Column of the interpolation window (interpolation according to Y /// coordinates) Eigen::VectorXd column_; /// GSL interpolators - gsl::Interpolate1D x_interpolator_; - gsl::Interpolate1D y_interpolator_; - - /// Evaluation of the GSL function performing the calculation. - auto evaluate( - const std::function - &function, - const double x, const double y, const Frame2D &xr) -> double { - // Spline interpolation as function of X-coordinate - for (Eigen::Index ix = 0; ix < xr.y()->size(); ++ix) { - column_(ix) = function(x_interpolator_, *(xr.x()), xr.q()->col(ix), x); - } - return function(y_interpolator_, *(xr.y()), column_, y); - } + std::unique_ptr> x_interpolator_; + std::unique_ptr> y_interpolator_; }; } // namespace pyinterp::detail::math