From f3772cd8865b3b77c1211db7b509fbe82b180289 Mon Sep 17 00:00:00 2001 From: Thomas Hahn Date: Sat, 28 Sep 2024 20:11:41 -0400 Subject: [PATCH] Change the type stored in lazy mpi objects --- c++/nda/mpi/gather.hpp | 8 ++++---- c++/nda/mpi/reduce.hpp | 10 +++++----- c++/nda/mpi/scatter.hpp | 8 ++++---- test/c++/nda_mpi.cpp | 17 +++++++++++++++++ 4 files changed, 30 insertions(+), 13 deletions(-) diff --git a/c++/nda/mpi/gather.hpp b/c++/nda/mpi/gather.hpp index 725dc050..a15b2f1a 100644 --- a/c++/nda/mpi/gather.hpp +++ b/c++/nda/mpi/gather.hpp @@ -55,11 +55,11 @@ struct mpi::lazy { /// Value type of the array/view. using value_type = typename std::decay_t::value_type; - /// Const view type of the array/view stored in the lazy object. - using const_view_type = decltype(std::declval()()); + /// Type of the array/view stored in the lazy object. + using stored_type = A; - /// View of the array/view to be gathered. - const_view_type rhs; + /// Array/View to be gathered. + stored_type rhs; /// MPI communicator. mpi::communicator comm; diff --git a/c++/nda/mpi/reduce.hpp b/c++/nda/mpi/reduce.hpp index 14bb7c99..2f053610 100644 --- a/c++/nda/mpi/reduce.hpp +++ b/c++/nda/mpi/reduce.hpp @@ -57,11 +57,11 @@ struct mpi::lazy { /// Value type of the array/view. using value_type = typename std::decay_t::value_type; - /// Const view type of the array/view stored in the lazy object. - using const_view_type = decltype(std::declval()()); + /// Type of the array/view stored in the lazy object. + using stored_type = A; - /// View of the array/view to be reduced. - const_view_type rhs; + /// Array/View to be reduced. + stored_type rhs; /// MPI communicator. mpi::communicator comm; @@ -87,7 +87,7 @@ struct mpi::lazy { */ [[nodiscard]] auto shape() const { if ((comm.rank() == root) || all) return rhs.shape(); - return std::array{}; + return std::array::rank>{}; } /** diff --git a/c++/nda/mpi/scatter.hpp b/c++/nda/mpi/scatter.hpp index 336130e4..6df18c31 100644 --- a/c++/nda/mpi/scatter.hpp +++ b/c++/nda/mpi/scatter.hpp @@ -54,11 +54,11 @@ struct mpi::lazy { /// Value type of the array/view. using value_type = typename std::decay_t::value_type; - /// Const view type of the array/view stored in the lazy object. - using const_view_type = decltype(std::declval()()); + /// Type of the array/view stored in the lazy object. + using stored_type = A; - /// View of the array/view to be scattered. - const_view_type rhs; + /// Array/View to be scattered. + stored_type rhs; /// MPI communicator. mpi::communicator comm; diff --git a/test/c++/nda_mpi.cpp b/test/c++/nda_mpi.cpp index 2b12635a..a9117f9f 100644 --- a/test/c++/nda_mpi.cpp +++ b/test/c++/nda_mpi.cpp @@ -339,4 +339,21 @@ TEST_F(NDAMpi, VariousCollectiveCommunications) { EXPECT_ARRAY_NEAR(R2, comm.size() * A); } +TEST_F(NDAMpi, PassingTemporaryObjects) { + auto A = nda::array{1, 2, 3}; + auto lazy_arr = mpi::gather(nda::array{1, 2, 3}, comm); + auto res_arr = nda::array(lazy_arr); + auto lazy_view = mpi::gather(A(), comm); + auto res_view = nda::array(lazy_view); + if (comm.rank() == 0) { + for (long i = 0; i < comm.size(); ++i) { + EXPECT_ARRAY_EQ(res_arr(nda::range(i * 3, (i + 1) * 3)), A); + EXPECT_ARRAY_EQ(res_view(nda::range(i * 3, (i + 1) * 3)), A); + } + } else { + EXPECT_TRUE(res_arr.empty()); + EXPECT_TRUE(res_view.empty()); + } +} + MPI_TEST_MAIN