Skip to content

Commit

Permalink
Merge pull request #10053 from Artemy-Mellanox/topic/gdr_rcache-8
Browse files Browse the repository at this point in the history
UCP/MM: Selectively exclude mem type MDs from rcache
  • Loading branch information
yosefe authored Sep 13, 2024
2 parents 5b7fb19 + 4572856 commit 27ab197
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 70 deletions.
130 changes: 67 additions & 63 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,6 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,

memh->uct[md_index] = NULL;
}
memh->md_map &= ~md_map;

if ((memh->flags & UCP_MEMH_FLAG_MLOCKED) &&
(context->gva_md_map[memh->mem_type] & memh->md_map) == 0) {
Expand Down Expand Up @@ -417,9 +416,19 @@ static void ucp_memh_put_rcache(ucp_context_h context, ucp_mem_h memh)
UCP_THREAD_CS_EXIT(&context->mt_lock);
}

static void ucp_memh_dereg_all(ucp_context_h context, ucp_mem_h memh)
{
if (memh->parent == memh) {
ucp_memh_dereg(context, memh, memh->md_map);
} else {
/* Have a parent memory handle from rcache */
ucp_memh_dereg(context, memh, memh->md_map & ~memh->parent->md_map);
ucp_memh_put_rcache(context, memh->parent);
}
}

