Skip to content

Commit

Permalink
change name for better understanding and fixed first running errors
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffnvidia committed Aug 20, 2024
1 parent 8ef02c3 commit 3417737
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 23 deletions.
25 changes: 14 additions & 11 deletions src/components/tl/ucp/allgather/allgather_knomial.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "coll_patterns/sra_knomial.h"
#include "utils/ucc_math.h"
#include "utils/ucc_coll_utils.h"
#include <stdio.h>

#define SAVE_STATE(_phase) \
do { \
Expand Down Expand Up @@ -70,7 +71,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
if (KN_NODE_EXTRA == node_type) {
peer = ucc_knomial_pattern_get_proxy(p, rank);
if (p->type != KN_PATTERN_ALLGATHERX) {
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_2(task->allgather_kn.sbuf,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(task->allgather_kn.sbuf,
local * dt_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
Expand All @@ -81,7 +82,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
goto out;
}
}
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_2(rbuf, data_size, mem_type,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(rbuf, data_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer,broot,size)),
team, task, mh_list[count_mh++]),
Expand All @@ -95,7 +96,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
peer = ucc_knomial_pattern_get_extra(p, rank);
extra_count = GET_LOCAL_COUNT(args, size, peer);
peer = ucc_ep_map_eval(task->subset.map, peer);
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb_2(PTR_OFFSET(task->allgather_kn.sbuf,
UCPCHECK_GOTO(ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(task->allgather_kn.sbuf,
local * dt_size), extra_count * dt_size,
mem_type, peer, team, task, mh_list[count_mh++]),
task, out);
Expand Down Expand Up @@ -131,7 +132,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
continue;
}
}
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_2(sbuf, local_seg_count * dt_size,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(sbuf, local_seg_count * dt_size,
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
Expand All @@ -158,7 +159,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)
}
}
UCPCHECK_GOTO(
ucc_tl_ucp_recv_nb_2(PTR_OFFSET(rbuf, peer_seg_offset * dt_size),
ucc_tl_ucp_recv_nb_with_mem(PTR_OFFSET(rbuf, peer_seg_offset * dt_size),
peer_seg_count * dt_size, mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
Expand All @@ -179,7 +180,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task)

if (KN_NODE_PROXY == node_type) {
peer = ucc_knomial_pattern_get_extra(p, rank);
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_2(args->dst.info.buffer, data_size,
UCPCHECK_GOTO(ucc_tl_ucp_send_nb_with_mem(args->dst.info.buffer, data_size,
mem_type,
ucc_ep_map_eval(task->subset.map,
INV_VRANK(peer, broot, size)),
Expand Down Expand Up @@ -222,6 +223,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task)

UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0);
ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
task->allgather_kn.etask_linked_list_head = (node_ucc_ee_executor_task_t *)malloc(sizeof(node_ucc_ee_executor_task_t));
task->allgather_kn.etask_linked_list_head->val = NULL;
task->allgather_kn.phase = UCC_KN_PHASE_INIT;
if (ct == UCC_COLL_TYPE_ALLGATHER) {
Expand Down Expand Up @@ -287,7 +289,7 @@ void register_memory(ucc_coll_task_t *coll_task){
ucc_rank_t peer, peer_dist;
ucc_kn_radix_t loop_step;
size_t peer_seg_count, local_seg_count;
ucc_status_t status;
// ucc_status_t status;
size_t extra_count;

ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team);
Expand All @@ -302,9 +304,9 @@ void register_memory(ucc_coll_task_t *coll_task){
UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
mmap_params.memory_type = ucc_memtype_to_ucs[mem_type];

EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test",
task->allgather_kn.etask_linked_list_head->val);
task->allgather_kn.etask_linked_list_head->val = NULL;
// EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test",
// task->allgather_kn.etask_linked_list_head->val);
// task->allgather_kn.etask_linked_list_head->val = NULL;
// UCC_KN_GOTO_PHASE(task->allgather_kn.phase);
if (KN_NODE_EXTRA == node_type) {
if (p->type != KN_PATTERN_ALLGATHERX) {
Expand Down Expand Up @@ -429,7 +431,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
ucc_tl_ucp_task_t *task;
ucc_sbgp_t *sbgp;

register_memory(*task_h);
printf("USING NEW KNOMIAL");

task = ucc_tl_ucp_init_task(coll_args, team);
ucc_mpool_init(&task->allgather_kn.etask_node_mpool, 0, sizeof(node_ucc_ee_executor_task_t),
Expand All @@ -444,6 +446,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r(
}
task->allgather_kn.p.radix = radix;
task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR;
register_memory(&task->super);
task->super.post = ucc_tl_ucp_allgather_knomial_start;
task->super.progress = ucc_tl_ucp_allgather_knomial_progress;
*task_h = &task->super;
Expand Down
12 changes: 6 additions & 6 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,14 @@ ucc_tl_ucp_context_service_init(const char *prefix, ucp_params_t ucp_params,
return ucc_status;
}

static int memcpy_device_start(void *dest, const void *src, size_t size,
static int memcpy_device_start(void *dest, void *src, size_t size,
void *completion, void *user_data) {

ucc_status_t status;
ucc_ee_executor_task_args_t eargs;
ucc_ee_executor_t *exec;
ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *) user_data;
void *non_const_src = (void *) src;
// void *non_const_src = (void *) src;

status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
Expand All @@ -150,7 +150,7 @@ static int memcpy_device_start(void *dest, const void *src, size_t size,
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = non_const_src;
eargs.copy.src = src;
eargs.copy.dst = dest;
eargs.copy.len = size;
node_ucc_ee_executor_task_t *new_node;
Expand All @@ -171,14 +171,14 @@ static int memcpy_device_start(void *dest, const void *src, size_t size,

}

static void memcpy_device(void *dest, const void *src, size_t size, void *user_data){
static void memcpy_device(void *dest, void *src, size_t size, void *user_data){

ucc_status_t status;
ucc_ee_executor_task_args_t eargs;
ucc_ee_executor_t *exec;
ucc_ee_executor_task_t *etask;
ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *) user_data;
void *non_const_src = (void *) src;
// void *non_const_src = (void *) src;

status = ucc_coll_task_get_executor(&task->super, &exec);
if (ucc_unlikely(status != UCC_OK)) {
Expand All @@ -187,7 +187,7 @@ static void memcpy_device(void *dest, const void *src, size_t size, void *user_d
}

eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY;
eargs.copy.src = non_const_src;
eargs.copy.src = src;
eargs.copy.dst = dest;
eargs.copy.len = size;

Expand Down
12 changes: 6 additions & 6 deletions src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ucc_tl_ucp_send_common(void *buffer, size_t msglen, ucc_memory_type_t mtype,
}

static inline ucs_status_ptr_t
ucc_tl_ucp_send_common_2(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_tl_ucp_send_common_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team,
ucc_tl_ucp_task_t *task, ucp_send_nbx_callback_t cb, void *user_data, ucp_mem_h mh)
{
Expand Down Expand Up @@ -143,13 +143,13 @@ ucc_tl_ucp_send_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype,
}

static inline ucc_status_t
ucc_tl_ucp_send_nb_2(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_tl_ucp_send_nb_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team,
ucc_tl_ucp_task_t *task, ucp_mem_h mh)
{
ucs_status_ptr_t ucp_status;

ucp_status = ucc_tl_ucp_send_common_2(buffer, msglen, mtype, dest_group_rank,
ucp_status = ucc_tl_ucp_send_common_with_mem(buffer, msglen, mtype, dest_group_rank,
team, task, ucc_tl_ucp_send_completion_cb,
(void *)task, mh);
if (UCS_OK != ucp_status) {
Expand Down Expand Up @@ -206,7 +206,7 @@ ucc_tl_ucp_recv_common(void *buffer, size_t msglen, ucc_memory_type_t mtype,
}

static inline ucs_status_ptr_t
ucc_tl_ucp_recv_common_2(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_tl_ucp_recv_common_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team,
ucc_tl_ucp_task_t *task, ucp_tag_recv_nbx_callback_t cb, void *user_data, ucp_mem_h mh)
{
Expand Down Expand Up @@ -254,13 +254,13 @@ ucc_tl_ucp_recv_nb(void *buffer, size_t msglen, ucc_memory_type_t mtype,
}

static inline ucc_status_t
ucc_tl_ucp_recv_nb_2(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_tl_ucp_recv_nb_with_mem(void *buffer, size_t msglen, ucc_memory_type_t mtype,
ucc_rank_t dest_group_rank, ucc_tl_ucp_team_t *team,
ucc_tl_ucp_task_t *task, ucp_mem_h mh)
{
ucs_status_ptr_t ucp_status;

ucp_status = ucc_tl_ucp_recv_common_2(buffer, msglen, mtype, dest_group_rank,
ucp_status = ucc_tl_ucp_recv_common_with_mem(buffer, msglen, mtype, dest_group_rank,
team, task, ucc_tl_ucp_recv_completion_cb,
(void *)task, mh);
if (UCS_OK != ucp_status) {
Expand Down

0 comments on commit 3417737

Please sign in to comment.