Skip to content

Commit

Permalink
feat(dcp): add support for S3 CopyObject API
Browse files Browse the repository at this point in the history
Add support for S3 `CopyObject` API, binding Python and Rust clients
together. Bump versions for mountpoint-s3-client and mountpoint-s3-crt
(required to use the `CopyObject` API).
  • Loading branch information
matthieu-d4r committed Oct 21, 2024
1 parent 712c2a3 commit 31246d9
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 57 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,8 @@ venv/
*.egg
multirun/

# Unit test / coverage reports
.hypothesis/

# Prevent publishing file with third party licenses
THIRD-PARTY-LICENSES
8 changes: 8 additions & 0 deletions s3torchconnector/src/s3torchconnector/_s3client/_s3client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,11 @@ def head_object(self, bucket: str, key: str) -> ObjectInfo:
def delete_object(self, bucket: str, key: str) -> None:
log.debug(f"DeleteObject s3://{bucket}/{key}")
self._client.delete_object(bucket, key)

def copy_object(
self, src_bucket: str, src_key: str, dst_bucket: str, dst_key: str
) -> None:
log.debug(
f"CopyObject s3://{src_bucket}/{src_key} to s3://{dst_bucket}/{dst_key}"
)
return self._client.copy_object(src_bucket, src_key, dst_bucket, dst_key)
14 changes: 14 additions & 0 deletions s3torchconnector/tst/unit/test_s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def test_list_objects_log(s3_client: S3Client, caplog):
assert f"ListObjects {S3_URI}" in caplog.messages


def test_delete_object_log(s3_client: S3Client, caplog):
with caplog.at_level(logging.DEBUG):
s3_client.delete_object(TEST_BUCKET, TEST_KEY)
assert f"DeleteObject {S3_URI}" in caplog.messages


def test_copy_object_log(s3_client: S3Client, caplog):
dst_bucket, dst_key = "dst_bucket", "dst_key"

with caplog.at_level(logging.DEBUG):
s3_client.copy_object(TEST_BUCKET, TEST_KEY, dst_bucket, dst_key)
assert f"CopyObject {S3_URI} to s3://{dst_bucket}/{dst_key}" in caplog.messages


def test_s3_client_default_user_agent():
s3_client = S3Client(region=TEST_REGION)
expected_user_agent = f"s3torchconnector/{__version__}"
Expand Down
14 changes: 7 additions & 7 deletions s3torchconnectorclient/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions s3torchconnectorclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ built = "0.7"
pyo3 = { version = "0.19.2" }
pyo3-log = "0.8.3"
futures = "0.3.28"
mountpoint-s3-client = { version = "0.10.0", features = ["mock"] }
mountpoint-s3-crt = "0.9.0"
mountpoint-s3-client = { version = "0.11.0", features = ["mock"] }
mountpoint-s3-crt = "0.10.0"
log = "0.4.20"
tracing = { version = "0.1.40", default-features = false, features = ["std", "log"] }
tracing-subscriber = { version = "0.3.18", features = ["fmt", "env-filter"]}
Expand Down
4 changes: 3 additions & 1 deletion s3torchconnectorclient/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ dependencies = []

[project.optional-dependencies]
test = [
"boto3",
"pytest",
"pytest-timeout",
"hypothesis",
"flake8",
"black",
"mypy"
"mypy",
"Pillow"
]

[tool.setuptools.packages]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class MountpointS3Client:
) -> ListObjectStream: ...
def head_object(self, bucket: str, key: str) -> ObjectInfo: ...
def delete_object(self, bucket: str, key: str) -> None: ...
def copy_object(
self, src_bucket: str, src_key: str, dst_bucket: str, dst_key: str
) -> None: ...

class MockMountpointS3Client:
throughput_target_gbps: float
Expand Down
91 changes: 53 additions & 38 deletions s3torchconnectorclient/python/tst/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import io
import os
import random
from dataclasses import dataclass, field
from typing import Optional

import boto3
import numpy as np
Expand All @@ -18,33 +20,28 @@ def getenv(var: str, optional: bool = False) -> str:
return v


