Skip to content

Commit

Permalink
#2281: Small fixes and code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobDomagala committed Oct 10, 2024
1 parent c86244b commit 76c377c
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 55 deletions.
13 changes: 4 additions & 9 deletions src/vt/collective/reduce/allreduce/data_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,17 @@ class DataHandler<Kokkos::View<T*, Kokkos::HostSpace, Props...>> {
using Scalar = T;

static std::vector<T> toVec(const ViewType& data) {
std::vector<T> vec;
vec.resize(data.extent(0));
std::memcpy(vec.data(), data.data(), data.extent(0) * sizeof(T));
return vec;
return std::vector<T>(data.data(), data.data() + data.extent(0));
}

static ViewType fromMemory(T* data, size_t size) {
return ViewType(data, size);
}

static ViewType fromVec(const std::vector<T>& data) {
ViewType view("", data.size());
Kokkos::parallel_for(
"InitView", view.extent(0),
KOKKOS_LAMBDA(const int i) { view(i) = static_cast<float>(data[i]); });

ViewType view("view", data.size());
auto data_view = Kokkos::View<const T*, Kokkos::HostSpace, Kokkos::MemoryUnmanaged>(data.data(), data.size());
Kokkos::deep_copy(view, data_view);
return view;
}

Expand Down
39 changes: 6 additions & 33 deletions src/vt/collective/reduce/allreduce/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,61 +48,33 @@
#include "rabenseifner_msg.h"
#include "vt/messaging/message/shared_message.h"
#include <vector>
#include <type_traits>

namespace vt {
template <typename T>
using remove_cvref = std::remove_cv_t<std::remove_reference_t<T>>;
}

namespace vt::collective::reduce::allreduce {

template <typename T>
struct function_traits; // General template declaration.

// Specialization for function pointers.
template <typename Ret, typename... Args>
struct function_traits<Ret(*)(Args...)> {
using return_type = Ret;
static constexpr std::size_t arity = sizeof...(Args);
using args_tuple = std::tuple<Args...>;

template <std::size_t N>
using arg_type = typename std::tuple_element<N, std::tuple<Args...>>::type;
};

template <typename Ret, typename ObjT, typename... Args>
struct function_traits<Ret(ObjT::*)(Args...)> {
using return_type = Ret;
static constexpr std::size_t arity = sizeof...(Args);
using args_tuple = std::tuple<Args...>;

template <std::size_t N>
using arg_type = typename std::tuple_element<N, std::tuple<Args...>>::type;
};

// Primary template
template <typename Scalar, typename DataT>
struct ShouldUseView {
static constexpr bool Value = false;
};

#if MAGISTRATE_KOKKOS_ENABLED
// Partial specialization for Kokkos::View
template <typename Scalar>
struct ShouldUseView<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
static constexpr bool Value = true;
};
#endif // MAGISTRATE_KOKKOS_ENABLED

// Helper alias for cleaner usage
template <typename Scalar, typename DataT>
inline constexpr bool ShouldUseView_v = ShouldUseView<Scalar, DataT>::Value;

template <typename Scalar, typename DataT>
struct DataHelper {
using DataHan = DataHandler<DataT>;

template <typename... Args>
static void assignFromMem(std::vector<Scalar>& dest, const Scalar* data, size_t size) {
std::memcpy(dest.data(), data, size * sizeof(Scalar));
}

template <typename... Args>
static void assign(std::vector<Scalar>& dest, Args&&... data) {
dest = DataHan::toVec(std::forward<Args>(data)...);
Expand Down Expand Up @@ -199,4 +171,5 @@ struct DataHelper<Scalar, Kokkos::View<Scalar*, Kokkos::HostSpace>> {
#endif // MAGISTRATE_KOKKOS_ENABLED

} // namespace vt::collective::reduce::allreduce

#endif /*INCLUDED_VT_COLLECTIVE_REDUCE_ALLREDUCE_HELPERS_H*/
26 changes: 16 additions & 10 deletions src/vt/collective/reduce/allreduce/rabenseifner.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ void Rabenseifner::executeFinalHan(size_t id) {
vtAssert(state.final_handler_.valid(), "Final handler is not set!");

if constexpr (ShouldUseView_v<typename DataHandler<DataT>::Scalar, DataT>) {
state.final_handler_.send(std::move(state.val_));
state.final_handler_.send(state.val_);
} else {
state.final_handler_.send(
std::move(DataHandler<DataT>::fromVec(state.val_)));
Expand Down Expand Up @@ -234,7 +234,7 @@ void Rabenseifner::adjustForPowerOfTwo(size_t id) {

if (is_even_) {
proxy_[actual_partner]
.template sendMsg<&Rabenseifner::template adjustForPowerOfTwoRightHalf<
.template sendMsg<&Rabenseifner::adjustForPowerOfTwoRightHalf<
DataT, Scalar, Op>>(
DataHelperT::createMessage(
state.val_, state.size_ / 2, state.size_ - (state.size_ / 2), id));
Expand Down Expand Up @@ -342,7 +342,6 @@ void Rabenseifner::adjustForPowerOfTwoFinalPart(

template <typename DataT>
bool Rabenseifner::scatterAllMessagesReceived(size_t id) {
// auto const& state = states_.at(id);
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);

Expand All @@ -354,7 +353,6 @@ bool Rabenseifner::scatterAllMessagesReceived(size_t id) {

template <typename DataT>
bool Rabenseifner::scatterIsDone(size_t id) {
// auto const& state = states_.at(id);
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
return (state.scatter_step_ == num_steps_) and
Expand All @@ -363,7 +361,6 @@ bool Rabenseifner::scatterIsDone(size_t id) {

template <typename DataT>
bool Rabenseifner::scatterIsReady(size_t id) {
// auto const& state = states_.at(id);
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
return ((is_part_of_adjustment_group_ and state.finished_adjustment_part_) and
Expand Down Expand Up @@ -508,7 +505,6 @@ template <typename DataT>
bool Rabenseifner::gatherAllMessagesReceived(size_t id) {
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
// auto& state = states_.at(id);
return std::all_of(
state.gather_steps_recv_.cbegin() + state.gather_step_ + 1,
state.gather_steps_recv_.cend(), [](auto const val) { return val; });
Expand All @@ -524,7 +520,6 @@ bool Rabenseifner::gatherIsDone(size_t id) {

template <typename DataT>
bool Rabenseifner::gatherIsReady(size_t id) {
// auto& state = states_.at(id);
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
return (state.gather_step_ == num_steps_ - 1) or
Expand Down Expand Up @@ -622,7 +617,6 @@ void Rabenseifner::gatherIterHandler(RabenseifnerMsg<Scalar, DataT>* msg) {

template <typename DataT>
void Rabenseifner::finalPart(size_t id) {
// auto& state = states_.at(id);
auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, id);
if (state.completed_) {
Expand Down Expand Up @@ -651,8 +645,8 @@ void Rabenseifner::sendToExcludedNodes(size_t id) {
auto const actual_partner = nodes_[this_node_ + 1];
vt_debug_print(
terse, allreduce,
"Rabenseifner::sendToExcludedNodes(): Sending to Node {} ID = {}\n",
actual_partner, id);
"Rabenseifner::sendToExcludedNodes(): Sending to Node {} ID = {} size={}\n",
actual_partner, id, state.size_);

proxy_[actual_partner]
.template sendMsg<
Expand All @@ -664,12 +658,24 @@ void Rabenseifner::sendToExcludedNodes(size_t id) {
template <typename DataT, typename Scalar>
void Rabenseifner::sendToExcludedNodesHandler(
RabenseifnerMsg<Scalar, DataT>* msg) {

vt_debug_print(
terse, allreduce,
"Rabenseifner::sendToExcludedNodesHandler(): Received allreduce result "
"with ID = {}\n",
msg->id_);

auto& state = getState<RabenseifnerT, DataT>(
collection_proxy_, objgroup_proxy_, group_, msg->id_
);
if constexpr (ShouldUseView_v<typename DataHandler<DataT>::Scalar, DataT>) {
state.val_ = msg->val_;
} else {
DataHelper<typename DataHandler<DataT>::Scalar, DataT>::assignFromMem(
state.val_, msg->val_, msg->size_
);
}

executeFinalHan<DataT>(msg->id_);
}

Expand Down
3 changes: 3 additions & 0 deletions src/vt/collective/reduce/allreduce/recursive_doubling.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ void RecursiveDoubling::sendToExcludedNodes(size_t id) {
template <typename DataT>
void RecursiveDoubling::sendToExcludedNodesHandler(
RecursiveDoublingMsg<DataT>* msg) {
auto& state = getState<RecursiveDoublingT, DataT>(
collection_proxy_, objgroup_proxy_, group_, msg->id_);
state.val_ = *msg->val_;
executeFinalHan<DataT>(msg->id_);
}

Expand Down
2 changes: 1 addition & 1 deletion src/vt/group/group_manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ void GroupManager::allreduce(GroupType group, Args&&... args) {
auto iter = local_collective_group_info_.find(group);
vtAssert(iter != local_collective_group_info_.end(), "Must exist");

using DataT = typename function_traits<decltype(f)>::template arg_type<0>;
using DataT = std::tuple_element_t<0, typename FuncTraits<decltype(f)>::TupleType>;

using Reducer = Rabenseifner;
auto const strong_group = collective::reduce::detail::StrongGroup{group};
Expand Down
2 changes: 1 addition & 1 deletion src/vt/vrt/collection/manager.impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -899,7 +899,7 @@ messaging::PendingSend CollectionManager::reduceLocal(
CollectionProxyWrapType<ColT> const& proxy, Args&&... args) {
using namespace collective::reduce::allreduce;

using DataT = typename function_traits<decltype(f)>::template arg_type<0>;
using DataT = std::tuple_element_t<0, typename FuncTraits<decltype(f)>::TupleType>;
using IndexT = typename ColT::IndexType;

// Get the current running index context
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/collection/test_allreduce_collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct RecursiveDoublingColl : vt::Collection<RecursiveDoublingColl, vt::Index1D
return val == expected_val;
});


ASSERT_TRUE(verify_result);
}

Expand Down Expand Up @@ -118,7 +119,7 @@ struct RecursiveDoublingColl : vt::Collection<RecursiveDoublingColl, vt::Index1D

struct TestAllreduceCollection : TestParallelHarness {};

TEST_F(TestAllreduceCollection, test_allreduce_recursive_doubling) {
TEST_F(TestAllreduceCollection, test_allreduce) {
using namespace vt::collective::reduce::allreduce;

auto const my_node = theContext()->getNode();
Expand Down

0 comments on commit 76c377c

Please sign in to comment.