Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UCS/UCP/RNDV: RNDV pipeline invalidation #10204

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions src/ucp/core/ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,6 @@ static ucs_config_field_t ucp_context_config_table[] = {
"Use two stage pipeline rendezvous protocol for intra-node GPU to GPU transfers",
ucs_offsetof(ucp_context_config_t, rndv_shm_ppln_enable), UCS_CONFIG_TYPE_BOOL},

{"RNDV_PIPELINE_ERROR_HANDLING", "n",
"Allow using error handling protocol in the rendezvous pipeline protocol\n"
"even if invalidation workflow isn't supported",
ucs_offsetof(ucp_context_config_t, rndv_errh_ppln_enable), UCS_CONFIG_TYPE_BOOL},

{"FLUSH_WORKER_EPS", "y",
"Enable flushing the worker by flushing its endpoints. Allows completing\n"
"the flush operation in a bounded time even if there are new requests on\n"
Expand Down
2 changes: 0 additions & 2 deletions src/ucp/core/ucp_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ typedef struct ucp_context_config {
size_t rndv_pipeline_send_thresh;
/** Enabling 2-stage pipeline rndv protocol */
int rndv_shm_ppln_enable;
/** Enable error handling for rndv pipeline protocol */
int rndv_errh_ppln_enable;
/** Threshold for using tag matching offload capabilities. Smaller buffers
* will not be posted to the transport. */
size_t tm_thresh;
Expand Down
6 changes: 4 additions & 2 deletions src/ucp/core/ucp_ep.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <ucp/core/ucp_listener.h>
#include <ucp/rma/rma.inl>
#include <ucp/rma/rma.h>
#include <ucp/rndv/proto_rndv.inl>

#include <ucs/datastruct/queue.h>
#include <ucs/debug/memtrack_int.h>
Expand Down Expand Up @@ -3568,7 +3569,8 @@ static void ucp_ep_req_purge_send(ucp_request_t *req, ucs_status_t status)
ucs_assertv(UCS_STATUS_IS_ERR(status), "req %p: status %s", req,
ucs_status_string(status));

if (ucp_request_memh_invalidate(req, status)) {
if (ucp_request_memh_check_invalidate(req, 0)) {
ucp_request_memh_invalidate(req, status, 0);
Comment on lines +3572 to +3573
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

having one function looks cleaner to me

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, it looks cleaner. But the purpose was to make it safe. Here is the existing code:

    << When ucp_request_memh_invalidate succeeds, it populates the `invalidate`
    << structure within request.send union, which invalidates all the other structures like `rndv`
    << Basically after invalidation you cannot use other structures within union
    if (ucp_request_memh_invalidate(req, status)) {
        << This is unsafe, `rndv` struct is overwritten by this moment!
        << Well, it works by chance, because for now `invalidate` struct contains just one pointer,
        << so it does not overwrite much from `rndv`. With this commit we extend `invalidate` struct,
        << so it's not possible to access rndv after invalidation
        if (req->send.rndv.rkey != NULL) {
            ucp_proto_rndv_rkey_destroy(req);
        }
        ucp_proto_request_zcopy_id_reset(req);
        return;
    }

So the split is done, because we need to perform some cleanup tasks on invalidated requests, but essentially we can do that only before invalidation actually happens.

I can check whether we can move this conditional cleanup tasks before - in that can split is not needed.

return;
}

Expand Down Expand Up @@ -3638,7 +3640,7 @@ void ucp_ep_req_purge(ucp_ep_h ucp_ep, ucp_request_t *req,

if (req->send.uct.func == ucp_proto_progress_rndv_rtr) {
if (req->send.rndv.mdesc != NULL) {
ucs_mpool_put_inline(req->send.rndv.mdesc);
ucs_mpool_rndv_put(req->send.rndv.mdesc);
}
} else {
/* SW RMA/PUT and AMO/Post operations don't allocate local request ID
Expand Down
158 changes: 139 additions & 19 deletions src/ucp/core/ucp_mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ucp_worker.h"
#include "ucp_mm.inl"

#include <ucs/datastruct/mpool.inl>
#include <ucs/debug/log.h>
#include <ucs/debug/memtrack_int.h>
#include <ucs/sys/math.h>
Expand Down Expand Up @@ -326,8 +327,8 @@ ucp_mem_map_params2uct_flags(const ucp_context_h context,
return flags;
}

static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map)
static void ucp_memh_dereg_internal(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map, int rkey_only)
{
uct_completion_t comp = {
.count = 1,
Expand All @@ -351,41 +352,71 @@ static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
continue;
}

ucs_trace("de-registering memh[%d]=%p", md_index, memh->uct[md_index]);
ucs_assert(context->tl_mds[md_index].attr.flags & UCT_MD_FLAG_REG);

params.memh = memh->uct[md_index];
if (memh->inv_md_map & UCS_BIT(md_index)) {
params.flags = UCT_MD_MEM_DEREG_FLAG_INVALIDATE;
if (rkey_only) {
params.flags |= UCT_MD_MEM_DEREG_FLAG_INVALIDATE_RKEY_ONLY;
}
Comment on lines +358 to +360
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can pass uct_flags instead of rkey_only and pass UCT_MD_MEM_DEREG_FLAG_INVALIDATE_RKEY_ONLY directly. This would simplify this code

comp.count++;
} else {
if (rkey_only) {
continue;
}
params.flags = 0;
}

ucs_trace("de-registering %smemh[%d]=%p", (rkey_only ? "rkey of " : ""),
md_index, memh->uct[md_index]);
ucs_assert(context->tl_mds[md_index].attr.flags & UCT_MD_FLAG_REG);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ucs_assertv() and print the flags


status = uct_md_mem_dereg_v2(context->tl_mds[md_index].md, &params);
if (status != UCS_OK) {
ucs_warn("failed to dereg from md[%d]=%s: %s", md_index,
ucs_warn("failed to dereg %sfrom md[%d]=%s: %s",
(rkey_only? "rkey " : ""), md_index,
context->tl_mds[md_index].rsc.md_name,
ucs_status_string(status));
if (params.flags & UCT_MD_MEM_DEREG_FLAG_INVALIDATE) {
comp.count--;
}
}

memh->uct[md_index] = NULL;
if (rkey_only) {
memh->inv_md_map &= ~UCS_BIT(md_index);
} else {
memh->uct[md_index] = NULL;
}
}

ucs_assert(comp.count == 1);
if (rkey_only) {
return;
}

if ((memh->flags & UCP_MEMH_FLAG_MLOCKED) &&
(context->gva_md_map[memh->mem_type] & memh->md_map) == 0) {
munlock(ucp_memh_address(memh), ucp_memh_length(memh));
memh->flags &= ~UCP_MEMH_FLAG_MLOCKED;
}
}

ucs_assert(comp.count == 1);
static void ucp_memh_dereg(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map)
{
ucp_memh_dereg_internal(context, memh, md_map, 0);
}

static void ucp_memh_dereg_rkey(ucp_context_h context, ucp_mem_h memh,
ucp_md_map_t md_map)
{
ucs_trace("memh %p: invalidate only rkeys: md_map %" PRIx64 " inv_md_map %"
PRIx64, memh, memh->md_map, md_map);

ucp_memh_dereg_internal(context, memh, md_map, 1);
}

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map)
{
ucs_trace("memh %p: invalidate address %p length %zu md_map %" PRIx64
Expand All @@ -399,7 +430,7 @@ void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
UCP_THREAD_CS_ENTER(&context->mt_lock);
memh->inv_md_map |= inv_md_map;
UCP_THREAD_CS_EXIT(&context->mt_lock);
ucs_rcache_region_invalidate(context->rcache, &memh->super, cb, arg);
ucs_rcache_region_invalidate(context->rcache, &memh->super, comp);
}

static void ucp_memh_put_rcache(ucp_context_h context, ucp_mem_h memh)
Expand Down Expand Up @@ -658,7 +689,6 @@ static void ucp_memh_init(ucp_mem_h memh, ucp_context_h context,
uint8_t memh_flags, unsigned uct_flags,
uct_alloc_method_t method, ucs_memory_type_t mem_type)
{

memh->md_map = 0;
memh->inv_md_map = 0;
memh->uct_flags = UCP_MM_UCT_ACCESS_FLAGS(uct_flags);
Expand Down Expand Up @@ -798,7 +828,8 @@ ucp_memh_init_uct_reg(ucp_context_h context, ucp_mem_h memh,

cache_md_map = context->cache_md_map[mem_type] & reg_md_map;

if ((context->rcache == NULL) || (cache_md_map == 0)) {
if ((context->rcache == NULL) || (cache_md_map == 0) ||
(uct_flags & UCT_MD_MEM_FLAG_NO_RCACHE)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should not be a new UCT flag because it has no meaning in UCT and is only used in UCP

status = ucp_memh_register(context, memh, reg_md_map, uct_flags,
alloc_name);
if (status != UCS_OK) {
Expand Down Expand Up @@ -909,8 +940,7 @@ ucp_memh_find_slow(ucp_context_h context, void *address, size_t length,
uct_flags |= UCP_MM_UCT_ACCESS_FLAGS(memh->uct_flags);

/* Invalidate the mismatching region and get a new one */
ucs_rcache_region_invalidate(context->rcache, &memh->super,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(context->rcache, &memh->super, NULL);
ucp_memh_put(memh);
}
}
Expand Down Expand Up @@ -1357,7 +1387,8 @@ void ucp_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk)
{
ucp_mem_desc_t *elem_hdr = obj;
ucp_mem_desc_t *chunk_hdr = (ucp_mem_desc_t*)((ucp_mem_desc_t*)chunk - 1);
elem_hdr->memh = chunk_hdr->memh;
elem_hdr->memh = chunk_hdr->memh;
elem_hdr->chunk = NULL;
}

static ucs_status_t
Expand All @@ -1367,6 +1398,9 @@ ucp_rndv_frag_malloc_mpools(ucs_mpool_t *mp, size_t *size_p, void **chunk_p)
ucp_context_h context = mpriv->worker->context;
ucs_memory_type_t mem_type = mpriv->mem_type;
size_t frag_size = context->config.ext.rndv_frag_size[mem_type];
unsigned uct_flags = UCT_MD_MEM_ACCESS_RMA |
UCT_MD_MEM_FLAG_LOCK |
UCT_MD_MEM_FLAG_NO_RCACHE;
ucp_rndv_frag_mp_chunk_hdr_t *chunk_hdr;
ucs_status_t status;
unsigned num_elems;
Expand All @@ -1382,12 +1416,14 @@ ucp_rndv_frag_malloc_mpools(ucs_mpool_t *mp, size_t *size_p, void **chunk_p)

/* payload; need to get default flags from ucp_mem_map_params2uct_flags() */
status = ucp_memh_alloc(context, NULL, frag_size * num_elems, mem_type,
UCT_MD_MEM_ACCESS_RMA | UCT_MD_MEM_FLAG_LOCK,
ucs_mpool_name(mp), &chunk_hdr->memh);
uct_flags, ucs_mpool_name(mp), &chunk_hdr->memh);
if (status != UCS_OK) {
return status;
}

/* We don't use rcache, but reuse rcache region part of memh */
ucs_list_head_init(&chunk_hdr->memh->super.comp_list);

chunk_hdr->next_frag_ptr = ucp_memh_address(chunk_hdr->memh);
*chunk_p = chunk_hdr + 1;
return UCS_OK;
Expand Down Expand Up @@ -1417,9 +1453,94 @@ void ucp_frag_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk)
frag_size = context->config.ext.rndv_frag_size[mem_type];
elem_hdr->memh = chunk_hdr->memh;
elem_hdr->ptr = next_frag_ptr;
elem_hdr->chunk = chunk;
chunk_hdr->next_frag_ptr = UCS_PTR_BYTE_OFFSET(next_frag_ptr, frag_size);
}

