Skip to content

Commit

Permalink
Use mpi::all_equal to simplify checks in the mpi interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 committed Sep 28, 2024
1 parent 1b0bc3f commit 23f7581
Showing 1 changed file with 3 additions and 25 deletions.
28 changes: 3 additions & 25 deletions c++/nda/mpi/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,30 +49,16 @@ namespace nda::detail {
template <nda::Array A>
void expect_equal_ranks([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) {
#if !defined(NDEBUG) || defined(NDA_DEBUG)
if (not mpi::has_env) return;
int rank = a.rank;
int max_rank = 0;
int min_rank = 0;
mpi::check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type<int>::get(), MPI_MAX, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type<int>::get(), MPI_MIN, comm.get()), "MPI_Allreduce");
if (max_rank != min_rank) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes";
if (!mpi::all_equal(a.rank, comm)) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes";
#endif
}

// Check if the shape of arrays/views are the same on all processes, otherwise throw an exception.
template <nda::Array A>
void expect_equal_shapes([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) {
#if !defined(NDEBUG) || defined(NDA_DEBUG)
if (not mpi::has_env) return;
expect_equal_ranks(a, comm, root);
auto shape = a.shape();
auto max_shape = shape;
auto min_shape = shape;
mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()),
"MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()),
"MPI_Allreduce");
if (max_shape != min_shape) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes";
if (!mpi::all_equal(a.shape(), comm)) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes";
#endif
}

Expand All @@ -82,16 +68,8 @@ namespace nda::detail {
void expect_equal_shapes_save_first([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) {
#if !defined(NDEBUG) || defined(NDA_DEBUG)
if constexpr (A::rank == 1) return;
if (not mpi::has_env) return;
expect_equal_ranks(a, comm, root);
auto shape = nda::stdutil::front_pop(a.shape());
auto max_shape = shape;
auto min_shape = shape;
mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()),
"MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()),
"MPI_Allreduce");
if (max_shape != min_shape)
if (!mpi::all_equal(nda::stdutil::front_pop(a.shape()), comm))
NDA_RUNTIME_ERROR << "Error in expect_equal_shapes_save_first: nda::Array shapes are not equal on all processes except for the first dimension";
#endif
}
Expand Down

0 comments on commit 23f7581

Please sign in to comment.