From ccab45fe170ac11d66151c0ba01a8a93094d08eb Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Wed, 4 Sep 2024 17:25:11 +0100 Subject: [PATCH] Use reference counting on factories Previously the factories used by ur_ldrddi (used when there are multiple backends) would add newly created objects to a map, but never release them. This patch adds reference counting semantics to the allocation, retention and release methods. A lot of changes were also made to fix use-after-free issues, specifically: * The `release` functions now no longer use the handle after freeing it. * `urDeviceTest` no longer frees devices that it dosen't own. * Some tests for reference counting now explicitly retain an extra copy before releasing them. No tests were added; this should be covered by tests for urThingRetain. Closes: #1784 . --- scripts/templates/ldrddi.cpp.mako | 16 ++-- scripts/templates/valddi.cpp.mako | 20 +++-- source/common/ur_singleton.hpp | 29 +++++- source/loader/layers/validation/ur_valddi.cpp | 88 +++++++++---------- source/loader/ur_ldrddi.cpp | 85 ++++++++++++++++++ test/conformance/adapter/urAdapterRelease.cpp | 1 + test/conformance/device/urDeviceRelease.cpp | 2 + .../testing/include/uur/fixtures.h | 1 - 8 files changed, 182 insertions(+), 60 deletions(-) diff --git a/scripts/templates/ldrddi.cpp.mako b/scripts/templates/ldrddi.cpp.mako index 9c797a0ec3..1b7d19fa67 100644 --- a/scripts/templates/ldrddi.cpp.mako +++ b/scripts/templates/ldrddi.cpp.mako @@ -273,11 +273,17 @@ namespace ur_loader %endif %endif - ## Before we can re-enable the releases we will need ref-counted object_t. - ## See unified-runtime github issue #1784 - ##%if item['release']: - ##// release loader handle - ##${item['factory']}.release( ${item['name']} ); + ## Possibly handle release/retain ref counting - there are no ur_exp-image factories + %if 'factory' in item and '_exp_image_' not in item['factory']: + %if item['release']: + // release loader handle + context->factories.${item['factory']}.release( ${item['name']} ); + %endif + %if item['retain']: + // increment refcount of handle + context->factories.${item['factory']}.retain( ${item['name']} ); + %endif + %endif %if not item['release'] and not item['retain'] and not '_native_object_' in item['obj'] or th.make_func_name(n, tags, obj) == 'urPlatformCreateWithNativeHandle': try { diff --git a/scripts/templates/valddi.cpp.mako b/scripts/templates/valddi.cpp.mako index 8cc4a9dc0f..7a18860ba9 100644 --- a/scripts/templates/valddi.cpp.mako +++ b/scripts/templates/valddi.cpp.mako @@ -94,6 +94,19 @@ namespace ur_validation_layer %endif %endfor + %for tp in tracked_params: + <% + tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None) + is_handle_to_adapter = ("_adapter_handle_t" in tp['type']) + %> + %if func_name in tp_handle_funcs['release']: + if( getContext()->enableLeakChecking ) + { + getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); + } + %endif + %endfor + ${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} ); %for tp in tracked_params: @@ -114,15 +127,10 @@ namespace ur_validation_layer } } %elif func_name in tp_handle_funcs['retain']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) + if( getContext()->enableLeakChecking ) { getContext()->refCountContext->incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); } - %elif func_name in tp_handle_funcs['release']: - if( getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS ) - { - getContext()->refCountContext->decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()}); - } %endif %endfor diff --git a/source/common/ur_singleton.hpp b/source/common/ur_singleton.hpp index b469c8b8a7..057d58c067 100644 --- a/source/common/ur_singleton.hpp +++ b/source/common/ur_singleton.hpp @@ -11,6 +11,7 @@ #ifndef UR_SINGLETON_H #define UR_SINGLETON_H 1 +#include #include #include #include @@ -18,13 +19,18 @@ ////////////////////////////////////////////////////////////////////////// /// a abstract factory for creation of singleton objects template class singleton_factory_t { + struct entry_t { + std::unique_ptr ptr; + size_t ref_count; + }; + protected: using singleton_t = singleton_tn; using key_t = typename std::conditional::value, size_t, key_tn>::type; using ptr_t = std::unique_ptr; - using map_t = std::unordered_map; + using map_t = std::unordered_map; std::mutex mut; ///< lock for thread-safety map_t map; ///< single instance of singleton for each unique key @@ -60,16 +66,31 @@ template class singleton_factory_t { if (map.end() == iter) { auto ptr = std::make_unique(std::forward(params)...); - iter = map.emplace(key, std::move(ptr)).first; + iter = map.emplace(key, entry_t{std::move(ptr), 0}).first; + } else { + iter->second.ref_count++; } - return iter->second.get(); + return iter->second.ptr.get(); + } + + void retain(key_tn key) { + std::lock_guard lk(mut); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + iter->second.ref_count++; } ////////////////////////////////////////////////////////////////////////// /// once the key is no longer valid, release the singleton void release(key_tn key) { std::lock_guard lk(mut); - map.erase(getKey(key)); + auto iter = map.find(getKey(key)); + assert(iter != map.end()); + if (iter->second.ref_count == 0) { + map.erase(iter); + } else { + iter->second.ref_count--; + } } void clear() { diff --git a/source/loader/layers/validation/ur_valddi.cpp b/source/loader/layers/validation/ur_valddi.cpp index fdfce7951b..a99d65c196 100644 --- a/source/loader/layers/validation/ur_valddi.cpp +++ b/source/loader/layers/validation/ur_valddi.cpp @@ -71,12 +71,12 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( } } - ur_result_t result = pfnAdapterRelease(hAdapter); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hAdapter, true); } + ur_result_t result = pfnAdapterRelease(hAdapter); + return result; } @@ -99,7 +99,7 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( ur_result_t result = pfnAdapterRetain(hAdapter); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hAdapter, true); } @@ -558,7 +558,7 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( ur_result_t result = pfnRetain(hDevice); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hDevice, false); } @@ -583,12 +583,12 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( } } - ur_result_t result = pfnRelease(hDevice); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hDevice, false); } + ur_result_t result = pfnRelease(hDevice); + return result; } @@ -861,7 +861,7 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( ur_result_t result = pfnRetain(hContext); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hContext, false); } @@ -886,12 +886,12 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( } } - ur_result_t result = pfnRelease(hContext); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hContext, false); } + ur_result_t result = pfnRelease(hContext); + return result; } @@ -1248,7 +1248,7 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( ur_result_t result = pfnRetain(hMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hMem, false); } @@ -1273,12 +1273,12 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( } } - ur_result_t result = pfnRelease(hMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hMem, false); } + ur_result_t result = pfnRelease(hMem); + return result; } @@ -1657,7 +1657,7 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( ur_result_t result = pfnRetain(hSampler); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hSampler, false); } @@ -1682,12 +1682,12 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( } } - ur_result_t result = pfnRelease(hSampler); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hSampler, false); } + ur_result_t result = pfnRelease(hSampler); + return result; } @@ -2154,7 +2154,7 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( ur_result_t result = pfnPoolRetain(pPool); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(pPool, false); } @@ -2178,12 +2178,12 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( } } - ur_result_t result = pfnPoolRelease(pPool); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(pPool, false); } + ur_result_t result = pfnPoolRelease(pPool); + return result; } @@ -2631,7 +2631,7 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( ur_result_t result = pfnRetain(hPhysicalMem); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hPhysicalMem, false); } @@ -2656,12 +2656,12 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( } } - ur_result_t result = pfnRelease(hPhysicalMem); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hPhysicalMem, false); } + ur_result_t result = pfnRelease(hPhysicalMem); + return result; } @@ -2952,7 +2952,7 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( ur_result_t result = pfnRetain(hProgram); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hProgram, false); } @@ -2977,12 +2977,12 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( } } - ur_result_t result = pfnRelease(hProgram); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hProgram, false); } + ur_result_t result = pfnRelease(hProgram); + return result; } @@ -3618,7 +3618,7 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( ur_result_t result = pfnRetain(hKernel); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hKernel, false); } @@ -3643,12 +3643,12 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( } } - ur_result_t result = pfnRelease(hKernel); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hKernel, false); } + ur_result_t result = pfnRelease(hKernel); + return result; } @@ -4138,7 +4138,7 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( ur_result_t result = pfnRetain(hQueue); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hQueue, false); } @@ -4163,12 +4163,12 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( } } - ur_result_t result = pfnRelease(hQueue); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hQueue, false); } + ur_result_t result = pfnRelease(hQueue); + return result; } @@ -4454,7 +4454,7 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( ur_result_t result = pfnRetain(hEvent); - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->incrementRefCount(hEvent, false); } @@ -4478,12 +4478,12 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( } } - ur_result_t result = pfnRelease(hEvent); - - if (getContext()->enableLeakChecking && result == UR_RESULT_SUCCESS) { + if (getContext()->enableLeakChecking) { getContext()->refCountContext->decrementRefCount(hEvent, false); } + ur_result_t result = pfnRelease(hEvent); + return result; } diff --git a/source/loader/ur_ldrddi.cpp b/source/loader/ur_ldrddi.cpp index a67879a9eb..7560f6122d 100644 --- a/source/loader/ur_ldrddi.cpp +++ b/source/loader/ur_ldrddi.cpp @@ -85,6 +85,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRelease( // forward to device-platform result = pfnAdapterRelease(hAdapter); + // release loader handle + context->factories.ur_adapter_factory.release(hAdapter); + return result; } @@ -110,6 +113,9 @@ __urdlllocal ur_result_t UR_APICALL urAdapterRetain( // forward to device-platform result = pfnAdapterRetain(hAdapter); + // increment refcount of handle + context->factories.ur_adapter_factory.retain(hAdapter); + return result; } @@ -614,6 +620,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRetain( // forward to device-platform result = pfnRetain(hDevice); + // increment refcount of handle + context->factories.ur_device_factory.retain(hDevice); + return result; } @@ -640,6 +649,9 @@ __urdlllocal ur_result_t UR_APICALL urDeviceRelease( // forward to device-platform result = pfnRelease(hDevice); + // release loader handle + context->factories.ur_device_factory.release(hDevice); + return result; } @@ -910,6 +922,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRetain( // forward to device-platform result = pfnRetain(hContext); + // increment refcount of handle + context->factories.ur_context_factory.retain(hContext); + return result; } @@ -936,6 +951,9 @@ __urdlllocal ur_result_t UR_APICALL urContextRelease( // forward to device-platform result = pfnRelease(hContext); + // release loader handle + context->factories.ur_context_factory.release(hContext); + return result; } @@ -1238,6 +1256,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRetain( // forward to device-platform result = pfnRetain(hMem); + // increment refcount of handle + context->factories.ur_mem_factory.retain(hMem); + return result; } @@ -1264,6 +1285,9 @@ __urdlllocal ur_result_t UR_APICALL urMemRelease( // forward to device-platform result = pfnRelease(hMem); + // release loader handle + context->factories.ur_mem_factory.release(hMem); + return result; } @@ -1615,6 +1639,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRetain( // forward to device-platform result = pfnRetain(hSampler); + // increment refcount of handle + context->factories.ur_sampler_factory.retain(hSampler); + return result; } @@ -1641,6 +1668,9 @@ __urdlllocal ur_result_t UR_APICALL urSamplerRelease( // forward to device-platform result = pfnRelease(hSampler); + // release loader handle + context->factories.ur_sampler_factory.release(hSampler); + return result; } @@ -2074,6 +2104,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRetain( // forward to device-platform result = pfnPoolRetain(pPool); + // increment refcount of handle + context->factories.ur_usm_pool_factory.retain(pPool); + return result; } @@ -2099,6 +2132,9 @@ __urdlllocal ur_result_t UR_APICALL urUSMPoolRelease( // forward to device-platform result = pfnPoolRelease(pPool); + // release loader handle + context->factories.ur_usm_pool_factory.release(pPool); + return result; } @@ -2484,6 +2520,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRetain( // forward to device-platform result = pfnRetain(hPhysicalMem); + // increment refcount of handle + context->factories.ur_physical_mem_factory.retain(hPhysicalMem); + return result; } @@ -2512,6 +2551,9 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemRelease( // forward to device-platform result = pfnRelease(hPhysicalMem); + // release loader handle + context->factories.ur_physical_mem_factory.release(hPhysicalMem); + return result; } @@ -2759,6 +2801,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRetain( // forward to device-platform result = pfnRetain(hProgram); + // increment refcount of handle + context->factories.ur_program_factory.retain(hProgram); + return result; } @@ -2785,6 +2830,9 @@ __urdlllocal ur_result_t UR_APICALL urProgramRelease( // forward to device-platform result = pfnRelease(hProgram); + // release loader handle + context->factories.ur_program_factory.release(hProgram); + return result; } @@ -3382,6 +3430,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRetain( // forward to device-platform result = pfnRetain(hKernel); + // increment refcount of handle + context->factories.ur_kernel_factory.retain(hKernel); + return result; } @@ -3408,6 +3459,9 @@ __urdlllocal ur_result_t UR_APICALL urKernelRelease( // forward to device-platform result = pfnRelease(hKernel); + // release loader handle + context->factories.ur_kernel_factory.release(hKernel); + return result; } @@ -3858,6 +3912,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRetain( // forward to device-platform result = pfnRetain(hQueue); + // increment refcount of handle + context->factories.ur_queue_factory.retain(hQueue); + return result; } @@ -3884,6 +3941,9 @@ __urdlllocal ur_result_t UR_APICALL urQueueRelease( // forward to device-platform result = pfnRelease(hQueue); + // release loader handle + context->factories.ur_queue_factory.release(hQueue); + return result; } @@ -4188,6 +4248,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRetain( // forward to device-platform result = pfnRetain(hEvent); + // increment refcount of handle + context->factories.ur_event_factory.retain(hEvent); + return result; } @@ -4213,6 +4276,9 @@ __urdlllocal ur_result_t UR_APICALL urEventRelease( // forward to device-platform result = pfnRelease(hEvent); + // release loader handle + context->factories.ur_event_factory.release(hEvent); + return result; } @@ -6745,6 +6811,9 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalMemoryExp( // forward to device-platform result = pfnReleaseExternalMemoryExp(hContext, hDevice, hExternalMem); + // release loader handle + context->factories.ur_exp_external_mem_factory.release(hExternalMem); + return result; } @@ -6835,6 +6904,10 @@ __urdlllocal ur_result_t UR_APICALL urBindlessImagesReleaseExternalSemaphoreExp( result = pfnReleaseExternalSemaphoreExp(hContext, hDevice, hExternalSemaphore); + // release loader handle + context->factories.ur_exp_external_semaphore_factory.release( + hExternalSemaphore); + return result; } @@ -7062,6 +7135,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainExp( // forward to device-platform result = pfnRetainExp(hCommandBuffer); + // increment refcount of handle + context->factories.ur_exp_command_buffer_factory.retain(hCommandBuffer); + return result; } @@ -7092,6 +7168,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseExp( // forward to device-platform result = pfnReleaseExp(hCommandBuffer); + // release loader handle + context->factories.ur_exp_command_buffer_factory.release(hCommandBuffer); + return result; } @@ -8408,6 +8487,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferRetainCommandExp( // forward to device-platform result = pfnRetainCommandExp(hCommand); + // increment refcount of handle + context->factories.ur_exp_command_buffer_command_factory.retain(hCommand); + return result; } @@ -8439,6 +8521,9 @@ __urdlllocal ur_result_t UR_APICALL urCommandBufferReleaseCommandExp( // forward to device-platform result = pfnReleaseCommandExp(hCommand); + // release loader handle + context->factories.ur_exp_command_buffer_command_factory.release(hCommand); + return result; } diff --git a/test/conformance/adapter/urAdapterRelease.cpp b/test/conformance/adapter/urAdapterRelease.cpp index 8b29fa0f2c..0b28287aa7 100644 --- a/test/conformance/adapter/urAdapterRelease.cpp +++ b/test/conformance/adapter/urAdapterRelease.cpp @@ -16,6 +16,7 @@ struct urAdapterReleaseTest : uur::runtime::urAdapterTest { TEST_F(urAdapterReleaseTest, Success) { uint32_t referenceCountBefore = 0; + ASSERT_SUCCESS(urAdapterRetain(adapter)); ASSERT_SUCCESS(urAdapterGetInfo(adapter, UR_ADAPTER_INFO_REFERENCE_COUNT, sizeof(referenceCountBefore), diff --git a/test/conformance/device/urDeviceRelease.cpp b/test/conformance/device/urDeviceRelease.cpp index a8f6a3bc9d..dd5510394f 100644 --- a/test/conformance/device/urDeviceRelease.cpp +++ b/test/conformance/device/urDeviceRelease.cpp @@ -8,6 +8,8 @@ struct urDeviceReleaseTest : uur::urAllDevicesTest {}; TEST_F(urDeviceReleaseTest, Success) { for (auto device : devices) { + ASSERT_SUCCESS(urDeviceRetain(device)); + uint32_t prevRefCount = 0; ASSERT_SUCCESS(uur::GetObjectReferenceCount(device, prevRefCount)); diff --git a/test/conformance/testing/include/uur/fixtures.h b/test/conformance/testing/include/uur/fixtures.h index 436e7821a9..5c09f218d1 100644 --- a/test/conformance/testing/include/uur/fixtures.h +++ b/test/conformance/testing/include/uur/fixtures.h @@ -98,7 +98,6 @@ struct urDeviceTest : urPlatformTest, } void TearDown() override { - EXPECT_SUCCESS(urDeviceRelease(device)); UUR_RETURN_ON_FATAL_FAILURE(urPlatformTest::TearDown()); }