ucp_mem_desc_t *ucp_frag_mpool_get(ucs_mpool_t *mpool)
{
ucp_mem_desc_t *mdesc = ucp_worker_mpool_get(mpool);

mdesc->memh->super.refcount++;
return mdesc;
}

void ucp_frag_mpool_put(ucp_mem_desc_t *mdesc)
{
ucp_mem_h memh = mdesc->memh;
ucs_rcache_region_t *region = &memh->super;
ucs_mpool_t *mpool = ucs_mpool_obj_owner(mdesc);
ucp_rndv_mpool_priv_t *mpriv = ucs_mpool_priv(mpool);
ucp_context_h context = mpriv->worker->context;

ucs_assert(region->refcount > 0);
region->refcount--;

if (!(region->flags & UCS_RCACHE_REGION_FLAG_INVALIDATE)) {
ucs_mpool_put(mdesc);
return;
}

if (region->refcount != 0) {
return;
}

/* If we reach this point, it means that chunk is marked for invalidation,
* and all of its inflight operations are completed. Now we can de-register
* its rkeys and revive the chunk = return it back to mpool. */
ucp_memh_dereg_rkey(context, memh, memh->inv_md_map);
ucs_assertv(0 == memh->inv_md_map, "inv_md_map=0x%lx", memh->inv_md_map);

ucs_rcache_region_completion(region);
region->flags &= ~UCS_RCACHE_REGION_FLAG_INVALIDATE;
ucs_mpool_add_chunk_to_freelist(mpool, mdesc->chunk);
}

