diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.h b/src/vt/collective/reduce/allreduce/rabenseifner.h index 1780e7fc72..b57d699d56 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.h @@ -88,7 +88,7 @@ struct Rabenseifner { static constexpr bool KokkosPaylod = ShouldUseView_v; template - Rabenseifner(detail::StrongVrtProxy proxy, Args&&... args); + Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems, Args&&... args); template Rabenseifner(detail::StrongGroup group, Args&&... args); @@ -96,8 +96,8 @@ struct Rabenseifner { template Rabenseifner(vt::objgroup::proxy::Proxy proxy, Args&&... args); - template - void localReduce(IdxT idx); + template + void localReduce(size_t id, Args&&... args); /** * \brief Initialize the allreduce algorithm. * @@ -271,7 +271,10 @@ struct Rabenseifner { vt::objgroup::proxy::Proxy proxy_ = {}; vt::objgroup::proxy::Proxy parent_proxy_ = {}; + VirtualProxyType collection_proxy_ = {}; + uint32_t local_col_wait_count_ = {}; + size_t local_num_elems_ = {}; size_t id_ = 0; std::unordered_map states_ = {}; diff --git a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h index fbdc7d77a7..b8cc4eed83 100644 --- a/src/vt/collective/reduce/allreduce/rabenseifner.impl.h +++ b/src/vt/collective/reduce/allreduce/rabenseifner.impl.h @@ -41,6 +41,7 @@ //@HEADER */ +#include "vt/configs/debug/debug_printconst.h" #if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H #define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H @@ -62,15 +63,39 @@ namespace vt::collective::reduce::allreduce { template class Op, auto finalHandler> -template -Rabenseifner::Rabenseifner(detail::StrongVrtProxy proxy, Args&&... data){ - vt_debug_print(terse, allreduce, "Rabenseifner: proxy={:x} \n", proxy.get()); +template +Rabenseifner::Rabenseifner( + detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems, + Args&&... data) + : Rabenseifner(group, std::forward(data)...) { + collection_proxy_ = proxy.get(); + local_num_elems_ = num_elems; + local_col_wait_count_++; + vt_debug_print( + terse, allreduce, + "Rabenseifner (this={}): proxy={:x} local_num_elems={} ID={} is_ready={}\n", + print_ptr(this), proxy.get(), local_num_elems_, id_, + local_col_wait_count_ == local_num_elems_); } template class Op, auto finalHandler> -template -void Rabenseifner::localReduce(IdxT idx){ - vt_debug_print(terse, allreduce, "Rabenseifner: idx={} \n", idx); +template +void Rabenseifner::localReduce( + size_t id, Args&&... ) { + local_col_wait_count_++; + + // auto& state = states_.at(id_); + // DataHelper::reduce(state.val_, std::forward(data)...); + auto const is_ready = local_col_wait_count_ == local_num_elems_; + vt_debug_print( + terse, allreduce, "Rabenseifner (this={}): local_col_wait_count_={} ID={} is_ready={} num_states={}\n", + print_ptr(this), local_col_wait_count_, id_, is_ready, states_.size() + ); + + if(is_ready){ + allreduce(id); + } + } template class Op, auto finalHandler> diff --git a/src/vt/vrt/collection/manager.h b/src/vt/vrt/collection/manager.h index 21166066e9..1c4f48f63d 100644 --- a/src/vt/vrt/collection/manager.h +++ b/src/vt/vrt/collection/manager.h @@ -1776,7 +1776,6 @@ struct CollectionManager // Allreduce stuff, probably should be moved elsewhere std::unordered_map rabenseifner_reducers_; - std::unordered_map waiting_count_ = {}; }; }}} /* end namespace vt::vrt::collection */ diff --git a/src/vt/vrt/collection/manager.impl.h b/src/vt/vrt/collection/manager.impl.h index cdce21e620..9caf98a247 100644 --- a/src/vt/vrt/collection/manager.impl.h +++ b/src/vt/vrt/collection/manager.impl.h @@ -893,43 +893,29 @@ messaging::PendingSend CollectionManager::reduceLocal( auto const group = elm_holder->group(); bool const use_group = group_ready && send_group; - // First time here - if (waiting_count_[col_proxy] == 0) { + if (auto reducer = rabenseifner_reducers_.find(col_proxy); + reducer == rabenseifner_reducers_.end()) { if (use_group) { // theGroup()->allreduce(group, ); } else { auto obj_proxy = theObjGroup()->makeCollective( "reducer", collective::reduce::detail::StrongVrtProxy{col_proxy}, - std::forward(args)... - ); + collective::reduce::detail::StrongGroup{group}, + num_elms, + std::forward(args)...); rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy(); obj_proxy[theContext()->getNode()].get()->proxy_ = obj_proxy; - - obj_proxy[theContext()->getNode()].get()->localReduce(idx); } - }else{ + } else { if (use_group) { // theGroup()->allreduce(group, ); } else { auto obj_proxy = rabenseifner_reducers_.at(col_proxy); - auto typed_proxy = static_cast>(obj_proxy); - typed_proxy[theContext()->getNode()].get()->localReduce(idx); - } - } - - waiting_count_[col_proxy]++; - bool is_ready = waiting_count_[col_proxy] == num_elms; - vt_debug_print( - terse, allreduce, "reduceLocal: idx={} num_elms={} is_ready={}\n", idx, - num_elms, is_ready); - if (is_ready) { - if (use_group) { - // theGroup()->allreduce(group, ); - } else { - auto obj_proxy = rabenseifner_reducers_[col_proxy]; - auto typed_proxy = static_cast>(obj_proxy); - typed_proxy[theContext()->getNode()].get()->localReduce(idx); + auto typed_proxy = + static_cast>(obj_proxy); + auto* obj = typed_proxy[theContext()->getNode()].get(); + obj->localReduce(obj->id_ - 1); } } diff --git a/tests/perf/allreduce.cc b/tests/perf/allreduce.cc index 9166056f2d..7b9c806544 100644 --- a/tests/perf/allreduce.cc +++ b/tests/perf/allreduce.cc @@ -274,14 +274,8 @@ VT_PERF_TEST(MyTest, test_allreduce_group_rabenseifner) { struct Hello : vt::Collection { Hello() = default; - void FInalHan(NodeType result) { - fmt::print("Allreduce result is {} \n", result); - } - void AllredHandler(NodeType result) { - fmt::print("Allreduce result is {} \n", result); - - auto proxy = this->getCollectionProxy(); - proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(theContext()->getNode()); + void FInalHan(std::vector result) { + fmt::print("Allreduce handler\n"); } void Handler() { @@ -289,7 +283,8 @@ struct Hello : vt::Collection { fmt::print("[{}] Hello from idx={} \n", theContext()->getNode(), getIndex()); // proxy.reduce<&Hello::AllredHandler, collective::PlusOp>(theContext()->getNode(), theContext()->getNode()); - proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(theContext()->getNode()); + std::vector payload(100, theContext()->getNode()); + proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(std::move(payload)); col_send_done_ = true; }