Skip to content

Commit

Permalink
fix: address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
GuanLuo committed Oct 31, 2024
1 parent 6a87c5f commit 886d5b2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
11 changes: 4 additions & 7 deletions src/python/library/tests/test_shared_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,8 @@ def test_set_region_oversize(self):
shm.set_shared_memory_region(self.shm_handles[0], [large_tensor])

def test_duplicate_key(self):
# [NOTE] change in behavior:
# previous: okay to create shared memory region of the same key with different size
# and the behavior is not being study clearly.
# now: return the same handle if existed, warning will be print if size is different
# by default, return the same handle if existed, warning will be print
# if size is different
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 32)
)
Expand Down Expand Up @@ -133,9 +131,8 @@ def test_duplicate_key(self):
shm.set_shared_memory_region(self.shm_handles[-1], [large_tensor])

def test_destroy_duplicate(self):
# [NOTE] change in behavior:
# previous: raise exception if underlying shared memory has been unlinked
# now: no exception as unlink only happen when last managed handle is destroyed
# destruction of duplicate shared memory region will occur when the last
# managed handle is destroyed
self.assertEqual(len(shm.mapped_shared_memory_regions()), 0)
self.shm_handles.append(
shm.create_shared_memory_region("shm_name", "shm_key", 64)
Expand Down
22 changes: 13 additions & 9 deletions src/python/library/tritonclient/utils/shared_memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
Whether a shared memory region must be created. If False and
a shared memory region of the same name exists, a handle to that
shared memory region will be returned and user must be aware that
the shared memory size can be different from the size requested.
the previously allocated shared memory size can be different from
the size requested.
Returns
-------
Expand All @@ -80,8 +81,11 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
try:
shm_handle._mpsm_handle = mpshm.SharedMemory(shm_key)
if shm_key not in _key_mapping:
_key_mapping[shm_key] = [False, 0]
_key_mapping[shm_key][1] += 1
_key_mapping[shm_key] = {
"needs_unlink": False,
"active_handle_count": 0,
}
_key_mapping[shm_key]["active_handle_count"] += 1
except FileNotFoundError:
# File not found means the shared memory region has not been created,
# suppress the exception and attempt to create the region.
Expand All @@ -96,9 +100,9 @@ def create_shared_memory_region(triton_shm_name, shm_key, byte_size, create_only
"unable to create the shared memory region"
) from ex
if shm_key not in _key_mapping:
_key_mapping[shm_key] = [False, 0]
_key_mapping[shm_key][0] = True
_key_mapping[shm_key][1] += 1
_key_mapping[shm_key] = {"needs_unlink": False, "active_handle_count": 0}
_key_mapping[shm_key]["needs_unlink"] = True
_key_mapping[shm_key]["active_handle_count"] += 1

if byte_size > shm_handle._mpsm_handle.size:
warnings.warn(
Expand Down Expand Up @@ -238,10 +242,10 @@ def destroy_shared_memory_region(shm_handle):
# fail to delete a region, we should not report it back to the user
# as a valid memory region.
shm_handle._mpsm_handle.close()
_key_mapping[shm_handle._shm_key][1] -= 1
if _key_mapping[shm_handle._shm_key][1] == 0:
_key_mapping[shm_handle._shm_key]["active_handle_count"] -= 1
if _key_mapping[shm_handle._shm_key]["active_handle_count"] == 0:
try:
if _key_mapping[shm_handle._shm_key][0]:
if _key_mapping[shm_handle._shm_key]["needs_unlink"]:
shm_handle._mpsm_handle.unlink()
finally:
_key_mapping.pop(shm_handle._shm_key)
Expand Down

0 comments on commit 886d5b2

Please sign in to comment.