class BucketPrefixFixture(object):
@dataclass
class BucketPrefixFixture:
"""An S3 bucket/prefix and its contents for use in a single unit test. The prefix will be unique
to this instance, so other concurrent tests won't affect its state."""

region: str
bucket: str
prefix: str
storage_class: str = None
endpoint_url: str = None

def __init__(
self,
region: str,
bucket: str,
prefix: str,
storage_class: str = None,
endpoint_url: str = None,
):
self.bucket = bucket
self.prefix = prefix
self.region = region
self.storage_class = storage_class
self.endpoint_url = endpoint_url
self.contents = {}
session = boto3.Session(region_name=region)
name: str

region: str = getenv("CI_REGION")
bucket: str = getenv("CI_BUCKET")
prefix: str = getenv("CI_PREFIX")
storage_class: Optional[str] = getenv("CI_STORAGE_CLASS", optional=True)
endpoint_url: Optional[str] = getenv("CI_CUSTOM_ENDPOINT_URL", optional=True)
contents: dict = field(default_factory=dict)

def __post_init__(self):
assert self.prefix == "" or self.prefix.endswith("/")
session = boto3.Session(region_name=self.region)
self.s3 = session.client("s3")

nonce = random.randrange(2**64)
self.prefix = f"{self.prefix}{self.name}/{nonce}/"

@property
def s3_uri(self):
return f"s3://{self.bucket}/{self.prefix}"
Expand All @@ -55,34 +52,47 @@ def add(self, key: str, contents: bytes, **kwargs):
self.s3.put_object(Bucket=self.bucket, Key=full_key, Body=contents, **kwargs)
self.contents[full_key] = contents

def remove(self, key: str):
full_key = f"{self.prefix}{key}"
self.s3.delete_object(Bucket=self.bucket, Key=full_key)

def __getitem__(self, index):
return self.contents[index]

def __iter__(self):
return iter(self.contents)


def get_test_bucket_prefix(name: str) -> BucketPrefixFixture:
"""Create a new bucket/prefix fixture for the given test name."""
bucket = getenv("CI_BUCKET")
prefix = getenv("CI_PREFIX")
region = getenv("CI_REGION")
storage_class = getenv("CI_STORAGE_CLASS", optional=True)
endpoint_url = getenv("CI_CUSTOM_ENDPOINT_URL", optional=True)
assert prefix == "" or prefix.endswith("/")
@dataclass
class CopyBucketFixture(BucketPrefixFixture):
src_key: str = "src.txt"
dst_key: str = "dst.txt"

@property
def full_src_key(self):
return self.prefix + self.src_key

@property
def full_dst_key(self):
return self.prefix + self.dst_key


def get_test_copy_bucket_fixture(name: str) -> CopyBucketFixture:
copy_bucket_fixture = CopyBucketFixture(name=name)

nonce = random.randrange(2**64)
prefix = f"{prefix}{name}/{nonce}/"
# set up / teardown
copy_bucket_fixture.add(copy_bucket_fixture.src_key, b"Hello, World!\n")
copy_bucket_fixture.remove(copy_bucket_fixture.dst_key)

return BucketPrefixFixture(region, bucket, prefix, storage_class, endpoint_url)
return copy_bucket_fixture


@pytest.fixture
def image_directory(request) -> BucketPrefixFixture:
"""Create a bucket/prefix fixture that contains a directory of random JPG image files."""
NUM_IMAGES = 10
IMAGE_SIZE = 100
fixture = get_test_bucket_prefix(f"{request.node.name}/image_directory")
fixture = BucketPrefixFixture(f"{request.node.name}/image_directory")
for i in range(NUM_IMAGES):
data = np.random.randint(0, 256, IMAGE_SIZE * IMAGE_SIZE * 3, np.uint8)
data = data.reshape(IMAGE_SIZE, IMAGE_SIZE, 3)
Expand All @@ -100,23 +110,28 @@ def image_directory(request) -> BucketPrefixFixture:

@pytest.fixture
def sample_directory(request) -> BucketPrefixFixture:
fixture = get_test_bucket_prefix(f"{request.node.name}/sample_files")
fixture = BucketPrefixFixture(f"{request.node.name}/sample_files")
fixture.add("hello_world.txt", b"Hello, World!\n")
return fixture


