From 36cb7b56d28874e09bcf56fa6258e7a8d993bc2e Mon Sep 17 00:00:00 2001 From: Nils Wentzell Date: Wed, 3 Apr 2024 17:03:47 -0400 Subject: [PATCH] Review and simplify nda::concatenate implementation --- c++/nda/basic_functions.hpp | 58 ++++++++++++++++++++----------------- test/c++/nda_basic.cpp | 10 +------ 2 files changed, 32 insertions(+), 36 deletions(-) diff --git a/c++/nda/basic_functions.hpp b/c++/nda/basic_functions.hpp index 450ec6fe5..fdbae4a90 100644 --- a/c++/nda/basic_functions.hpp +++ b/c++/nda/basic_functions.hpp @@ -377,41 +377,45 @@ namespace nda { } // ------------------------------- concatenate -------------------------------------------- - // slice in all dimensions but Axis - template - auto all_view_except(A const &a, auto const &arg) { - auto slice_or_arg = [&arg](auto x) { - if constexpr (Axis == decltype(x)::value) - return arg; - else - return range::all; - }; - - return [&](std::index_sequence) { return a(slice_or_arg(std::integral_constant{})...); } - (std::make_index_sequence{}); - }; - // numpy style concatenation - template + /** + * Join a sequence of arrays along an existing axis. + * + * The arrays must have the same value_type and also shape, + * except in the dimension corresponding to axis (the first, by default). + * + * @tparam Axis The axis along which to concatenate (default: 0) + * @tparam A0 Type of the first array + * @tparam A Types of the subsequent arrays + * @param a0 The first array + * @param a The subsequent arrays + * @return New array with the concatenated data + */ + template auto concatenate(A0 const &a0, A const &...a) { // sanity checks - static_assert(A0::rank >= Axis); - static_assert(((A0::rank == A::rank) and ... and true)); + auto constexpr rank = A0::rank; + static_assert(Axis < rank); + static_assert(((rank == A::rank) and ... and true)); static_assert(((std::is_same_v, get_value_t>) and ... and true)); + for (auto ax [[maybe_unused]] : range(rank)) { EXPECTS(ax == Axis or ((a0.extent(ax) == a.extent(ax)) and ... and true)); } - for (auto const ax : range(A0::rank)) { - if (not (ax == Axis)) { assert(((a0.shape()[ax] == a.shape()[ax]) and ... and true)); } - } - - // build concatenated array + // construct concatenated array auto new_shape = a0.shape(); - long offset = 0; - new_shape[Axis] = new_shape[Axis] + ((a.shape()[Axis] + ... + 0)); - array, A0::rank> new_array(new_shape); + new_shape[Axis] = (a.extent(Axis) + ... + new_shape[Axis]); + auto new_array = array, rank>(new_shape); + + // slicing helper function + auto slice_Axis = [](Array auto &a, range r) { + auto all_or_range = std::make_tuple(range::all, r); + return [&](std::index_sequence) { return a(std::get(all_or_range)...); }(std::make_index_sequence{}); + }; + // initialize concatenated array + long offset = 0; for (auto const &a_view : {basic_array_view(a0), basic_array_view(a)...}) { - all_view_except(new_array, range(offset, offset + a_view.shape()[Axis])) = a_view; - offset += a_view.shape()[Axis]; + slice_Axis(new_array, range(offset, offset + a_view.extent(Axis))) = a_view; + offset += a_view.extent(Axis); } return new_array; diff --git a/test/c++/nda_basic.cpp b/test/c++/nda_basic.cpp index 3bf133eea..2d00e7d3e 100644 --- a/test/c++/nda_basic.cpp +++ b/test/c++/nda_basic.cpp @@ -418,14 +418,6 @@ TEST(Array, Concatenate) { //NOLINT for (int k = 0; k < 6; ++k) { c(i, j, k) = i + 10 * j + 102 * k; } } - // test all_view_except - auto const a_view_except = all_view_except<1>(a, range(1, 3)); - EXPECT_EQ(a_view_except.shape()[1], 2); - - for (int i = 0; i < 2; ++i) - for (int j = 0; j < 2; ++j) - for (int k = 0; k < 4; ++k) { EXPECT_EQ(a_view_except(i, j, k), a(i, j + 1, k)); } - // test concatenate auto const abc_axis2_concat = concatenate<2>(a, b, c); EXPECT_EQ(abc_axis2_concat.shape()[2], 15); @@ -441,4 +433,4 @@ TEST(Array, Concatenate) { //NOLINT EXPECT_EQ(abc_axis2_concat(i, j, k), c(i, j, k - 9)); } } -} \ No newline at end of file +}