static void ucp_frag_mpool_remove_from_freelist(ucp_mem_desc_t *mdesc)
{
ucs_mpool_t *mpool = ucs_mpool_obj_owner(mdesc);
ucs_mpool_elem_t **it = &mpool->freelist;
unsigned count = 0;
ucp_mem_desc_t *it_mdesc;

while ((*it) != NULL) {
it_mdesc = (void*)((*it) + 1);
if (it_mdesc->chunk == mdesc->chunk) {
*it = (*it)->next;
++count;
} else {
it = &(*it)->next;
}
}

ucs_trace("mpool %s: removed %u elements of chunk %p from the freelist",
ucs_mpool_name(mpool), count, mdesc->chunk);
}

void ucp_frag_mpool_invalidate(ucp_mem_desc_t *mdesc,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map)
{
ucs_rcache_region_t *region = &mdesc->memh->super;

ucs_assert(region->refcount > 0);
mdesc->memh->inv_md_map |= inv_md_map;
if (comp != NULL) {
ucs_list_add_tail(&region->comp_list, &comp->list);
}

/* Mark chunk as invalidated and remove all of its elements from the
* freelist to avoid reuse in the future. Indirect rkeys of the chunk memory
* handle will be invalidated once all of its inflight operations completed
*/
if (!(region->flags & UCS_RCACHE_REGION_FLAG_INVALIDATE)) {
region->flags |= UCS_RCACHE_REGION_FLAG_INVALIDATE;
ucp_frag_mpool_remove_from_freelist(mdesc);
}

ucp_frag_mpool_put(mdesc);
}