@pytest.fixture
def put_object_tests_directory(request) -> BucketPrefixFixture:
fixture = get_test_bucket_prefix(f"{request.node.name}/put_integration_tests")
fixture = BucketPrefixFixture(f"{request.node.name}/put_integration_tests")
fixture.add("to_overwrite.txt", b"before")
return fixture


@pytest.fixture
def checkpoint_directory(request) -> BucketPrefixFixture:
return get_test_bucket_prefix(f"{request.node.name}/checkpoint_directory")
return BucketPrefixFixture(f"{request.node.name}/checkpoint_directory")


@pytest.fixture
def empty_directory(request) -> BucketPrefixFixture:
return get_test_bucket_prefix(f"{request.node.name}/empty_directory")
return BucketPrefixFixture(f"{request.node.name}/empty_directory")


@pytest.fixture
def copy_directory(request) -> CopyBucketFixture:
return get_test_copy_bucket_fixture(f"{request.node.name}/copy_directory")
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
ListObjectStream,
)

from conftest import BucketPrefixFixture
from conftest import BucketPrefixFixture, CopyBucketFixture

logging.basicConfig(
format="%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s"
Expand Down Expand Up @@ -404,6 +404,91 @@ def test_delete_object_invalid_bucket(
)


def test_copy_object(copy_directory: CopyBucketFixture):
full_src_key, full_dst_key = (
copy_directory.full_src_key,
copy_directory.full_dst_key,
)
bucket = copy_directory.bucket

client = MountpointS3Client(copy_directory.region, TEST_USER_AGENT_PREFIX)

client.copy_object(
src_bucket=bucket, src_key=full_src_key, dst_bucket=bucket, dst_key=full_dst_key
)

src_object = client.get_object(bucket, full_src_key)
dst_object = client.get_object(bucket, full_dst_key)

assert dst_object.key == full_dst_key
assert b"".join(dst_object) == b"".join(src_object)


def test_copy_object_raises_when_source_bucket_does_not_exist(
copy_directory: CopyBucketFixture,
):
full_src_key, full_dst_key = (
copy_directory.full_src_key,
copy_directory.full_dst_key,
)

client = MountpointS3Client(copy_directory.region, TEST_USER_AGENT_PREFIX)
# TODO: error message looks unexpected for Express One Zone, compared to the other tests for non-existing bucket or
# key (see below)
error_message = (
"Client error: Forbidden: <no message>"
if copy_directory.storage_class == "EXPRESS_ONEZONE"
else "Service error: The object was not found"
)

with pytest.raises(S3Exception, match=error_message):
client.copy_object(
src_bucket=str(uuid.uuid4()),
src_key=full_src_key,
dst_bucket=copy_directory.bucket,
dst_key=full_dst_key,
)


def test_copy_object_raises_when_destination_bucket_does_not_exist(
copy_directory: CopyBucketFixture,
):
full_src_key, full_dst_key = (
copy_directory.full_src_key,
copy_directory.full_dst_key,
)

client = MountpointS3Client(copy_directory.region, TEST_USER_AGENT_PREFIX)

# NOTE: `copy_object` and its underlying implementation does not
# differentiate between `NoSuchBucket` and `NoSuchKey` errors.
with pytest.raises(S3Exception, match="Service error: The object was not found"):
client.copy_object(
src_bucket=copy_directory.bucket,
src_key=full_src_key,
dst_bucket=str(uuid.uuid4()),
dst_key=full_dst_key,
)


def test_copy_object_raises_when_source_key_does_not_exist(
copy_directory: CopyBucketFixture,
):
full_dst_key = copy_directory.full_dst_key

bucket = copy_directory.bucket

client = MountpointS3Client(copy_directory.region, TEST_USER_AGENT_PREFIX)

with pytest.raises(S3Exception, match="Service error: The object was not found"):
client.copy_object(
src_bucket=bucket,
src_key=str(uuid.uuid4()),
dst_bucket=bucket,
dst_key=full_dst_key,
)


def _parse_list_result(stream: ListObjectStream, max_keys: int):
object_infos = []
i = 0
Expand Down
Loading

0 comments on commit 31246d9

Please sign in to comment.