Skip to content

Commit

Permalink
concurrent-upload: add progress reporting when uploading chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
giacomo-alzetta-aiven committed Jul 24, 2023
1 parent be62d9f commit 7e8e7c3
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 20 deletions.
9 changes: 7 additions & 2 deletions rohmu/object_storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..errors import FileNotFoundFromStorageError, StorageError
from ..notifier.interface import Notifier
from ..typing import Metadata
from ..util import BinaryStreamsConcatenation
from ..util import BinaryStreamsConcatenation, ProgressStream
from .base import (
BaseTransfer,
ConcurrentUploadData,
Expand Down Expand Up @@ -284,8 +284,13 @@ def upload_concurrent_chunk(
chunks_dir = self.format_key_for_backend("concurrent_upload_" + concurrent_data.backend_id)
try:
with atomic_create_file_binary(os.path.join(chunks_dir, str(chunk_number))) as chunk_fp:
for data in iter(lambda: fd.read(CHUNK_SIZE), b""):
wrapped_fd = ProgressStream(fd) # type: ignore[abstract]
for data in iter(lambda: wrapped_fd.read(CHUNK_SIZE), b""):
chunk_fp.write(data)
bytes_read = wrapped_fd.bytes_read
if upload_progress_fn:
upload_progress_fn(bytes_read)
self.stats.operation(StorageOperation.store_file, size=bytes_read)
chunks[chunk_number] = "no-etag"
except OSError as ex:
raise StorageError(
Expand Down
16 changes: 13 additions & 3 deletions rohmu/object_storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from botocore.response import StreamingBody
from enum import Enum, unique
from rohmu.util import batched
from functools import partial
from rohmu.util import batched, ProgressStream
from typing import Any, BinaryIO, cast, Collection, Iterator, Optional, Tuple, TYPE_CHECKING, Union

import botocore.client
Expand Down Expand Up @@ -638,6 +639,7 @@ def complete_concurrent_upload(self, upload_id: ConcurrentUploadId) -> None:
key=lambda part: cast(int, part["PartNumber"]),
)
try:
self.stats.operation(StorageOperation.multipart_complete)
self.s3_client.complete_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
Expand All @@ -655,6 +657,7 @@ def abort_concurrent_upload(self, upload_id: ConcurrentUploadId) -> None:
concurrent_data, _, _ = self._get_upload_info(upload_id)
backend_key = self.format_key_for_backend(concurrent_data.key, remove_slash_prefix=True)
try:
self.stats.operation(StorageOperation.multipart_aborted)
self.s3_client.abort_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
Expand All @@ -679,13 +682,20 @@ def upload_concurrent_chunk(
concurrent_data, _, chunks = self._get_upload_info(upload_id)
backend_key = self.format_key_for_backend(concurrent_data.key, remove_slash_prefix=True)
try:
response = self.s3_client.upload_part(
upload_func = partial(
self.s3_client.upload_part,
Bucket=self.bucket_name,
Key=backend_key,
UploadId=concurrent_data.backend_id,
Body=fd,
PartNumber=chunk_number,
)
body = ProgressStream(fd) # type: ignore[abstract]
response = upload_func(Body=body)
if upload_progress_fn:
upload_progress_fn(body.bytes_read)
else:
response = upload_func(Body=fd)
self.stats.operation(StorageOperation.store_file, size=body.bytes_read)
chunks[chunk_number] = response["ETag"]
except botocore.exceptions.ClientError as ex:
raise StorageError(
Expand Down
83 changes: 81 additions & 2 deletions rohmu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
Copyright (c) 2022 Ohmu Ltd
See LICENSE for details
"""
from io import BytesIO
from io import BytesIO, UnsupportedOperation
from itertools import islice
from rohmu.typing import HasFileno
from typing import BinaryIO, Generator, Iterable, Optional, Tuple, TypeVar, Union
from typing import Any, BinaryIO, Generator, Iterable, Optional, Tuple, TypeVar, Union

import fcntl
import logging
Expand Down Expand Up @@ -109,3 +109,82 @@ def read(self, size: int = -1) -> bytes:
break

return result.getvalue()


_FORWARD_TO_STREAM_METHODS = [
"close",
"fileno",
"flush",
"isatty",
"tell",
"__enter__",
"__exit__",
]

_FORBIDDEN_METHODS = [
"seek",
"truncate",
"write",
"writelines",
]


class ProgressStream(BinaryIO): # pylint: disable=abstract-method
"""Wrapper for binary streams that can report the amount of bytes read through it."""

def __init__(self, raw_stream: BinaryIO) -> None:
self.raw_stream = raw_stream
self.bytes_read = 0

# NOTE: BinaryIO interface has 20+ methods, and most of these should just be forwaded as-is
# so here we dynamically create them. In the future we can hopefully move to protocols
# but since this is used with botocore we need something accepted by it
for method_name in _FORWARD_TO_STREAM_METHODS:
locals()[method_name] = lambda self, *args, method_name=method_name, **kwargs: getattr(self.raw_stream, method_name)(
*args, **kwargs
)

for method_name in _FORBIDDEN_METHODS:

def _unused(self, *args: Any, method_name: str = method_name, **kwargs: Any) -> Any:
raise UnsupportedOperation(method_name)

locals()[method_name] = _unused

del method_name # don't pollute class scope with loop variable

def seekable(self) -> bool:
return False

def writable(self) -> bool:
return False

def readable(self) -> bool:
return True

@property
def closed(self) -> bool:
return self.raw_stream.closed

@property
def name(self) -> str:
return self.raw_stream.name

@property
def mode(self) -> str:
return self.raw_stream.mode

def read(self, n: int = -1) -> bytes:
data = self.raw_stream.read(n)
self.bytes_read += len(data)
return data

def readline(self, limit: int = -1) -> bytes:
line = self.raw_stream.readline(limit)
self.bytes_read += len(line)
return line

def readlines(self, hint: int = -1) -> list[bytes]:
lines = self.raw_stream.readlines(hint)
self.bytes_read += sum(map(len, lines))
return lines
23 changes: 16 additions & 7 deletions test/test_object_storage_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,25 @@ def test_upload_files_concurrently_can_be_aborted() -> None:
notifier=notifier,
)
upload_id = transfer.create_concurrent_upload(key="test_key1", metadata={"some-key": "some-value"})

total = 0

def inc_progress(size: int) -> None:
nonlocal total
total += size

# should end up with b"Hello, World!\nHello, World!"
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"Hello"))
transfer.upload_concurrent_chunk(upload_id, 4, BytesIO(b", "))
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, World!"))
transfer.upload_concurrent_chunk(upload_id, 7, BytesIO(b"!"))
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"\n"))
transfer.upload_concurrent_chunk(upload_id, 6, BytesIO(b"ld"))
transfer.upload_concurrent_chunk(upload_id, 5, BytesIO(b"Wor"))
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"Hello"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 4, BytesIO(b", "), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, World!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 7, BytesIO(b"!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"\n"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 6, BytesIO(b"ld"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 5, BytesIO(b"Wor"), upload_progress_fn=inc_progress)
transfer.abort_concurrent_upload(upload_id)

assert total == 27

# we should not be able to find this
with pytest.raises(FileNotFoundFromStorageError):
transfer.get_metadata_for_key("test_key1")
Expand Down
24 changes: 20 additions & 4 deletions test/test_object_storage_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from rohmu.errors import InvalidByteRangeError
from rohmu.object_storage.s3 import S3Transfer
from tempfile import NamedTemporaryFile
from typing import Any, Iterator, Optional
from typing import Any, BinaryIO, Iterator, Optional
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -138,12 +138,26 @@ def test_get_contents_to_fileobj_passes_the_correct_range_header(infra: S3Infra)
def test_concurrent_upload_complete(infra: S3Infra) -> None:
metadata = {"some-date": datetime(2022, 11, 15, 18, 30, 58, 486644)}
infra.s3_client.create_multipart_upload.return_value = {"UploadId": "<aws-mpu-id>"}

def upload_part_side_effect(Body: BinaryIO, **_kwargs: Any) -> dict[str, str]:
# to check the progress function we need to actually consume the body
Body.read()
return {"ETag": "some-etag"}

infra.s3_client.upload_part.side_effect = upload_part_side_effect
transfer = infra.transfer
upload_id = transfer.create_concurrent_upload("test_key", metadata=metadata)
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, "))

total = 0

def inc_progress(size: int) -> None:
nonlocal total
total += size

transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, "), upload_progress_fn=inc_progress)
# we can upload chunks in non-monotonically increasing order
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"!"))
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"World"))
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"World"), upload_progress_fn=inc_progress)
transfer.complete_concurrent_upload(upload_id)

notifier = infra.notifier
Expand All @@ -160,6 +174,8 @@ def test_concurrent_upload_complete(infra: S3Infra) -> None:
metadata={"some-date": "2022-11-15 18:30:58.486644"},
)

assert total == 13


def test_concurrent_upload_abort(infra: S3Infra) -> None:
infra.s3_client.create_multipart_upload.return_value = {"UploadId": "<aws-mpu-id>"}
Expand Down
32 changes: 30 additions & 2 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from io import BytesIO
from rohmu.util import BinaryStreamsConcatenation, get_total_size_from_content_range
from io import BytesIO, UnsupportedOperation
from rohmu.util import BinaryStreamsConcatenation, get_total_size_from_content_range, ProgressStream
from typing import Optional

import pytest
Expand Down Expand Up @@ -41,3 +41,31 @@ def test_binary_stream_concatenation(
for output_chunk in iter(lambda: concatenation.read(chunk_size), b""):
outputs.append(output_chunk)
assert outputs == expected_outputs


def test_progress_stream() -> None:
stream = BytesIO(b"Hello, World!\nSecond line\nThis is a longer third line\n")
progress_stream = ProgressStream(stream) # type: ignore[abstract]
assert progress_stream.readable()
assert not progress_stream.writable()
assert not progress_stream.seekable()

assert progress_stream.read(14) == b"Hello, World!\n"
assert progress_stream.bytes_read == 14
assert progress_stream.readlines() == [b"Second line\n", b"This is a longer third line\n"]
assert progress_stream.bytes_read == 54

with pytest.raises(UnsupportedOperation):
progress_stream.seek(0) # pylint: disable=no-member
with pytest.raises(UnsupportedOperation):
progress_stream.truncate(0) # pylint: disable=no-member
with pytest.raises(UnsupportedOperation):
progress_stream.write(b"Something") # pylint: disable=no-member
with pytest.raises(UnsupportedOperation):
progress_stream.writelines([b"Something"]) # pylint: disable=no-member

assert not progress_stream.closed
with progress_stream: # pylint: disable=not-context-manager
# check that __exit__ closes the file
pass
assert progress_stream.closed

0 comments on commit 7e8e7c3

Please sign in to comment.