Skip to content

Commit

Permalink
#2281: Sort out the PendingSends in new allreduce methods
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent e24cacc commit 46c73af
Show file tree
Hide file tree
Showing 14 changed files with 90 additions and 162 deletions.
6 changes: 5 additions & 1 deletion src/vt/collective/reduce/allreduce/allreduce_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
*/
#include "allreduce_holder.h"
#include "vt/objgroup/manager.h"
#include "state_holder.h"

namespace vt::collective::reduce::allreduce {

Expand Down Expand Up @@ -145,7 +146,7 @@ AllreduceHolder::addRabensifnerAllreducer(detail::StrongObjGroup strong_objgroup

vt_debug_print(
verbose, allreduce,
"Adding new Rabenseifner reducer for objgroup={:x} Size={}\n", objgroup
"Adding new Rabenseifner reducer for objgroup={:x}\n", objgroup
);

return obj_proxy;
Expand All @@ -171,16 +172,19 @@ AllreduceHolder::addRecursiveDoublingAllreducer(

void AllreduceHolder::remove(detail::StrongVrtProxy strong_proxy) {
auto const key = strong_proxy.get();
StateHolder::clearAll(strong_proxy);
removeImpl(col_reducers_, key);
}

void AllreduceHolder::remove(detail::StrongGroup strong_group) {
auto const key = strong_group.get();
StateHolder::clearAll(strong_group);
removeImpl(group_reducers_, key);
}

void AllreduceHolder::remove(detail::StrongObjGroup strong_objgroup) {
auto const key = strong_objgroup.get();
StateHolder::clearAll(strong_objgroup);
removeImpl(objgroup_reducers_, key);
}

Expand Down
11 changes: 3 additions & 8 deletions src/vt/collective/reduce/allreduce/rabenseifner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,20 +92,15 @@ Rabenseifner::Rabenseifner(detail::StrongGroup group)
auto const is_part_of_allreduce =
(not is_default_group and in_group) or is_default_group;

std::string nodes_info;
for(auto& node : nodes_){
nodes_info += fmt::format("{} ", node);
}

vt_debug_print(
terse, allreduce,
"Rabenseifner: is_default_group={} is_part_of_allreduce={} nodes=[{}] \n",
is_default_group, is_part_of_allreduce, nodes_info);
"Rabenseifner: is_default_group={} is_part_of_allreduce={} \n",
is_default_group, is_part_of_allreduce);

if (not is_default_group and in_group) {
auto it = std::find(nodes_.begin(), nodes_.end(), theContext()->getNode());

vtAssert(it != nodes_.end(), fmt::format("This node was not found in group nodes! Nodes=[{}]", nodes_info));
vtAssert(it != nodes_.end(), "This node was not found in group nodes!");

// index in group list
this_node_ = it - nodes_.begin();
Expand Down
7 changes: 5 additions & 2 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,17 @@ struct Rabenseifner {
template <typename DataT, typename CallbackType>
void setFinalHandler(const CallbackType& fin, size_t id);

template <typename DataT, template <typename Arg> class Op, typename... Args>
void storeData(size_t id, Args&&... args);

/**
* \brief Performs local reduce, and once the local one is done it starts up the global allreduce
*
* \param id Allreduce ID
* \param args Data to be allreduced
*/
template <typename DataT, template <typename Arg> class Op, typename... Args>
void localReduce(size_t id, Args&&... args);
template <typename DataT, template <typename Arg> class Op>
void run(size_t id);

/**
* \brief Initialize the allreduce algorithm.
Expand Down
12 changes: 9 additions & 3 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ void Rabenseifner::setFinalHandler(const CallbackType& fin, size_t id) {
}

template <typename DataT, template <typename Arg> class Op, typename... Args>
void Rabenseifner::localReduce(size_t id, Args&&... data) {
void Rabenseifner::storeData(size_t id, Args&&... data) {
using DataHelperT = DataHelper<typename DataHandler<DataT>::Scalar, DataT>;
auto& state = getState<RabenseifnerT, DataT>(info_, id);

vt_debug_print(
terse, allreduce,
"Rabenseifner(ID = {}) localReduce (this={}): local_col_wait_count_={} "
"Rabenseifner(ID = {}) storeData: local_col_wait_count_={} "
"initialized={}\n",
id, print_ptr(this), state.local_col_wait_count_, state.initialized_);
id, state.local_col_wait_count_, state.initialized_);

if (DataHelperT::empty(state.val_)) {
initialize<DataT>(id, std::forward<Args>(data)...);
Expand All @@ -93,6 +93,12 @@ void Rabenseifner::localReduce(size_t id, Args&&... data) {

state.local_col_wait_count_++;
auto const is_ready = state.local_col_wait_count_ == local_num_elems_;
}

template <typename DataT, template <typename Arg> class Op>
void Rabenseifner::run(size_t id) {
auto& state = getState<RabenseifnerT, DataT>(info_, id);
auto const is_ready = state.local_col_wait_count_ == local_num_elems_;

if (is_ready) {
// Execute early in case we're the only node
Expand Down
7 changes: 5 additions & 2 deletions src/vt/collective/reduce/allreduce/recursive_doubling.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,17 @@ struct RecursiveDoubling {
template <typename DataT, typename CallbackType>
void setFinalHandler(const CallbackType& fin, size_t id);

template <typename DataT, template <typename Arg> class Op, typename... Args>
void storeData(size_t id, Args&&... args);

/**
* \brief Performs local reduce, and once the local one is done it starts up the global allreduce
*
* \param id Allreduce ID
* \param args Data to be allreduced
*/
template <typename DataT, template <typename Arg> class Op, typename... Args>
void localReduce(size_t id, Args&&... args);
template <typename DataT, template <typename Arg> class Op>
void run(size_t id);

/**
* \brief Start the allreduce operation.
Expand Down
21 changes: 14 additions & 7 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void RecursiveDoubling::setFinalHandler(const CallbackType& fin, size_t id) {
}

template <typename DataT, template <typename Arg> class Op, typename... Args>
void RecursiveDoubling::localReduce(size_t id, Args&&... data) {
void RecursiveDoubling::storeData(size_t id, Args&&... data) {
auto& state = getState<RecursiveDoublingT, DataT>(info_, id);

vt_debug_print(
Expand All @@ -79,6 +79,11 @@ void RecursiveDoubling::localReduce(size_t id, Args&&... data) {
}

state.local_col_wait_count_++;
}

template <typename DataT, template <typename Arg> class Op>
void RecursiveDoubling::run(size_t id) {
auto& state = getState<RecursiveDoublingT, DataT>(info_, id);
auto const is_ready = state.local_col_wait_count_ == local_num_elems_;

if (is_ready) {
Expand Down Expand Up @@ -174,7 +179,7 @@ void RecursiveDoubling::adjustForPowerOfTwoHan(
RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);
if (DataType::size(state.val_) == 0) {
if (not state.value_assigned_) {
if (not state.initialized_) {
initializeState<DataT>(msg->id_);
}
Expand Down Expand Up @@ -282,6 +287,12 @@ template <typename DataT, template <typename Arg> class Op>
RecursiveDoubling::reduceIterHandler(RecursiveDoublingMsg<DataT>* msg) {
auto* reducer = getAllreducer<RecursiveDoublingT>(msg->info_);

vt_debug_print(
terse, allreduce,
"RecursiveDoubling::reduceIterHandler reducer={} ID={} \n",
(reducer != nullptr), msg->id_
);

if(reducer){
reducer->template reduceIterHan<DataT, Op>(msg);
}else{
Expand All @@ -303,7 +314,7 @@ void RecursiveDoubling::reduceIterHan(RecursiveDoublingMsg<DataT>* msg) {
using DataType = DataHandler<DataT>;
auto& state = getState<RecursiveDoublingT, DataT>(info_, msg->id_);

if (DataType::size(state.val_) == 0) {
if (not state.value_assigned_) {
if (not state.initialized_) {
initializeState<DataT>(msg->id_);
}
Expand Down Expand Up @@ -378,10 +389,6 @@ void RecursiveDoubling::finalPart(size_t id) {
return;
}

vt_debug_print(
terse, allreduce,
"RecursiveDoubling Part4: Executing final handler ID = {}\n", id);

if (nprocs_rem_) {
sendToExcludedNodes<DataT>(id);
}
Expand Down
8 changes: 4 additions & 4 deletions src/vt/collective/reduce/allreduce/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ struct StateBase {
bool active_ = false;
};

struct RabensiferBase : StateBase {
~RabensiferBase() override = default;
struct RabenseifnerBase : StateBase {
~RabenseifnerBase() override = default;
// Scatter
int32_t scatter_mask_ = 1;
int32_t scatter_step_ = 0;
Expand Down Expand Up @@ -102,7 +102,7 @@ struct RecursiveDoublingState : StateBase {
};

template <typename Scalar, typename DataT>
struct RabenseifnerState : RabensiferBase {
struct RabenseifnerState : RabenseifnerBase {
~RabenseifnerState() override = default;

std::vector<Scalar> val_ = {};
Expand All @@ -120,7 +120,7 @@ struct RabenseifnerState : RabensiferBase {
#if MAGISTRATE_KOKKOS_ENABLED
template <typename Scalar>
struct RabenseifnerState<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>>
: RabensiferBase {
: RabenseifnerBase {
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;
~RabenseifnerState() override = default;

Expand Down
2 changes: 1 addition & 1 deletion src/vt/group/group_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ struct GroupManager : runtime::component::Component<GroupManager> {
return sendMsg<MsgT, f>(group, msg);
}

template <auto f, template <typename Arg> typename Op, typename ...Args>
template <typename ReducerT, auto f, template <typename Arg> typename Op, typename ...Args>
void
allreduce(GroupType group, Args &&... args);

Expand Down
32 changes: 18 additions & 14 deletions src/vt/group/group_manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,30 +162,34 @@ void GroupManagerT<T>::triggerContinuationT(
}
}

template <auto f, template <typename Arg> typename Op, typename... Args>
template <
typename ReducerT, auto f, template <typename Arg> typename Op,
typename... Args>
void GroupManager::allreduce(GroupType group, Args&&... args) {
using namespace collective::reduce::allreduce;

auto iter = local_collective_group_info_.find(group);
vtAssert(iter != local_collective_group_info_.end(), "Must exist");
vtAssert(
iter != local_collective_group_info_.end(),
"allreduce for groups is only supported for collective ones!");

using DataT = std::tuple_element_t<0, typename FuncTraits<decltype(f)>::TupleType>;

// using Reducer = Rabenseifner;
auto const strong_group = collective::reduce::detail::StrongGroup{group};
auto* reducer =
AllreduceHolder::getOrCreateAllreducer<RabenseifnerT>(strong_group);
using DataT =
std::tuple_element_t<0, typename FuncTraits<decltype(f)>::TupleType>;

if (iter->second->is_in_group) {
auto const strong_group = collective::reduce::detail::StrongGroup{group};
auto* reducer =
AllreduceHolder::getOrCreateAllreducer<ReducerT>(strong_group);

auto const this_node = theContext()->getNode();
auto id = StateHolder::getNextID(strong_group);
reducer->template setFinalHandler<DataT>(theCB()->makeSend<f>(this_node), id);
reducer->template localReduce<DataT, Op>(id, std::forward<Args>(args)...);
}
reducer->template setFinalHandler<DataT>(
theCB()->makeSend<f>(this_node), id);
reducer->template storeData<DataT, Op>(id, std::forward<Args>(args)...);
reducer->template run<DataT, Op>(id);

addCleanupAction([strong_group] {
AllreduceHolder::remove(strong_group);
});
addCleanupAction([strong_group] { AllreduceHolder::remove(strong_group); });
}
}

}} /* end namespace vt::group */
Expand Down
4 changes: 0 additions & 4 deletions src/vt/objgroup/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,10 +298,6 @@ struct ObjGroupManager : runtime::component::Component<ObjGroupManager> {
ProxyType<ObjT> proxy, std::string const& name, std::string const& parent = ""
);

template <typename Reducer, typename ObjT, typename CbT, typename... Args>
ObjGroupManager::PendingSendType allreduce(
ProxyType<ObjT> proxy, CbT cb, size_t id, Args&&... data);

template <
typename Type, auto f, template <typename Arg> class Op, typename ObjT,
typename... Args
Expand Down
25 changes: 7 additions & 18 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,20 +280,6 @@ ObjGroupManager::PendingSendType
ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
using namespace collective::reduce::allreduce;

auto cb = theCB()->makeSend<f>(proxy[theContext()->getNode()]);
if (theContext()->getNumNodes() < 2) {
return PendingSendType{
theTerm()->getEpoch(),
[cb = std::move(cb),
args_tuple = std::make_tuple(std::forward<Args>(data)...)]() mutable {
std::apply(
[&cb](auto&&... args) {
cb.send(std::forward<decltype(args)>(args)...);
},
args_tuple);
}};
}

using Trait = ObjFuncTraits<decltype(f)>;

// We only support allreduce with a single data type
Expand All @@ -302,15 +288,18 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {
auto const this_node = vt::theContext()->getNode();
auto const strong_proxy = vt::collective::reduce::detail::StrongObjGroup{proxy.getProxy()};

auto const id = StateHolder::getNextID<RabenseifnerT>(strong_proxy);
auto const id = StateHolder::getNextID<Type>(strong_proxy);

auto* reducer = AllreduceHolder::getOrCreateAllreducer<Type>(strong_proxy);

auto cb = theCB()->makeSend<f>(proxy[theContext()->getNode()]);
reducer->template setFinalHandler<DataT>(cb, id);
reducer->template localReduce<DataT, Op>(
reducer->template storeData<DataT, Op>(
id, std::forward<Args>(data)...);

// Silence nvcc warning
return PendingSendType{nullptr};
return PendingSendType{theTerm()->getEpoch(), [reducer, id] {
reducer->template run<DataT, Op>(id);
}};
}

template <typename ObjT, typename MsgT, ActiveTypedFnType<MsgT> *f>
Expand Down
Loading

0 comments on commit 46c73af

Please sign in to comment.