diff --git a/src/components/ec/base/ucc_ec_base.h b/src/components/ec/base/ucc_ec_base.h index 9e753104e2..da76d61140 100644 --- a/src/components/ec/base/ucc_ec_base.h +++ b/src/components/ec/base/ucc_ec_base.h @@ -181,7 +181,7 @@ typedef struct ucc_ee_executor_task { typedef struct node_ucc_ee_executor_task node_ucc_ee_executor_task_t; typedef struct node_ucc_ee_executor_task { - ucc_ee_executor_task_t *val; + ucc_ee_executor_task_t *etask; node_ucc_ee_executor_task_t *next; } node_ucc_ee_executor_task_t; diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index dd7158b3cc..826a43f476 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -51,8 +51,8 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) args->root : 0; ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); size_t local = GET_LOCAL_COUNT(args, size, rank); - ucp_mem_h *mh_list = task->super.mh_list; - int max_count = task->super.count_mh; + ucp_mem_h *mh_list = task->mh_list; + int max_count = task->count_mh; int count_mh = 0; void *sbuf; ptrdiff_t peer_seg_offset, local_seg_offset; @@ -63,10 +63,9 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) size_t extra_count; EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", - task->allgather_kn.etask_linked_list_head->val); - // EXEC_TASK_TEST_2("failed to copy data to user buffer", - // task->allgather_kn.etask); - task->allgather_kn.etask_linked_list_head->val = NULL; + task->allgather_kn.etask_linked_list_head->etask); + + // task->allgather_kn.etask_linked_list_head = NULL; UCC_KN_GOTO_PHASE(task->allgather_kn.phase); if (KN_NODE_EXTRA == node_type) { peer = ucc_knomial_pattern_get_proxy(p, rank); @@ -223,8 +222,6 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_start", 0); ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); - task->allgather_kn.etask_linked_list_head = (node_ucc_ee_executor_task_t *)malloc(sizeof(node_ucc_ee_executor_task_t)); - task->allgather_kn.etask_linked_list_head->val = NULL; task->allgather_kn.phase = UCC_KN_PHASE_INIT; if (ct == UCC_COLL_TYPE_ALLGATHER) { ucc_kn_ag_pattern_init(size, rank, radix, args->dst.info.count, @@ -243,7 +240,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) eargs.copy.len = args->src.info.count * ucc_dt_size(args->src.info.datatype); status = ucc_ee_executor_task_post(exec, &eargs, - &task->allgather_kn.etask_linked_list_head->val); + &task->allgather_kn.etask_linked_list_head->etask); if (ucc_unlikely(status != UCC_OK)) { task->super.status = status; return status; @@ -289,13 +286,13 @@ void register_memory(ucc_coll_task_t *coll_task){ ucc_rank_t peer, peer_dist; ucc_kn_radix_t loop_step; size_t peer_seg_count, local_seg_count; - // ucc_status_t status; + ucc_status_t status; size_t extra_count; ucc_tl_ucp_context_t *ctx = UCC_TL_UCP_TEAM_CTX(team); ucp_mem_map_params_t mmap_params; ucp_mem_h mh; - int size_of_list = 6; + int size_of_list = 1; int count_mh = 0; ucp_mem_h *mh_list = (ucp_mem_h *)malloc(size_of_list * sizeof(ucp_mem_h)); @@ -304,17 +301,16 @@ void register_memory(ucc_coll_task_t *coll_task){ UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; mmap_params.memory_type = ucc_memtype_to_ucs[mem_type]; - // EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", - // task->allgather_kn.etask_linked_list_head->val); - // task->allgather_kn.etask_linked_list_head->val = NULL; - // UCC_KN_GOTO_PHASE(task->allgather_kn.phase); if (KN_NODE_EXTRA == node_type) { if (p->type != KN_PATTERN_ALLGATHERX) { mmap_params.address = task->allgather_kn.sbuf; mmap_params.length = local * dt_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -324,8 +320,11 @@ void register_memory(ucc_coll_task_t *coll_task){ mmap_params.address = rbuf; mmap_params.length = data_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -339,8 +338,11 @@ void register_memory(ucc_coll_task_t *coll_task){ mmap_params.address = PTR_OFFSET(task->allgather_kn.sbuf, local * dt_size); mmap_params.length = extra_count * dt_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -348,6 +350,11 @@ void register_memory(ucc_coll_task_t *coll_task){ mh_list[count_mh++] = mh; } + if ((KN_NODE_EXTRA == node_type) || (KN_NODE_PROXY == node_type)) { + if (KN_NODE_EXTRA == node_type) { + goto out; + } + } while (!ucc_knomial_pattern_loop_done(p)) { ucc_kn_ag_pattern_peer_seg(rank, p, &local_seg_count, &local_seg_offset); @@ -366,8 +373,11 @@ void register_memory(ucc_coll_task_t *coll_task){ } mmap_params.address = sbuf; mmap_params.length = local_seg_count * dt_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -389,10 +399,13 @@ void register_memory(ucc_coll_task_t *coll_task){ continue; } } - mmap_params.address = PTR_OFFSET(rbuf, peer_seg_offset * dt_size); - mmap_params.length = peer_seg_count * dt_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + mmap_params.address = PTR_OFFSET(rbuf, peer_seg_offset * dt_size); + mmap_params.length = peer_seg_count * dt_size; + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -405,8 +418,11 @@ void register_memory(ucc_coll_task_t *coll_task){ if (KN_NODE_PROXY == node_type) { mmap_params.address = args->dst.info.buffer; mmap_params.length = data_size; - UCPCHECK_GOTO(ucs_status_to_ucc_status(ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh)), - task, out); + status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + if (UCC_OK != status) { + task->super.status = status; + return; + } if (count_mh == size_of_list){ size_of_list *= 2; mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); @@ -417,8 +433,8 @@ void register_memory(ucc_coll_task_t *coll_task){ out: ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; - coll_task->mh_list = mh_list; - coll_task->count_mh = count_mh-1; + task->mh_list = mh_list; + task->count_mh = count_mh-1; UCC_TL_UCP_PROFILE_REQUEST_EVENT(coll_task, "ucp_allgather_kn_done", 0); } @@ -444,9 +460,10 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_init_r( task->subset.myrank = sbgp->group_rank; task->subset.map = sbgp->map; } + register_memory(&task->super); + task->allgather_kn.etask_linked_list_head = NULL; task->allgather_kn.p.radix = radix; task->super.flags |= UCC_COLL_TASK_FLAG_EXECUTOR; - register_memory(&task->super); task->super.post = ucc_tl_ucp_allgather_knomial_start; task->super.progress = ucc_tl_ucp_allgather_knomial_progress; *task_h = &task->super; diff --git a/src/components/tl/ucp/tl_ucp_coll.h b/src/components/tl/ucp/tl_ucp_coll.h index dc864523a1..78fd3642da 100644 --- a/src/components/tl/ucp/tl_ucp_coll.h +++ b/src/components/tl/ucp/tl_ucp_coll.h @@ -57,21 +57,17 @@ void ucc_tl_ucp_team_default_score_str_free( } \ } while(0) -// #define EXEC_TASK_TEST_2(_errmsg, _etask) do { -// if (_etask != NULL) { -// status = ucc_ee_executor_task_test(_etask); -// if (status > 0) { -// task->super.status = UCC_INPROGRESS; -// return; -// } -// ucc_ee_executor_task_finalize(_etask); -// _etask = NULL; -// if (ucc_unlikely(status < 0)) { -// tl_error(UCC_TASK_LIB(task), _errmsg); -// task->super.status = status; -// return; -// } -// } +// #define MEM_MAP() do { + // status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); + // if (UCC_OK != status) { + // task->super.status = status; + // return; + // } +// if (count_mh == size_of_list){ +// size_of_list *= 2; +// mh_list = (ucp_mem_h *)realloc(mh_list, size_of_list * sizeof(ucp_mem_h)); +// } +// mh_list[count_mh++] = mh; // } while(0) #define EXEC_TASK_WAIT(_etask, ...) \ @@ -113,6 +109,8 @@ typedef struct ucc_tl_ucp_allreduce_sw_host_allgather typedef struct ucc_tl_ucp_task { ucc_coll_task_t super; uint32_t flags; + ucp_mem_h *mh_list; + int count_mh; union { struct { uint32_t send_posted; @@ -428,25 +426,30 @@ static inline ucc_status_t ucc_tl_ucp_test_with_etasks(ucc_tl_ucp_task_t *task) { int polls = 0; ucc_status_t status; - int all_tests_positive = 1; while (polls++ < task->n_polls) { node_ucc_ee_executor_task_t *current_node; + node_ucc_ee_executor_task_t *prev_node; current_node = task->allgather_kn.etask_linked_list_head; + prev_node = NULL; while(current_node != NULL) { - if (current_node->val != NULL) { - status = ucc_ee_executor_task_test(current_node->val); \ - if (status > 0) { \ - ucc_ee_executor_task_finalize(current_node->val); \ - ucp_memcpy_device_complete(current_node->val->completion, status); - current_node->val = NULL; \ - } \ - else { - all_tests_positive = 0; + status = ucc_ee_executor_task_test(current_node->etask); + if (status > 0) { + ucp_memcpy_device_complete(current_node->etask->completion, status); + ucc_ee_executor_task_finalize(current_node->etask); + if (prev_node != NULL){ + prev_node->next = current_node->next; //to remove from list + } + else{ //i'm on first node + task->allgather_kn.etask_linked_list_head = current_node->next; } } + else { + prev_node = current_node; + } + current_node = current_node->next; //to iterate to next node } - if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && all_tests_positive==1) { + if (UCC_TL_UCP_TASK_P2P_COMPLETE(task) && task->allgather_kn.etask_linked_list_head == NULL) { return UCC_OK; } ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); @@ -479,22 +482,30 @@ static inline ucc_status_t ucc_tl_ucp_test_recv(ucc_tl_ucp_task_t *task) static inline ucc_status_t ucc_tl_ucp_test_recv_with_etasks(ucc_tl_ucp_task_t *task) { int polls = 0; ucc_status_t status; - int all_tests_positive = 1; while (polls++ < task->n_polls) { node_ucc_ee_executor_task_t *current_node; + node_ucc_ee_executor_task_t *prev_node; current_node = task->allgather_kn.etask_linked_list_head; + prev_node = NULL; while(current_node != NULL) { - status = ucc_ee_executor_task_test(current_node->val); \ - if (status > 0) { \ - return UCC_INPROGRESS; \ - } \ - ucc_ee_executor_task_finalize(current_node->val); \ - ucp_memcpy_device_complete(current_node->val->completion, status); - current_node->val = NULL; \ - all_tests_positive = 0; + status = ucc_ee_executor_task_test(current_node->etask); \ + if (status > 0) { + ucp_memcpy_device_complete(current_node->etask->completion, status); \ + ucc_ee_executor_task_finalize(current_node->etask); \ + if (prev_node != NULL){ + prev_node->next = current_node->next; //to remove from list + } + else{ //i'm on first node + task->allgather_kn.etask_linked_list_head = current_node->next; + } + } + else { + prev_node = current_node; + } + current_node = current_node->next; //to iterate to next node } - if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && all_tests_positive==1) { + if (UCC_TL_UCP_TASK_RECV_COMPLETE(task) && task->allgather_kn.etask_linked_list_head==NULL) { return UCC_OK; } ucp_worker_progress(UCC_TL_UCP_TASK_TEAM(task)->worker->ucp_worker); diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 6a86570e5b..530152fc47 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -146,7 +146,7 @@ static int memcpy_device_start(void *dest, void *src, size_t size, status = ucc_coll_task_get_executor(&task->super, &exec); if (ucc_unlikely(status != UCC_OK)) { task->super.status = status; - return 0; + return -1; } eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; @@ -156,17 +156,16 @@ static int memcpy_device_start(void *dest, void *src, size_t size, node_ucc_ee_executor_task_t *new_node; new_node = ucc_mpool_get(&task->allgather_kn.etask_node_mpool); status = ucc_ee_executor_task_post(exec, &eargs, - &new_node->val); + &new_node->etask); - new_node->next = task->allgather_kn.etask_linked_list_head; - task->allgather_kn.etask_linked_list_head = new_node; - if (ucc_unlikely(status != UCC_OK)) { task->super.status = status; - return 0; + return -1; } + new_node->next = task->allgather_kn.etask_linked_list_head; + task->allgather_kn.etask_linked_list_head = new_node; - task->allgather_kn.etask_linked_list_head->val->completion = completion; + task->allgather_kn.etask_linked_list_head->etask->completion = completion; return 1; } @@ -178,7 +177,6 @@ static void memcpy_device(void *dest, void *src, size_t size, void *user_data){ ucc_ee_executor_t *exec; ucc_ee_executor_task_t *etask; ucc_tl_ucp_task_t *task = (ucc_tl_ucp_task_t *) user_data; - // void *non_const_src = (void *) src; status = ucc_coll_task_get_executor(&task->super, &exec); if (ucc_unlikely(status != UCC_OK)) { @@ -200,17 +198,15 @@ static void memcpy_device(void *dest, void *src, size_t size, void *user_data){ // user_data->super.status = status; // return; // } - continue; } ucc_ee_executor_task_finalize(etask); return; } -ucp_worker_mem_callbacks_t copy_callback = { - +ucp_worker_mem_callbacks_t copy_callback = +{ .memcpy_device_start = memcpy_device_start, .memcpy_device = memcpy_device - }; UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, diff --git a/src/schedule/ucc_schedule.h b/src/schedule/ucc_schedule.h index 3f071152f9..2706572e63 100644 --- a/src/schedule/ucc_schedule.h +++ b/src/schedule/ucc_schedule.h @@ -113,8 +113,6 @@ typedef struct ucc_coll_task { /* timestamp of the start time: either post or triggered_post */ double start_time; uint32_t seq_num; - ucp_mem_h *mh_list; - int count_mh; } ucc_coll_task_t; extern struct ucc_mpool_ops ucc_coll_task_mpool_ops;