Skip to content

Commit

Permalink
Return std::shared_ptr<Buffer> in allocateBuffer()
Browse files Browse the repository at this point in the history
Returning raw pointers may cause leaks, and can be problematic
specifically in `ucxx::RequestTagMulti` as the `Buffer*` will not be
released unless the user takes ownership and releases it appropriately.
  • Loading branch information
pentschev committed Jul 31, 2023
1 parent ef711a9 commit 8809dde
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 22 deletions.
2 changes: 1 addition & 1 deletion cpp/include/ucxx/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ class RMMBuffer : public Buffer {
};
#endif

Buffer* allocateBuffer(BufferType bufferType, const size_t size);
std::shared_ptr<Buffer> allocateBuffer(BufferType bufferType, const size_t size);

} // namespace ucxx
2 changes: 1 addition & 1 deletion cpp/include/ucxx/request_tag_multi.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class RequestTagMulti;
struct BufferRequest {
std::shared_ptr<Request> request{nullptr}; ///< The `ucxx::RequestTag` of a header or frame
std::shared_ptr<std::string> stringBuffer{nullptr}; ///< Serialized `Header`
Buffer* buffer{nullptr}; ///< Internally allocated buffer to receive a frame
std::shared_ptr<Buffer> buffer{nullptr}; ///< Internally allocated buffer to receive a frame

BufferRequest();
~BufferRequest();
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,17 +82,17 @@ void* RMMBuffer::data()
}
#endif

Buffer* allocateBuffer(const BufferType bufferType, const size_t size)
std::shared_ptr<Buffer> allocateBuffer(const BufferType bufferType, const size_t size)
{
#if UCXX_ENABLE_RMM
if (bufferType == BufferType::RMM)
return new RMMBuffer(size);
return std::make_shared<RMMBuffer>(size);
else
#else
if (bufferType == BufferType::RMM)
throw std::runtime_error("RMM support not enabled, please compile with -DUCXX_ENABLE_RMM=1");
#endif
return new HostBuffer(size);
return std::make_shared<HostBuffer>(size);
}

} // namespace ucxx
20 changes: 9 additions & 11 deletions cpp/tests/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class BufferAllocator : public ::testing::Test,
protected:
ucxx::BufferType _type;
size_t _size;
ucxx::Buffer* _buffer;
std::shared_ptr<ucxx::Buffer> _buffer;

void SetUp()
{
Expand All @@ -27,16 +27,14 @@ class BufferAllocator : public ::testing::Test,

_buffer = allocateBuffer(_type, _size);
}

void TearDown() { delete _buffer; }
};

