diff --git a/src/python/library/tests/test_shared_memory.py b/src/python/library/tests/test_shared_memory.py index f765bc781..36c64f090 100644 --- a/src/python/library/tests/test_shared_memory.py +++ b/src/python/library/tests/test_shared_memory.py @@ -101,10 +101,8 @@ def test_set_region_oversize(self): shm.set_shared_memory_region(self.shm_handles[0], [large_tensor]) def test_duplicate_key(self): - # [NOTE] change in behavior: - # previous: okay to create shared memory region of the same key with different size - # and the behavior is not being study clearly. - # now: return the same handle if existed, warning will be print if size is different + # by default, return the same handle if existed, warning will be print + # if size is different self.shm_handles.append( shm.create_shared_memory_region("shm_name", "shm_key", 32) ) @@ -133,9 +131,8 @@ def test_duplicate_key(self): shm.set_shared_memory_region(self.shm_handles[-1], [large_tensor]) def test_destroy_duplicate(self): - # [NOTE] change in behavior: - # previous: raise exception if underlying shared memory has been unlinked - # now: no exception as unlink only happen when last managed handle is destroyed + # destruction of duplicate shared memory region will occur when the last + # managed handle is destroyed self.assertEqual(len(shm.mapped_shared_memory_regions()), 0) self.shm_handles.append( shm.create_shared_memory_region("shm_name", "shm_key", 64) diff --git a/src/python/library/tritonclient/utils/shared_memory/__init__.py b/src/python/library/tritonclient/utils/shared_memory/__init__.py index 364b58b99..12904445e 100755 --- a/src/python/library/tritonclient/utils/shared_memory/__init__.py +++ b/src/python/library/tritonclient/utils/shared_memory/__init__.py @@ -62,7 +62,8 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only Whether a shared memory region must be created. If False and a shared memory region of the same name exists, a handle to that shared memory region will be returned and user must be aware that - the shared memory size can be different from the size requested. + the previously allocated shared memory size can be different from + the size requested. Returns ------- @@ -80,8 +81,11 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only try: shm_handle._mpsm_handle = mpshm.SharedMemory(shm_key) if shm_key not in _key_mapping: - _key_mapping[shm_key] = [False, 0] - _key_mapping[shm_key][1] += 1 + _key_mapping[shm_key] = { + "needs_unlink": False, + "active_handle_count": 0, + } + _key_mapping[shm_key]["active_handle_count"] += 1 except FileNotFoundError: # File not found means the shared memory region has not been created, # suppress the exception and attempt to create the region. @@ -96,9 +100,9 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only "unable to create the shared memory region" ) from ex if shm_key not in _key_mapping: - _key_mapping[shm_key] = [False, 0] - _key_mapping[shm_key][0] = True - _key_mapping[shm_key][1] += 1 + _key_mapping[shm_key] = {"needs_unlink": False, "active_handle_count": 0} + _key_mapping[shm_key]["needs_unlink"] = True + _key_mapping[shm_key]["active_handle_count"] += 1 if byte_size > shm_handle._mpsm_handle.size: warnings.warn( @@ -238,10 +242,10 @@ def destroy_shared_memory_region(shm_handle): # fail to delete a region, we should not report it back to the user # as a valid memory region. shm_handle._mpsm_handle.close() - _key_mapping[shm_handle._shm_key][1] -= 1 - if _key_mapping[shm_handle._shm_key][1] == 0: + _key_mapping[shm_handle._shm_key]["active_handle_count"] -= 1 + if _key_mapping[shm_handle._shm_key]["active_handle_count"] == 0: try: - if _key_mapping[shm_handle._shm_key][0]: + if _key_mapping[shm_handle._shm_key]["needs_unlink"]: shm_handle._mpsm_handle.unlink() finally: _key_mapping.pop(shm_handle._shm_key)