Skip to content

Commit

Permalink
concurrent-upload: add progress reporting when uploading chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
giacomo-alzetta-aiven committed Jul 27, 2023
1 parent 795d841 commit df85b54
Show file tree
Hide file tree
Showing 6 changed files with 186 additions and 19 deletions.
9 changes: 7 additions & 2 deletions rohmu/object_storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..errors import FileNotFoundFromStorageError, StorageError
from ..notifier.interface import Notifier
from ..typing import Metadata
from ..util import BinaryStreamsConcatenation
from ..util import BinaryStreamsConcatenation, ProgressStream
from .base import (
BaseTransfer,
ConcurrentUploadData,
Expand Down Expand Up @@ -275,8 +275,13 @@ def upload_concurrent_chunk(
chunks_dir = self.format_key_for_backend("concurrent_upload_" + concurrent_data.backend_id)
try:
with atomic_create_file_binary(os.path.join(chunks_dir, str(chunk_number))) as chunk_fp:
for data in iter(lambda: fd.read(CHUNK_SIZE), b""):
wrapped_fd = ProgressStream(fd)
for data in iter(lambda: wrapped_fd.read(CHUNK_SIZE), b""):
chunk_fp.write(data)
bytes_read = wrapped_fd.bytes_read
if upload_progress_fn:
upload_progress_fn(bytes_read)
self.stats.operation(StorageOperation.store_file, size=bytes_read)
chunks[chunk_number] = "no-etag"
except OSError as ex:
raise StorageError(
Expand Down
16 changes: 13 additions & 3 deletions rohmu/object_storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from botocore.response import StreamingBody
from enum import Enum, unique
from rohmu.util import batched
from functools import partial
from rohmu.util import batched, ProgressStream
from typing import Any, BinaryIO, cast, Collection, Iterator, Optional, Tuple, TYPE_CHECKING, Union

import botocore.client
Expand Down Expand Up @@ -629,6 +630,7 @@ def complete_concurrent_upload(self, upload_id: ConcurrentUploadId) -> None:
key=lambda part: cast(int, part["PartNumber"]),
)
try:
self.stats.operation(StorageOperation.multipart_complete)
self.s3_client.complete_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
Expand All @@ -646,6 +648,7 @@ def abort_concurrent_upload(self, upload_id: ConcurrentUploadId) -> None:
concurrent_data, _, _ = self._get_concurrent_upload(upload_id)
backend_key = self.format_key_for_backend(concurrent_data.key, remove_slash_prefix=True)
try:
self.stats.operation(StorageOperation.multipart_aborted)
self.s3_client.abort_multipart_upload(
Bucket=self.bucket_name,
Key=backend_key,
Expand All @@ -670,13 +673,20 @@ def upload_concurrent_chunk(
concurrent_data, _, chunks = self._get_concurrent_upload(upload_id)
backend_key = self.format_key_for_backend(concurrent_data.key, remove_slash_prefix=True)
try:
response = self.s3_client.upload_part(
upload_func = partial(
self.s3_client.upload_part,
Bucket=self.bucket_name,
Key=backend_key,
UploadId=concurrent_data.backend_id,
Body=fd,
PartNumber=chunk_number,
)
body = ProgressStream(fd)
response = upload_func(Body=body)
if upload_progress_fn:
upload_progress_fn(body.bytes_read)
else:
response = upload_func(Body=fd)
self.stats.operation(StorageOperation.store_file, size=body.bytes_read)
chunks[chunk_number] = response["ETag"]
except botocore.exceptions.ClientError as ex:
raise StorageError(
Expand Down
97 changes: 96 additions & 1 deletion rohmu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
Copyright (c) 2022 Ohmu Ltd
See LICENSE for details
"""
from io import BytesIO
from __future__ import annotations

from io import BytesIO, UnsupportedOperation
from itertools import islice
from rohmu.typing import HasFileno
from typing import BinaryIO, Generator, Iterable, Optional, Tuple, TypeVar, Union
from typing_extensions import Buffer

import fcntl
import logging
import os
import platform
import types

LOG = logging.getLogger("rohmu.util")

Expand Down Expand Up @@ -109,3 +113,94 @@ def read(self, size: int = -1) -> bytes:
break

return result.getvalue()


class ProgressStream(BinaryIO):
"""Wrapper for binary streams that can report the amount of bytes read through it."""

def __init__(self, raw_stream: BinaryIO) -> None:
self.raw_stream = raw_stream
self.bytes_read = 0

def seekable(self) -> bool:
return False

def writable(self) -> bool:
return False

def readable(self) -> bool:
return True

@property
def closed(self) -> bool:
return self.raw_stream.closed

@property
def name(self) -> str:
return self.raw_stream.name

@property
def mode(self) -> str:
return self.raw_stream.mode

def read(self, n: int = -1) -> bytes:
data = self.raw_stream.read(n)
self.bytes_read += len(data)
return data

def readline(self, limit: int = -1) -> bytes:
line = self.raw_stream.readline(limit)
self.bytes_read += len(line)
return line

def readlines(self, hint: int = -1) -> list[bytes]:
lines = self.raw_stream.readlines(hint)
self.bytes_read += sum(map(len, lines))
return lines

def __iter__(self) -> "ProgressStream":
return self

def __next__(self) -> bytes:
data = next(self.raw_stream)
self.bytes_read += len(data)
return data

def __enter__(self) -> "ProgressStream":
self.raw_stream.__enter__()
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[types.TracebackType],
) -> None:
return self.raw_stream.__exit__(exc_type, exc_val, exc_tb)

def close(self) -> None:
self.raw_stream.close()

def flush(self) -> None:
self.raw_stream.flush()

def isatty(self) -> bool:
return self.raw_stream.isatty()

def tell(self) -> int:
return self.raw_stream.tell()

def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("seek")

def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("truncate")

def write(self, s: Union[bytes, Buffer]) -> int:
raise UnsupportedOperation("write")

def writelines(self, __lines: Union[Iterable[bytes], Iterable[Buffer]]) -> None:
raise UnsupportedOperation("writelines")

def fileno(self) -> int:
raise UnsupportedOperation("fileno")
23 changes: 16 additions & 7 deletions test/test_object_storage_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,25 @@ def test_upload_files_concurrently_can_be_aborted() -> None:
notifier=notifier,
)
upload_id = transfer.create_concurrent_upload(key="test_key1", metadata={"some-key": "some-value"})

total = 0

def inc_progress(size: int) -> None:
nonlocal total
total += size

# should end up with b"Hello, World!\nHello, World!"
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"Hello"))
transfer.upload_concurrent_chunk(upload_id, 4, BytesIO(b", "))
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, World!"))
transfer.upload_concurrent_chunk(upload_id, 7, BytesIO(b"!"))
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"\n"))
transfer.upload_concurrent_chunk(upload_id, 6, BytesIO(b"ld"))
transfer.upload_concurrent_chunk(upload_id, 5, BytesIO(b"Wor"))
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"Hello"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 4, BytesIO(b", "), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, World!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 7, BytesIO(b"!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"\n"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 6, BytesIO(b"ld"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 5, BytesIO(b"Wor"), upload_progress_fn=inc_progress)
transfer.abort_concurrent_upload(upload_id)

assert total == 27

# we should not be able to find this
with pytest.raises(FileNotFoundFromStorageError):
transfer.get_metadata_for_key("test_key1")
Expand Down
26 changes: 22 additions & 4 deletions test/test_object_storage_s3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Copyright (c) 2022 Aiven, Helsinki, Finland. https://aiven.io/"""
from __future__ import annotations

from botocore.response import StreamingBody
from dataclasses import dataclass
from datetime import datetime
Expand All @@ -7,7 +9,7 @@
from rohmu.errors import InvalidByteRangeError
from rohmu.object_storage.s3 import S3Transfer
from tempfile import NamedTemporaryFile
from typing import Any, Iterator, Optional
from typing import Any, BinaryIO, Iterator, Optional
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -138,12 +140,26 @@ def test_get_contents_to_fileobj_passes_the_correct_range_header(infra: S3Infra)
def test_concurrent_upload_complete(infra: S3Infra) -> None:
metadata = {"some-date": datetime(2022, 11, 15, 18, 30, 58, 486644)}
infra.s3_client.create_multipart_upload.return_value = {"UploadId": "<aws-mpu-id>"}

def upload_part_side_effect(Body: BinaryIO, **_kwargs: Any) -> dict[str, str]:
# to check the progress function we need to actually consume the body
Body.read()
return {"ETag": "some-etag"}

infra.s3_client.upload_part.side_effect = upload_part_side_effect
transfer = infra.transfer
upload_id = transfer.create_concurrent_upload("test_key", metadata=metadata)
transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, "))

total = 0

def inc_progress(size: int) -> None:
nonlocal total
total += size

transfer.upload_concurrent_chunk(upload_id, 1, BytesIO(b"Hello, "), upload_progress_fn=inc_progress)
# we can upload chunks in non-monotonically increasing order
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"!"))
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"World"))
transfer.upload_concurrent_chunk(upload_id, 3, BytesIO(b"!"), upload_progress_fn=inc_progress)
transfer.upload_concurrent_chunk(upload_id, 2, BytesIO(b"World"), upload_progress_fn=inc_progress)
transfer.complete_concurrent_upload(upload_id)

notifier = infra.notifier
Expand All @@ -160,6 +176,8 @@ def test_concurrent_upload_complete(infra: S3Infra) -> None:
metadata={"some-date": "2022-11-15 18:30:58.486644"},
)

assert total == 13


def test_concurrent_upload_abort(infra: S3Infra) -> None:
infra.s3_client.create_multipart_upload.return_value = {"UploadId": "<aws-mpu-id>"}
Expand Down
34 changes: 32 additions & 2 deletions test/test_util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from io import BytesIO
from rohmu.util import BinaryStreamsConcatenation, get_total_size_from_content_range
from io import BytesIO, UnsupportedOperation
from rohmu.util import BinaryStreamsConcatenation, get_total_size_from_content_range, ProgressStream
from typing import Optional

import pytest
Expand Down Expand Up @@ -41,3 +41,33 @@ def test_binary_stream_concatenation(
for output_chunk in iter(lambda: concatenation.read(chunk_size), b""):
outputs.append(output_chunk)
assert outputs == expected_outputs


def test_progress_stream() -> None:
stream = BytesIO(b"Hello, World!\nSecond line\nThis is a longer third line\n")
progress_stream = ProgressStream(stream)
assert progress_stream.readable()
assert not progress_stream.writable()
assert not progress_stream.seekable()

assert progress_stream.read(14) == b"Hello, World!\n"
assert progress_stream.bytes_read == 14
assert progress_stream.readlines() == [b"Second line\n", b"This is a longer third line\n"]
assert progress_stream.bytes_read == 54

with pytest.raises(UnsupportedOperation):
progress_stream.seek(0)
with pytest.raises(UnsupportedOperation):
progress_stream.truncate(0)
with pytest.raises(UnsupportedOperation):
progress_stream.write(b"Something")
with pytest.raises(UnsupportedOperation):
progress_stream.writelines([b"Something"])
with pytest.raises(UnsupportedOperation):
progress_stream.fileno()

assert not progress_stream.closed
with progress_stream:
# check that __exit__ closes the file
pass
assert progress_stream.closed

0 comments on commit df85b54

Please sign in to comment.