From 909087f2257468e740b6eb2b23b81d83010e2e8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fr=C3=A9d=C3=A9ric=20BRIOL?= Date: Sat, 17 Feb 2024 00:17:15 +0100 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Minor=20enhancements.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../pyinterp/detail/interpolation/bicubic.hpp | 97 +++++++------------ .../pyinterp/detail/interpolation/cspline.hpp | 4 +- .../detail/interpolation/cspline_base.hpp | 18 ++-- .../detail/interpolation/cspline_periodic.hpp | 5 +- 4 files changed, 49 insertions(+), 75 deletions(-) diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/bicubic.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/bicubic.hpp index a6369d0c..a36fd859 100644 --- a/src/pyinterp/core/include/pyinterp/detail/interpolation/bicubic.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/bicubic.hpp @@ -28,8 +28,8 @@ class Bicubic : public Interpolator2D { /// @param x The point where the interpolation must be calculated. /// @param y The point where the interpolation must be calculated. /// @return The interpolated value at the coordinates x, y. - auto interpolate_(const Vector &xa, const Vector &ya, - const Matrix &za, const T &x, const T &y) const + constexpr auto interpolate_(const Vector &xa, const Vector &ya, + const Matrix &za, const T &x, const T &y) const -> T override; /// Compute the coefficients of the bicubic interpolation @@ -70,9 +70,10 @@ auto Bicubic::compute_coefficients(const Vector &xa, const Vector &ya, } template -auto Bicubic::interpolate_(const Vector &xa, const Vector &ya, - const Matrix &za, const T &x, const T &y) const - -> T { +constexpr auto Bicubic::interpolate_(const Vector &xa, + const Vector &ya, + const Matrix &za, const T &x, + const T &y) const -> T { auto search_x = this->search(xa, x); auto search_y = this->search(ya, y); if (!search_x || !search_y) { @@ -90,6 +91,7 @@ auto Bicubic::interpolate_(const Vector &xa, const Vector &ya, const auto z11 = za(i1, j1); const auto dx = x1 - x0; const auto dy = y1 - y0; + const auto dxdy = dx * dy; const auto t = (x - x0) / dx; const auto u = (y - y0) / dy; const auto zx00 = zx_(i0, j0) * dx; @@ -100,10 +102,10 @@ auto Bicubic::interpolate_(const Vector &xa, const Vector &ya, const auto zy01 = zy_(i0, j1) * dy; const auto zy10 = zy_(i1, j0) * dy; const auto zy11 = zy_(i1, j1) * dy; - const auto zxy00 = zxy_(i0, j0) * (dx * dy); - const auto zxy01 = zxy_(i0, j1) * (dx * dy); - const auto zxy10 = zxy_(i1, j0) * (dx * dy); - const auto zxy11 = zxy_(i1, j1) * (dx * dy); + const auto zxy00 = zxy_(i0, j0) * dxdy; + const auto zxy01 = zxy_(i0, j1) * dxdy; + const auto zxy10 = zxy_(i1, j0) * dxdy; + const auto zxy11 = zxy_(i1, j1) * dxdy; const auto t0 = 1; const auto t1 = t; const auto t2 = t * t; @@ -113,60 +115,29 @@ auto Bicubic::interpolate_(const Vector &xa, const Vector &ya, const auto u2 = u * u; const auto u3 = u * u2; - auto v = z00; - auto z = v * t0 * u0; - - v = zy00; - z += v * t0 * u1; - - v = 3 * (-z00 + z01) - 2 * zy00 - zy01; - z += v * t0 * u2; - - v = 2 * (z00 - z01) + zy00 + zy01; - z += v * t0 * u3; - - v = zx00; - z += v * t1 * u0; - - v = zxy00; - z += v * t1 * u1; - - v = 3 * (-zx00 + zx01) - 2 * zxy00 - zxy01; - z += v * t1 * u2; - - v = 2 * (zx00 - zx01) + zxy00 + zxy01; - z += v * t1 * u3; - - v = 3 * (-z00 + z10) - 2 * zx00 - zx10; - z += v * t2 * u0; - - v = 3 * (-zy00 + zy10) - 2 * zxy00 - zxy10; - z += v * t2 * u1; - - v = 9 * (z00 - z10 + z11 - z01) + 6 * (zx00 - zx01 + zy00 - zy10) + - 3 * (zx10 - zx11 - zy11 + zy01) + 4 * zxy00 + 2 * (zxy10 + zxy01) + zxy11; - z += v * t2 * u2; - - v = 6 * (-z00 + z10 - z11 + z01) + 4 * (-zx00 + zx01) + - 3 * (-zy00 + zy10 + zy11 - zy01) + 2 * (-zx10 + zx11 - zxy00 - zxy01) - - zxy10 - zxy11; - z += v * t2 * u3; - - v = 2 * (z00 - z10) + zx00 + zx10; - z += v * t3 * u0; - - v = 2 * (zy00 - zy10) + zxy00 + zxy10; - z += v * t3 * u1; - - v = 6 * (-z00 + z10 - z11 + z01) + 3 * (-zx00 - zx10 + zx11 + zx01) + - 4 * (-zy00 + zy10) + 2 * (zy11 - zy01 - zxy00 - zxy10) - zxy11 - zxy01; - z += v * t3 * u2; - - v = 4 * (z00 - z10 + z11 - z01) + - 2 * (zx00 + zx10 - zx11 - zx01 + zy00 - zy10 - zy11 + zy01) + zxy00 + - zxy10 + zxy11 + zxy01; - z += v * t3 * u3; - return z; + return t0 * (u0 * z00 + u1 * zy00 + u2 * (3 * (z01 - z00) - 2 * zy00 - zy01) + + u3 * (2 * (z00 - z01) + zy00 + zy01)) + + t1 * (u0 * zx00 + u1 * zxy00 + + u2 * (3 * (zx01 - zx00) - 2 * zxy00 - zxy01) + + u3 * (2 * (zx00 - zx01) + zxy00 + zxy01)) + + t2 * (u0 * (3 * (z10 - z00) - 2 * zx00 - zx10) + + u1 * (3 * (zy10 - zy00) - 2 * zxy00 - zxy10) + + u2 * (9 * (z00 - z01 - z10 + z11) + + 6 * (zx00 - zx01 + zy00 - zy10) + + 3 * (zx10 - zx11 + zy01 - zy11) + 4 * zxy00 + + 2 * (zxy01 + zxy10) + zxy11) + + u3 * (6 * (z01 - z00 + z10 - z11) + 4 * (zx01 - zx00) + + 3 * (zy10 - zy00 - zy01 + zy11) + + 2 * (zx11 - zx10 - zxy00 - zxy01) - zxy10 - zxy11)) + + t3 * (u0 * (2 * (z00 - z10) + zx00 + zx10) + + u1 * (zxy00 + zxy10 + 2 * (zy00 - zy10)) + + u2 * (6 * (z01 - z00 + z10 - z11) + 4 * (-zy00 + zy10) + + 3 * (zx01 - zx00 - zx10 + zx11) + + 2 * (zy11 - zy01 - zxy00 - zxy10) - zxy01 - zxy11) + + u3 * (4 * (z00 - z01 - z10 + z11) + + 2 * (zx00 - zx01 + zx10 - zx11 + zy00 + zy01 - zy10 - + zy11) + + zxy00 + zxy01 + zxy10 + zxy11)); } } // namespace pyinterp::detail::interpolation diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline.hpp index 693f883c..c8569a3c 100644 --- a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline.hpp @@ -29,14 +29,14 @@ class CSpline : public CSplineBase { /// @brief Solve a symmetric tridiagonal system /// @param x The solution of the system - auto solve_symmetric_tridiagonal(T *x) -> void; + constexpr auto solve_symmetric_tridiagonal(T *x) -> void; Vector c_; Vector d_; }; template -auto CSpline::solve_symmetric_tridiagonal(T *x) -> void { +constexpr auto CSpline::solve_symmetric_tridiagonal(T *x) -> void { const auto size = this->A_.rows(); const auto size_m1 = size - 1; const auto size_m2 = size - 2; diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_base.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_base.hpp index 8225d6cc..ad81467f 100644 --- a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_base.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_base.hpp @@ -46,8 +46,8 @@ class CSplineBase : public Interpolator1D { /// @param ya Y-coordinates of the data points. /// @param x The point where the interpolation must be calculated. /// @return The interpolated value at the point x. - auto interpolate_(const Vector &xa, const Vector &ya, const T &x) const - -> T override; + constexpr auto interpolate_(const Vector &xa, const Vector &ya, + const T &x) const -> T override; /// @brief Returns the derivative of the interpolation function at the point /// x. @@ -55,8 +55,8 @@ class CSplineBase : public Interpolator1D { /// @param ya Y-coordinates of the data points. /// @param x The point where the derivative must be calculated. /// @return The derivative of the interpolation function at the point x. - auto derivative_(const Vector &xa, const Vector &ya, const T &x) const - -> T override; + constexpr auto derivative_(const Vector &xa, const Vector &ya, + const T &x) const -> T override; protected: Matrix A_; @@ -65,8 +65,9 @@ class CSplineBase : public Interpolator1D { }; template -auto CSplineBase::interpolate_(const Vector &xa, const Vector &ya, - const T &x) const -> T { +constexpr auto CSplineBase::interpolate_(const Vector &xa, + const Vector &ya, + const T &x) const -> T { auto where = this->search(xa, x); if (!where) { return std::numeric_limits::quiet_NaN(); @@ -85,8 +86,9 @@ auto CSplineBase::interpolate_(const Vector &xa, const Vector &ya, } template -auto CSplineBase::derivative_(const Vector &xa, const Vector &ya, - const T &x) const -> T { +constexpr auto CSplineBase::derivative_(const Vector &xa, + const Vector &ya, + const T &x) const -> T { auto where = this->search(xa, x); if (!where) { return std::numeric_limits::quiet_NaN(); diff --git a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_periodic.hpp b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_periodic.hpp index 38ec130b..28479616 100644 --- a/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_periodic.hpp +++ b/src/pyinterp/core/include/pyinterp/detail/interpolation/cspline_periodic.hpp @@ -27,7 +27,7 @@ class CSplinePeriodic : public CSplineBase { /// Solve a symmetric cyclic tridiagonal system /// @param x The solution of the system - auto solve_symmetric_cyclic_tridiagonal(T *x) -> void; + constexpr auto solve_symmetric_cyclic_tridiagonal(T *x) -> void; Vector alpha_{}; Vector gamma_{}; @@ -37,7 +37,8 @@ class CSplinePeriodic : public CSplineBase { }; template -auto CSplinePeriodic::solve_symmetric_cyclic_tridiagonal(T *x) -> void { +constexpr auto CSplinePeriodic::solve_symmetric_cyclic_tridiagonal(T *x) + -> void { const auto size = this->A_.rows(); const auto size_m1 = size - 1; const auto size_m2 = size - 2;