From 671f0a8a711253a655a65e4155bfc41bb2a5e595 Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Fri, 27 Sep 2024 21:54:02 -0400 Subject: [PATCH] Simplify logic in mpi interface by replacing direct calls to the MPI C library --- c++/nda/mpi/broadcast.hpp | 10 ++++---- c++/nda/mpi/gather.hpp | 52 ++++++++++++++------------------------- c++/nda/mpi/reduce.hpp | 21 ++++++---------- c++/nda/mpi/scatter.hpp | 39 +++++++++-------------------- c++/nda/mpi/utils.hpp | 6 +++-- 5 files changed, 47 insertions(+), 81 deletions(-) diff --git a/c++/nda/mpi/broadcast.hpp b/c++/nda/mpi/broadcast.hpp index 1af32c69..69d12973 100644 --- a/c++/nda/mpi/broadcast.hpp +++ b/c++/nda/mpi/broadcast.hpp @@ -23,13 +23,13 @@ #include "./utils.hpp" #include "../basic_functions.hpp" -#include "../concepts.hpp" -#include "../exceptions.hpp" #include "../traits.hpp" -#include #include +#include +#include + namespace nda { /** @@ -68,9 +68,9 @@ namespace nda { { detail::check_mpi_contiguous_layout(a, "mpi_broadcast"); auto dims = a.shape(); - mpi::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type::get(), root, comm.get()), "MPI_Bcast"); + mpi::broadcast(dims, comm, root); if (comm.rank() != root) { resize_or_check_if_view(a, dims); } - mpi::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type::get(), root, comm.get()), "MPI_Bcast"); + mpi::broadcast_range(std::span{a.data(), static_cast(a.size())}, comm, root); } } // namespace nda diff --git a/c++/nda/mpi/gather.hpp b/c++/nda/mpi/gather.hpp index a96cae8b..725dc050 100644 --- a/c++/nda/mpi/gather.hpp +++ b/c++/nda/mpi/gather.hpp @@ -24,16 +24,17 @@ #include "./utils.hpp" #include "../basic_functions.hpp" #include "../concepts.hpp" -#include "../exceptions.hpp" #include "../stdutil/array.hpp" #include "../traits.hpp" -#include #include +#include +#include +#include +#include #include #include -#include /** * @ingroup av_mpi @@ -69,6 +70,9 @@ struct mpi::lazy { /// Should all processes receive the result. const bool all{false}; // NOLINT (const is fine here) + /// Size of the gathered array/view. + mutable long gathered_size{0}; + /** * @brief Compute the shape of the nda::ArrayInitializer object. * @@ -83,13 +87,10 @@ struct mpi::lazy { * @return Shape of the nda::ArrayInitializer object. */ [[nodiscard]] auto shape() const { - auto dims = rhs.shape(); - if (!all) { - dims[0] = mpi::reduce(dims[0], comm, root); - if (comm.rank() != root) dims = nda::stdutil::make_initialized_array(0l); - } else { - dims[0] = mpi::all_reduce(dims[0], comm); - } + auto dims = rhs.shape(); + dims[0] = mpi::all_reduce(dims[0], comm); + gathered_size = std::accumulate(dims.begin(), dims.end(), 1l, std::multiplies<>()); + if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array(0l); return dims; } @@ -117,8 +118,9 @@ struct mpi::lazy { return; } - // get target shape and resize or check the target array/view - auto dims = shape(); + // get target shape, resize or check the target array/view and prepare output span + auto dims = shape(); + auto target_span = std::span{target.data(), 0}; if (all || (comm.rank() == root)) { // check if the target array/view can be used in the MPI call check_mpi_contiguous_layout(target, "mpi_gather"); @@ -126,30 +128,14 @@ struct mpi::lazy { // resize/check the size of the target array/view nda::resize_or_check_if_view(target, dims); - } - // gather receive counts and memory displacements - auto recvcounts = std::vector(comm.size()); - auto displs = std::vector(comm.size() + 1, 0); - int sendcount = rhs.size(); - auto mpi_int_type = mpi::mpi_type::get(); - if (!all) - mpi::check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather"); - else - mpi::check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather"); - - for (int r = 0; r < comm.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r]; + // prepare the output span + target_span = std::span{target.data(), static_cast(target.size())}; + } // gather the data - auto mpi_value_type = mpi::mpi_type::get(); - if (!all) - mpi::check_mpi_call( - MPI_Gatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, root, comm.get()), - "MPI_Gatherv"); - else - mpi::check_mpi_call( - MPI_Allgatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, comm.get()), - "MPI_Allgatherv"); + auto rhs_span = std::span{rhs.data(), static_cast(rhs.size())}; + mpi::gather_range(rhs_span, target_span, comm, root, all, gathered_size); } }; diff --git a/c++/nda/mpi/reduce.hpp b/c++/nda/mpi/reduce.hpp index 879a708b..14bb7c99 100644 --- a/c++/nda/mpi/reduce.hpp +++ b/c++/nda/mpi/reduce.hpp @@ -33,6 +33,8 @@ #include #include +#include +#include #include #include @@ -136,21 +138,12 @@ struct mpi::lazy { } // reduce the data - void *target_ptr = (void *)target.data(); - void *rhs_ptr = (void *)rhs.data(); - auto count = rhs.size(); - auto mpi_value_type = mpi::mpi_type::get(); - if (!all) { - if (in_place) - mpi::check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()), - "MPI_Reduce"); - else - mpi::check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce"); + auto target_span = std::span{target.data(), static_cast(target.size())}; + if (in_place) { + mpi::reduce_in_place_range(target_span, comm, root, all, op); } else { - if (in_place) - mpi::check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); - else - mpi::check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce"); + auto rhs_span = std::span{rhs.data(), static_cast(rhs.size())}; + mpi::reduce_range(rhs_span, target_span, comm, root, all, op); } } } diff --git a/c++/nda/mpi/scatter.hpp b/c++/nda/mpi/scatter.hpp index 900c744a..336130e4 100644 --- a/c++/nda/mpi/scatter.hpp +++ b/c++/nda/mpi/scatter.hpp @@ -21,23 +21,19 @@ #pragma once -#include "./broadcast.hpp" #include "./utils.hpp" -#include "../basic_functions.hpp" #include "../concepts.hpp" -#include "../exceptions.hpp" -#include "../stdutil/array.hpp" #include "../traits.hpp" #include #include -#include +#include #include #include +#include #include #include -#include /** * @ingroup av_mpi @@ -73,6 +69,9 @@ struct mpi::lazy { /// Should all processes receive the result. (doesn't make sense for scatter) const bool all{false}; // NOLINT (const is fine here) + /// Size of the array/view to be scattered. + mutable long scatter_size{0}; + /** * @brief Compute the shape of the nda::ArrayInitializer object. * @@ -87,10 +86,10 @@ struct mpi::lazy { * @return Shape of the nda::ArrayInitializer object. */ [[nodiscard]] auto shape() const { - auto dims = rhs.shape(); - auto dims_v = nda::basic_array_view(dims); - mpi::broadcast(dims_v, comm, root); - dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank()); + auto dims = rhs.shape(); + mpi::broadcast(dims, comm, root); + scatter_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); + dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank()); return dims; } @@ -121,26 +120,12 @@ struct mpi::lazy { // get target shape and resize or check the target array/view auto dims = shape(); - auto size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>()); resize_or_check_if_view(target, dims); - // compute send counts, receive counts and memory displacements - auto sendcounts = std::vector(comm.size()); - auto displs = std::vector(comm.size() + 1, 0); - if (comm.rank() == root) { - auto dim0 = rhs.extent(0); - auto stride0 = rhs.indexmap().strides()[0]; - for (int r = 0; r < comm.size(); ++r) { - sendcounts[r] = mpi::chunk_length(dim0, comm.size(), r) * stride0; - displs[r + 1] = sendcounts[r] + displs[r]; - } - } - // scatter the data - auto mpi_value_type = mpi::mpi_type::get(); - mpi::check_mpi_call( - MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), size, mpi_value_type, root, comm.get()), - "MPI_Scatterv"); + auto target_span = std::span{target.data(), static_cast(target.size())}; + auto rhs_span = std::span{rhs.data(), static_cast(rhs.size())}; + mpi::scatter_range(rhs_span, target_span, comm, root, scatter_size, rhs.indexmap().strides()[0]); } }; diff --git a/c++/nda/mpi/utils.hpp b/c++/nda/mpi/utils.hpp index a2dc2bd4..c6c8be6d 100644 --- a/c++/nda/mpi/utils.hpp +++ b/c++/nda/mpi/utils.hpp @@ -23,13 +23,15 @@ #include "../concepts.hpp" #include "../exceptions.hpp" -#include "../macros.hpp" -#include "../stdutil/array.hpp" #include #include +#if !defined(NDEBUG) || defined(NDA_DEBUG) +#include "../stdutil/array.hpp" +#endif + namespace nda::detail { // Check if the layout of an array/view is contiguous with positive strides, otherwise throw an exception.