Skip to content

Commit

Permalink
Merge branch 'branch-0.41' into bug/fix-rmm-imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Matt711 authored Oct 7, 2024
2 parents 3e2338b + a7d36f5 commit ece472c
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 67 deletions.
35 changes: 30 additions & 5 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ concurrency:
jobs:
pr-builder:
needs:
- changed-files
- checks
- conda-cpp-build
- docs-build
Expand All @@ -25,6 +26,25 @@ jobs:
- wheel-tests-distributed-ucxx
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: always()
with:
needs: ${{ toJSON(needs) }}
changed-files:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
with:
files_yaml: |
test_cpp:
- '**'
- '!.pre-commit-config.yaml'
- '!README.md'
- '!docs/**'
- '!python/**'
test_python:
- '**'
- '!.pre-commit-config.yaml'
- '!README.md'
- '!docs/**'
checks:
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
Expand All @@ -47,23 +67,26 @@ jobs:
container_image: "rapidsai/ci-conda:latest"
run_script: "ci/build_docs.sh"
conda-cpp-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_cpp
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
conda-python-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
conda-python-distributed-tests:
needs: conda-cpp-build
needs: [conda-cpp-build, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
script: "ci/test_python_distributed.sh"
Expand All @@ -83,9 +106,10 @@ jobs:
build_type: pull-request
script: ci/build_wheel_ucxx.sh
wheel-tests-ucxx:
needs: wheel-build-ucxx
needs: [wheel-build-ucxx, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
Expand All @@ -98,9 +122,10 @@ jobs:
build_type: pull-request
script: ci/build_wheel_distributed_ucxx.sh
wheel-tests-distributed-ucxx:
needs: [wheel-build-ucxx, wheel-build-distributed-ucxx]
needs: [wheel-build-ucxx, wheel-build-distributed-ucxx, changed-files]
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
if: fromJSON(needs.changed-files.outputs.changed_file_groups).test_python
with:
build_type: pull-request
container-options: "--cap-add CAP_SYS_PTRACE --shm-size=8g --ulimit=nofile=1000000:1000000"
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/utils/callback_notifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bool CallbackNotifier::wait(uint64_t period,
bool ret = false;
for (size_t i = 0; i < attempts; ++i) {
ret = _conditionVariable.wait_for(
lock, std::chrono::duration<uint64_t, std::nano>(period), [this]() {
lock, std::chrono::duration<uint64_t, std::nano>(signalInterval), [this]() {
return _flag.load(std::memory_order_relaxed) == true;
});
if (signalWorkerFunction) signalWorkerFunction();
Expand Down
1 change: 0 additions & 1 deletion python/distributed-ucxx/distributed_ucxx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .ucxx import UCXXBackend, UCXXConnector, UCXXListener # noqa: F401
from . import distributed_patches # noqa: F401


from ._version import __git_commit__, __version__
41 changes: 0 additions & 41 deletions python/distributed-ucxx/distributed_ucxx/distributed_patches.py

This file was deleted.

96 changes: 96 additions & 0 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from __future__ import annotations

import functools
import itertools
import logging
import os
import struct
import weakref
from collections.abc import Awaitable, Callable, Collection
from threading import Lock
from typing import TYPE_CHECKING, Any
from unittest.mock import patch

Expand Down Expand Up @@ -48,6 +50,13 @@
pre_existing_cuda_context = False
cuda_context_created = False
multi_buffer = None
# Lock protecting access to _resources dict
_resources_lock = Lock()
# Mapping from UCXX context handles to sets of registered dask resource IDs
# Used to track when there are no more users of the context, at which point
# its progress task and notification thread can be shut down.
# See _register_dask_resource and _deregister_dask_resource.
_resources = dict()


_warning_suffix = (
Expand Down Expand Up @@ -90,6 +99,81 @@ def synchronize_stream(stream=0):
stream.synchronize()


def make_register():
count = itertools.count()

def register() -> int:
"""Register a Dask resource with the resource tracker.
Generate a unique ID for the resource and register it with the resource
tracker. The resource ID is later used to deregister the resource from
the tracker calling `_deregister_dask_resource(resource_id)`, which
stops the notifier thread and progress tasks when no more UCXX resources
are alive.
Returns
-------
resource_id: int
The ID of the registered resource that should be used with
`_deregister_dask_resource` during stop/destruction of the resource.
"""
ctx = ucxx.core._get_ctx()
handle = ctx.context.handle
with _resources_lock:
if handle not in _resources:
_resources[handle] = set()

resource_id = next(count)
_resources[handle].add(resource_id)
ctx.start_notifier_thread()
ctx.continuous_ucx_progress()
return resource_id

return register


_register_dask_resource = make_register()

del make_register


def _deregister_dask_resource(resource_id):
"""Deregister a Dask resource from the resource tracker.
Deregister a Dask resource from the resource tracker with given ID, and if
no resources remain after deregistration, stop the notifier thread and
progress tasks.
Parameters
----------
resource_id: int
The unique ID of the resource returned by `_register_dask_resource` upon
registration.
"""
if ucxx.core._ctx is None:
# Prevent creation of context if it was already destroyed, all
# registered references are already gone.
return

ctx = ucxx.core._get_ctx()
handle = ctx.context.handle

# Check if the attribute exists first, in tests the UCXX context may have
# been reset before some resources are deregistered.
with _resources_lock:
try:
_resources[handle].remove(resource_id)
except KeyError:
pass

# Stop notifier thread and progress tasks if no Dask resources using
# UCXX communicators are running anymore.
if handle in _resources and len(_resources[handle]) == 0:
ctx.stop_notifier_thread()
ctx.progress_tasks.clear()
del _resources[handle]


def init_once():
global ucxx, device_array
global ucx_create_endpoint, ucx_create_listener
Expand Down Expand Up @@ -279,8 +363,12 @@ def __init__( # type: ignore[no-untyped-def]
else:
self._has_close_callback = False

self._resource_id = _register_dask_resource()

logger.debug("UCX.__init__ %s", self)

weakref.finalize(self, _deregister_dask_resource, self._resource_id)

def __del__(self) -> None:
self.abort()

Expand Down Expand Up @@ -488,6 +576,7 @@ def abort(self):
if self._ep is not None:
self._ep.abort()
self._ep = None
_deregister_dask_resource(self._resource_id)

def closed(self):
if self._has_close_callback is True:
Expand Down Expand Up @@ -522,15 +611,19 @@ async def connect(
init_once()

try:
self._resource_id = _register_dask_resource()
ep = await ucxx.create_endpoint(ip, port)
except (
ucxx.exceptions.UCXCloseError,
ucxx.exceptions.UCXCanceledError,
ucxx.exceptions.UCXConnectionResetError,
ucxx.exceptions.UCXMessageTruncatedError,
ucxx.exceptions.UCXNotConnectedError,
ucxx.exceptions.UCXUnreachableError,
):
raise CommClosedError("Connection closed before handshake completed")
finally:
_deregister_dask_resource(self._resource_id)
return self.comm_class(
ep,
local_addr="",
Expand Down Expand Up @@ -588,10 +681,13 @@ async def serve_forever(client_ep):
await self.comm_handler(ucx)

init_once()
self._resource_id = _register_dask_resource()
weakref.finalize(self, _deregister_dask_resource, self._resource_id)
self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port)

def stop(self):
self.ucxx_server = None
_deregister_dask_resource(self._resource_id)

def get_host_port(self):
# TODO: TCP raises if this hasn't started yet.
Expand Down
25 changes: 18 additions & 7 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def _listener_handler(conn_request):
while ep[0] is None:
worker.progress()

wireup_msg = Array(bytearray(WireupMessageSize))
wireup_request = ep[0].tag_recv(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_requests = [
ep[0].tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
ep[0].tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is True:
while closed[0] is False:
Expand All @@ -72,13 +76,20 @@ def _client(port, server_close_callback):
port,
endpoint_error_handling=True,
)
worker.progress()
wireup_msg = Array(bytes(os.urandom(WireupMessageSize)))
wireup_request = ep.tag_send(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
if server_close_callback is False:
closed = [False]
ep.set_close_callback(_close_callback, cb_args=(closed,))
worker.progress()

wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_requests = [
ep.tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
ep.tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is False:
while closed[0] is False:
worker.progress()

Expand Down
Loading

0 comments on commit ece472c

Please sign in to comment.