From f4db82913c494ca8028c719c9a26a50e3d267825 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 4 Oct 2024 08:40:52 -0700 Subject: [PATCH] Resolve thread-safety issues in distributed-ucxx Instead of creating a per-UCXX context lock that may result in race conditions during the creation of the locks, use a module-level lock that is guaranteed to be thread-safe as it occurs at import time. --- .../distributed-ucxx/distributed_ucxx/ucxx.py | 80 ++++++++----------- 1 file changed, 32 insertions(+), 48 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 1f5fc1df..8fe8dd79 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -14,6 +14,7 @@ import struct import weakref from collections.abc import Awaitable, Callable, Collection +from threading import Lock from typing import TYPE_CHECKING, Any from unittest.mock import patch @@ -49,6 +50,8 @@ pre_existing_cuda_context = False cuda_context_created = False multi_buffer = None +_resources_lock = Lock() +_resources = dict() _warning_suffix = ( @@ -95,13 +98,13 @@ def make_register(): count = itertools.count() def register() -> int: - """Register a Dask resource with the UCXX context. + """Register a Dask resource with the resource tracker. - Register a Dask resource with the UCXX context and keep track of it with the - use of a unique ID for the resource. The resource ID is later used to - deregister the resource from the UCXX context calling - `_deregister_dask_resource(resource_id)`, which stops the notifier thread - and progress tasks when no more UCXX resources are alive. + Generate a unique ID for the resource and register it with the resource + tracker. The resource ID is later used to deregister the resource from + the tracker calling `_deregister_dask_resource(resource_id)`, which + stops the notifier thread and progress tasks when no more UCXX resources + are alive. Returns ------- @@ -110,9 +113,14 @@ def register() -> int: `_deregister_dask_resource` during stop/destruction of the resource. """ ctx = ucxx.core._get_ctx() - with ctx._dask_resources_lock: + handle = ctx.context.handle + with _resources_lock: + handle = ctx.context.handle + if handle not in _resources: + _resources[handle] = set() + resource_id = next(count) - ctx._dask_resources.add(resource_id) + _resources[handle].add(resource_id) ctx.start_notifier_thread() ctx.continuous_ucx_progress() return resource_id @@ -126,11 +134,11 @@ def register() -> int: def _deregister_dask_resource(resource_id): - """Deregister a Dask resource with the UCXX context. + """Deregister a Dask resource from the resource tracker. - Deregister a Dask resource from the UCXX context with given ID, and if no - resources remain after deregistration, stop the notifier thread and progress - tasks. + Deregister a Dask resource from the resource tracker with given ID, and if + no resources remain after deregistration, stop the notifier thread and + progress tasks. Parameters ---------- @@ -144,40 +152,22 @@ def _deregister_dask_resource(resource_id): return ctx = ucxx.core._get_ctx() + handle = ctx.context.handle # Check if the attribute exists first, in tests the UCXX context may have # been reset before some resources are deregistered. - if hasattr(ctx, "_dask_resources_lock"): - with ctx._dask_resources_lock: - try: - ctx._dask_resources.remove(resource_id) - except KeyError: - pass - - # Stop notifier thread and progress tasks if no Dask resources using - # UCXX communicators are running anymore. - if len(ctx._dask_resources) == 0: - ctx.stop_notifier_thread() - ctx.progress_tasks.clear() - - -def _allocate_dask_resources_tracker() -> None: - """Allocate Dask resources tracker. - - Allocate a Dask resources tracker in the UCXX context. This is useful to - track Distributed communicators so that progress and notifier threads can - be cleanly stopped when no UCXX communicators are alive anymore. - """ - ctx = ucxx.core._get_ctx() - if not hasattr(ctx, "_dask_resources"): - # TODO: Move the `Lock` to a file/module-level variable for true - # lock-safety. The approach implemented below could cause race - # conditions if this function is called simultaneously by multiple - # threads. - from threading import Lock + with _resources_lock: + try: + _resources[handle].remove(resource_id) + except KeyError: + pass - ctx._dask_resources = set() - ctx._dask_resources_lock = Lock() + # Stop notifier thread and progress tasks if no Dask resources using + # UCXX communicators are running anymore. + if handle in _resources and len(_resources[handle]) == 0: + ctx.stop_notifier_thread() + ctx.progress_tasks.clear() + del _resources[handle] def init_once(): @@ -187,11 +177,6 @@ def init_once(): global multi_buffer if ucxx is not None: - # Ensure reallocation of Dask resources tracker if the UCXX context was - # reset since the previous `init_once()` call. This may happen in tests, - # where the `ucxx_loop` fixture will reset the context after each test. - _allocate_dask_resources_tracker() - return # remove/process dask.ucx flags for valid ucx options @@ -254,7 +239,6 @@ def init_once(): # environment, so the user's external environment can safely # override things here. ucxx.init(options=ucx_config, env_takes_precedence=True) - _allocate_dask_resources_tracker() pool_size_str = dask.config.get("distributed.rmm.pool-size")