diff --git a/c++/nda/basic_functions.hpp b/c++/nda/basic_functions.hpp index 558266413..450ec6fe5 100644 --- a/c++/nda/basic_functions.hpp +++ b/c++/nda/basic_functions.hpp @@ -376,4 +376,44 @@ namespace nda { return std::make_tuple(n_blocks, block_size, block_str); } + // ------------------------------- concatenate -------------------------------------------- + // slice in all dimensions but Axis + template + auto all_view_except(A const &a, auto const &arg) { + auto slice_or_arg = [&arg](auto x) { + if constexpr (Axis == decltype(x)::value) + return arg; + else + return range::all; + }; + + return [&](std::index_sequence) { return a(slice_or_arg(std::integral_constant{})...); } + (std::make_index_sequence{}); + }; + + // numpy style concatenation + template + auto concatenate(A0 const &a0, A const &...a) { + // sanity checks + static_assert(A0::rank >= Axis); + static_assert(((A0::rank == A::rank) and ... and true)); + static_assert(((std::is_same_v, get_value_t>) and ... and true)); + + for (auto const ax : range(A0::rank)) { + if (not (ax == Axis)) { assert(((a0.shape()[ax] == a.shape()[ax]) and ... and true)); } + } + + // build concatenated array + auto new_shape = a0.shape(); + long offset = 0; + new_shape[Axis] = new_shape[Axis] + ((a.shape()[Axis] + ... + 0)); + array, A0::rank> new_array(new_shape); + + for (auto const &a_view : {basic_array_view(a0), basic_array_view(a)...}) { + all_view_except(new_array, range(offset, offset + a_view.shape()[Axis])) = a_view; + offset += a_view.shape()[Axis]; + } + + return new_array; + }; } // namespace nda diff --git a/test/c++/nda_basic.cpp b/test/c++/nda_basic.cpp index 09194b8e4..3bf133eea 100644 --- a/test/c++/nda_basic.cpp +++ b/test/c++/nda_basic.cpp @@ -401,3 +401,44 @@ TEST(Assign, CrossStrideOrder) { //NOLINT for (int j = 0; j < 3; ++j) for (int k = 0; k < 4; ++k) { EXPECT_EQ(af(i, j, k), i + 10 * j + 100 * k); } } + +// ============================================================= + +TEST(Array, Concatenate) { //NOLINT + + // some dummy arrays + nda::array a(2, 3, 4); + nda::array b(2, 3, 5); + nda::array c(2, 3, 6); + + for (int i = 0; i < 2; ++i) + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 4; ++k) { a(i, j, k) = i + 10 * j + 100 * k; } + for (int k = 0; k < 5; ++k) { b(i, j, k) = i + 10 * j + 101 * k; } + for (int k = 0; k < 6; ++k) { c(i, j, k) = i + 10 * j + 102 * k; } + } + + // test all_view_except + auto const a_view_except = all_view_except<1>(a, range(1, 3)); + EXPECT_EQ(a_view_except.shape()[1], 2); + + for (int i = 0; i < 2; ++i) + for (int j = 0; j < 2; ++j) + for (int k = 0; k < 4; ++k) { EXPECT_EQ(a_view_except(i, j, k), a(i, j + 1, k)); } + + // test concatenate + auto const abc_axis2_concat = concatenate<2>(a, b, c); + EXPECT_EQ(abc_axis2_concat.shape()[2], 15); + + for (int i = 0; i < 2; ++i) + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 15; ++k) { + if (k < 4) + EXPECT_EQ(abc_axis2_concat(i, j, k), a(i, j, k)); + else if (k < 9) + EXPECT_EQ(abc_axis2_concat(i, j, k), b(i, j, k - 4)); + else if (k < 15) + EXPECT_EQ(abc_axis2_concat(i, j, k), c(i, j, k - 9)); + } + } +} \ No newline at end of file