Skip to content

Commit

Permalink
Simplify logic in mpi interface by replacing direct calls to the MPI …
Browse files Browse the repository at this point in the history
…C library
  • Loading branch information
Thoemi09 committed Sep 28, 2024
1 parent 23f7581 commit 671f0a8
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 81 deletions.
10 changes: 5 additions & 5 deletions c++/nda/mpi/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@

#include "./utils.hpp"
#include "../basic_functions.hpp"
#include "../concepts.hpp"
#include "../exceptions.hpp"
#include "../traits.hpp"

#include <mpi.h>
#include <mpi/mpi.hpp>

#include <cstddef>
#include <span>

namespace nda {

/**
Expand Down Expand Up @@ -68,9 +68,9 @@ namespace nda {
{
detail::check_mpi_contiguous_layout(a, "mpi_broadcast");
auto dims = a.shape();
mpi::check_mpi_call(MPI_Bcast(&dims[0], dims.size(), mpi::mpi_type<typename decltype(dims)::value_type>::get(), root, comm.get()), "MPI_Bcast");
mpi::broadcast(dims, comm, root);
if (comm.rank() != root) { resize_or_check_if_view(a, dims); }
mpi::check_mpi_call(MPI_Bcast(a.data(), a.size(), mpi::mpi_type<typename A::value_type>::get(), root, comm.get()), "MPI_Bcast");
mpi::broadcast_range(std::span{a.data(), static_cast<std::size_t>(a.size())}, comm, root);
}

} // namespace nda
52 changes: 19 additions & 33 deletions c++/nda/mpi/gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@
#include "./utils.hpp"
#include "../basic_functions.hpp"
#include "../concepts.hpp"
#include "../exceptions.hpp"
#include "../stdutil/array.hpp"
#include "../traits.hpp"

#include <mpi.h>
#include <mpi/mpi.hpp>

#include <cstddef>
#include <functional>
#include <numeric>
#include <span>
#include <type_traits>
#include <utility>
#include <vector>

/**
* @ingroup av_mpi
Expand Down Expand Up @@ -69,6 +70,9 @@ struct mpi::lazy<mpi::tag::gather, A> {
/// Should all processes receive the result.
const bool all{false}; // NOLINT (const is fine here)

/// Size of the gathered array/view.
mutable long gathered_size{0};

/**
* @brief Compute the shape of the nda::ArrayInitializer object.
*
Expand All @@ -83,13 +87,10 @@ struct mpi::lazy<mpi::tag::gather, A> {
* @return Shape of the nda::ArrayInitializer object.
*/
[[nodiscard]] auto shape() const {
auto dims = rhs.shape();
if (!all) {
dims[0] = mpi::reduce(dims[0], comm, root);
if (comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
} else {
dims[0] = mpi::all_reduce(dims[0], comm);
}
auto dims = rhs.shape();
dims[0] = mpi::all_reduce(dims[0], comm);
gathered_size = std::accumulate(dims.begin(), dims.end(), 1l, std::multiplies<>());
if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
return dims;
}

Expand Down Expand Up @@ -117,39 +118,24 @@ struct mpi::lazy<mpi::tag::gather, A> {
return;
}

// get target shape and resize or check the target array/view
auto dims = shape();
// get target shape, resize or check the target array/view and prepare output span
auto dims = shape();
auto target_span = std::span{target.data(), 0};
if (all || (comm.rank() == root)) {
// check if the target array/view can be used in the MPI call
check_mpi_contiguous_layout(target, "mpi_gather");
check_mpi_c_layout(target, "mpi_gather");

// resize/check the size of the target array/view
nda::resize_or_check_if_view(target, dims);
}

// gather receive counts and memory displacements
auto recvcounts = std::vector<int>(comm.size());
auto displs = std::vector<int>(comm.size() + 1, 0);
int sendcount = rhs.size();
auto mpi_int_type = mpi::mpi_type<int>::get();
if (!all)
mpi::check_mpi_call(MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get()), "MPI_Gather");
else
mpi::check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get()), "MPI_Allgather");

for (int r = 0; r < comm.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
// prepare the output span
target_span = std::span{target.data(), static_cast<std::size_t>(target.size())};
}

// gather the data
auto mpi_value_type = mpi::mpi_type<value_type>::get();
if (!all)
mpi::check_mpi_call(
MPI_Gatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, root, comm.get()),
"MPI_Gatherv");
else
mpi::check_mpi_call(
MPI_Allgatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, comm.get()),
"MPI_Allgatherv");
auto rhs_span = std::span{rhs.data(), static_cast<std::size_t>(rhs.size())};
mpi::gather_range(rhs_span, target_span, comm, root, all, gathered_size);
}
};

