Skip to content

Commit

Permalink
Added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rajdchak committed Sep 25, 2024
1 parent 952616f commit 3e6ca8c
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
4 changes: 4 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@
def _identity(obj: Any) -> Any:
return obj


_client_lock = threading.Lock()


class S3Client:
def __init__(
self,
Expand All @@ -53,8 +55,10 @@ def __init__(

@property
def _client(self) -> MountpointS3Client:
# This is a fast check to avoid acquiring the lock unnecessarily.
if self._client_pid is None or self._client_pid != os.getpid():
with _client_lock:
# This double-check ensures that the client is only created once.
if self._client_pid is None or self._client_pid != os.getpid():
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed.
self._real_client = self._client_builder()
Expand Down
19 changes: 17 additions & 2 deletions s3torchconnector/tst/e2e/test_mountpoint_client_parallel_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from s3torchconnector._s3client import S3Client
from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client


class S3ClientWithoutLock(S3Client):
@property
def _client(self) -> MountpointS3Client:
Expand All @@ -20,11 +21,13 @@ def _client_builder(self):
time.sleep(1)
return super()._client_builder()


class S3ClientWithLock(S3Client):
def _client_builder(self):
time.sleep(1)
return super()._client_builder()


def access_client(client, error_event):
try:
if not error_event.is_set():
Expand All @@ -34,18 +37,23 @@ def access_client(client, error_event):
print(f"AssertionError in thread {threading.current_thread().name}: {e}")
error_event.set()


def test_multiple_thread_accessing_mountpoint_client_in_parallel_without_lock():
print("Running test without lock...")
client = S3ClientWithoutLock("us-west-2")
if not access_mountpoint_client_in_parallel(client):
pytest.fail("Test failed as AssertionError did not happen in one of the threads.")
pytest.fail(
"Test failed as AssertionError did not happen in one of the threads."
)


def test_multiple_thread_accessing_mountpoint_client_in_parallel_with_lock():
print("Running test with lock...")
client = S3ClientWithLock("us-west-2")
if access_mountpoint_client_in_parallel(client):
pytest.fail("Test failed as AssertionError happened in one of the threads.")


def access_mountpoint_client_in_parallel(client):

error_event = threading.Event()
Expand All @@ -56,7 +64,14 @@ def access_mountpoint_client_in_parallel(client):
for i in range(num_accessor_threads):
if error_event.is_set():
break
accessor_thread = threading.Thread(target=access_client, args=(client, error_event,), name=f"Accessor-{i + 1}")
accessor_thread = threading.Thread(
target=access_client,
args=(
client,
error_event,
),
name=f"Accessor-{i + 1}",
)
accessor_threads.append(accessor_thread)
accessor_thread.start()

Expand Down

0 comments on commit 3e6ca8c

Please sign in to comment.