Skip to content

Commit

Permalink
Resolve thread-safety issues in distributed-ucxx
Browse files Browse the repository at this point in the history
Instead of creating a per-UCXX context lock that may result in race
conditions during the creation of the locks, use a module-level lock
that is guaranteed to be thread-safe as it occurs at import time.
  • Loading branch information
pentschev committed Oct 4, 2024
1 parent 1f4e508 commit f4db829
Showing 1 changed file with 32 additions and 48 deletions.
80 changes: 32 additions & 48 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import struct
import weakref
from collections.abc import Awaitable, Callable, Collection
from threading import Lock
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

Expand Down Expand Up @@ -49,6 +50,8 @@
pre_existing_cuda_context = False
cuda_context_created = False
multi_buffer = None
_resources_lock = Lock()
_resources = dict()


_warning_suffix = (
Expand Down Expand Up @@ -95,13 +98,13 @@ def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the UCXX context.
"""Register a Dask resource with the resource tracker.
Register a Dask resource with the UCXX context and keep track of it with the
use of a unique ID for the resource. The resource ID is later used to
deregister the resource from the UCXX context calling
`_deregister_dask_resource(resource_id)`, which stops the notifier thread
and progress tasks when no more UCXX resources are alive.
Generate a unique ID for the resource and register it with the resource
tracker. The resource ID is later used to deregister the resource from
the tracker calling `_deregister_dask_resource(resource_id)`, which
stops the notifier thread and progress tasks when no more UCXX resources
are alive.
Returns
-------
Expand All @@ -110,9 +113,14 @@ def register() -> int:
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
with ctx._dask_resources_lock:
handle = ctx.context.handle
with _resources_lock:
handle = ctx.context.handle
if handle not in _resources:
_resources[handle] = set()

resource_id = next(count)
ctx._dask_resources.add(resource_id)
_resources[handle].add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id
Expand All @@ -126,11 +134,11 @@ def register() -> int:


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource with the UCXX context.
"""Deregister a Dask resource from the resource tracker.
Deregister a Dask resource from the UCXX context with given ID, and if no
resources remain after deregistration, stop the notifier thread and progress
tasks.
Deregister a Dask resource from the resource tracker with given ID, and if
no resources remain after deregistration, stop the notifier thread and
progress tasks.
Parameters
----------
Expand All @@ -144,40 +152,22 @@ def _deregister_dask_resource(resource_id):
return

ctx = ucxx.core._get_ctx()
handle = ctx.context.handle

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
if hasattr(ctx, "_dask_resources_lock"):
with ctx._dask_resources_lock:
try:
ctx._dask_resources.remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if len(ctx._dask_resources) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()


def _allocate_dask_resources_tracker() -> None:
"""Allocate Dask resources tracker.
Allocate a Dask resources tracker in the UCXX context. This is useful to
track Distributed communicators so that progress and notifier threads can
be cleanly stopped when no UCXX communicators are alive anymore.
"""
ctx = ucxx.core._get_ctx()
if not hasattr(ctx, "_dask_resources"):
# TODO: Move the `Lock` to a file/module-level variable for true
# lock-safety. The approach implemented below could cause race
# conditions if this function is called simultaneously by multiple
# threads.
from threading import Lock
with _resources_lock:
try:
_resources[handle].remove(resource_id)
except KeyError:
pass

ctx._dask_resources = set()
ctx._dask_resources_lock = Lock()
# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if handle in _resources and len(_resources[handle]) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
del _resources[handle]


def init_once():
Expand All @@ -187,11 +177,6 @@ def init_once():
global multi_buffer

if ucxx is not None:
# Ensure reallocation of Dask resources tracker if the UCXX context was
# reset since the previous `init_once()` call. This may happen in tests,
# where the `ucxx_loop` fixture will reset the context after each test.
_allocate_dask_resources_tracker()

return

# remove/process dask.ucx flags for valid ucx options
Expand Down Expand Up @@ -254,7 +239,6 @@ def init_once():
# environment, so the user's external environment can safely
# override things here.
ucxx.init(options=ucx_config, env_takes_precedence=True)
_allocate_dask_resources_tracker()

pool_size_str = dask.config.get("distributed.rmm.pool-size")

Expand Down

0 comments on commit f4db829

Please sign in to comment.