Skip to content

Commit

Permalink
#2281: Don't use explicit Kokkos::HostSpace for Kokkos::View in Raben…
Browse files Browse the repository at this point in the history
…seifnerMsg and ShouldUseView
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent 46c73af commit 33c5e98
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 23 deletions.
4 changes: 2 additions & 2 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ struct ShouldUseView {
};

#if MAGISTRATE_KOKKOS_ENABLED
template <typename Scalar>
struct ShouldUseView<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
template <typename Scalar, typename... Properties>
struct ShouldUseView<Scalar, Kokkos::View<Scalar*, Properties...>> {
static constexpr bool Value = true;
};

Expand Down
8 changes: 8 additions & 0 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

#include "vt/collective/reduce/allreduce/rabenseifner.h"
#include "vt/collective/reduce/allreduce/data_handler.h"
#include "vt/collective/reduce/allreduce/helpers.h"
#include "vt/config.h"
#include "vt/context/context.h"
#include "vt/configs/error/config_assert.h"
Expand Down Expand Up @@ -469,6 +470,13 @@ template <typename DataT, typename Scalar, template <typename Arg> class Op>
/*static*/ void
Rabenseifner::scatterReduceIterHandler(RabenseifnerMsg<Scalar, DataT>* msg) {
auto* reducer = getAllreducer<RabenseifnerT>(msg->info_);

vt_debug_print(
terse, allreduce,
"Rabenseifner::scatterReduceIterHandler reducer={} ID={} \n",
(reducer != nullptr), msg->id_
);

if (reducer) {
reducer->template scatterHandler<DataT, Scalar, Op>(msg);
} else {
Expand Down
3 changes: 1 addition & 2 deletions src/vt/collective/reduce/allreduce/rabenseifner_msg.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,9 @@
#define INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_RABENSEIFNER_MSG_H

#include "vt/config.h"
#include "vt/messaging/active.h"
#include "vt/configs/debug/debug_print.h"
#include "vt/collective/reduce/operators/default_msg.h"
#include "type.h"

namespace vt::collective::reduce::allreduce {

template <typename Scalar, typename DataT>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void RecursiveDoubling::storeData(size_t id, Args&&... data) {

vt_debug_print(
terse, allreduce,
"RecursiveDoubling (this={}): local_col_wait_count_={} ID={} "
"RecursiveDoubling::storeData (this={}): local_col_wait_count_={} ID={} "
"initialized={}\n",
print_ptr(this), state.local_col_wait_count_, id, state.initialized_);

Expand Down
31 changes: 15 additions & 16 deletions src/vt/collective/reduce/allreduce/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,35 +103,34 @@ struct RecursiveDoublingState : StateBase {

template <typename Scalar, typename DataT>
struct RabenseifnerState : RabenseifnerBase {
using RabenseifnerMsgT = MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>;

~RabenseifnerState() override = default;

std::vector<Scalar> val_ = {};

MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ =
{};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ =
{};
RabenseifnerMsgT left_adjust_message_ = nullptr;
RabenseifnerMsgT right_adjust_message_ = nullptr;
std::vector<RabenseifnerMsgT> scatter_messages_ = {};
std::vector<RabenseifnerMsgT> gather_messages_ = {};

vt::pipe::callback::cbunion::CallbackTyped<DataT> final_handler_ = {};
};

#if MAGISTRATE_KOKKOS_ENABLED
template <typename Scalar>
struct RabenseifnerState<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>>
template <typename Scalar, typename... Properties>
struct RabenseifnerState<Scalar, Kokkos::View<Scalar*, Properties...>>
: RabenseifnerBase {
using DataT = Kokkos::View<Scalar*, Kokkos::HostSpace>;
using DataT = Kokkos::View<Scalar*, Properties...>;
using RabenseifnerMsgT = MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>;
~RabenseifnerState() override = default;

Kokkos::View<Scalar*, Kokkos::HostSpace> val_ = {};
DataT val_ = {};

MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> left_adjust_message_ = nullptr;
MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>> right_adjust_message_ = nullptr;
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> scatter_messages_ =
{};
std::vector<MsgSharedPtr<RabenseifnerMsg<Scalar, DataT>>> gather_messages_ =
{};
RabenseifnerMsgT left_adjust_message_ = nullptr;
RabenseifnerMsgT right_adjust_message_ = nullptr;
std::vector<RabenseifnerMsgT> scatter_messages_ = {};
std::vector<RabenseifnerMsgT> gather_messages_ = {};

vt::pipe::callback::cbunion::CallbackTyped<DataT> final_handler_ = {};
};
Expand Down
1 change: 1 addition & 0 deletions src/vt/group/group_manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ void GroupManager::allreduce(GroupType group, Args&&... args) {

auto const this_node = theContext()->getNode();
auto id = StateHolder::getNextID(strong_group);

reducer->template setFinalHandler<DataT>(
theCB()->makeSend<f>(this_node), id);
reducer->template storeData<DataT, Op>(id, std::forward<Args>(args)...);
Expand Down
3 changes: 1 addition & 2 deletions src/vt/objgroup/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,7 @@ ObjGroupManager::allreduce(ProxyType<ObjT> proxy, Args&&... data) {

auto const this_node = vt::theContext()->getNode();
auto const strong_proxy = vt::collective::reduce::detail::StrongObjGroup{proxy.getProxy()};

auto const id = StateHolder::getNextID<Type>(strong_proxy);
auto const id = StateHolder::getNextID(strong_proxy);

auto* reducer = AllreduceHolder::getOrCreateAllreducer<Type>(strong_proxy);

Expand Down

0 comments on commit 33c5e98

Please sign in to comment.