From 0e3c8aa81765d856e18ba823b7162ba680c380a5 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Wed, 26 Jul 2023 08:11:19 -0500 Subject: [PATCH 1/2] Use gtest version from conda_build_config.yaml. (#66) This PR adds appropriate pinnings for `gtest` and `gmock` in `libucxx-tests`. Authors: - Bradley Dice (https://github.com/bdice) Approvers: - Peter Andreas Entschev (https://github.com/pentschev) - Ray Douglass (https://github.com/raydouglass) URL: https://github.com/rapidsai/ucxx/pull/66 --- conda/recipes/ucxx/conda_build_config.yaml | 5 ++++- conda/recipes/ucxx/meta.yaml | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/conda/recipes/ucxx/conda_build_config.yaml b/conda/recipes/ucxx/conda_build_config.yaml index 3f1343a3..1bd2a4c7 100644 --- a/conda/recipes/ucxx/conda_build_config.yaml +++ b/conda/recipes/ucxx/conda_build_config.yaml @@ -23,5 +23,8 @@ python: ucx: - 1.14.0 -gtest_version: +gmock: + - ">=1.13.0" + +gtest: - ">=1.13.0" diff --git a/conda/recipes/ucxx/meta.yaml b/conda/recipes/ucxx/meta.yaml index 045890ea..d0169ec5 100644 --- a/conda/recipes/ucxx/meta.yaml +++ b/conda/recipes/ucxx/meta.yaml @@ -52,7 +52,7 @@ requirements: - ucx - python - librmm =23.08 - - gtest {{ gtest_version }} + - gtest outputs: - name: libucxx From 65deb675557eb03345b99346ee50232394cca9bb Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 3 Aug 2023 00:03:18 +0200 Subject: [PATCH 2/2] Fix destruction of `ucxx::RequestTagMulti` and final status (#72) This PR introduces various fixes: * Fix destruction of `ucxx::RequestTagMulti`, which after moving to inheriting from `ucxx::Request`, using `setStatus()` was missing, thus causing the `ucxx::RequestTagMulti` to never cleanup its inflight request status. Calling `setStatus()` instead of writing the `_status` attribute directly resolves this. * The status of `ucxx::RequestTagMulti` was until now always `UCS_INPROGRESS` or `UCS_OK` after completing. This is not right but there is no good way to combine the statuses of all underlying requests. Therefore we now set the status of the first failing request as the final status instead of `UCS_OK` if at least one of the requests failed. The user can still check each underlying's request status if granular information is required. * Since `ucxx::RequestTagMulti` now inherits from `ucxx::Request`, the latter's trace logging for creation/destruction are sufficient, thus removing the redundant trace logs is appropriate. * Add lifetime trace logs for `ucxx::BufferRequest`. * Execute user-defined callback within the scope fo `ucxx::Request::setStatus()` to prevent accidental reordering mistakes, as the callback needs to always execute after the status is set, because the callback itself might make use of that information. Additionally it will prevent missing execution of the callback should `setStatus()` be called elsewhere in the future. * Return `std::shared_ptr in allocateBuffer()` as returning raw pointers may cause leaks, and can be problematic specifically in `ucxx::RequestTagMulti` as the `Buffer*` will not be released unless the user takes ownership and releases it appropriately. * Introduce mutex in `ucxx::Request` because of potential concurrency issues when running with the progress thread, which a mutex can resolve in `ucxx::Request`. A common case is when `ucxx::Endpoint` is registering an inflight request and `setStatus()` has just begun running, in which case the inflight request will be removed (a no-op as it hasn't been previously registered) and then actually registered, which will cause the `ucxx::Request` to never cleanup. Authors: - Peter Andreas Entschev (https://github.com/pentschev) Approvers: - Lawrence Mitchell (https://github.com/wence-) URL: https://github.com/rapidsai/ucxx/pull/72 --- cpp/include/ucxx/buffer.h | 2 +- cpp/include/ucxx/request.h | 13 +-- cpp/include/ucxx/request_tag_multi.h | 19 +++- cpp/src/buffer.cpp | 6 +- cpp/src/request.cpp | 91 ++++++++++--------- cpp/src/request_am.cpp | 17 ++-- cpp/src/request_stream.cpp | 16 ++-- cpp/src/request_tag.cpp | 26 +++--- cpp/src/request_tag_multi.cpp | 73 +++++++-------- cpp/tests/buffer.cpp | 20 ++-- cpp/tests/worker.cpp | 2 +- dependencies.yaml | 1 + python/ucxx/_lib/libucxx.pyx | 16 ++-- python/ucxx/_lib/ucxx_api.pxd | 2 +- .../_lib_async/tests/test_send_recv_multi.py | 4 +- 15 files changed, 165 insertions(+), 143 deletions(-) diff --git a/cpp/include/ucxx/buffer.h b/cpp/include/ucxx/buffer.h index 7102169e..17471148 100644 --- a/cpp/include/ucxx/buffer.h +++ b/cpp/include/ucxx/buffer.h @@ -245,6 +245,6 @@ class RMMBuffer : public Buffer { }; #endif -Buffer* allocateBuffer(BufferType bufferType, const size_t size); +std::shared_ptr allocateBuffer(BufferType bufferType, const size_t size); } // namespace ucxx diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index bedf3aa3..1374aafc 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -23,12 +23,12 @@ namespace ucxx { class Request : public Component { protected: - std::atomic _status{UCS_INPROGRESS}; ///< Requests status - std::string _status_msg{}; ///< Human-readable status message - void* _request{nullptr}; ///< Pointer to UCP request - std::shared_ptr _future{nullptr}; ///< Future to notify upon completion - RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback - RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data + ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status + std::string _status_msg{}; ///< Human-readable status message + void* _request{nullptr}; ///< Pointer to UCP request + std::shared_ptr _future{nullptr}; ///< Future to notify upon completion + RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback + RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data std::shared_ptr _worker{ nullptr}; ///< Worker that generated request (if not from endpoint) std::shared_ptr _endpoint{ @@ -39,6 +39,7 @@ class Request : public Component { nullptr}; ///< The submission object that will dispatch the request std::string _operationName{ "request_undefined"}; ///< Human-readable operation name, mostly used for log messages + std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request /** diff --git a/cpp/include/ucxx/request_tag_multi.h b/cpp/include/ucxx/request_tag_multi.h index 57db19ce..ff8d9b2b 100644 --- a/cpp/include/ucxx/request_tag_multi.h +++ b/cpp/include/ucxx/request_tag_multi.h @@ -23,7 +23,15 @@ class RequestTagMulti; struct BufferRequest { std::shared_ptr request{nullptr}; ///< The `ucxx::RequestTag` of a header or frame std::shared_ptr stringBuffer{nullptr}; ///< Serialized `Header` - Buffer* buffer{nullptr}; ///< Internally allocated buffer to receive a frame + std::shared_ptr buffer{nullptr}; ///< Internally allocated buffer to receive a frame + + BufferRequest(); + ~BufferRequest(); + + BufferRequest(const BufferRequest&) = delete; + BufferRequest& operator=(BufferRequest const&) = delete; + BufferRequest(BufferRequest&& o) = delete; + BufferRequest& operator=(BufferRequest&& o) = delete; }; typedef std::shared_ptr BufferRequestPtr; @@ -34,8 +42,8 @@ class RequestTagMulti : public Request { ucp_tag_t _tag{0}; ///< Tag to match size_t _totalFrames{0}; ///< The total number of frames handled by this request std::mutex - _completedRequestsMutex{}; ///< Mutex to control access to completed requests container - std::vector _completedRequests{}; ///< Requests that already completed + _completedRequestsMutex{}; ///< Mutex to control access to completed requests container + size_t _completedRequests{0}; ///< Count requests that already completed public: std::vector _bufferRequests{}; ///< Container of all requests posted @@ -191,7 +199,10 @@ class RequestTagMulti : public Request { * * When this method is called, the request that completed will be pushed into a container * which will be later used to evaluate if all frames completed and set the final status - * of the multi-transfer request and the Python future, if enabled. + * of the multi-transfer request and the Python future, if enabled. The final status is + * either `UCS_OK` if all underlying requests completed successfully, otherwise it will + * contain the status of the first failing request, for granular information the user + * may still verify each of the underlying requests individually. * * @param[in] status the status of the request being completed. * @param[in] request the `ucxx::BufferRequest` object containing a single tag . diff --git a/cpp/src/buffer.cpp b/cpp/src/buffer.cpp index ebbc8fdc..bd63d09a 100644 --- a/cpp/src/buffer.cpp +++ b/cpp/src/buffer.cpp @@ -82,17 +82,17 @@ void* RMMBuffer::data() } #endif -Buffer* allocateBuffer(const BufferType bufferType, const size_t size) +std::shared_ptr allocateBuffer(const BufferType bufferType, const size_t size) { #if UCXX_ENABLE_RMM if (bufferType == BufferType::RMM) - return new RMMBuffer(size); + return std::make_shared(size); else #else if (bufferType == BufferType::RMM) throw std::runtime_error("RMM support not enabled, please compile with -DUCXX_ENABLE_RMM=1"); #endif - return new HostBuffer(size); + return std::make_shared(size); } } // namespace ucxx diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 13ab2b54..80365717 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -58,6 +58,7 @@ Request::~Request() { ucxx_trace("Request destroyed: %p, %s", this, _operationNa void Request::cancel() { + std::lock_guard lock(_mutex); if (_status == UCS_INPROGRESS) { if (UCS_PTR_IS_ERR(_request)) { ucs_status_t status = UCS_PTR_STATUS(_request); @@ -72,29 +73,39 @@ void Request::cancel() if (_request != nullptr) ucp_request_cancel(_worker->getHandle(), _request); } } else { - auto status = _status.load(); ucxx_trace_req_f(_ownerString.c_str(), _request, _operationName.c_str(), "already completed with status: %d (%s)", - status, - ucs_status_string(status)); + _status, + ucs_status_string(_status)); } } -ucs_status_t Request::getStatus() { return _status; } +ucs_status_t Request::getStatus() +{ + std::lock_guard lock(_mutex); + return _status; +} -void* Request::getFuture() { return _future ? _future->getHandle() : nullptr; } +void* Request::getFuture() +{ + std::lock_guard lock(_mutex); + return _future ? _future->getHandle() : nullptr; +} void Request::checkError() { - // Only load the atomic variable once - auto status = _status.load(); + std::lock_guard lock(_mutex); - utils::ucsErrorThrow(status, status == UCS_ERR_MESSAGE_TRUNCATED ? _status_msg : std::string()); + utils::ucsErrorThrow(_status, _status == UCS_ERR_MESSAGE_TRUNCATED ? _status_msg : std::string()); } -bool Request::isCompleted() { return _status != UCS_INPROGRESS; } +bool Request::isCompleted() +{ + std::lock_guard lock(_mutex); + return _status != UCS_INPROGRESS; +} void Request::callback(void* request, ucs_status_t status) { @@ -109,12 +120,11 @@ void Request::callback(void* request, ucs_status_t status) ucxx_debug("Request %p destroyed before callback() was executed", this); return; } - auto statusAttr = _status.load(); - if (statusAttr != UCS_INPROGRESS) + if (_status != UCS_INPROGRESS) ucxx_trace("Request %p has status already set to %d (%s), callback setting %d (%s)", this, - statusAttr, - ucs_status_string(statusAttr), + _status, + ucs_status_string(_status), status, ucs_status_string(status)); @@ -123,18 +133,13 @@ void Request::callback(void* request, ucs_status_t status) ucxx_trace("Request completed: %p, handle: %p", this, request); setStatus(status); ucxx_trace("Request %p, isCompleted: %d", this, isCompleted()); - - ucxx_trace_req_f(_ownerString.c_str(), - request, - _operationName.c_str(), - "callback %p", - _callback.target()); - if (_callback) _callback(status, _callbackData); } void Request::process() { - ucs_status_t status = _status.load(); + std::lock_guard lock(_mutex); + + ucs_status_t status = UCS_INPROGRESS; if (UCS_PTR_IS_ERR(_request)) { // Operation errored immediately @@ -167,34 +172,38 @@ void Request::process() _ownerString.c_str(), _request, _operationName.c_str(), "completed immediately"); } - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "callback %p", - _callback.target()); - if (_callback) _callback(status, _callbackData); - setStatus(status); } void Request::setStatus(ucs_status_t status) { - if (_endpoint != nullptr) _endpoint->removeInflightRequest(this); - _worker->removeInflightRequest(this); + { + std::lock_guard lock(_mutex); - ucxx_trace_req_f(_ownerString.c_str(), - _request, - _operationName.c_str(), - "callback called with status %d (%s)", - status, - ucs_status_string(status)); + if (_endpoint != nullptr) _endpoint->removeInflightRequest(this); + _worker->removeInflightRequest(this); + + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "callback called with status %d (%s)", + status, + ucs_status_string(status)); - if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set"); - _status.store(status); + if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set"); + _status = status; - if (_enablePythonFuture) { - auto future = std::static_pointer_cast(_future); - future->notify(status); + if (_enablePythonFuture) { + auto future = std::static_pointer_cast(_future); + future->notify(status); + } + + ucxx_trace_req_f(_ownerString.c_str(), + _request, + _operationName.c_str(), + "callback %p", + _callback.target()); + if (_callback) _callback(status, _callbackData); } } diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index c09bfcab..ba29374d 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -247,13 +247,16 @@ void RequestAm::request() if (_delayedSubmission->_send) { param.cb.send = _amSendCallback; - _request = ucp_am_send_nbx(_endpoint->getHandle(), - 0, - &_sendHeader, - sizeof(_sendHeader), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - ¶m); + void* request = ucp_am_send_nbx(_endpoint->getHandle(), + 0, + &_sendHeader, + sizeof(_sendHeader), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + ¶m); + + std::lock_guard lock(_mutex); + _request = request; } else { throw ucxx::UnsupportedError( "Receiving active messages must be handled by the worker's callback"); diff --git a/cpp/src/request_stream.cpp b/cpp/src/request_stream.cpp index e44d6772..624e3734 100644 --- a/cpp/src/request_stream.cpp +++ b/cpp/src/request_stream.cpp @@ -50,21 +50,25 @@ void RequestStream::request() UCP_OP_ATTR_FIELD_USER_DATA, .datatype = ucp_dt_make_contig(1), .user_data = this}; + void* request = nullptr; if (_delayedSubmission->_send) { param.cb.send = streamSendCallback; - _request = ucp_stream_send_nbx( + request = ucp_stream_send_nbx( _endpoint->getHandle(), _delayedSubmission->_buffer, _delayedSubmission->_length, ¶m); } else { param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; param.flags = UCP_STREAM_RECV_FLAG_WAITALL; param.cb.recv_stream = streamRecvCallback; - _request = ucp_stream_recv_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - &_delayedSubmission->_length, - ¶m); + request = ucp_stream_recv_nbx(_endpoint->getHandle(), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + &_delayedSubmission->_length, + ¶m); } + + std::lock_guard lock(_mutex); + _request = request; } void RequestStream::populateDelayedSubmission() diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index 9c1493d2..f100ad13 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -99,23 +99,27 @@ void RequestTag::request() UCP_OP_ATTR_FIELD_USER_DATA, .datatype = ucp_dt_make_contig(1), .user_data = this}; + void* request = nullptr; if (_delayedSubmission->_send) { param.cb.send = tagSendCallback; - _request = ucp_tag_send_nbx(_endpoint->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _delayedSubmission->_tag, - ¶m); + request = ucp_tag_send_nbx(_endpoint->getHandle(), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _delayedSubmission->_tag, + ¶m); } else { param.cb.recv = tagRecvCallback; - _request = ucp_tag_recv_nbx(_worker->getHandle(), - _delayedSubmission->_buffer, - _delayedSubmission->_length, - _delayedSubmission->_tag, - tagMask, - ¶m); + request = ucp_tag_recv_nbx(_worker->getHandle(), + _delayedSubmission->_buffer, + _delayedSubmission->_length, + _delayedSubmission->_tag, + tagMask, + ¶m); } + + std::lock_guard lock(_mutex); + _request = request; } void RequestTag::populateDelayedSubmission() diff --git a/cpp/src/request_tag_multi.cpp b/cpp/src/request_tag_multi.cpp index 0fb916a9..986bbf51 100644 --- a/cpp/src/request_tag_multi.cpp +++ b/cpp/src/request_tag_multi.cpp @@ -17,6 +17,10 @@ namespace ucxx { +BufferRequest::BufferRequest() { ucxx_trace("BufferRequest created: %p", this); } + +BufferRequest::~BufferRequest() { ucxx_trace("BufferRequest destroyed: %p", this); } + RequestTagMulti::RequestTagMulti(std::shared_ptr endpoint, const bool send, const ucp_tag_t tag, @@ -28,8 +32,6 @@ RequestTagMulti::RequestTagMulti(std::shared_ptr endpoint, _send(send), _tag(tag) { - ucxx_trace_req("RequestTagMulti::RequestTagMulti: %p, send: %d, tag: %lx", this, send, _tag); - auto worker = endpoint->getWorker(); if (enablePythonFuture) _future = worker->getFuture(); } @@ -50,7 +52,6 @@ RequestTagMulti::~RequestTagMulti() */ br->request = nullptr; } - ucxx_trace("RequestTagMulti destroyed: %p", this); } std::shared_ptr createRequestTagMultiSend(std::shared_ptr endpoint, @@ -60,15 +61,12 @@ std::shared_ptr createRequestTagMultiSend(std::shared_ptr(new RequestTagMulti(endpoint, true, tag, enablePythonFuture)); if (size.size() != buffer.size() || isCUDA.size() != buffer.size()) throw std::runtime_error("All input vectors should be of equal size"); - ucxx_trace("RequestTagMulti created: %p", ret.get()); - ret->send(buffer, size, isCUDA); return ret; @@ -78,12 +76,9 @@ std::shared_ptr createRequestTagMultiRecv(std::shared_ptr(new RequestTagMulti(endpoint, false, tag, enablePythonFuture)); - ucxx_trace("RequestTagMulti created: %p", ret.get()); - ret->recvCallback(UCS_OK); return ret; @@ -159,22 +154,24 @@ void RequestTagMulti::markCompleted(ucs_status_t status, RequestCallbackUserData ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: %lx", this, _tag); std::lock_guard lock(_completedRequestsMutex); - /* TODO: Move away from std::shared_ptr to avoid casting void* to - * BufferRequest*, or remove pointer holding entirely here since it - * is not currently used for anything besides counting completed transfers. - */ - _completedRequests.push_back(reinterpret_cast(request.get())); + if (++_completedRequests == _totalFrames) { + auto s = UCS_OK; - if (_completedRequests.size() == _totalFrames) { - // TODO: Actually handle errors - _status = UCS_OK; - if (_future) _future->notify(UCS_OK); + // Get the first non-UCS_OK status and set that as complete status + for (const auto& br : _bufferRequests) { + if (br->request) { + s = br->request->getStatus(); + if (s != UCS_OK) break; + } + } + + setStatus(s); } ucxx_trace_req("RequestTagMulti::markCompleted request: %p, tag: %lx, completed: %lu/%lu", this, _tag, - _completedRequests.size(), + _completedRequests, _totalFrames); } @@ -187,13 +184,14 @@ void RequestTagMulti::recvHeader() auto bufferRequest = std::make_shared(); _bufferRequests.push_back(bufferRequest); bufferRequest->stringBuffer = std::make_shared(Header::dataSize(), 0); - bufferRequest->request = _endpoint->tagRecv( - &bufferRequest->stringBuffer->front(), - bufferRequest->stringBuffer->size(), - _tag, - false, - [this](ucs_status_t status, RequestCallbackUserData arg) { return this->recvCallback(status); }, - nullptr); + bufferRequest->request = + _endpoint->tagRecv(&bufferRequest->stringBuffer->front(), + bufferRequest->stringBuffer->size(), + _tag, + false, + [this](ucs_status_t status, RequestCallbackUserData arg) { + return this->recvCallback(status); + }); if (bufferRequest->request->isCompleted()) { // TODO: Errors may not be raisable within callback @@ -215,8 +213,6 @@ void RequestTagMulti::recvCallback(ucs_status_t status) if (_bufferRequests.empty()) { recvHeader(); } else { - const auto request = _bufferRequests.back(); - if (status == UCS_OK) { ucxx_trace_req( "RequestTagMulti::recvCallback header received, multi request: %p, tag: %lx", this, _tag); @@ -258,26 +254,19 @@ void RequestTagMulti::send(const std::vector& buffer, for (const auto& header : headers) { auto serializedHeader = std::make_shared(header.serialize()); - auto r = _endpoint->tagSend(&serializedHeader->front(), serializedHeader->size(), _tag, false); - - auto bufferRequest = std::make_shared(); - bufferRequest->request = r; + auto bufferRequest = std::make_shared(); + bufferRequest->request = + _endpoint->tagSend(&serializedHeader->front(), serializedHeader->size(), _tag, false); bufferRequest->stringBuffer = serializedHeader; _bufferRequests.push_back(bufferRequest); } for (size_t i = 0; i < _totalFrames; ++i) { - auto bufferRequest = std::make_shared(); - auto r = _endpoint->tagSend( - buffer[i], - size[i], - _tag, - false, - [this](ucs_status_t status, RequestCallbackUserData arg) { + auto bufferRequest = std::make_shared(); + bufferRequest->request = _endpoint->tagSend( + buffer[i], size[i], _tag, false, [this](ucs_status_t status, RequestCallbackUserData arg) { return this->markCompleted(status, arg); - }, - bufferRequest); - bufferRequest->request = r; + }); _bufferRequests.push_back(bufferRequest); } diff --git a/cpp/tests/buffer.cpp b/cpp/tests/buffer.cpp index fb74c321..942a846f 100644 --- a/cpp/tests/buffer.cpp +++ b/cpp/tests/buffer.cpp @@ -17,7 +17,7 @@ class BufferAllocator : public ::testing::Test, protected: ucxx::BufferType _type; size_t _size; - ucxx::Buffer* _buffer; + std::shared_ptr _buffer; void SetUp() { @@ -27,8 +27,6 @@ class BufferAllocator : public ::testing::Test, _buffer = allocateBuffer(_type, _size); } - - void TearDown() { delete _buffer; } }; TEST_P(BufferAllocator, TestType) @@ -36,7 +34,7 @@ TEST_P(BufferAllocator, TestType) ASSERT_EQ(_buffer->getType(), _type); if (_type == ucxx::BufferType::Host) { - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->getType(), _type); auto releasedBuffer = buffer->release(); @@ -46,7 +44,7 @@ TEST_P(BufferAllocator, TestType) free(releasedBuffer); } else if (_type == ucxx::BufferType::RMM) { #if UCXX_ENABLE_RMM - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->getType(), _type); auto releasedBuffer = buffer->release(); @@ -65,7 +63,7 @@ TEST_P(BufferAllocator, TestSize) ASSERT_EQ(_buffer->getSize(), _size); if (_type == ucxx::BufferType::Host) { - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->getSize(), _size); auto releasedBuffer = buffer->release(); @@ -75,7 +73,7 @@ TEST_P(BufferAllocator, TestSize) free(releasedBuffer); } else if (_type == ucxx::BufferType::RMM) { #if UCXX_ENABLE_RMM - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->getSize(), _size); auto releasedBuffer = buffer->release(); @@ -94,7 +92,7 @@ TEST_P(BufferAllocator, TestData) ASSERT_NE(_buffer->data(), nullptr); if (_type == ucxx::BufferType::Host) { - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->data(), _buffer->data()); auto releasedBuffer = buffer->release(); @@ -104,7 +102,7 @@ TEST_P(BufferAllocator, TestData) free(releasedBuffer); } else if (_type == ucxx::BufferType::RMM) { #if UCXX_ENABLE_RMM - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); ASSERT_EQ(buffer->data(), _buffer->data()); auto releasedBuffer = buffer->release(); @@ -123,7 +121,7 @@ TEST_P(BufferAllocator, TestData) TEST_P(BufferAllocator, TestThrowAfterRelease) { if (_type == ucxx::BufferType::Host) { - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); auto releasedBuffer = buffer->release(); EXPECT_THROW(buffer->data(), std::runtime_error); @@ -132,7 +130,7 @@ TEST_P(BufferAllocator, TestThrowAfterRelease) free(releasedBuffer); } else if (_type == ucxx::BufferType::RMM) { #if UCXX_ENABLE_RMM - auto buffer = dynamic_cast(_buffer); + auto buffer = std::dynamic_pointer_cast(_buffer); auto releasedBuffer = buffer->release(); EXPECT_THROW(buffer->data(), std::runtime_error); diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index a85d996f..abc99d3a 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -208,7 +208,7 @@ TEST_P(WorkerProgressTest, ProgressTagMulti) reinterpret_cast(br->buffer->data()) + send.size()); ASSERT_EQ(recvAbstract[0], send[0]); - const auto& recvConcretePtr = dynamic_cast(br->buffer); + const auto& recvConcretePtr = std::dynamic_pointer_cast(br->buffer); ASSERT_EQ(recvConcretePtr->getType(), ucxx::BufferType::Host); ASSERT_EQ(recvConcretePtr->getSize(), send.size() * sizeof(int)); diff --git a/dependencies.yaml b/dependencies.yaml index 2e69f8e8..320a7d46 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -121,3 +121,4 @@ dependencies: - numba>=0.57.0 - pytest - pytest-asyncio + - pytest-rerunfailures diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index d8a17af5..027bd1d6 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -679,14 +679,14 @@ cdef class UCXRequest(): with nogil: buf = self._request.get().getRecvBuffer() - bufType = buf.get().getType() + bufType = buf.get().getType() if buf != nullptr else BufferType.Invalid # If buf == NULL, it's not allocated by the request but rather the user if buf == NULL: return None elif bufType == BufferType.RMM: return _get_rmm_buffer(buf.get()) - else: + elif bufType == BufferType.Host: return _get_host_buffer(buf.get()) @@ -706,18 +706,20 @@ cdef class UCXBufferRequest: ) def get_py_buffer(self): - cdef Buffer* buf + cdef shared_ptr[Buffer] buf + cdef BufferType bufType with nogil: buf = self._buffer_request.get().buffer + bufType = buf.get().getType() if buf != nullptr else BufferType.Invalid # If buf == NULL, it holds a header if buf == NULL: return None - elif buf.getType() == BufferType.RMM: - return _get_rmm_buffer(buf) - else: - return _get_host_buffer(buf) + elif bufType == BufferType.RMM: + return _get_rmm_buffer(buf.get()) + elif bufType == BufferType.Host: + return _get_host_buffer(buf.get()) cdef class UCXBufferRequests: diff --git a/python/ucxx/_lib/ucxx_api.pxd b/python/ucxx/_lib/ucxx_api.pxd index 5e1a0b8f..2f93288e 100644 --- a/python/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/_lib/ucxx_api.pxd @@ -324,7 +324,7 @@ cdef extern from "" namespace "ucxx" nogil: ctypedef struct BufferRequest: shared_ptr[Request] request shared_ptr[string] stringBuffer - Buffer* buffer + shared_ptr[Buffer] buffer ctypedef shared_ptr[BufferRequest] BufferRequestPtr diff --git a/python/ucxx/_lib_async/tests/test_send_recv_multi.py b/python/ucxx/_lib_async/tests/test_send_recv_multi.py index a1dc7521..3c528d98 100644 --- a/python/ucxx/_lib_async/tests/test_send_recv_multi.py +++ b/python/ucxx/_lib_async/tests/test_send_recv_multi.py @@ -72,7 +72,7 @@ async def test_send_recv_numpy(size, multi_size, dtype): @pytest.mark.parametrize("size", msg_sizes) @pytest.mark.parametrize("multi_size", multi_sizes) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.rerun_on_failure(3) +@pytest.mark.flaky(reruns=3) async def test_send_recv_cupy(size, multi_size, dtype): cupy = pytest.importorskip("cupy") @@ -91,7 +91,7 @@ async def test_send_recv_cupy(size, multi_size, dtype): @pytest.mark.parametrize("size", msg_sizes) @pytest.mark.parametrize("multi_size", multi_sizes) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.rerun_on_failure(3) +@pytest.mark.flaky(reruns=3) async def test_send_recv_numba(size, multi_size, dtype): cuda = pytest.importorskip("numba.cuda")