TEST_P(BufferAllocator, TestType)
{
ASSERT_EQ(_buffer->getType(), _type);

if (_type == ucxx::BufferType::Host) {
auto buffer = dynamic_cast<ucxx::HostBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::HostBuffer>(_buffer);
ASSERT_EQ(buffer->getType(), _type);

auto releasedBuffer = buffer->release();
Expand All @@ -46,7 +44,7 @@ TEST_P(BufferAllocator, TestType)
free(releasedBuffer);
} else if (_type == ucxx::BufferType::RMM) {
#if UCXX_ENABLE_RMM
auto buffer = dynamic_cast<ucxx::RMMBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::RMMBuffer>(_buffer);
ASSERT_EQ(buffer->getType(), _type);

auto releasedBuffer = buffer->release();
Expand All @@ -65,7 +63,7 @@ TEST_P(BufferAllocator, TestSize)
ASSERT_EQ(_buffer->getSize(), _size);

if (_type == ucxx::BufferType::Host) {
auto buffer = dynamic_cast<ucxx::HostBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::HostBuffer>(_buffer);
ASSERT_EQ(buffer->getSize(), _size);

auto releasedBuffer = buffer->release();
Expand All @@ -75,7 +73,7 @@ TEST_P(BufferAllocator, TestSize)
free(releasedBuffer);
} else if (_type == ucxx::BufferType::RMM) {
#if UCXX_ENABLE_RMM
auto buffer = dynamic_cast<ucxx::RMMBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::RMMBuffer>(_buffer);
ASSERT_EQ(buffer->getSize(), _size);

auto releasedBuffer = buffer->release();
Expand All @@ -94,7 +92,7 @@ TEST_P(BufferAllocator, TestData)
ASSERT_NE(_buffer->data(), nullptr);

if (_type == ucxx::BufferType::Host) {
auto buffer = dynamic_cast<ucxx::HostBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::HostBuffer>(_buffer);
ASSERT_EQ(buffer->data(), _buffer->data());

auto releasedBuffer = buffer->release();
Expand All @@ -104,7 +102,7 @@ TEST_P(BufferAllocator, TestData)
free(releasedBuffer);
} else if (_type == ucxx::BufferType::RMM) {
#if UCXX_ENABLE_RMM
auto buffer = dynamic_cast<ucxx::RMMBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::RMMBuffer>(_buffer);
ASSERT_EQ(buffer->data(), _buffer->data());

auto releasedBuffer = buffer->release();
Expand All @@ -123,7 +121,7 @@ TEST_P(BufferAllocator, TestData)
TEST_P(BufferAllocator, TestThrowAfterRelease)
{
if (_type == ucxx::BufferType::Host) {
auto buffer = dynamic_cast<ucxx::HostBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::HostBuffer>(_buffer);
auto releasedBuffer = buffer->release();

EXPECT_THROW(buffer->data(), std::runtime_error);
Expand All @@ -132,7 +130,7 @@ TEST_P(BufferAllocator, TestThrowAfterRelease)
free(releasedBuffer);
} else if (_type == ucxx::BufferType::RMM) {
#if UCXX_ENABLE_RMM
auto buffer = dynamic_cast<ucxx::RMMBuffer*>(_buffer);
auto buffer = std::dynamic_pointer_cast<ucxx::RMMBuffer>(_buffer);
auto releasedBuffer = buffer->release();

EXPECT_THROW(buffer->data(), std::runtime_error);
Expand Down
2 changes: 1 addition & 1 deletion cpp/tests/worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ TEST_P(WorkerProgressTest, ProgressTagMulti)
reinterpret_cast<int*>(br->buffer->data()) + send.size());
ASSERT_EQ(recvAbstract[0], send[0]);

const auto& recvConcretePtr = dynamic_cast<ucxx::HostBuffer*>(br->buffer);
const auto& recvConcretePtr = std::dynamic_pointer_cast<ucxx::HostBuffer>(br->buffer);
ASSERT_EQ(recvConcretePtr->getType(), ucxx::BufferType::Host);
ASSERT_EQ(recvConcretePtr->getSize(), send.size() * sizeof(int));

Expand Down
10 changes: 6 additions & 4 deletions python/ucxx/_lib/libucxx.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -706,18 +706,20 @@ cdef class UCXBufferRequest:
)

def get_py_buffer(self):
cdef Buffer* buf
cdef shared_ptr[Buffer] buf
cdef BufferType bufType

with nogil:
buf = self._buffer_request.get().buffer
bufType = buf.get().getType()

# If buf == NULL, it holds a header
if buf == NULL:
return None
elif buf.getType() == BufferType.RMM:
return _get_rmm_buffer(<uintptr_t><void*>buf)
elif bufType == BufferType.RMM:
return _get_rmm_buffer(<uintptr_t><void*>buf.get())
else:
return _get_host_buffer(<uintptr_t><void*>buf)
return _get_host_buffer(<uintptr_t><void*>buf.get())


cdef class UCXBufferRequests:
Expand Down
2 changes: 1 addition & 1 deletion python/ucxx/_lib/ucxx_api.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ cdef extern from "<ucxx/request_tag_multi.h>" namespace "ucxx" nogil:
ctypedef struct BufferRequest:
shared_ptr[Request] request
shared_ptr[string] stringBuffer
Buffer* buffer
shared_ptr[Buffer] buffer

ctypedef shared_ptr[BufferRequest] BufferRequestPtr

Expand Down

0 comments on commit 8809dde

Please sign in to comment.