From 23f7581afc61cd390572a47dd7f15371232bb5bc Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Fri, 27 Sep 2024 20:41:00 -0400 Subject: [PATCH] Use mpi::all_equal to simplify checks in the mpi interface --- c++/nda/mpi/utils.hpp | 28 +++------------------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/c++/nda/mpi/utils.hpp b/c++/nda/mpi/utils.hpp index 842f7ca8..a2dc2bd4 100644 --- a/c++/nda/mpi/utils.hpp +++ b/c++/nda/mpi/utils.hpp @@ -49,13 +49,7 @@ namespace nda::detail { template void expect_equal_ranks([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) { #if !defined(NDEBUG) || defined(NDA_DEBUG) - if (not mpi::has_env) return; - int rank = a.rank; - int max_rank = 0; - int min_rank = 0; - mpi::check_mpi_call(MPI_Allreduce(&rank, &max_rank, 1, mpi::mpi_type::get(), MPI_MAX, comm.get()), "MPI_Allreduce"); - mpi::check_mpi_call(MPI_Allreduce(&rank, &min_rank, 1, mpi::mpi_type::get(), MPI_MIN, comm.get()), "MPI_Allreduce"); - if (max_rank != min_rank) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes"; + if (!mpi::all_equal(a.rank, comm)) NDA_RUNTIME_ERROR << "Error in expect_equal_ranks: nda::Array ranks are not equal on all processes"; #endif } @@ -63,16 +57,8 @@ namespace nda::detail { template void expect_equal_shapes([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) { #if !defined(NDEBUG) || defined(NDA_DEBUG) - if (not mpi::has_env) return; expect_equal_ranks(a, comm, root); - auto shape = a.shape(); - auto max_shape = shape; - auto min_shape = shape; - mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), - "MPI_Allreduce"); - mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), - "MPI_Allreduce"); - if (max_shape != min_shape) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes"; + if (!mpi::all_equal(a.shape(), comm)) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes: nda::Array shapes are not equal on all processes"; #endif } @@ -82,16 +68,8 @@ namespace nda::detail { void expect_equal_shapes_save_first([[maybe_unused]] const A &a, [[maybe_unused]] const mpi::communicator &comm, [[maybe_unused]] int root) { #if !defined(NDEBUG) || defined(NDA_DEBUG) if constexpr (A::rank == 1) return; - if (not mpi::has_env) return; expect_equal_ranks(a, comm, root); - auto shape = nda::stdutil::front_pop(a.shape()); - auto max_shape = shape; - auto min_shape = shape; - mpi::check_mpi_call(MPI_Allreduce(shape.data(), max_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MAX, comm.get()), - "MPI_Allreduce"); - mpi::check_mpi_call(MPI_Allreduce(shape.data(), min_shape.data(), shape.size(), mpi::mpi_type::get(), MPI_MIN, comm.get()), - "MPI_Allreduce"); - if (max_shape != min_shape) + if (!mpi::all_equal(nda::stdutil::front_pop(a.shape()), comm)) NDA_RUNTIME_ERROR << "Error in expect_equal_shapes_save_first: nda::Array shapes are not equal on all processes except for the first dimension"; #endif }