Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/branch-0.33' into ucs-spinlock…
Browse files Browse the repository at this point in the history
…-functions-progress-thread
  • Loading branch information
pentschev committed Aug 2, 2023
2 parents df4c397 + 65deb67 commit 762545c
Show file tree
Hide file tree
Showing 17 changed files with 170 additions and 145 deletions.
5 changes: 4 additions & 1 deletion conda/recipes/ucxx/conda_build_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@ python:
ucx:
- 1.14.0

gtest_version:
gmock:
- ">=1.13.0"

gtest:
- ">=1.13.0"
2 changes: 1 addition & 1 deletion conda/recipes/ucxx/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ requirements:
- ucx
- python
- librmm =23.08
- gtest {{ gtest_version }}
- gtest

outputs:
- name: libucxx
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/ucxx/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ class RMMBuffer : public Buffer {
};
#endif

Buffer* allocateBuffer(BufferType bufferType, const size_t size);
std::shared_ptr<Buffer> allocateBuffer(BufferType bufferType, const size_t size);

} // namespace ucxx
13 changes: 7 additions & 6 deletions cpp/include/ucxx/request.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ namespace ucxx {

class Request : public Component {
protected:
std::atomic<ucs_status_t> _status{UCS_INPROGRESS}; ///< Requests status
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status
std::string _status_msg{}; ///< Human-readable status message
void* _request{nullptr}; ///< Pointer to UCP request
std::shared_ptr<Future> _future{nullptr}; ///< Future to notify upon completion
RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback
RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data
std::shared_ptr<Worker> _worker{
nullptr}; ///< Worker that generated request (if not from endpoint)
std::shared_ptr<Endpoint> _endpoint{
Expand All @@ -39,6 +39,7 @@ class Request : public Component {
nullptr}; ///< The submission object that will dispatch the request
std::string _operationName{
"request_undefined"}; ///< Human-readable operation name, mostly used for log messages
std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set
bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request

/**
Expand Down
19 changes: 15 additions & 4 deletions cpp/include/ucxx/request_tag_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,15 @@ class RequestTagMulti;
struct BufferRequest {
std::shared_ptr<Request> request{nullptr}; ///< The `ucxx::RequestTag` of a header or frame
std::shared_ptr<std::string> stringBuffer{nullptr}; ///< Serialized `Header`
Buffer* buffer{nullptr}; ///< Internally allocated buffer to receive a frame
std::shared_ptr<Buffer> buffer{nullptr}; ///< Internally allocated buffer to receive a frame

BufferRequest();
~BufferRequest();

BufferRequest(const BufferRequest&) = delete;
BufferRequest& operator=(BufferRequest const&) = delete;
BufferRequest(BufferRequest&& o) = delete;
BufferRequest& operator=(BufferRequest&& o) = delete;
};

typedef std::shared_ptr<BufferRequest> BufferRequestPtr;
Expand All @@ -34,8 +42,8 @@ class RequestTagMulti : public Request {
ucp_tag_t _tag{0}; ///< Tag to match
size_t _totalFrames{0}; ///< The total number of frames handled by this request
std::mutex
_completedRequestsMutex{}; ///< Mutex to control access to completed requests container
std::vector<BufferRequest*> _completedRequests{}; ///< Requests that already completed
_completedRequestsMutex{}; ///< Mutex to control access to completed requests container
size_t _completedRequests{0}; ///< Count requests that already completed

public:
std::vector<BufferRequestPtr> _bufferRequests{}; ///< Container of all requests posted
Expand Down Expand Up @@ -191,7 +199,10 @@ class RequestTagMulti : public Request {
*
* When this method is called, the request that completed will be pushed into a container
* which will be later used to evaluate if all frames completed and set the final status
* of the multi-transfer request and the Python future, if enabled.
* of the multi-transfer request and the Python future, if enabled. The final status is
* either `UCS_OK` if all underlying requests completed successfully, otherwise it will
* contain the status of the first failing request, for granular information the user
* may still verify each of the underlying requests individually.
*
* @param[in] status the status of the request being completed.
* @param[in] request the `ucxx::BufferRequest` object containing a single tag .
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ void* RMMBuffer::data()
}
#endif

Buffer* allocateBuffer(const BufferType bufferType, const size_t size)
std::shared_ptr<Buffer> allocateBuffer(const BufferType bufferType, const size_t size)
{
#if UCXX_ENABLE_RMM
if (bufferType == BufferType::RMM)
return new RMMBuffer(size);
return std::make_shared<RMMBuffer>(size);
else
#else
if (bufferType == BufferType::RMM)
throw std::runtime_error("RMM support not enabled, please compile with -DUCXX_ENABLE_RMM=1");
#endif
return new HostBuffer(size);
return std::make_shared<HostBuffer>(size);
}

} // namespace ucxx
91 changes: 50 additions & 41 deletions cpp/src/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Request::~Request() { ucxx_trace("Request destroyed: %p, %s", this, _operationNa

void Request::cancel()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
if (_status == UCS_INPROGRESS) {
if (UCS_PTR_IS_ERR(_request)) {
ucs_status_t status = UCS_PTR_STATUS(_request);
Expand All @@ -72,29 +73,39 @@ void Request::cancel()
if (_request != nullptr) ucp_request_cancel(_worker->getHandle(), _request);
}
} else {
auto status = _status.load();
ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"already completed with status: %d (%s)",
status,
ucs_status_string(status));
_status,
ucs_status_string(_status));
}
}

ucs_status_t Request::getStatus() { return _status; }
ucs_status_t Request::getStatus()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
return _status;
}

void* Request::getFuture() { return _future ? _future->getHandle() : nullptr; }
void* Request::getFuture()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
return _future ? _future->getHandle() : nullptr;
}

void Request::checkError()
{
// Only load the atomic variable once
auto status = _status.load();
std::lock_guard<std::recursive_mutex> lock(_mutex);

utils::ucsErrorThrow(status, status == UCS_ERR_MESSAGE_TRUNCATED ? _status_msg : std::string());
utils::ucsErrorThrow(_status, _status == UCS_ERR_MESSAGE_TRUNCATED ? _status_msg : std::string());
}

bool Request::isCompleted() { return _status != UCS_INPROGRESS; }
bool Request::isCompleted()
{
std::lock_guard<std::recursive_mutex> lock(_mutex);
return _status != UCS_INPROGRESS;
}

void Request::callback(void* request, ucs_status_t status)
{
Expand All @@ -109,12 +120,11 @@ void Request::callback(void* request, ucs_status_t status)
ucxx_debug("Request %p destroyed before callback() was executed", this);
return;
}
auto statusAttr = _status.load();
if (statusAttr != UCS_INPROGRESS)
if (_status != UCS_INPROGRESS)
ucxx_trace("Request %p has status already set to %d (%s), callback setting %d (%s)",
this,
statusAttr,
ucs_status_string(statusAttr),
_status,
ucs_status_string(_status),
status,
ucs_status_string(status));

Expand All @@ -123,18 +133,13 @@ void Request::callback(void* request, ucs_status_t status)
ucxx_trace("Request completed: %p, handle: %p", this, request);
setStatus(status);
ucxx_trace("Request %p, isCompleted: %d", this, isCompleted());

ucxx_trace_req_f(_ownerString.c_str(),
request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(status, _callbackData);
}

void Request::process()
{
ucs_status_t status = _status.load();
std::lock_guard<std::recursive_mutex> lock(_mutex);

ucs_status_t status = UCS_INPROGRESS;

if (UCS_PTR_IS_ERR(_request)) {
// Operation errored immediately
Expand Down Expand Up @@ -167,34 +172,38 @@ void Request::process()
_ownerString.c_str(), _request, _operationName.c_str(), "completed immediately");
}

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(status, _callbackData);

setStatus(status);
}

void Request::setStatus(ucs_status_t status)
{
if (_endpoint != nullptr) _endpoint->removeInflightRequest(this);
_worker->removeInflightRequest(this);
{
std::lock_guard<std::recursive_mutex> lock(_mutex);

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback called with status %d (%s)",
status,
ucs_status_string(status));
if (_endpoint != nullptr) _endpoint->removeInflightRequest(this);
_worker->removeInflightRequest(this);

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback called with status %d (%s)",
status,
ucs_status_string(status));

if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set");
_status.store(status);
if (_status != UCS_INPROGRESS) ucxx_error("setStatus called but the status was already set");
_status = status;

if (_enablePythonFuture) {
auto future = std::static_pointer_cast<ucxx::Future>(_future);
future->notify(status);
if (_enablePythonFuture) {
auto future = std::static_pointer_cast<ucxx::Future>(_future);
future->notify(status);
}

ucxx_trace_req_f(_ownerString.c_str(),
_request,
_operationName.c_str(),
"callback %p",
_callback.target<void (*)(void)>());
if (_callback) _callback(status, _callbackData);
}
}

Expand Down
17 changes: 10 additions & 7 deletions cpp/src/request_am.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,16 @@ void RequestAm::request()

if (_delayedSubmission->_send) {
param.cb.send = _amSendCallback;
_request = ucp_am_send_nbx(_endpoint->getHandle(),
0,
&_sendHeader,
sizeof(_sendHeader),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
&param);
void* request = ucp_am_send_nbx(_endpoint->getHandle(),
0,
&_sendHeader,
sizeof(_sendHeader),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
&param);

std::lock_guard<std::recursive_mutex> lock(_mutex);
_request = request;
} else {
throw ucxx::UnsupportedError(
"Receiving active messages must be handled by the worker's callback");
Expand Down
16 changes: 10 additions & 6 deletions cpp/src/request_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,25 @@ void RequestStream::request()
UCP_OP_ATTR_FIELD_USER_DATA,
.datatype = ucp_dt_make_contig(1),
.user_data = this};
void* request = nullptr;

if (_delayedSubmission->_send) {
param.cb.send = streamSendCallback;
_request = ucp_stream_send_nbx(
request = ucp_stream_send_nbx(
_endpoint->getHandle(), _delayedSubmission->_buffer, _delayedSubmission->_length, &param);
} else {
param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS;
param.flags = UCP_STREAM_RECV_FLAG_WAITALL;
param.cb.recv_stream = streamRecvCallback;
_request = ucp_stream_recv_nbx(_endpoint->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
&_delayedSubmission->_length,
&param);
request = ucp_stream_recv_nbx(_endpoint->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
&_delayedSubmission->_length,
&param);
}

std::lock_guard<std::recursive_mutex> lock(_mutex);
_request = request;
}

void RequestStream::populateDelayedSubmission()
Expand Down
26 changes: 15 additions & 11 deletions cpp/src/request_tag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,27 @@ void RequestTag::request()
UCP_OP_ATTR_FIELD_USER_DATA,
.datatype = ucp_dt_make_contig(1),
.user_data = this};
void* request = nullptr;

if (_delayedSubmission->_send) {
param.cb.send = tagSendCallback;
_request = ucp_tag_send_nbx(_endpoint->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_tag,
&param);
request = ucp_tag_send_nbx(_endpoint->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_tag,
&param);
} else {
param.cb.recv = tagRecvCallback;
_request = ucp_tag_recv_nbx(_worker->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_tag,
tagMask,
&param);
request = ucp_tag_recv_nbx(_worker->getHandle(),
_delayedSubmission->_buffer,
_delayedSubmission->_length,
_delayedSubmission->_tag,
tagMask,
&param);
}

std::lock_guard<std::recursive_mutex> lock(_mutex);
_request = request;
}

void RequestTag::populateDelayedSubmission()
Expand Down
Loading

0 comments on commit 762545c

Please sign in to comment.