Skip to content

Commit

Permalink
Fix issues creating endpoints and with wireup (#293)
Browse files Browse the repository at this point in the history
Fix bug with `create_endpoint`, where a `UCXMessageTruncatedError` may not be properly raised, as well as its use in `distributed-ucxx` not raising the exception expected by `distributed`.

Additionally, fix wireup issues in tests, where only sending (or receiving) a wireup message does not seem to suffice, but sending _and_ receiving suffices to prevent errors in tests.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: #293
  • Loading branch information
pentschev authored Oct 4, 2024
1 parent d8d5ca7 commit c5f5583
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
1 change: 1 addition & 0 deletions python/distributed-ucxx/distributed_ucxx/ucxx.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ async def connect(
ucxx.exceptions.UCXCloseError,
ucxx.exceptions.UCXCanceledError,
ucxx.exceptions.UCXConnectionResetError,
ucxx.exceptions.UCXMessageTruncatedError,
ucxx.exceptions.UCXNotConnectedError,
ucxx.exceptions.UCXUnreachableError,
):
Expand Down
25 changes: 18 additions & 7 deletions python/ucxx/ucxx/_lib/tests/test_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def _listener_handler(conn_request):
while ep[0] is None:
worker.progress()

wireup_msg = Array(bytearray(WireupMessageSize))
wireup_request = ep[0].tag_recv(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_requests = [
ep[0].tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
ep[0].tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is True:
while closed[0] is False:
Expand All @@ -72,13 +76,20 @@ def _client(port, server_close_callback):
port,
endpoint_error_handling=True,
)
worker.progress()
wireup_msg = Array(bytes(os.urandom(WireupMessageSize)))
wireup_request = ep.tag_send(wireup_msg, tag=ucx_api.UCXXTag(0))
wait_requests(worker, "blocking", wireup_request)
if server_close_callback is False:
closed = [False]
ep.set_close_callback(_close_callback, cb_args=(closed,))
worker.progress()

wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_requests = [
ep.tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)),
ep.tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)),
]
wait_requests(worker, "blocking", wireup_requests)

if server_close_callback is False:
while closed[0] is False:
worker.progress()

Expand Down
24 changes: 17 additions & 7 deletions python/ucxx/ucxx/_lib/tests/test_server_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _echo_server(get_queue, put_queue, transfer_api, msg_size, progress_mode):
we keep a reference to the listener's endpoint and execute transfers
outside of the callback function.
"""
# TAG is always used for wireup
feature_flags = [ucx_api.Feature.WAKEUP]
if transfer_api == "am":
feature_flags.append(ucx_api.Feature.AM)
Expand Down Expand Up @@ -75,9 +76,13 @@ def _listener_handler(conn_request):
if progress_mode == "blocking":
worker.progress()

wireup_msg = Array(bytearray(WireupMessageSize))
wireup_request = _recv(ep[0], transfer_api, wireup_msg)
wait_requests(worker, progress_mode, wireup_request)
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_requests = [
_recv(ep[0], transfer_api, wireup_msg_recv),
_send(ep[0], transfer_api, wireup_msg_send),
]
wait_requests(worker, progress_mode, wireup_requests)

msg = Array(bytearray(msg_size))

Expand Down Expand Up @@ -110,10 +115,11 @@ def _listener_handler(conn_request):


def _echo_client(transfer_api, msg_size, progress_mode, port):
# TAG is always used for wireup
feature_flags = [ucx_api.Feature.WAKEUP]
if transfer_api == "am":
feature_flags.append(ucx_api.Feature.AM)
if transfer_api == "stream":
elif transfer_api == "stream":
feature_flags.append(ucx_api.Feature.STREAM)
else:
feature_flags.append(ucx_api.Feature.TAG)
Expand All @@ -136,9 +142,13 @@ def _echo_client(transfer_api, msg_size, progress_mode, port):
if progress_mode == "blocking":
worker.progress()

wireup_msg = Array(bytes(os.urandom(WireupMessageSize)))
wireup_request = _send(ep, transfer_api, wireup_msg)
wait_requests(worker, progress_mode, wireup_request)
wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize)))
wireup_msg_recv = Array(bytearray(WireupMessageSize))
wireup_requests = [
_send(ep, transfer_api, wireup_msg_send),
_recv(ep, transfer_api, wireup_msg_recv),
]
wait_requests(worker, progress_mode, wireup_requests)

send_msg = bytes(os.urandom(msg_size))
recv_msg = bytearray(msg_size)
Expand Down
7 changes: 5 additions & 2 deletions python/ucxx/ucxx/_lib_async/application_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,14 @@ async def create_endpoint(
listener=False,
stream_timeout=exchange_peer_info_timeout,
)
except UCXMessageTruncatedError:
except UCXMessageTruncatedError as e:
# A truncated message occurs if the remote endpoint closed before
# exchanging peer info, in that case we should raise the endpoint
# error instead.
# error, if available.
ucx_ep.raise_on_error()
# If no endpoint error is available, re-raise exception.
raise e

tags = {
"msg_send": peer_info["msg_tag"],
"msg_recv": msg_tag,
Expand Down

0 comments on commit c5f5583

Please sign in to comment.