static void ucp_memh_cleanup(ucp_context_h context, ucp_mem_h memh)
{
ucp_md_map_t md_map = memh->md_map;
uct_allocated_memory_t mem;
ucs_status_t status;

Expand All @@ -433,18 +442,12 @@ static void ucp_memh_cleanup(ucp_context_h context, ucp_mem_h memh)

if (mem.method == UCT_ALLOC_METHOD_MD) {
ucs_assert(memh->alloc_md_index != UCP_NULL_RESOURCE);
mem.md = context->tl_mds[memh->alloc_md_index].md;
mem.memh = memh->uct[memh->alloc_md_index];
md_map &= ~UCS_BIT(memh->alloc_md_index);
mem.md = context->tl_mds[memh->alloc_md_index].md;
mem.memh = memh->uct[memh->alloc_md_index];
memh->md_map &= ~UCS_BIT(memh->alloc_md_index);
}

/* Have a parent memory handle from rcache */
if (memh->parent != memh) {
ucp_memh_dereg(context, memh, md_map & ~memh->parent->md_map);
ucp_memh_put_rcache(context, memh->parent);
} else {
ucp_memh_dereg(context, memh, md_map);
}
ucp_memh_dereg_all(context, memh);

/* If the memory was also allocated, release it */
if (memh->alloc_method != UCT_ALLOC_METHOD_LAST) {
Expand Down Expand Up @@ -645,44 +648,27 @@ ucs_status_t ucp_memh_register(ucp_context_h context, ucp_mem_h memh,
alloc_name, err_level, 1);
}

static size_t ucp_memh_size(ucp_context_h context)
{
return sizeof(ucp_mem_t) + (sizeof(uct_mem_h) * context->num_mds);
}

static void ucp_memh_set_uct_flags(ucp_mem_h memh, unsigned uct_flags)
{
/* When changing memh->uct_flags, must not have any existing registrations,
since those may not support the new flags */
ucs_assertv(memh->md_map == 0,
"memh=%p memh->md_map=0x%" PRIx64
" memh->uct_flags=0x%x uct_flags=0x%x",
memh, memh->md_map, memh->uct_flags, uct_flags);
memh->uct_flags = UCP_MM_UCT_ACCESS_FLAGS(uct_flags);
}

static void ucp_memh_init(ucp_mem_h memh, ucp_context_h context,
uint8_t memh_flags, unsigned uct_flags,
uct_alloc_method_t method, ucs_memory_type_t mem_type)
{
ucp_memory_info_t info;

ucp_memory_detect(context, ucp_memh_address(memh), ucp_memh_length(memh),
&info);
ucp_memh_set_uct_flags(memh, uct_flags);
memh->context = context;
memh->flags = memh_flags;
memh->alloc_md_index = UCP_NULL_RESOURCE;
memh->alloc_method = method;
memh->mem_type = mem_type;
memh->sys_dev = info.sys_dev;
memh->md_map = 0;
memh->inv_md_map = 0;
memh->uct_flags = UCP_MM_UCT_ACCESS_FLAGS(uct_flags);
memh->context = context;
memh->flags = memh_flags;
memh->alloc_md_index = UCP_NULL_RESOURCE;
memh->alloc_method = method;
memh->mem_type = mem_type;
}

static ucs_status_t
ucp_memh_create(ucp_context_h context, void *address, size_t length,
ucs_memory_type_t mem_type, uct_alloc_method_t method,
uint8_t memh_flags, unsigned uct_flags, ucp_mem_h *memh_p)
{
ucp_memory_info_t info;
ucp_mem_h memh;

memh = ucs_calloc(1, ucp_memh_size(context), "ucp_memh");
Expand All @@ -694,6 +680,10 @@ ucp_memh_create(ucp_context_h context, void *address, size_t length,
memh->super.super.end = (uintptr_t)address + length;
ucp_memh_init(memh, context, memh_flags, uct_flags, method, mem_type);

ucp_memory_detect(context, ucp_memh_address(memh), ucp_memh_length(memh),
&info);
memh->sys_dev = info.sys_dev;

*memh_p = memh;
return UCS_OK;
}
Expand Down Expand Up @@ -785,30 +775,24 @@ ucp_memh_init_from_parent(ucp_mem_h memh, ucp_md_map_t parent_md_map)
memh->flags = memh->parent->flags;

ucs_for_each_bit(md_index, parent_md_map) {
ucs_assert(memh->uct[md_index] == NULL);
memh->uct[md_index] = memh->parent->uct[md_index];
}
}

static ucs_status_t ucp_memh_init_uct_reg(ucp_context_h context, ucp_mem_h memh,
unsigned uct_flags,
const char *alloc_name)
static ucs_status_t
ucp_memh_init_uct_reg(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t reg_md_map, unsigned uct_flags,
const char *alloc_name)
{
ucs_memory_type_t mem_type = memh->mem_type;
ucp_md_map_t reg_md_map = context->reg_md_map[mem_type];
void *address = ucp_memh_address(memh);
size_t length = ucp_memh_length(memh);
ucp_md_map_t cache_md_map;
ucs_status_t status;

if (uct_flags & UCT_MD_MEM_FLAG_LOCK) {
reg_md_map |= context->reg_block_md_map[mem_type];
}

reg_md_map &= ~memh->md_map;
cache_md_map = context->cache_md_map[mem_type] & reg_md_map;

if (context->rcache == NULL) {
if ((context->rcache == NULL) || (cache_md_map == 0)) {
status = ucp_memh_register(context, memh, reg_md_map, uct_flags,
alloc_name);
if (status != UCS_OK) {
Expand All @@ -832,8 +816,6 @@ static ucs_status_t ucp_memh_init_uct_reg(ucp_context_h context, ucp_mem_h memh,
goto err_put_rcache;
}
}

ucs_assert(ucp_memh_is_user_memh(memh));
return UCS_OK;

err_put_rcache:
Expand All @@ -843,6 +825,28 @@ static ucs_status_t ucp_memh_init_uct_reg(ucp_context_h context, ucp_mem_h memh,
return status;
}

static ucs_status_t
ucp_memh_init_all_uct_reg(ucp_context_h context, ucp_mem_h memh,
unsigned uct_flags, const char *alloc_name)
{
ucs_memory_type_t mem_type = memh->mem_type;
ucp_md_map_t reg_md_map = context->reg_md_map[mem_type];
ucs_status_t status;

if (uct_flags & UCT_MD_MEM_FLAG_LOCK) {
reg_md_map |= context->reg_block_md_map[mem_type];
}

status = ucp_memh_init_uct_reg(context, memh, reg_md_map & ~memh->md_map,
uct_flags, alloc_name);
if (status != UCS_OK) {
return status;
}

ucs_assert(ucp_memh_is_user_memh(memh));
return UCS_OK;
}

static size_t ucp_memh_reg_align(ucp_context_h context, ucp_md_map_t reg_md_map)
{
size_t reg_align = UCS_RCACHE_MIN_ALIGNMENT;
Expand Down Expand Up @@ -983,7 +987,7 @@ ucp_memh_alloc(ucp_context_h context, void *address, size_t length,
goto err_dealloc;
}

status = ucp_memh_init_uct_reg(context, memh, uct_flags, alloc_name);
status = ucp_memh_init_all_uct_reg(context, memh, uct_flags, alloc_name);
if (status != UCS_OK) {
goto err_free_memh;
}
Expand Down Expand Up @@ -1117,7 +1121,7 @@ ucs_status_t ucp_mem_map(ucp_context_h context, const ucp_mem_map_params_t *para
goto out;
}

status = ucp_memh_init_uct_reg(context, memh, uct_flags, alloc_name);
status = ucp_memh_init_all_uct_reg(context, memh, uct_flags, alloc_name);
if (status != UCS_OK) {
ucs_free(memh);
}
Expand Down Expand Up @@ -1159,12 +1163,11 @@ ucs_status_t ucp_mem_unmap(ucp_context_h context, ucp_mem_h memh)

ucs_status_t ucp_mem_type_reg_buffers(ucp_worker_h worker, void *remote_addr,
size_t length, ucs_memory_type_t mem_type,
ucp_md_index_t md_index, ucp_mem_h *memh_p,
ucp_md_index_t md_index, ucp_mem_h memh,
uct_rkey_bundle_t *rkey_bundle)
{
ucp_context_h context = worker->context;
const uct_md_attr_v2_t *md_attr = &context->tl_mds[md_index].attr;
ucp_mem_h memh = NULL; /* To suppress compiler warning */
uct_md_mkey_pack_params_t params = { .field_mask = 0 };
uct_component_h cmpt;
ucp_tl_md_t *tl_md;
Expand All @@ -1178,12 +1181,15 @@ ucs_status_t ucp_mem_type_reg_buffers(ucp_worker_h worker, void *remote_addr,
goto out;
}

memh->super.super.start = (uintptr_t)remote_addr;
memh->super.super.end = (uintptr_t)remote_addr + length;
ucp_memh_init(memh, context, 0, UCT_MD_MEM_ACCESS_ALL,
UCT_ALLOC_METHOD_LAST, mem_type);

tl_md = &context->tl_mds[md_index];
cmpt = context->tl_cmpts[tl_md->cmpt_index].cmpt;

status = ucp_memh_get(context, remote_addr, length, mem_type,
UCS_BIT(md_index), UCT_MD_MEM_ACCESS_ALL, "mem_type",
&memh);
status = ucp_memh_init_uct_reg(context, memh, UCS_BIT(md_index),
UCT_MD_MEM_ACCESS_ALL, "mem_type");
if (status != UCS_OK) {
goto out;
}
Expand All @@ -1204,12 +1210,10 @@ ucs_status_t ucp_mem_type_reg_buffers(ucp_worker_h worker, void *remote_addr,
md_index, ucs_status_string(status));
goto out_dereg_mem;
}

*memh_p = memh;
return UCS_OK;

out_dereg_mem:
ucp_memh_put(memh);
ucp_memh_dereg_all(context, memh);
out:
return status;
}
Expand All @@ -1223,7 +1227,7 @@ void ucp_mem_type_unreg_buffers(ucp_worker_h worker, ucp_md_index_t md_index,
if (rkey_bundle->rkey != UCT_INVALID_RKEY) {
cmpt_index = context->tl_mds[md_index].cmpt_index;
uct_rkey_release(context->tl_cmpts[cmpt_index].cmpt, rkey_bundle);
ucp_memh_put(memh);
ucp_memh_dereg_all(context, memh);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ ucs_status_t ucp_mem_rereg_mds(ucp_context_h context, ucp_md_map_t reg_md_map,

ucs_status_t ucp_mem_type_reg_buffers(ucp_worker_h worker, void *remote_addr,
size_t length, ucs_memory_type_t mem_type,
ucp_md_index_t md_index, ucp_mem_h *memh_p,
ucp_md_index_t md_index, ucp_mem_h memh,
uct_rkey_bundle_t *rkey_bundle);

void ucp_mem_type_unreg_buffers(ucp_worker_h worker, ucp_md_index_t md_index,
Expand Down
5 changes: 5 additions & 0 deletions src/ucp/core/ucp_mm.inl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ ucp_memh_is_zero_length(const ucp_mem_h memh)
return memh == &ucp_mem_dummy_handle.memh;
}

static UCS_F_ALWAYS_INLINE size_t ucp_memh_size(ucp_context_h context)
{
return sizeof(ucp_mem_t) + (sizeof(uct_mem_h) * context->num_mds);
}

static UCS_F_ALWAYS_INLINE void
ucp_memh_rcache_print(ucp_mem_h memh, void *address, size_t length)
{
Expand Down
12 changes: 7 additions & 5 deletions src/ucp/dt/dt.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <ucp/core/ucp_ep.inl>
#include <ucp/core/ucp_request.h>
#include <ucp/core/ucp_mm.h>
#include <ucp/core/ucp_mm.inl>
#include <ucs/profile/profile.h>


Expand All @@ -31,7 +31,7 @@ UCS_PROFILE_FUNC_VOID(ucp_mem_type_unpack,
ucp_worker_h worker, void *buffer, const void *recv_data,
size_t recv_length, ucs_memory_type_t mem_type)
{
ucp_ep_h ep = worker->mem_type_ep[mem_type];
ucp_ep_h ep = worker->mem_type_ep[mem_type];
ucp_lane_index_t lane;
unsigned md_index;
ucp_mem_h memh;
Expand All @@ -42,11 +42,12 @@ UCS_PROFILE_FUNC_VOID(ucp_mem_type_unpack,
return;
}

memh = ucs_alloca(ucp_memh_size(worker->context));
lane = ucp_ep_config(ep)->key.rma_lanes[0];
md_index = ucp_ep_md_index(ep, lane);

status = ucp_mem_type_reg_buffers(worker, buffer, recv_length, mem_type,
md_index, &memh, &rkey_bundle);
md_index, memh, &rkey_bundle);
if (status != UCS_OK) {
ucs_fatal("failed to register buffer with mem type domain %s",
ucs_memory_type_names[mem_type]);
Expand All @@ -67,7 +68,7 @@ UCS_PROFILE_FUNC_VOID(ucp_mem_type_pack,
ucp_worker_h worker, void *dest, const void *src,
size_t length, ucs_memory_type_t mem_type)
{
ucp_ep_h ep = worker->mem_type_ep[mem_type];
ucp_ep_h ep = worker->mem_type_ep[mem_type];
ucp_lane_index_t lane;
ucp_md_index_t md_index;
ucs_status_t status;
Expand All @@ -78,11 +79,12 @@ UCS_PROFILE_FUNC_VOID(ucp_mem_type_pack,
return;
}

memh = ucs_alloca(ucp_memh_size(worker->context));
lane = ucp_ep_config(ep)->key.rma_lanes[0];
md_index = ucp_ep_md_index(ep, lane);

status = ucp_mem_type_reg_buffers(worker, (void *)src, length, mem_type,
md_index, &memh, &rkey_bundle);
md_index, memh, &rkey_bundle);
if (status != UCS_OK) {
ucs_fatal("failed to register buffer with mem type domain %s",
ucs_memory_type_names[mem_type]);
Expand Down
1 change: 0 additions & 1 deletion src/uct/cuda/gdr_copy/gdr_copy_md.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ uct_gdr_copy_md_query(uct_md_h uct_md, uct_md_attr_v2_t *md_attr)
uct_md_base_md_query(md_attr);
md_attr->flags = UCT_MD_FLAG_REG | UCT_MD_FLAG_NEED_RKEY;
md_attr->reg_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
md_attr->cache_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
md_attr->access_mem_types = UCS_BIT(UCS_MEMORY_TYPE_CUDA);
md_attr->rkey_packed_size = sizeof(uct_gdr_copy_key_t);

Expand Down

0 comments on commit 27ab197

Please sign in to comment.