Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use reference counting on factories #2048

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions scripts/templates/ldrddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,17 @@ namespace ur_loader

%endif
%endif
## Before we can re-enable the releases we will need ref-counted object_t.
## See unified-runtime github issue #1784
##%if item['release']:
##// release loader handle
##${item['factory']}.release( ${item['name']} );
## Possibly handle release/retain ref counting - there are no ur_exp-image factories
%if 'factory' in item and '_exp_image_' not in item['factory']:
%if item['release']:
// release loader handle
context->factories.${item['factory']}.release( ${item['name']} );
%endif
%if item['retain']:
// increment refcount of handle
context->factories.${item['factory']}.retain( ${item['name']} );
%endif
%endif
%if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle':
try
{
Expand Down
20 changes: 14 additions & 6 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,19 @@ namespace ur_validation_layer
%endif
%endfor

%for tp in tracked_params:
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

%for tp in tracked_params:
Expand All @@ -114,15 +127,10 @@ namespace ur_validation_layer
}
}
%elif func_name in tp_handle_funcs['retain']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
if( getContext()->enableLeakChecking )
{
getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in tp_handle_funcs['release']:
if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS )
{
getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

Expand Down
29 changes: 25 additions & 4 deletions source/common/ur_singleton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,26 @@
#ifndef UR_SINGLETON_H
#define UR_SINGLETON_H 1

#include <cassert>
#include <memory>
#include <mutex>
#include <unordered_map>

//////////////////////////////////////////////////////////////////////////
/// a abstract factory for creation of singleton objects
template <typename singleton_tn, typename key_tn> class singleton_factory_t {
struct entry_t {
std::unique_ptr<singleton_tn> ptr;
size_t ref_count;
};

protected:
using singleton_t = singleton_tn;
using key_t = typename std::conditional<std::is_pointer<key_tn>::value,
size_t, key_tn>::type;

using ptr_t = std::unique_ptr<singleton_t>;
using map_t = std::unordered_map<key_t, ptr_t>;
using map_t = std::unordered_map<key_t, entry_t>;

std::mutex mut; ///< lock for thread-safety
map_t map; ///< single instance of singleton for each unique key
Expand Down Expand Up @@ -60,16 +66,31 @@ template <typename singleton_tn, typename key_tn> class singleton_factory_t {
if (map.end() == iter) {
auto ptr =
std::make_unique<singleton_t>(std::forward<Ts>(params)...);
iter = map.emplace(key, std::move(ptr)).first;
iter = map.emplace(key, entry_t{std::move(ptr), 0}).first;
} else {
iter->second.ref_count++;
}
return iter->second.get();
return iter->second.ptr.get();
}

void retain(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
auto iter = map.find(getKey(key));
assert(iter != map.end());
iter->second.ref_count++;
}

//////////////////////////////////////////////////////////////////////////
/// once the key is no longer valid, release the singleton
void release(key_tn key) {
std::lock_guard<std::mutex> lk(mut);
map.erase(getKey(key));
auto iter = map.find(getKey(key));
assert(iter != map.end());
if (iter->second.ref_count == 0) {
map.erase(iter);
} else {
iter->second.ref_count--;
}
}

void clear() {
Expand Down
88 changes: 44 additions & 44 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease(
}
}

ur_result_t result = pfnAdapterRelease(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hAdapter, true);
}

ur_result_t result = pfnAdapterRelease(hAdapter);

return result;
}

Expand All @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain(

ur_result_t result = pfnAdapterRetain(hAdapter);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hAdapter, true);
}

Expand Down Expand Up @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain(

ur_result_t result = pfnRetain(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hDevice, false);
}

Expand All @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease(
}
}

ur_result_t result = pfnRelease(hDevice);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hDevice, false);
}

ur_result_t result = pfnRelease(hDevice);

return result;
}

Expand Down Expand Up @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain(

ur_result_t result = pfnRetain(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hContext, false);
}

Expand All @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease(
}
}

ur_result_t result = pfnRelease(hContext);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hContext, false);
}

ur_result_t result = pfnRelease(hContext);

return result;
}

Expand Down Expand Up @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain(

ur_result_t result = pfnRetain(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hMem, false);
}

Expand All @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease(
}
}

ur_result_t result = pfnRelease(hMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hMem, false);
}

ur_result_t result = pfnRelease(hMem);

return result;
}

Expand Down Expand Up @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain(

ur_result_t result = pfnRetain(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hSampler, false);
}

Expand All @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease(
}
}

ur_result_t result = pfnRelease(hSampler);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hSampler, false);
}

ur_result_t result = pfnRelease(hSampler);

return result;
}

Expand Down Expand Up @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain(

ur_result_t result = pfnPoolRetain(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(pPool, false);
}

Expand All @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease(
}
}

ur_result_t result = pfnPoolRelease(pPool);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(pPool, false);
}

ur_result_t result = pfnPoolRelease(pPool);

return result;
}

Expand Down Expand Up @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain(

ur_result_t result = pfnRetain(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hPhysicalMem, false);
}

Expand All @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease(
}
}

ur_result_t result = pfnRelease(hPhysicalMem);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hPhysicalMem, false);
}

ur_result_t result = pfnRelease(hPhysicalMem);

return result;
}

Expand Down Expand Up @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain(

ur_result_t result = pfnRetain(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hProgram, false);
}

Expand All @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease(
}
}

ur_result_t result = pfnRelease(hProgram);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hProgram, false);
}

ur_result_t result = pfnRelease(hProgram);

return result;
}

Expand Down Expand Up @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain(

ur_result_t result = pfnRetain(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hKernel, false);
}

Expand All @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease(
}
}

ur_result_t result = pfnRelease(hKernel);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hKernel, false);
}

ur_result_t result = pfnRelease(hKernel);

return result;
}

Expand Down Expand Up @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain(

ur_result_t result = pfnRetain(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hQueue, false);
}

Expand All @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease(
}
}

ur_result_t result = pfnRelease(hQueue);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hQueue, false);
}

ur_result_t result = pfnRelease(hQueue);

return result;
}

Expand Down Expand Up @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain(

ur_result_t result = pfnRetain(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->incrementRefCount(hEvent, false);
}

Expand All @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease(
}
}

ur_result_t result = pfnRelease(hEvent);

if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) {
if (getContext()->enableLeakChecking) {
getContext()->refCountContext->decrementRefCount(hEvent, false);
}

ur_result_t result = pfnRelease(hEvent);

return result;
}

Expand Down
Loading
Loading