ucs_status_t ucp_reg_mpool_malloc(ucs_mpool_t *mp, size_t *size_p, void **chunk_p)
{
ucp_worker_h worker = ucs_container_of(mp, ucp_worker_t, reg_mp);
Expand Down Expand Up @@ -1948,8 +2069,7 @@ ucp_memh_import(ucp_context_h context, const void *export_mkey_buffer,
"This may indicate that exported memory handle was "
"destroyed, but imported memory handle was not",
rregion->refcount);
ucs_rcache_region_invalidate(rcache, rregion,
ucs_empty_function, NULL);
ucs_rcache_region_invalidate(rcache, rregion, NULL);
ucs_rcache_region_put_unsafe(rcache, rregion);
}
}
Expand Down
12 changes: 10 additions & 2 deletions src/ucp/core/ucp_mm.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ typedef struct ucp_mem {
struct ucp_mem_desc {
ucp_mem_h memh;
void *ptr;
void *chunk;
};


/**
* Memory descriptor details for rndv fragments.
*/
typedef struct ucp_rndv_frag_mp_chunk_hdr {
typedef struct {
ucp_mem_h memh;
void *next_frag_ptr;
} ucp_rndv_frag_mp_chunk_hdr_t;
Expand Down Expand Up @@ -128,6 +129,13 @@ void ucp_frag_mpool_free(ucs_mpool_t *mp, void *chunk);

void ucp_frag_mpool_obj_init(ucs_mpool_t *mp, void *obj, void *chunk);

ucp_mem_desc_t *ucp_frag_mpool_get(ucs_mpool_t *mpool);

void ucp_frag_mpool_put(ucp_mem_desc_t *mdesc);

void ucp_frag_mpool_invalidate(ucp_mem_desc_t *mdesc,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map);

/**
* Update memory registration to a specified set of memory domains.
Expand Down Expand Up @@ -177,7 +185,7 @@ ucs_status_t ucp_memh_register(ucp_context_h context, ucp_mem_h memh,
const char *alloc_name);

void ucp_memh_invalidate(ucp_context_h context, ucp_mem_h memh,
ucs_rcache_invalidate_comp_func_t cb, void *arg,
ucs_rcache_comp_entry_t *comp,
ucp_md_map_t inv_md_map);

void ucp_memh_put_slow(ucp_context_h context, ucp_mem_h memh);
Expand Down
Loading
Loading