Skip to content

Commit

Permalink
Review and simplify nda::concatenate implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Wentzell committed Apr 3, 2024
1 parent 20b5056 commit 36cb7b5
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 36 deletions.
58 changes: 31 additions & 27 deletions c++/nda/basic_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,41 +377,45 @@ namespace nda {
}

// ------------------------------- concatenate --------------------------------------------
// slice in all dimensions but Axis
template <auto Axis, Array A>
auto all_view_except(A const &a, auto const &arg) {
auto slice_or_arg = [&arg](auto x) {
if constexpr (Axis == decltype(x)::value)
return arg;
else
return range::all;
};

return [&]<auto... Is>(std::index_sequence<Is...>) { return a(slice_or_arg(std::integral_constant<size_t, Is>{})...); }
(std::make_index_sequence<A::rank>{});
};

// numpy style concatenation
template <auto Axis, Array A0, Array... A>
/**
* Join a sequence of arrays along an existing axis.
*
* The arrays must have the same value_type and also shape,
* except in the dimension corresponding to axis (the first, by default).
*
* @tparam Axis The axis along which to concatenate (default: 0)
* @tparam A0 Type of the first array
* @tparam A Types of the subsequent arrays
* @param a0 The first array
* @param a The subsequent arrays
* @return New array with the concatenated data
*/
template <size_t Axis = 0, Array A0, Array... A>
auto concatenate(A0 const &a0, A const &...a) {
// sanity checks
static_assert(A0::rank >= Axis);
static_assert(((A0::rank == A::rank) and ... and true));
auto constexpr rank = A0::rank;
static_assert(Axis < rank);
static_assert(((rank == A::rank) and ... and true));
static_assert(((std::is_same_v<get_value_t<A0>, get_value_t<A>>) and ... and true));
for (auto ax [[maybe_unused]] : range(rank)) { EXPECTS(ax == Axis or ((a0.extent(ax) == a.extent(ax)) and ... and true)); }

for (auto const ax : range(A0::rank)) {
if (not (ax == Axis)) { assert(((a0.shape()[ax] == a.shape()[ax]) and ... and true)); }
}

// build concatenated array
// construct concatenated array
auto new_shape = a0.shape();
long offset = 0;
new_shape[Axis] = new_shape[Axis] + ((a.shape()[Axis] + ... + 0));
array<get_value_t<A0>, A0::rank> new_array(new_shape);
new_shape[Axis] = (a.extent(Axis) + ... + new_shape[Axis]);
auto new_array = array<get_value_t<A0>, rank>(new_shape);

// slicing helper function
auto slice_Axis = [](Array auto &a, range r) {
auto all_or_range = std::make_tuple(range::all, r);
return [&]<auto... Is>(std::index_sequence<Is...>) { return a(std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
};

// initialize concatenated array
long offset = 0;
for (auto const &a_view : {basic_array_view(a0), basic_array_view(a)...}) {
all_view_except<Axis>(new_array, range(offset, offset + a_view.shape()[Axis])) = a_view;
offset += a_view.shape()[Axis];
slice_Axis(new_array, range(offset, offset + a_view.extent(Axis))) = a_view;
offset += a_view.extent(Axis);
}

return new_array;
Expand Down
10 changes: 1 addition & 9 deletions test/c++/nda_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,6 @@ TEST(Array, Concatenate) { //NOLINT
for (int k = 0; k < 6; ++k) { c(i, j, k) = i + 10 * j + 102 * k; }
}

// test all_view_except
auto const a_view_except = all_view_except<1>(a, range(1, 3));
EXPECT_EQ(a_view_except.shape()[1], 2);

for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 4; ++k) { EXPECT_EQ(a_view_except(i, j, k), a(i, j + 1, k)); }

// test concatenate
auto const abc_axis2_concat = concatenate<2>(a, b, c);
EXPECT_EQ(abc_axis2_concat.shape()[2], 15);
Expand All @@ -441,4 +433,4 @@ TEST(Array, Concatenate) { //NOLINT
EXPECT_EQ(abc_axis2_concat(i, j, k), c(i, j, k - 9));
}
}
}
}

0 comments on commit 36cb7b5

Please sign in to comment.