Skip to content

Commit

Permalink
CUDA: mpool for pipeline staging for domain memory
Browse files Browse the repository at this point in the history
  • Loading branch information
bureddy committed Sep 28, 2017
1 parent 369d416 commit 2915cdb
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,10 @@ static ucs_config_field_t ucp_config_table[] = {
"Also the value has to be bigger than UCX_TM_THRESH to take an effect." ,
ucs_offsetof(ucp_config_t, ctx.tm_max_bcopy), UCS_CONFIG_TYPE_MEMUNITS},

{"RNDV_FRAG_SIZE", "65536",
"RNDV fragment size \n",
ucs_offsetof(ucp_config_t, ctx.rndv_frag_size), UCS_CONFIG_TYPE_MEMUNITS},

{NULL}
};

Expand Down
2 changes: 2 additions & 0 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ typedef struct ucp_context_config {
ucp_atomic_mode_t atomic_mode;
/** If use mutex for MT support or not */
int use_mt_mutex;
/** RNDV pipeline fragment size */
size_t rndv_frag_size;
/** On-demand progress */
int adaptive_progress;
} ucp_context_config_t;
Expand Down
102 changes: 102 additions & 0 deletions src/ucp/core/ucp_worker.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,88 @@ static ucs_stats_class_t ucp_worker_stats_class = {
#endif


static ucs_status_t ucp_mpool_dereg_mds(ucp_context_h context, ucp_mem_h memh) {
unsigned md_index, uct_index;
ucs_status_t status;

uct_index = 0;

for (md_index = 0; md_index < context->num_mds; ++md_index) {
if (!(memh->md_map & UCS_BIT(md_index))) {
continue;
}

status = uct_md_mem_dereg(context->tl_mds[md_index].md,
memh->uct[uct_index]);
if (status != UCS_OK) {
ucs_error("Failed to dereg address %p with md %s", memh->address,
context->tl_mds[md_index].rsc.md_name);
return status;
}

++uct_index;
}

return UCS_OK;
}

static ucs_status_t ucp_mpool_reg_mds(ucp_context_h context, ucp_mem_h memh) {
unsigned md_index, uct_memh_count;
ucs_status_t status;

uct_memh_count = 0;
memh->md_map = 0;

for (md_index = 0; md_index < context->num_mds; ++md_index) {
if (context->tl_mds[md_index].attr.cap.flags & UCT_MD_FLAG_REG) {
status = uct_md_mem_reg(context->tl_mds[md_index].md, memh->address,
memh->length, 0, memh->uct[uct_memh_count]);
if (status != UCS_OK) {
ucs_error("Failed to register memory pool chunk %p with md %s",
memh->address, context->tl_mds[md_index].rsc.md_name);
return status;
}

memh->md_map |= UCS_BIT(md_index);
uct_memh_count++;
}
}

return UCS_OK;
}


static ucs_status_t ucp_mpool_rndv_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p) {
ucp_worker_h worker = ucs_container_of(mp, ucp_worker_t, reg_mp);
ucp_mem_desc_t *chunk_hdr;
ucs_status_t status;

status = ucp_mpool_malloc(mp, size_p, chunk_p);
if (status != UCS_OK) {
ucs_error("Failed to allocate memory pool chunk: %s", ucs_status_string(status));
return UCS_ERR_NO_MEMORY;
}

chunk_hdr = (ucp_mem_desc_t *)(*chunk_p) - 1;

status = ucp_mpool_reg_mds(worker->context, chunk_hdr->memh);
if (status != UCS_OK) {
ucp_mpool_dereg_mds(worker->context, chunk_hdr->memh);
return status;
}

return UCS_OK;
}


static void ucp_mpool_rndv_free(ucs_mpool_t *mp, void *chunk) {
ucp_worker_h worker = ucs_container_of(mp, ucp_worker_t, reg_mp);
ucp_mem_desc_t *chunk_hdr = (ucp_mem_desc_t *)chunk - 1;
ucp_mpool_dereg_mds(worker->context, chunk_hdr->memh);
ucp_mpool_free(mp, chunk);
}


