-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added integration test for parallel access of mountpoint client
- Loading branch information
Showing
1 changed file
with
66 additions
and
0 deletions.
There are no files selected for viewing
66 changes: 66 additions & 0 deletions
66
s3torchconnector/tst/e2e/test_mountpoint_client_parallel_access.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import random | ||
import time | ||
import threading | ||
import pytest | ||
from s3torchconnector._s3client import S3Client | ||
from s3torchconnectorclient._mountpoint_s3_client import MountpointS3Client | ||
|
||
class S3ClientWithoutLock(S3Client): | ||
@property | ||
def _client(self) -> MountpointS3Client: | ||
if self._client_pid is None or self._client_pid != os.getpid(): | ||
self._client_pid = os.getpid() | ||
# `MountpointS3Client` does not survive forking, so re-create it if the PID has changed. | ||
self._real_client = self._client_builder() | ||
time.sleep(10) | ||
assert self._real_client is not None | ||
return self._real_client | ||
|
||
def invalidate_client(self): | ||
self._real_client = None | ||
|
||
def access_client(client, error_event): | ||
try: | ||
if not error_event.is_set(): | ||
client._client | ||
print(f"Successfully accessed by thread {threading.current_thread().name}") | ||
except AssertionError as e: | ||
print(f"AssertionError in thread {threading.current_thread().name}: {e}") | ||
error_event.set() | ||
|
||
def invalidate_client(client, error_event): | ||
if not error_event.is_set(): | ||
client.invalidate_client() | ||
print(f"Client invalidated by thread {threading.current_thread().name}") | ||
|
||
def test_multiple_thread_accessing_mountpoint_client_in_parallel(): | ||
print("Running test without lock...") | ||
client = S3ClientWithoutLock("us-west-2") | ||
error_event = threading.Event() | ||
|
||
# Start one accessor thread | ||
accessor_thread = threading.Thread(target=access_client, args=(client, error_event,), name="Accessor") | ||
accessor_thread.start() | ||
|
||
# Create and start multiple invalidator threads | ||
invalidator_threads = [] | ||
num_invalidators = 500 # Number of invalidator threads | ||
|
||
for i in range(num_invalidators): | ||
if error_event.is_set(): | ||
break | ||
invalidator_thread = threading.Thread(target=invalidate_client, args=(client, error_event,), | ||
name=f"Invalidator-{i + 1}") | ||
|
||
invalidator_threads.append(invalidator_thread) | ||
time.sleep(random.uniform(0.1, 0.5)) | ||
invalidator_thread.start() | ||
|
||
accessor_thread.join() | ||
|
||
for thread in invalidator_threads: | ||
thread.join(timeout=1) | ||
|
||
if error_event.is_set(): | ||
pytest.fail("Test failed due to AssertionError in one of the threads.") |