Skip to content

Commit

Permalink
#2281: Initial work for using new allreduce within collections
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Aug 21, 2024
1 parent d48e90d commit 7c7f870
Show file tree
Hide file tree
Showing 11 changed files with 273 additions and 70 deletions.
24 changes: 24 additions & 0 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ using remove_cvref = std::remove_cv_t<std::remove_reference_t<T>>;

namespace vt::collective::reduce::allreduce {

template <typename T>
struct function_traits; // General template declaration.

// Specialization for function pointers.
template <typename Ret, typename... Args>
struct function_traits<Ret(*)(Args...)> {
using return_type = Ret;
static constexpr std::size_t arity = sizeof...(Args);
using args_tuple = std::tuple<Args...>;

template <std::size_t N>
using arg_type = typename std::tuple_element<N, std::tuple<Args...>>::type;
};

template <typename Ret, typename ObjT, typename... Args>
struct function_traits<Ret(ObjT::*)(Args...)> {
using return_type = Ret;
static constexpr std::size_t arity = sizeof...(Args);
using args_tuple = std::tuple<Args...>;

template <std::size_t N>
using arg_type = typename std::tuple_element<N, std::tuple<Args...>>::type;
};

// Primary template
template <typename Scalar, typename DataT>
struct ShouldUseView {
Expand Down
23 changes: 20 additions & 3 deletions src/vt/collective/reduce/allreduce/rabenseifner.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//@HEADER
*/

#include "vt/configs/types/types_type.h"
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_H

Expand Down Expand Up @@ -87,11 +88,16 @@ struct Rabenseifner {
static constexpr bool KokkosPaylod = ShouldUseView_v<Scalar, DataT>;

template <typename ...Args>
Rabenseifner(GroupType group, Args&&... args);
Rabenseifner(detail::StrongVrtProxy proxy, Args&&... args);

template <typename ...Args>
Rabenseifner(detail::StrongGroup group, Args&&... args);

template <typename ...Args>
Rabenseifner(vt::objgroup::proxy::Proxy<ObjT> proxy, Args&&... args);

template <typename IdxT>
void localReduce(IdxT idx);
/**
* \brief Initialize the allreduce algorithm.
*
Expand Down Expand Up @@ -265,25 +271,36 @@ struct Rabenseifner {

vt::objgroup::proxy::Proxy<Rabenseifner> proxy_ = {};
vt::objgroup::proxy::Proxy<ObjT> parent_proxy_ = {};
VirtualProxyType collection_proxy_ = {};

size_t id_ = 0;
std::unordered_map<size_t, StateT> states_ = {};

/// Only used when non-default group is beign used
/// Sorted list of Nodes that take part in allreduce
std::vector<NodeType> nodes_ = {};

NodeType num_nodes_ = {};

/// Represents an index inside nodes_
NodeType this_node_ = {};

bool is_even_ = false;

/// Num steps for each scatter/gather phase
int32_t num_steps_ = {};

/// 2^num_steps_
int32_t nprocs_pof2_ = {};
int32_t nprocs_rem_ = {};

/// For non-power-of-2 number of nodes this respresents whether current Node
/// is excluded (has value of -1) from computation
NodeType vrt_node_ = {};

bool is_part_of_adjustment_group_ = false;

static inline const std::string name_ = "Rabenseifner";
static inline const ReducerType type_ = ReducerType::Rabenseifner;
static inline constexpr ReducerType type_ = ReducerType::Rabenseifner;
};

} // namespace vt::collective::reduce::allreduce
Expand Down
26 changes: 19 additions & 7 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
//@HEADER
*/

#include "vt/configs/debug/debug_print.h"
#include <string>
#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_IMPL_H

Expand All @@ -56,16 +54,30 @@
#include "vt/configs/types/types_sentinels.h"
#include "vt/registry/auto/auto_registry.h"
#include "vt/utils/fntraits/fntraits.h"
#include "vt/configs/debug/debug_print.h"
#include <string>

#include <type_traits>

namespace vt::collective::reduce::allreduce {

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
template <typename ...Args>
Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(detail::StrongVrtProxy proxy, Args&&... data){
vt_debug_print(terse, allreduce, "Rabenseifner: proxy={:x} \n", proxy.get());
}

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
template <typename IdxT>
void Rabenseifner<DataT, Op, finalHandler>::localReduce(IdxT idx){
vt_debug_print(terse, allreduce, "Rabenseifner: idx={} \n", idx);
}

template <typename DataT, template <typename Arg> class Op, auto finalHandler>
template <typename... Args>
Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(
GroupType group, Args&&... data)
: nodes_(theGroup()->GetGroupNodes(group)),
detail::StrongGroup group, Args&&... data)
: nodes_(theGroup()->GetGroupNodes(group.get())),
num_nodes_(nodes_.size()),
this_node_(theContext()->getNode()),
num_steps_(static_cast<int32_t>(log2(num_nodes_))),
Expand All @@ -77,8 +89,8 @@ Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(
for(auto& node : nodes_){
nodes_info += fmt::format("{} ", node);
}
auto const is_default_group = group == default_group;
auto const is_part_of_allreduce = (not is_default_group and theGroup()->inGroup(group)) or is_default_group;
auto const is_default_group = group.get() == default_group;
auto const is_part_of_allreduce = (not is_default_group and theGroup()->inGroup(group.get())) or is_default_group;

vt_debug_print(
terse, allreduce,
Expand All @@ -87,7 +99,7 @@ Rabenseifner<DataT, Op, finalHandler>::Rabenseifner(
is_default_group, is_part_of_allreduce, num_nodes_, nodes_info
);

if (not is_default_group and theGroup()->inGroup(group)) {
if (not is_default_group and theGroup()->inGroup(group.get())) {
// vtAssert(theGroup()->inGroup(group), fmt::format("This node is not part of group {:x}!", group));

auto it = std::find(nodes_.begin(), nodes_.end(), theContext()->getNode());
Expand Down
13 changes: 0 additions & 13 deletions src/vt/collective/reduce/allreduce/rabenseifner_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,6 @@ struct StateHolder : StateHolderBase {
// }
};

template <typename T>
struct function_traits; // General template declaration.

// Specialization for function pointers.
template <typename Ret, typename... Args>
struct function_traits<Ret(*)(Args...)> {
using return_type = Ret;
static constexpr std::size_t arity = sizeof...(Args);
using args_tuple = std::tuple<Args...>;

template <std::size_t N>
using arg_type = typename std::tuple_element<N, std::tuple<Args...>>::type;
};

/**
* \struct Rabenseifner
Expand Down
57 changes: 56 additions & 1 deletion src/vt/collective/reduce/allreduce/rabenseifner_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
//@HEADER
*/


#if !defined INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_MSG_H
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_MSG_H
#include "vt/config.h"
#include "vt/messaging/active.h"

#include "vt/configs/debug/debug_print.h"
#include "vt/collective/reduce/operators/default_msg.h"
namespace vt::collective::reduce::allreduce {

template <typename Scalar, typename DataT>
Expand Down Expand Up @@ -86,6 +88,59 @@ struct RabenseifnerMsg : Message {
s | step_;
}

struct NoCombine {};

template <typename>
struct IsTuple : std::false_type {};
template <typename... Args>
struct IsTuple<std::tuple<Args...>> : std::true_type {};

template <typename MsgT, typename Op, typename ActOp>
static void combine(MsgT* m1, MsgT* m2) {
Op()(m1->getVal(), m2->getConstVal());
}

template <typename Tuple, typename Op, typename ActOp>
static void FinalHandler(ReduceTMsg<Tuple>* msg) {
// using MsgT = ReduceTMsg<Tuple>;
vt_debug_print(
terse, reduce,
"FinalHandler: reduce root: ptr={}\n", print_ptr(msg)
);
// if (msg->isRoot()) {
// vt_debug_print(
// terse, reduce,
// "FinalHandler::ROOT: reduce root: ptr={}\n", print_ptr(msg)
// );
// if (msg->hasValidCallback()) {
// envelopeUnlockForForwarding(msg->env);
// if (msg->isParamCallback()) {
// if constexpr (IsTuple<typename MsgT::DataT>::value) {
// msg->getParamCallback().sendTuple(std::move(msg->getVal()));
// }
// } else {
// // We need to force the type to the more specific one here
// auto cb = msg->getMsgCallback();
// auto typed_cb = reinterpret_cast<Callback<MsgT>*>(&cb);
// typed_cb->sendMsg(msg);
// }
// } else if (msg->root_handler_ != uninitialized_handler) {
// auto_registry::getAutoHandler(msg->root_handler_)->dispatch(msg, nullptr);
// }
// } else {
// MsgT* fst_msg = msg;
// MsgT* cur_msg = msg->template getNext<MsgT>();
// vt_debug_print(
// terse, reduce,
// "FinalHandler::leaf: fst ptr={}\n", print_ptr(fst_msg)
// );
// while (cur_msg != nullptr) {
// RabenseifnerMsg<Scalar, DataT>::combine<MsgT,Op,ActOp>(fst_msg, cur_msg);
// cur_msg = cur_msg->template getNext<MsgT>();
// }
// }
}

const Scalar* val_ = {};
size_t size_ = {};
size_t id_ = {};
Expand Down
6 changes: 5 additions & 1 deletion src/vt/group/group_manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ void GroupManager::allreduce(GroupType group, Args&&... args) {
using Reducer = collective::reduce::allreduce::Rabenseifner<DataT, Op, f>;

// TODO; Save the proxy so it can be deleted afterwards
auto proxy = theObjGroup()->makeCollective<Reducer>("reducer", group, std::forward<Args>(args)...);
auto proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongGroup{group},
std::forward<Args>(args)...
);

if (iter->second->is_in_group) {
auto const this_node = theContext()->getNode();
auto id = proxy[this_node].get()->id_ - 1;
Expand Down
11 changes: 11 additions & 0 deletions src/vt/vrt/collection/manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
//@HEADER
*/

#include "vt/configs/types/types_type.h"
#include <cstdint>
#if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_H
#define INCLUDED_VT_VRT_COLLECTION_MANAGER_H

Expand Down Expand Up @@ -743,6 +745,11 @@ struct CollectionManager
bool instrument
);

template <auto f, typename ColT, template <typename Arg> class Op, typename ...Args>
messaging::PendingSend reduceLocal(
CollectionProxyWrapType<ColT> const& proxy, Args &&... args
);

/**
* \brief Reduce over a collection
*
Expand Down Expand Up @@ -1766,6 +1773,10 @@ struct CollectionManager
VirtualIDType next_rooted_id_ = 0;
TypelessHolder typeless_holder_;
std::unordered_map<VirtualProxyType, SequentialIDType> reduce_stamp_;

// Allreduce stuff, probably should be moved elsewhere
std::unordered_map<VirtualProxyType, ObjGroupProxyType> rabenseifner_reducers_;
std::unordered_map<VirtualProxyType, uint32_t> waiting_count_ = {};
};

}}} /* end namespace vt::vrt::collection */
Expand Down
68 changes: 67 additions & 1 deletion src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
//@HEADER
*/

#include "vt/collective/reduce/scoping/strong_types.h"
#if !defined INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H
#define INCLUDED_VT_VRT_COLLECTION_MANAGER_IMPL_H

Expand Down Expand Up @@ -184,7 +185,7 @@ GroupType CollectionManager::createGroupCollection(
}

vt_debug_print(
normal, allreduce,
normal, vrt_coll,
"group finished construction: proxy={:x}, new_group={:x}, use_group={}, "
"ready={}, root={}, is_group_default={}\n",
proxy, new_group, elm_holder->useGroup(), elm_holder->groupReady(),
Expand Down Expand Up @@ -870,6 +871,71 @@ messaging::PendingSend CollectionManager::broadcastMsgUntypedHandler(
}
}

template <
auto f, typename ColT, template <typename Arg> class Op, typename... Args>
messaging::PendingSend CollectionManager::reduceLocal(
CollectionProxyWrapType<ColT> const& proxy, Args&&... args) {
using namespace collective::reduce::allreduce;
using DataT = typename function_traits<decltype(f)>::template arg_type<0>;

using Reducer = collective::reduce::allreduce::Rabenseifner<DataT, Op, f>;

using IndexT = typename ColT::IndexType;

// Get the current running index context
IndexT idx = *queryIndexContext<IndexT>();
auto const col_proxy = proxy.getProxy();
auto elm_holder = findElmHolder<IndexT>(col_proxy);
std::size_t num_elms = elm_holder->numElements();

auto const group_ready = elm_holder->groupReady();
auto const send_group = elm_holder->useGroup();
auto const group = elm_holder->group();
bool const use_group = group_ready && send_group;

// First time here
if (waiting_count_[col_proxy] == 0) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = theObjGroup()->makeCollective<Reducer>(
"reducer", collective::reduce::detail::StrongVrtProxy{col_proxy},
std::forward<Args>(args)...
);

rabenseifner_reducers_[col_proxy] = obj_proxy.getProxy();
obj_proxy[theContext()->getNode()].get()->proxy_ = obj_proxy;

obj_proxy[theContext()->getNode()].get()->localReduce(idx);
}
}else{
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = rabenseifner_reducers_.at(col_proxy);
auto typed_proxy = static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
typed_proxy[theContext()->getNode()].get()->localReduce(idx);
}
}

waiting_count_[col_proxy]++;
bool is_ready = waiting_count_[col_proxy] == num_elms;
vt_debug_print(
terse, allreduce, "reduceLocal: idx={} num_elms={} is_ready={}\n", idx,
num_elms, is_ready);
if (is_ready) {
if (use_group) {
// theGroup()->allreduce<f, Op>(group, );
} else {
auto obj_proxy = rabenseifner_reducers_[col_proxy];
auto typed_proxy = static_cast<vt::objgroup::proxy::Proxy<Reducer>>(obj_proxy);
typed_proxy[theContext()->getNode()].get()->localReduce(idx);
}
}

return messaging::PendingSend{nullptr};
}

template <typename ColT, typename MsgT, ActiveTypedFnType<MsgT> *f>
messaging::PendingSend CollectionManager::reduceMsgExpr(
CollectionProxyWrapType<ColT> const& proxy,
Expand Down
Loading

0 comments on commit 7c7f870

Please sign in to comment.