ucs_mpool_ops_t ucp_am_mpool_ops = {
.chunk_alloc = ucs_mpool_hugetlb_malloc,
.chunk_release = ucs_mpool_hugetlb_free,
Expand All @@ -52,6 +134,14 @@ ucs_mpool_ops_t ucp_reg_mpool_ops = {
};


ucs_mpool_ops_t ucp_rndv_frag_mpool_ops = {
.chunk_alloc = ucp_mpool_rndv_malloc,
.chunk_release = ucp_mpool_rndv_free,
.obj_init = ucs_empty_function,
.obj_cleanup = ucs_empty_function
};


void ucp_worker_iface_check_events(ucp_worker_iface_t *wiface, int force);


Expand Down Expand Up @@ -909,8 +999,19 @@ static ucs_status_t ucp_worker_init_mpools(ucp_worker_h worker,
goto err_release_am_mpool;
}


status = ucs_mpool_init(&worker->rndv_frag_mp, 0,
context->config.ext.rndv_frag_size,
0, 128, 128, UINT_MAX,
&ucp_rndv_frag_mpool_ops, "ucp_rndv_frags");
if (status != UCS_OK) {
goto err_release_reg_mpool;
}

return UCS_OK;

err_release_reg_mpool:
ucs_mpool_cleanup(&worker->reg_mp, 0);
err_release_am_mpool:
ucs_mpool_cleanup(&worker->am_mp, 0);
out:
Expand Down Expand Up @@ -1120,6 +1221,7 @@ void ucp_worker_destroy(ucp_worker_h worker)
ucp_worker_destroy_eps(worker);
ucs_mpool_cleanup(&worker->am_mp, 1);
ucs_mpool_cleanup(&worker->reg_mp, 1);
ucs_mpool_cleanup(&worker->rndv_frag_mp, 1);
ucp_worker_close_ifaces(worker);
ucp_worker_wakeup_cleanup(worker);
ucs_mpool_cleanup(&worker->req_mp, 1);
Expand Down
1 change: 1 addition & 0 deletions src/ucp/core/ucp_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ typedef struct ucp_worker {
ucs_mpool_t am_mp; /* Memory pool for AM receives */
ucs_mpool_t reg_mp; /* Registered memory pool */
ucp_mt_lock_t mt_lock; /* Configuration of multi-threading support */
ucs_mpool_t rndv_frag_mp; /* Memory pool for RNDV fragments */

UCS_STATS_NODE_DECLARE(stats);

Expand Down
2 changes: 1 addition & 1 deletion src/uct/cuda/cuda_copy/cuda_copy_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ static ucs_config_field_t uct_cuda_copy_md_config_table[] = {

static ucs_status_t uct_cuda_copy_md_query(uct_md_h md, uct_md_attr_t *md_attr)
{
md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_RNDV_REG | UCT_MD_FLAG_ADDR_DN;
md_attr->cap.flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_ADDR_DN;
md_attr->cap.addr_dn = UCT_MD_ADDR_DOMAIN_CUDA;
md_attr->cap.max_alloc = 0;
md_attr->cap.max_reg = ULONG_MAX;
Expand Down
1 change: 0 additions & 1 deletion src/uct/ib/base/ib_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,6 @@ static ucs_status_t uct_ib_md_query(uct_md_h uct_md, uct_md_attr_t *md_attr)
md_attr->cap.max_alloc = ULONG_MAX; /* TODO query device */
md_attr->cap.max_reg = ULONG_MAX; /* TODO query device */
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_RNDV_REG |
UCT_MD_FLAG_NEED_MEMH |
UCT_MD_FLAG_NEED_RKEY |
UCT_MD_FLAG_ADVISE;
Expand Down

0 comments on commit 2915cdb

Please sign in to comment.