Skip to content

Commit

Permalink
added concatenate impl
Browse files Browse the repository at this point in the history
  • Loading branch information
dominikkiese committed Mar 30, 2024
1 parent 6ef40ec commit dbbfdf5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 0 deletions.
40 changes: 40 additions & 0 deletions c++/nda/basic_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,4 +376,44 @@ namespace nda {
return std::make_tuple(n_blocks, block_size, block_str);
}

// ------------------------------- concatenate --------------------------------------------
// slice in all dimensions but Axis
template <auto Axis, Array A>
auto all_view_except(A const &a, auto const &arg) {
auto slice_or_arg = [&arg](auto x) {
if constexpr (Axis == decltype(x)::value)
return arg;
else
return range::all;
};

return [&]<auto... Is>(std::index_sequence<Is...>) { return a(slice_or_arg(std::integral_constant<size_t, Is>{})...); }
(std::make_index_sequence<A::rank>{});
};

// numpy style concatenation
template <auto Axis, Array A0, Array... A>
auto concatenate(A0 const &a0, A const &...a) {
// sanity checks
static_assert(A0::rank >= Axis);
static_assert(((A0::rank == A::rank) and ... and true));
static_assert(((std::is_same_v<get_value_t<A0>, get_value_t<A>>) and ... and true));

for (auto const ax : range(A0::rank)) {
if (not (ax == Axis)) { assert(((a0.shape()[ax] == a.shape()[ax]) and ... and true)); }
}

// build concatenated array
auto new_shape = a0.shape();
long offset = 0;
new_shape[Axis] = new_shape[Axis] + ((a.shape()[Axis] + ... + 0));
array<get_value_t<A0>, A0::rank> new_array(new_shape);

for (auto const &a_view : {basic_array_view(a0), basic_array_view(a)...}) {
all_view_except<Axis>(new_array, range(offset, offset + a_view.shape()[Axis])) = a_view;
offset += a_view.shape()[Axis];
}

return new_array;
};
} // namespace nda
41 changes: 41 additions & 0 deletions test/c++/nda_basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,44 @@ TEST(Assign, CrossStrideOrder) { //NOLINT
for (int j = 0; j < 3; ++j)
for (int k = 0; k < 4; ++k) { EXPECT_EQ(af(i, j, k), i + 10 * j + 100 * k); }
}

// =============================================================

TEST(Array, Concatenate) { //NOLINT

// some dummy arrays
nda::array<long, 3> a(2, 3, 4);
nda::array<long, 3> b(2, 3, 5);
nda::array<long, 3> c(2, 3, 6);

for (int i = 0; i < 2; ++i)
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 4; ++k) { a(i, j, k) = i + 10 * j + 100 * k; }
for (int k = 0; k < 5; ++k) { b(i, j, k) = i + 10 * j + 101 * k; }
for (int k = 0; k < 6; ++k) { c(i, j, k) = i + 10 * j + 102 * k; }
}

// test all_view_except
auto const a_view_except = all_view_except<1>(a, range(1, 3));
EXPECT_EQ(a_view_except.shape()[1], 2);

for (int i = 0; i < 2; ++i)
for (int j = 0; j < 2; ++j)
for (int k = 0; k < 4; ++k) { EXPECT_EQ(a_view_except(i, j, k), a(i, j + 1, k)); }

// test concatenate
auto const abc_axis2_concat = concatenate<2>(a, b, c);
EXPECT_EQ(abc_axis2_concat.shape()[2], 15);

for (int i = 0; i < 2; ++i)
for (int j = 0; j < 3; ++j) {
for (int k = 0; k < 15; ++k) {
if (k < 4)
EXPECT_EQ(abc_axis2_concat(i, j, k), a(i, j, k));
else if (k < 9)
EXPECT_EQ(abc_axis2_concat(i, j, k), b(i, j, k - 4));
else if (k < 15)
EXPECT_EQ(abc_axis2_concat(i, j, k), c(i, j, k - 9));
}
}
}

0 comments on commit dbbfdf5

Please sign in to comment.