Skip to content

Commit

Permalink
#2281: Working Rabenseifner (without final handler) for collection
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Aug 29, 2024
1 parent a3a4b1a commit 83a80c9
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 43 deletions.
9 changes: 6 additions & 3 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ struct Rabenseifner {
static constexpr bool KokkosPaylod = ShouldUseView_v<Scalar, DataT>;

template <typename ...Args>
Rabenseifner(detail::StrongVrtProxy proxy, Args&&... args);
Rabenseifner(detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems, Args&&... args);

template <typename ...Args>
Rabenseifner(detail::StrongGroup group, Args&&... args);

template <typename ...Args>
Rabenseifner(vt::objgroup::proxy::Proxy<ObjT> proxy, Args&&... args);

template <typename IdxT>
void localReduce(IdxT idx);
template <typename ...Args>
void localReduce(size_t id, Args&&... args);
/**
* \brief Initialize the allreduce algorithm.
*
Expand Down Expand Up @@ -271,7 +271,10 @@ struct Rabenseifner {

vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};

VirtualProxyType collection_proxy_ = {};
uint32_t local_col_wait_count_ = {};
size_t local_num_elems_ = {};

size_t id_ = 0;
std::unordered_map<size_t, StateT> states_ = {};
Expand Down
37 changes: 31 additions & 6 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//@HEADER
*/

#include "vt/configs/debug/debug_printconst.h"
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H

Expand All @@ -62,15 +63,39 @@
namespace vt::collective::reduce::allreduce {

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
template <typename ...Args>
Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(detail::StrongVrtProxy proxy, Args&&... data){
vt_debug_print(terse, allreduce, "Rabenseifner: proxy={:x} \n", proxy.get());
template <typename... Args>
Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(
detail::StrongVrtProxy proxy, detail::StrongGroup group, size_t num_elems,
Args&&... data)
: Rabenseifner<DataT, Op, finalHandler>(group, std::forward<Args>(data)...) {
collection_proxy_ = proxy.get();
local_num_elems_ = num_elems;
local_col_wait_count_++;
vt_debug_print(
terse, allreduce,
"Rabenseifner (this={}): proxy={:x} local_num_elems={} ID={} is_ready={}\n",
print_ptr(this), proxy.get(), local_num_elems_, id_,
local_col_wait_count_ == local_num_elems_);
}

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
template <typename IdxT>
void Rabenseifner<DataT, Op, finalHandler>::localReduce(IdxT idx){
vt_debug_print(terse, allreduce, "Rabenseifner: idx={} \n", idx);
template <typename... Args>
void Rabenseifner<DataT, Op, finalHandler>::localReduce(
size_t id, Args&&... ) {
local_col_wait_count_++;

// auto& state = states_.at(id_);
// DataHelper<Scalar, DataT>::reduce(state.val_, std::forward<Args>(data)...);
auto const is_ready = local_col_wait_count_ == local_num_elems_;
vt_debug_print(
terse, allreduce, "Rabenseifner (this={}): local_col_wait_count_={} ID={} is_ready={} num_states={}\n",
print_ptr(this), local_col_wait_count_, id_, is_ready, states_.size()
);

if(is_ready){
allreduce(id);
}

}

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
Expand Down
1 change: 0 additions & 1 deletion src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -1776,7 +1776,6 @@ struct CollectionManager

// Allreduce stuff, probably should be moved elsewhere
std::unordered_map<VirtualProxyType, ObjGroupProxyType> rabenseifner_reducers_;
std::unordered_map<VirtualProxyType, uint32_t> waiting_count_ = {};
};

}}} /* end namespace vt::vrt::collection */
Expand Down
34 changes: 10 additions & 24 deletions src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -893,43 +893,29 @@ messaging::PendingSend CollectionManager::reduceLocal(
auto const group = elm_holder->group();
bool const use_group = group_ready && send_group;

// First time here
if (waiting_count_[col_proxy] == 0) {
if (auto reducer = rabenseifner_reducers_.find(col_proxy);
reducer == rabenseifner_reducers_.end()) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongVrtProxy{col_proxy},
std::forward<Args>(args)...
);
collective::reduce::detail::StrongGroup{group},
num_elms,
std::forward<Args>(args)...);

rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy();
obj_proxy[theContext()->getNode()].get()->proxy_ = obj_proxy;

obj_proxy[theContext()->getNode()].get()->localReduce(idx);
}
}else{
} else {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = rabenseifner_reducers_.at(col_proxy);
auto typed_proxy = static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
typed_proxy[theContext()->getNode()].get()->localReduce(idx);
}
}

waiting_count_[col_proxy]++;
bool is_ready = waiting_count_[col_proxy] == num_elms;
vt_debug_print(
terse, allreduce, "reduceLocal: idx={} num_elms={} is_ready={}\n", idx,
num_elms, is_ready);
if (is_ready) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = rabenseifner_reducers_[col_proxy];
auto typed_proxy = static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
typed_proxy[theContext()->getNode()].get()->localReduce(idx);
auto typed_proxy =
static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
auto* obj = typed_proxy[theContext()->getNode()].get();
obj->localReduce(obj->id_ - 1);
}
}

Expand Down
13 changes: 4 additions & 9 deletions tests/perf/allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,22 +274,17 @@ VT_PERF_TEST(MyTest, test_allreduce_group_rabenseifner) {

struct Hello : vt::Collection<Hello, vt::Index1D> {
Hello() = default;
void FInalHan(NodeType result) {
fmt::print("Allreduce result is {} \n", result);
}
void AllredHandler(NodeType result) {
fmt::print("Allreduce result is {} \n", result);

auto proxy = this->getCollectionProxy();
proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(theContext()->getNode());
void FInalHan(std::vector<int32_t> result) {
fmt::print("Allreduce handler\n");
}

void Handler() {
auto proxy = this->getCollectionProxy();

fmt::print("[{}] Hello from idx={} \n", theContext()->getNode(), getIndex());
// proxy.reduce<&Hello::AllredHandler, collective::PlusOp>(theContext()->getNode(), theContext()->getNode());
proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(theContext()->getNode());
std::vector<int32_t> payload(100, theContext()->getNode());
proxy.allreduce_h<&Hello::FInalHan, collective::PlusOp>(std::move(payload));
col_send_done_ = true;
}

Expand Down

0 comments on commit 83a80c9

Please sign in to comment.