Expand Down
21 changes: 7 additions & 14 deletions c++/nda/mpi/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <array>
#include <cmath>
#include <cstddef>
#include <span>
#include <type_traits>
#include <utility>

Expand Down Expand Up @@ -136,21 +138,12 @@ struct mpi::lazy<mpi::tag::reduce, A> {
}

// reduce the data
void *target_ptr = (void *)target.data();
void *rhs_ptr = (void *)rhs.data();
auto count = rhs.size();
auto mpi_value_type = mpi::mpi_type<value_type>::get();
if (!all) {
if (in_place)
mpi::check_mpi_call(MPI_Reduce((comm.rank() == root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type, op, root, comm.get()),
"MPI_Reduce");
else
mpi::check_mpi_call(MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type, op, root, comm.get()), "MPI_Reduce");
auto target_span = std::span{target.data(), static_cast<std::size_t>(target.size())};
if (in_place) {
mpi::reduce_in_place_range(target_span, comm, root, all, op);
} else {
if (in_place)
mpi::check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
else
mpi::check_mpi_call(MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type, op, comm.get()), "MPI_Allreduce");
auto rhs_span = std::span{rhs.data(), static_cast<std::size_t>(rhs.size())};
mpi::reduce_range(rhs_span, target_span, comm, root, all, op);
}
}
}
Expand Down
39 changes: 12 additions & 27 deletions c++/nda/mpi/scatter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,19 @@

#pragma once

#include "./broadcast.hpp"
#include "./utils.hpp"
#include "../basic_functions.hpp"
#include "../concepts.hpp"
#include "../exceptions.hpp"
#include "../stdutil/array.hpp"
#include "../traits.hpp"

#include <mpi.h>
#include <mpi/mpi.hpp>
#include <mpi/vector.hpp>

#include <cstddef>
#include <functional>
#include <numeric>
#include <span>
#include <type_traits>
#include <utility>
#include <vector>

/**
* @ingroup av_mpi
Expand Down Expand Up @@ -73,6 +69,9 @@ struct mpi::lazy<mpi::tag::scatter, A> {
/// Should all processes receive the result. (doesn't make sense for scatter)
const bool all{false}; // NOLINT (const is fine here)

/// Size of the array/view to be scattered.
mutable long scatter_size{0};

/**
* @brief Compute the shape of the nda::ArrayInitializer object.
*
Expand All @@ -87,10 +86,10 @@ struct mpi::lazy<mpi::tag::scatter, A> {
* @return Shape of the nda::ArrayInitializer object.
*/
[[nodiscard]] auto shape() const {
auto dims = rhs.shape();
auto dims_v = nda::basic_array_view(dims);
mpi::broadcast(dims_v, comm, root);
dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank());
auto dims = rhs.shape();
mpi::broadcast(dims, comm, root);
scatter_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank());
return dims;
}

Expand Down Expand Up @@ -121,26 +120,12 @@ struct mpi::lazy<mpi::tag::scatter, A> {

// get target shape and resize or check the target array/view
auto dims = shape();
auto size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
resize_or_check_if_view(target, dims);

// compute send counts, receive counts and memory displacements
auto sendcounts = std::vector<int>(comm.size());
auto displs = std::vector<int>(comm.size() + 1, 0);
if (comm.rank() == root) {
auto dim0 = rhs.extent(0);
auto stride0 = rhs.indexmap().strides()[0];
for (int r = 0; r < comm.size(); ++r) {
sendcounts[r] = mpi::chunk_length(dim0, comm.size(), r) * stride0;
displs[r + 1] = sendcounts[r] + displs[r];
}
}

// scatter the data
auto mpi_value_type = mpi::mpi_type<value_type>::get();
mpi::check_mpi_call(
MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), size, mpi_value_type, root, comm.get()),
"MPI_Scatterv");
auto target_span = std::span{target.data(), static_cast<std::size_t>(target.size())};
auto rhs_span = std::span{rhs.data(), static_cast<std::size_t>(rhs.size())};
mpi::scatter_range(rhs_span, target_span, comm, root, scatter_size, rhs.indexmap().strides()[0]);
}
};

Expand Down
6 changes: 4 additions & 2 deletions c++/nda/mpi/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@

#include "../concepts.hpp"
#include "../exceptions.hpp"
#include "../macros.hpp"
#include "../stdutil/array.hpp"

#include <mpi/mpi.hpp>

#include <string>

#if !defined(NDEBUG) || defined(NDA_DEBUG)
#include "../stdutil/array.hpp"
#endif

namespace nda::detail {

// Check if the layout of an array/view is contiguous with positive strides, otherwise throw an exception.
Expand Down

0 comments on commit 671f0a8

Please sign in to comment.