Skip to content

Commit

Permalink
#2281: Initial work to make collective group info contain the informa…
Browse files Browse the repository at this point in the history
…tion about all the nodes from that group
  • Loading branch information
JacobDomagala committed Aug 29, 2024
1 parent 975dcb7 commit 907d0ef
Show file tree
Hide file tree
Showing 5 changed files with 132 additions and 38 deletions.
2 changes: 1 addition & 1 deletion src/vt/group/collective/group_collective_finished.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ void InfoColl::CollSetupFinished::operator()(FinishedReduceMsg* msg) {
info->known_root_node_ != this_node and
info->known_root_node_ != uninitialized_destination
) {
auto nmsg = makeMessage<GroupOnlyMsg>(
auto nmsg = makeMessage<GroupCollectiveFinalMsg>(
msg->getGroup(),info->new_tree_cont_
);
theMsg()->sendMsg<InfoColl::newTreeHan>(
Expand Down
48 changes: 45 additions & 3 deletions src/vt/group/collective/group_collective_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
#include "vt/group/group_common.h"
#include "vt/group/msg/group_msg.h"
#include "vt/messaging/message.h"
#include "vt/configs/types/types_type.h"
#include "vt/messaging/message/message_serialize.h"

#include <cstdlib>

Expand All @@ -56,7 +58,7 @@ namespace vt { namespace group {
template <typename MsgT>
struct GroupCollectiveInfoMsg : MsgT {
using MessageParentType = MsgT;
vt_msg_serialize_prohibited(); // no existing serialization function
vt_msg_serialize_required();
static_assert(
std::is_base_of<BaseMessage, MsgT>::value,
"Base must derive from Message."
Expand All @@ -69,10 +71,11 @@ struct GroupCollectiveInfoMsg : MsgT {
GroupType const& in_group, RemoteOperationIDType in_op, bool in_is_in_group,
NodeType const& in_subtree,
NodeType const& in_child_node = uninitialized_destination,
CountType const& level = 0, CountType const& extra_nodes = 0
CountType const& level = 0, CountType const& extra_nodes = 0,
std::vector<NodeType> nodes = {}
) : MsgT(in_group, in_op), is_in_group(in_is_in_group),
child_node_(in_child_node), subtree_size_(in_subtree),
extra_nodes_(extra_nodes), level_(level)
extra_nodes_(extra_nodes), level_(level), nodes_(nodes)
{ }

NodeType getChild() const { return child_node_; }
Expand All @@ -81,6 +84,19 @@ struct GroupCollectiveInfoMsg : MsgT {
bool isInGroup() const { return is_in_group; }
CountType getExtraNodes() const { return extra_nodes_; }
CountType getLevel() const { return level_; }
std::vector<NodeType> const& getNodes() {return nodes_;}

template <typename SerializerT>
void serialize(SerializerT& s) {
MessageParentType::serialize(s);
s | is_in_group;
s | is_static_;
s | child_node_;
s | subtree_size_;
s | extra_nodes_;
s | level_;
s | nodes_;
}

private:
bool is_in_group = false;
Expand All @@ -89,9 +105,35 @@ struct GroupCollectiveInfoMsg : MsgT {
NodeType subtree_size_ = 0;
CountType extra_nodes_ = 0;
CountType level_ = 0;
std::vector<NodeType> nodes_ = {};
};

template <typename MsgT>
struct GroupCollectiveFinalInfoMsg : MsgT {
using MessageParentType = MsgT;
vt_msg_serialize_required();

GroupCollectiveFinalInfoMsg() = default;
GroupCollectiveFinalInfoMsg(
GroupType const& in_group, RemoteOperationIDType in_op,
std::vector<NodeType> const& nodes = {}
) : MsgT(in_group, in_op), nodes_(nodes)
{ }

std::vector<NodeType> const& getNodes() {return nodes_;}

template <typename SerializerT>
void serialize(SerializerT& s) {
MessageParentType::serialize(s);
s | nodes_;
}

private:
std::vector<NodeType> nodes_ = {};
};

using GroupCollectiveMsg = GroupCollectiveInfoMsg<GroupMsg<::vt::Message>>;
using GroupCollectiveFinalMsg = GroupCollectiveFinalInfoMsg<GroupMsg<::vt::Message>>;

}} /* end namespace vt::group */

Expand Down
95 changes: 67 additions & 28 deletions src/vt/group/collective/group_info_collective.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
*/

#include "vt/config.h"
#include "vt/configs/types/types_type.h"
#include "vt/group/group_common.h"
#include "vt/group/base/group_info_base.h"
#include "vt/group/collective/group_info_collective.h"
Expand Down Expand Up @@ -157,8 +158,9 @@ void InfoColl::setupCollective() {
new_tree_cont_ = theGroup()->nextCollectiveID();
new_root_cont_ = theGroup()->nextCollectiveID();

using GroupCollectiveTMsg = GroupManagerT<MsgSharedPtr<GroupCollectiveMsg>>;
using GroupOnlyTMsg = GroupManagerT<MsgSharedPtr<GroupOnlyMsg>>;
using GroupCollectiveTMsg = GroupManagerT<MsgSharedPtr<GroupCollectiveMsg>>;
using GroupOnlyTMsg = GroupManagerT<MsgSharedPtr<GroupOnlyMsg>>;
using GroupCollectiveFinalTMsg = GroupManagerT<MsgSharedPtr<GroupCollectiveFinalMsg>>;

GroupCollectiveTMsg::registerContinuationT(
down_tree_cont_,
Expand All @@ -168,25 +170,25 @@ void InfoColl::setupCollective() {
iter->second->downTree(msg.get());
}
);
GroupOnlyTMsg::registerContinuationT(
GroupCollectiveFinalTMsg::registerContinuationT(
down_tree_fin_cont_,
[group_](MsgSharedPtr<GroupOnlyMsg> msg){
[group_](MsgSharedPtr<GroupCollectiveFinalMsg> msg){
auto iter = theGroup()->local_collective_group_info_.find(group_);
vtAssertExpr(iter != theGroup()->local_collective_group_info_.end());
iter->second->downTreeFinished(msg.get());
}
);
GroupOnlyTMsg::registerContinuationT(
GroupCollectiveFinalTMsg::registerContinuationT(
finalize_cont_,
[group_](MsgSharedPtr<GroupOnlyMsg> msg){
[group_](MsgSharedPtr<GroupCollectiveFinalMsg> msg){
auto iter = theGroup()->local_collective_group_info_.find(group_);
vtAssertExpr(iter != theGroup()->local_collective_group_info_.end());
iter->second->finalizeTree(msg.get());
}
);
GroupOnlyTMsg::registerContinuationT(
GroupCollectiveFinalTMsg::registerContinuationT(
new_tree_cont_,
[group_]([[maybe_unused]] MsgSharedPtr<GroupOnlyMsg> msg){
[group_]([[maybe_unused]] MsgSharedPtr<GroupCollectiveFinalMsg> msg){
auto iter = theGroup()->local_collective_group_info_.find(group_);
vtAssertExpr(iter != theGroup()->local_collective_group_info_.end());
auto const& from = theContext()->getFromNodeCurrentTask();
Expand Down Expand Up @@ -217,15 +219,15 @@ void InfoColl::setupCollective() {

vt_debug_print(
normal, group,
"InfoColl::setupCollective: is_in_group={}, parent={}: up tree\n",
is_in_group, parent
"InfoColl::setupCollective: group={:x}, is_in_group={}, parent={} num_children={}: up tree\n",
group_, is_in_group, parent, collective_->getInitialChildren()
);

if (collective_->getInitialChildren() == 0) {
auto const& size = static_cast<NodeType>(is_in_group ? 1 : 0);
auto const& child = theContext()->getNode();
auto msg = makeMessage<GroupCollectiveMsg>(
group_, up_tree_cont_, in_group, size, child
group_, up_tree_cont_, in_group, size, child, 0, 0, std::vector<NodeType>{theContext()->getNode()}
);
theMsg()->sendMsg<upHan>(parent, msg);
}
Expand All @@ -239,6 +241,19 @@ void InfoColl::atRoot() {
"InfoColl::atRoot: is_in_group={}, group={:x}, root={}, children={}\n",
is_in_group, group, is_root, collective_->span_children_.size()
);

std::string nodes_info = "Nodes: ";
nodes_info.reserve(1024);
for(auto& node : nodes_){
nodes_info += fmt::format("{} ", node);
}
nodes_info += "\n";

vt_debug_print(
terse, group,
"InfoColl::atRoot: group={:x} {}\n",
group, nodes_info
);
}

void InfoColl::upTree() {
Expand All @@ -248,10 +263,14 @@ void InfoColl::upTree() {
);
decltype(msgs_) msg_in_group = {};
subtree_ = 0;
nodes_.push_back(theContext()->getNode());
for (auto&& msg : msgs_) {
if (msg->isInGroup()) {
msg_in_group.push_back(msg);
subtree_ += msg->getSubtreeSize();
for(auto& node : msg->getNodes()){
nodes_.push_back(node);
}
}
}

Expand All @@ -269,6 +288,15 @@ void InfoColl::upTree() {
coll_wait_count_, group, op, is_root, subtree_
);

std::string nodes_info = "Nodes: ";
nodes_info.reserve(1024);
for (auto& node : nodes_) {
nodes_info += fmt::format("{} ", node);
}
nodes_info += "\n";
vt_debug_print(
terse, group, "InfoColl::upTree: group={:x}, {}\n", group, nodes_info);

if (is_root) {
if (is_in_group) {
auto const& this_node = theContext()->getNode();
Expand Down Expand Up @@ -375,7 +403,7 @@ void InfoColl::upTree() {
);

auto cmsg = makeMessage<GroupCollectiveMsg>(
group,op,is_in_group,total_subtree,child,level
group,op,is_in_group,total_subtree,child,level, 0, nodes_
);
theMsg()->sendMsg<upHan>(p, cmsg);

Expand Down Expand Up @@ -404,7 +432,7 @@ void InfoColl::upTree() {
);

auto msg = makeMessage<GroupCollectiveMsg>(
group,op,is_in_group,static_cast<NodeType>(subtree_),child,0,extra
group,op,is_in_group,static_cast<NodeType>(subtree_),child,0,extra, nodes_
);
theMsg()->sendMsg<upHan>(p, msg);
/*
Expand Down Expand Up @@ -436,7 +464,7 @@ void InfoColl::upTree() {
);

auto msg = makeMessage<GroupCollectiveMsg>(
group,op,is_in_group,total_subtree,child,0,extra
group,op,is_in_group,total_subtree,child,0,extra, nodes_
);
theMsg()->sendMsg<upHan>(p, msg);
// new MsgPtr to avoid thief of original in collection
Expand Down Expand Up @@ -476,7 +504,7 @@ void InfoColl::upTree() {
);

auto msg = makeMessage<GroupCollectiveMsg>(
group,op,is_in_group,total_subtree,child,0,extra
group,op,is_in_group,total_subtree,child,0,extra, nodes_
);
theMsg()->sendMsg<upHan>(p, msg);

Expand Down Expand Up @@ -599,11 +627,11 @@ void InfoColl::collectiveFn(MsgSharedPtr<GroupCollectiveMsg> msg) {
}
}

/*static*/ void InfoColl::tree(GroupOnlyMsg* msg) {
/*static*/ void InfoColl::tree(GroupCollectiveFinalMsg* msg) {
auto const& op_id = msg->getOpID();
vtAssert(op_id != no_op_id, "Must have valid op");
auto msg_ptr = promoteMsg(msg);
GroupManagerT<MsgSharedPtr<GroupOnlyMsg>>::triggerContinuationT(op_id,msg_ptr);
GroupManagerT<MsgSharedPtr<GroupCollectiveFinalMsg>>::triggerContinuationT(op_id,msg_ptr);
}

/*static*/ void InfoColl::upHan(GroupCollectiveMsg* msg) {
Expand Down Expand Up @@ -644,7 +672,7 @@ void InfoColl::downTree(GroupCollectiveMsg* msg) {
}

auto const& group_ = getGroupID();
auto nmsg = makeMessage<GroupOnlyMsg>(group_,down_tree_fin_cont_);
auto nmsg = makeMessage<GroupCollectiveFinalMsg>(group_,down_tree_fin_cont_, nodes_);
theMsg()->sendMsg<downFinishedHan>(from, nmsg);
}

Expand Down Expand Up @@ -685,7 +713,7 @@ void InfoColl::sendDownNewTree() {
"InfoColl::sendDownNewTree: group={:x}, sending to child={}\n",
group_, c
);
auto msg = makeMessage<GroupOnlyMsg>(group_,new_tree_cont_);
auto msg = makeMessage<GroupCollectiveFinalMsg>(group_,new_tree_cont_, nodes_);
theMsg()->sendMsg<newTreeHan>(c, msg);
}
}
Expand Down Expand Up @@ -716,17 +744,27 @@ void InfoColl::finalize() {
);
}

std::string nodes_info = "Nodes: ";
nodes_info.reserve(1024);
for (auto& node : nodes_) {
nodes_info += fmt::format("{} ", node);
}
nodes_info += "\n";
vt_debug_print(
terse, group, "InfoColl::finalize: group={:x}, {}\n", group_, nodes_info
);

auto const& children = collective_->getChildren();
for (auto&& c : children) {

vt_debug_print(
verbose, group,
terse, group,
"InfoColl::finalize: group={:x}, sending to child={}\n",
group_, c
);

auto msg = makeMessage<GroupOnlyMsg>(
group_,finalize_cont_,known_root_node_,is_default_group_
auto msg = makeMessage<GroupCollectiveFinalMsg>(
group_,finalize_cont_,nodes_
);
theMsg()->sendMsg<finalizeHan>(c, msg);
}
Expand All @@ -752,21 +790,22 @@ void InfoColl::finalize() {
}
}

void InfoColl::finalizeTree(GroupOnlyMsg* msg) {
void InfoColl::finalizeTree(GroupCollectiveFinalMsg* msg) {
auto const& new_root = msg->getRoot();
vt_debug_print(
verbose, group,
normal, group,
"InfoColl::finalizeTree: group={:x}, new_root={}\n",
msg->getGroup(), new_root
);
in_phase_two_ = true;
known_root_node_ = new_root;
has_root_ = true;
is_default_group_ = msg->isDefault();
nodes_ = msg->getNodes();
finalize();
}

void InfoColl::downTreeFinished([[maybe_unused]] GroupOnlyMsg* msg) {
void InfoColl::downTreeFinished([[maybe_unused]] GroupCollectiveFinalMsg* msg) {
send_down_finished_++;
finalize();
}
Expand All @@ -779,15 +818,15 @@ void InfoColl::downTreeFinished([[maybe_unused]] GroupOnlyMsg* msg) {
return upHan(msg);
}

/*static*/ void InfoColl::downFinishedHan(GroupOnlyMsg* msg) {
/*static*/ void InfoColl::downFinishedHan(GroupCollectiveFinalMsg* msg) {
return tree(msg);
}

/*static*/ void InfoColl::finalizeHan(GroupOnlyMsg* msg) {
/*static*/ void InfoColl::finalizeHan(GroupCollectiveFinalMsg* msg) {
return tree(msg);
}

/*static*/ void InfoColl::newTreeHan(GroupOnlyMsg* msg) {
/*static*/ void InfoColl::newTreeHan(GroupCollectiveFinalMsg* msg) {
return tree(msg);
}

Expand Down
Loading

0 comments on commit 907d0ef

Please sign in to comment.