diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index dc4c1a52..a9a38b8b 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -527,6 +527,7 @@ async def connect( ucxx.exceptions.UCXCloseError, ucxx.exceptions.UCXCanceledError, ucxx.exceptions.UCXConnectionResetError, + ucxx.exceptions.UCXMessageTruncatedError, ucxx.exceptions.UCXNotConnectedError, ucxx.exceptions.UCXUnreachableError, ): diff --git a/python/ucxx/ucxx/_lib/tests/test_endpoint.py b/python/ucxx/ucxx/_lib/tests/test_endpoint.py index 71a3f926..d56df10b 100644 --- a/python/ucxx/ucxx/_lib/tests/test_endpoint.py +++ b/python/ucxx/ucxx/_lib/tests/test_endpoint.py @@ -50,9 +50,13 @@ def _listener_handler(conn_request): while ep[0] is None: worker.progress() - wireup_msg = Array(bytearray(WireupMessageSize)) - wireup_request = ep[0].tag_recv(wireup_msg, tag=ucx_api.UCXXTag(0)) - wait_requests(worker, "blocking", wireup_request) + wireup_msg_recv = Array(bytearray(WireupMessageSize)) + wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize))) + wireup_requests = [ + ep[0].tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)), + ep[0].tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)), + ] + wait_requests(worker, "blocking", wireup_requests) if server_close_callback is True: while closed[0] is False: @@ -72,13 +76,20 @@ def _client(port, server_close_callback): port, endpoint_error_handling=True, ) - worker.progress() - wireup_msg = Array(bytes(os.urandom(WireupMessageSize))) - wireup_request = ep.tag_send(wireup_msg, tag=ucx_api.UCXXTag(0)) - wait_requests(worker, "blocking", wireup_request) if server_close_callback is False: closed = [False] ep.set_close_callback(_close_callback, cb_args=(closed,)) + worker.progress() + + wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize))) + wireup_msg_recv = Array(bytearray(WireupMessageSize)) + wireup_requests = [ + ep.tag_send(wireup_msg_send, tag=ucx_api.UCXXTag(0)), + ep.tag_recv(wireup_msg_recv, tag=ucx_api.UCXXTag(0)), + ] + wait_requests(worker, "blocking", wireup_requests) + + if server_close_callback is False: while closed[0] is False: worker.progress() diff --git a/python/ucxx/ucxx/_lib/tests/test_server_client.py b/python/ucxx/ucxx/_lib/tests/test_server_client.py index 4970134d..b3665302 100644 --- a/python/ucxx/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/ucxx/_lib/tests/test_server_client.py @@ -42,6 +42,7 @@ def _echo_server(get_queue, put_queue, transfer_api, msg_size, progress_mode): we keep a reference to the listener's endpoint and execute transfers outside of the callback function. """ + # TAG is always used for wireup feature_flags = [ucx_api.Feature.WAKEUP] if transfer_api == "am": feature_flags.append(ucx_api.Feature.AM) @@ -75,9 +76,13 @@ def _listener_handler(conn_request): if progress_mode == "blocking": worker.progress() - wireup_msg = Array(bytearray(WireupMessageSize)) - wireup_request = _recv(ep[0], transfer_api, wireup_msg) - wait_requests(worker, progress_mode, wireup_request) + wireup_msg_recv = Array(bytearray(WireupMessageSize)) + wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize))) + wireup_requests = [ + _recv(ep[0], transfer_api, wireup_msg_recv), + _send(ep[0], transfer_api, wireup_msg_send), + ] + wait_requests(worker, progress_mode, wireup_requests) msg = Array(bytearray(msg_size)) @@ -110,10 +115,11 @@ def _listener_handler(conn_request): def _echo_client(transfer_api, msg_size, progress_mode, port): + # TAG is always used for wireup feature_flags = [ucx_api.Feature.WAKEUP] if transfer_api == "am": feature_flags.append(ucx_api.Feature.AM) - if transfer_api == "stream": + elif transfer_api == "stream": feature_flags.append(ucx_api.Feature.STREAM) else: feature_flags.append(ucx_api.Feature.TAG) @@ -136,9 +142,13 @@ def _echo_client(transfer_api, msg_size, progress_mode, port): if progress_mode == "blocking": worker.progress() - wireup_msg = Array(bytes(os.urandom(WireupMessageSize))) - wireup_request = _send(ep, transfer_api, wireup_msg) - wait_requests(worker, progress_mode, wireup_request) + wireup_msg_send = Array(bytes(os.urandom(WireupMessageSize))) + wireup_msg_recv = Array(bytearray(WireupMessageSize)) + wireup_requests = [ + _send(ep, transfer_api, wireup_msg_send), + _recv(ep, transfer_api, wireup_msg_recv), + ] + wait_requests(worker, progress_mode, wireup_requests) send_msg = bytes(os.urandom(msg_size)) recv_msg = bytearray(msg_size) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 555c9a01..dd94a112 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -365,11 +365,14 @@ async def create_endpoint( listener=False, stream_timeout=exchange_peer_info_timeout, ) - except UCXMessageTruncatedError: + except UCXMessageTruncatedError as e: # A truncated message occurs if the remote endpoint closed before # exchanging peer info, in that case we should raise the endpoint - # error instead. + # error, if available. ucx_ep.raise_on_error() + # If no endpoint error is available, re-raise exception. + raise e + tags = { "msg_send": peer_info["msg_tag"], "msg_recv": msg_tag,