From 1b0bc3f8c7ccf88fdca65167612c2bbc7186736d Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Fri, 27 Sep 2024 20:34:06 -0400 Subject: [PATCH] Use check_mpi_call from mpi library and remove local implementation --- c++/nda/mpi/broadcast.hpp | 5 ++--- c++/nda/mpi/gather.hpp | 8 ++++---- c++/nda/mpi/reduce.hpp | 10 +++++----- c++/nda/mpi/scatter.hpp | 2 +- c++/nda/mpi/utils.hpp | 21 ++++++++++----------- test/c++/nda_mpi.cpp | 7 ------- 6 files changed, 22 insertions(+), 31 deletions(-) diff --git a/c++/nda/mpi/broadcast.hpp b/c++/nda/mpi/broadcast.hpp index d932c8e4..1af32c69 100644 --- a/c++/nda/mpi/broadcast.hpp +++ b/c++/nda/mpi/broadcast.hpp @@ -68,10 +68,9 @@ namespace nda { { detail::check_mpi_contiguous_layout(a, "mpi_broadcast"); auto dims = a.shape(); - detail::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type::get(), root, comm.get()), - "MPI_Bcast"); + mpi::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type::get(), root, comm.get()), "MPI_Bcast"); if (comm.rank() != root) { resize_or_check_if_view(a, dims); } - detail::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type::get(), root, comm.get()), "MPI_Bcast"); + mpi::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type::get(), root, comm.get()), "MPI_Bcast"); } } // namespace nda diff --git a/c++/nda/mpi/gather.hpp b/c++/nda/mpi/gather.hpp index 5727303c..a96cae8b 100644 --- a/c++/nda/mpi/gather.hpp +++ b/c++/nda/mpi/gather.hpp @@ -134,20 +134,20 @@ struct mpi::lazy { int sendcount = rhs.size(); auto mpi_int_type = mpi::mpi_type::get(); if (!all) - check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather"); + mpi::check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather"); else - check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather"); + mpi::check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather"); for (int r = 0; r < comm.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; // gather the data auto mpi_value_type = mpi::mpi_type::get(); if (!all) - check_mpi_call( + mpi::check_mpi_call( MPI_Gatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, root, comm.get()), "MPI_Gatherv"); else - check_mpi_call( + mpi::check_mpi_call( MPI_Allgatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, comm.get()), "MPI_Allgatherv"); } diff --git a/c++/nda/mpi/reduce.hpp b/c++/nda/mpi/reduce.hpp index 55c05f25..879a708b 100644 --- a/c++/nda/mpi/reduce.hpp +++ b/c++/nda/mpi/reduce.hpp @@ -142,15 +142,15 @@ struct mpi::lazy { auto mpi_value_type = mpi::mpi_type::get(); if (!all) { if (in_place) - check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()), - "MPI_Reduce"); + mpi::check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()), + "MPI_Reduce"); else - check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce"); + mpi::check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce"); } else { if (in_place) - check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); else - check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); } } } diff --git a/c++/nda/mpi/scatter.hpp b/c++/nda/mpi/scatter.hpp index 37df53dc..900c744a 100644 --- a/c++/nda/mpi/scatter.hpp +++ b/c++/nda/mpi/scatter.hpp @@ -138,7 +138,7 @@ struct mpi::lazy { // scatter the data auto mpi_value_type = mpi::mpi_type::get(); - check_mpi_call( + mpi::check_mpi_call( MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), size, mpi_value_type, root, comm.get()), "MPI_Scatterv"); } diff --git a/c++/nda/mpi/utils.hpp b/c++/nda/mpi/utils.hpp index 4247d48f..842f7ca8 100644 --- a/c++/nda/mpi/utils.hpp +++ b/c++/nda/mpi/utils.hpp @@ -32,11 +32,6 @@ namespace nda::detail { - // Check the success of an MPI call, otherwise throw an exception. - inline void check_mpi_call(int errcode, const std::string &mpi_routine) { - if (errcode != MPI_SUCCESS) NDA_RUNTIME_ERROR << "MPI error " << errcode << " in MPI routine " << mpi_routine; - } - // Check if the layout of an array/view is contiguous with positive strides, otherwise throw an exception. template void check_mpi_contiguous_layout(const A &a, const std::string &func) { @@ -58,8 +53,8 @@ namespace nda::detail { int rank = a.rank; int max_rank = 0; int min_rank = 0; - check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type::get(), MPI_MAX, comm.get()), "MPI_Allreduce"); - check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type::get(), MPI_MIN, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type::get(), MPI_MAX, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type::get(), MPI_MIN, comm.get()), "MPI_Allreduce"); if (max_rank != min_rank) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes"; #endif } @@ -73,8 +68,10 @@ namespace nda::detail { auto shape = a.shape(); auto max_shape = shape; auto min_shape = shape; - check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), "MPI_Allreduce"); - check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), + "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), + "MPI_Allreduce"); if (max_shape != min_shape) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes"; #endif } @@ -90,8 +87,10 @@ namespace nda::detail { auto shape = nda::stdutil::front_pop(a.shape()); auto max_shape = shape; auto min_shape = shape; - check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), "MPI_Allreduce"); - check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), + "MPI_Allreduce"); + mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), + "MPI_Allreduce"); if (max_shape != min_shape) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes_save_first: nda::Array shapes are not equal on all processes except for the first dimension"; #endif diff --git a/test/c++/nda_mpi.cpp b/test/c++/nda_mpi.cpp index 615a46b8..2b12635a 100644 --- a/test/c++/nda_mpi.cpp +++ b/test/c++/nda_mpi.cpp @@ -52,13 +52,6 @@ struct NDAMpi : public ::testing::Test { mpi::communicator comm; }; -TEST_F(NDAMpi, CheckMPICall) { - // test if check_mpi_call throws an exception - try { - nda::detail::check_mpi_call(MPI_SUCCESS - 1, "not_a_real_mpi_call"); - } catch (nda::runtime_error const &e) { std::cout << e.what() << std::endl; } -} - TEST_F(NDAMpi, ExpectEqualRanks) { // test the expect_equal_ranks function if (comm.size() > 1) {