Skip to content

Commit

Permalink
Use check_mpi_call from mpi library and remove local implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Thoemi09 committed Sep 28, 2024
1 parent eccc56b commit 1b0bc3f
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 31 deletions.
5 changes: 2 additions & 3 deletions c++/nda/mpi/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ namespace nda {
{
detail::check_mpi_contiguous_layout(a, "mpi_broadcast");
auto dims = a.shape();
detail::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type<typename decltype(dims)::value_type>::get(), root, comm.get()),
"MPI_Bcast");
mpi::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type<typename decltype(dims)::value_type>::get(), root, comm.get()), "MPI_Bcast");
if (comm.rank() != root) { resize_or_check_if_view(a, dims); }
detail::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type<typename A::value_type>::get(), root, comm.get()), "MPI_Bcast");
mpi::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type<typename A::value_type>::get(), root, comm.get()), "MPI_Bcast");
}

} // namespace nda
8 changes: 4 additions & 4 deletions c++/nda/mpi/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,20 +134,20 @@ struct mpi::lazy<mpi::tag::gather, A> {
int sendcount = rhs.size();
auto mpi_int_type = mpi::mpi_type<int>::get();
if (!all)
check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather");
mpi::check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather");
else
check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather");
mpi::check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather");

for (int r = 0; r < comm.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];

// gather the data
auto mpi_value_type = mpi::mpi_type<value_type>::get();
if (!all)
check_mpi_call(
mpi::check_mpi_call(
MPI_Gatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, root, comm.get()),
"MPI_Gatherv");
else
check_mpi_call(
mpi::check_mpi_call(
MPI_Allgatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, comm.get()),
"MPI_Allgatherv");
}
Expand Down
10 changes: 5 additions & 5 deletions c++/nda/mpi/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,15 @@ struct mpi::lazy<mpi::tag::reduce, A> {
auto mpi_value_type = mpi::mpi_type<value_type>::get();
if (!all) {
if (in_place)
check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()),
"MPI_Reduce");
mpi::check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()),
"MPI_Reduce");
else
check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce");
mpi::check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce");
} else {
if (in_place)
check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
else
check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion c++/nda/mpi/scatter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct mpi::lazy<mpi::tag::scatter, A> {

// scatter the data
auto mpi_value_type = mpi::mpi_type<value_type>::get();
check_mpi_call(
mpi::check_mpi_call(
MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), size, mpi_value_type, root, comm.get()),
"MPI_Scatterv");
}
Expand Down
21 changes: 10 additions & 11 deletions c++/nda/mpi/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@

namespace nda::detail {

// Check the success of an MPI call, otherwise throw an exception.
inline void check_mpi_call(int errcode, const std::string &mpi_routine) {
if (errcode != MPI_SUCCESS) NDA_RUNTIME_ERROR << "MPI error " << errcode << " in MPI routine " << mpi_routine;
}

// Check if the layout of an array/view is contiguous with positive strides, otherwise throw an exception.
template <nda::Array A>
void check_mpi_contiguous_layout(const A &a, const std::string &func) {
Expand All @@ -58,8 +53,8 @@ namespace nda::detail {
int rank = a.rank;
int max_rank = 0;
int min_rank = 0;
check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type<int>::get(), MPI_MAX, comm.get()), "MPI_Allreduce");
check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type<int>::get(), MPI_MIN, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type<int>::get(), MPI_MAX, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type<int>::get(), MPI_MIN, comm.get()), "MPI_Allreduce");
if (max_rank != min_rank) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes";
#endif
}
Expand All @@ -73,8 +68,10 @@ namespace nda::detail {
auto shape = a.shape();
auto max_shape = shape;
auto min_shape = shape;
check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()), "MPI_Allreduce");
check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()),
"MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()),
"MPI_Allreduce");
if (max_shape != min_shape) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes";
#endif
}
Expand All @@ -90,8 +87,10 @@ namespace nda::detail {
auto shape = nda::stdutil::front_pop(a.shape());
auto max_shape = shape;
auto min_shape = shape;
check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()), "MPI_Allreduce");
check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()), "MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MAX, comm.get()),
"MPI_Allreduce");
mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type<long>::get(), MPI_MIN, comm.get()),
"MPI_Allreduce");
if (max_shape != min_shape)
NDA_RUNTIME_ERROR << "Error in expect_equal_shapes_save_first: nda::Array shapes are not equal on all processes except for the first dimension";
#endif
Expand Down
7 changes: 0 additions & 7 deletions test/c++/nda_mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ struct NDAMpi : public ::testing::Test {
mpi::communicator comm;
};

TEST_F(NDAMpi, CheckMPICall) {
// test if check_mpi_call throws an exception
try {
nda::detail::check_mpi_call(MPI_SUCCESS - 1, "not_a_real_mpi_call");
} catch (nda::runtime_error const &e) { std::cout << e.what() << std::endl; }
}

TEST_F(NDAMpi, ExpectEqualRanks) {
// test the expect_equal_ranks function
if (comm.size() > 1) {
Expand Down

0 comments on commit 1b0bc3f

Please sign in to comment.