Skip to content

Commit

Permalink
add get_contents_iterator API
Browse files Browse the repository at this point in the history
This is slightly lower level than get_contents_to_fileobj and skips
the need to create an (almost always fake) file object while returning
the metadata first, instead of after the file is entirely downloaded.

This helps with implementing streaming fetches.
  • Loading branch information
kmichel-aiven committed Jul 13, 2023
1 parent 2e6fee9 commit 4f10eb7
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 87 deletions.
43 changes: 24 additions & 19 deletions rohmu/object_storage/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,34 +254,33 @@ def _parse_length_from_content_range(cls, content_range: str) -> int:

return int(content_range.split(" ", 1)[1].split("/", 1)[1])

def _stream_blob(
def _iter_blob(
self,
key: str,
fileobj: BinaryIO,
byte_range: Optional[tuple[int, int]],
progress_callback: ProgressProportionCallbackType,
) -> None:
) -> Iterator[bytes]:
"""Streams contents of given key to given fileobj. Data is read sequentially in chunks
without any seeks. This requires duplicating some functionality of the Azure SDK, which only
allows reading entire blob into memory at once or returning data from random offsets"""
file_size = None
start_range = byte_range[0] if byte_range else 0
chunk_size = self.conn._config.max_chunk_get_size # type: ignore [attr-defined] # pylint: disable=protected-access
chunk_size = self.conn._config.max_chunk_get_size # type: ignore[attr-defined] # pylint: disable=protected-access
end_range = chunk_size - 1
blob = self.conn.get_blob_client(self.container_name, key)
while True:
try:
# pylint: disable=protected-access
if byte_range:
length = min(byte_range[1] - start_range + 1, chunk_size)
else:
length = chunk_size
download_stream = blob.download_blob(offset=start_range, length=length)
if file_size is None:
file_size = download_stream._file_size
file_size = download_stream._file_size # pylint: disable=protected-access
if byte_range:
file_size = min(file_size, byte_range[1] + 1)
download_stream.readinto(fileobj)
for chunk in download_stream.chunks():
yield chunk
start_range += download_stream.size
if start_range >= file_size:
break
Expand All @@ -296,6 +295,8 @@ def _stream_blob(
if ex.status_code == 416: # Empty file
return
raise FileNotFoundFromStorageError(key) from ex
if progress_callback:
progress_callback(1, 1)

def get_contents_to_fileobj(
self,
Expand All @@ -305,18 +306,22 @@ def get_contents_to_fileobj(
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> Metadata:
metadata, chunks = self.get_contents_iterator(key, byte_range=byte_range, progress_callback=progress_callback)
for chunk in chunks:
fileobj_to_store_to.write(chunk)
return metadata

def get_contents_iterator(
self,
key: str,
*,
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> tuple[Metadata, Iterator[bytes]]:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
self._validate_byte_range(byte_range)

self.log.debug("Starting to fetch the contents of: %r", path)
try:
self._stream_blob(path, fileobj_to_store_to, byte_range, progress_callback)
except azure.core.exceptions.ResourceNotFoundError as ex: # pylint: disable=no-member
raise FileNotFoundFromStorageError(path) from ex

if progress_callback:
progress_callback(1, 1)
return self._metadata_for_key(path)
return self._metadata_for_key(path), self._iter_blob(path, byte_range, progress_callback)

def get_file_size(self, key: str) -> int:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
Expand Down Expand Up @@ -357,13 +362,13 @@ def progress_callback(pipeline_response: Any) -> None:
seekable = hasattr(fd, "seekable") and fd.seekable()
if not seekable:
original_tell = getattr(fd, "tell", None)
fd.tell = lambda: None # type: ignore [assignment,method-assign,return-value]
fd.tell = lambda: None # type: ignore[assignment,method-assign,return-value]
sanitized_metadata = self.sanitize_metadata(metadata, replace_hyphen_with="_")
try:
blob_client = self.conn.get_blob_client(self.container_name, path)
blob_client.upload_blob(
fd,
blob_type=BlobType.BlockBlob, # type: ignore [arg-type]
blob_type=BlobType.BlockBlob, # type: ignore[arg-type]
content_settings=content_settings,
metadata=sanitized_metadata,
raw_response_hook=progress_callback,
Expand All @@ -373,7 +378,7 @@ def progress_callback(pipeline_response: Any) -> None:
finally:
if not seekable:
if original_tell is not None:
fd.tell = original_tell # type: ignore [method-assign]
fd.tell = original_tell # type: ignore[method-assign]
else:
delattr(fd, "tell")

Expand Down
9 changes: 9 additions & 0 deletions rohmu/object_storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,15 @@ def get_contents_to_fileobj(
"""Like `get_contents_to_file()` but writes to an open file-like object."""
raise NotImplementedError

def get_contents_iterator(
self,
key: str,
*,
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> tuple[Metadata, Iterator[bytes]]:
raise NotImplementedError

def _validate_byte_range(self, byte_range: Optional[Tuple[int, int]]) -> None:
if byte_range is not None and byte_range[0] > byte_range[1]:
raise InvalidByteRangeError(f"Invalid byte_range: {byte_range}. Start must be <= end.")
Expand Down
55 changes: 27 additions & 28 deletions rohmu/object_storage/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,7 @@ def get_contents_to_fileobj(
)
# https://googleapis.github.io/google-api-python-client/docs/dyn/storage_v1.objects.html#get_media
req: HttpRequest = clob.get_media(bucket=self.bucket_name, object=path)
download: MediaDownloadProtocol
if byte_range is None:
download = MediaIoBaseDownload(fileobj_to_store_to, req, chunksize=DOWNLOAD_CHUNK_SIZE)
else:
download = MediaIoBaseDownloadWithByteRange(
fileobj_to_store_to, req, chunksize=DOWNLOAD_CHUNK_SIZE, byte_range=byte_range
)
download = MediaIoBaseDownloadWithByteRange(req, chunksize=DOWNLOAD_CHUNK_SIZE, byte_range=byte_range)

done = False
while not done:
Expand All @@ -458,6 +452,16 @@ def get_contents_to_fileobj(
progress_callback(100, 100)
return metadata

def get_contents_iterator(
self,
key: str,
*,
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> tuple[Metadata, Iterator[bytes]]:
# The annoying part is MediaIoBaseDownload that require a file but does not deal with files...
pass

def get_file_size(self, key: str) -> int:
path = self.format_key_for_backend(key)
reporter = Reporter(StorageOperation.get_file_size)
Expand Down Expand Up @@ -723,24 +727,19 @@ def _read_bytes(self, length: int, *, initial_data: Optional[bytes] = None) -> b
return b"".join(read_results)


class MediaDownloadProtocol(Protocol):
def next_chunk(self) -> tuple[MediaDownloadProgress, bool]:
...


class MediaIoBaseDownloadWithByteRange:
"""This class is mostly a copy of the googleapiclient's MediaIOBaseDownload class,
but with the addition of the support for fetching a specific byte_range.
And the content is returned instead of written to a file object.
"""

def __init__(
self,
fd: BinaryIO,
request: HttpRequest,
chunksize: int = DOWNLOAD_CHUNK_SIZE,
*,
byte_range: tuple[int, int],
byte_range: tuple[int, int] | None,
) -> None:
"""Constructor.
Expand All @@ -752,14 +751,13 @@ def __init__(
chunksize: int, File will be downloaded in chunks of this many bytes.
byte_range: tuple[int, int], The byterange to fetch
"""
self._fd = fd
self._http = request.http
self._uri = request.uri
self._chunksize = chunksize
self._start_position, self._end_position = byte_range
self._start_position, self._end_position = byte_range if byte_range is not None else (0, None)
self._num_bytes_downloaded = 0
self._range_size = self._end_position - self._start_position + 1
if self._range_size < 0:
self._range_size = self._end_position - self._start_position + 1 if byte_range is not None else None
if self._range_size is not None and self._range_size < 0:
raise InvalidByteRangeError(f"Invalid byte_range: {byte_range}. Start must be < end.")
self._done = False

Expand All @@ -771,11 +769,11 @@ def __init__(
if k.lower() not in ("accept", "accept-encoding", "user-agent"):
self._headers[k] = v

def next_chunk(self) -> tuple[MediaDownloadProgress, bool]:
def next_chunk(self) -> tuple[bytes, MediaDownloadProgress, bool]:
"""Get the next chunk of the download.
Returns:
(status, done): The value of done will be True when the media has been fully
(chunk, status, done): The value of done will be True when the media has been fully
downloaded or the total size of the media is unknown.
Raises:
Expand All @@ -785,7 +783,7 @@ def next_chunk(self) -> tuple[MediaDownloadProgress, bool]:
headers = self._headers.copy()
chunk_start = self._num_bytes_downloaded + self._start_position
chunk_end = chunk_start + self._chunksize - 1
if self._end_position < chunk_end:
if self._end_position is not None and self._end_position < chunk_end:
chunk_end = self._end_position
headers["range"] = f"bytes={chunk_start}-{chunk_end}"
resp, content = self._http.request(self._uri, "GET", headers=headers)
Expand All @@ -795,25 +793,26 @@ def next_chunk(self) -> tuple[MediaDownloadProgress, bool]:
if "content-location" in resp and resp["content-location"] != self._uri:
self._uri = resp["content-location"]
self._num_bytes_downloaded += len(content)
self._fd.write(content)

if "content-range" in resp:
total_size = get_total_size_from_content_range(resp["content-range"])
elif "content-length" in resp:
# By RFC 9110 if we end up here this is a 200 OK response and this is the total size of the object
total_size = int(resp["content-length"])

size_to_download = (
self._range_size if total_size is None else min(self._range_size, total_size - self._start_position)
)
if self._range_size is None:
size_to_download = total_size
else:
size_to_download = (
self._range_size if total_size is None else min(self._range_size, total_size - self._start_position)
)
if self._num_bytes_downloaded == size_to_download:
self._done = True
return MediaDownloadProgress(self._num_bytes_downloaded, size_to_download), self._done
return content, MediaDownloadProgress(self._num_bytes_downloaded, size_to_download), self._done
elif resp.status == 416:
# 416 is Range Not Satisfiable
# This typically occurs with a zero byte file
total_size = get_total_size_from_content_range(resp["content-range"])
if total_size == 0:
self._done = True
return MediaDownloadProgress(self._num_bytes_downloaded, total_size), self._done
return b"", MediaDownloadProgress(self._num_bytes_downloaded, total_size), self._done
raise HttpError(resp, content, uri=self._uri)
50 changes: 33 additions & 17 deletions rohmu/object_storage/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,28 +178,44 @@ def get_contents_to_fileobj(
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> Metadata:
metadata, chunks = self.get_contents_iterator(key, byte_range=byte_range, progress_callback=progress_callback)
for chunk in chunks:
fileobj_to_store_to.write(chunk)
return metadata

def get_contents_iterator(
self,
key: str,
*,
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> tuple[Metadata, Iterator[bytes]]:
self._validate_byte_range(byte_range)
metadata = self.get_metadata_for_key(key)
source_path = self.format_key_for_backend(key.strip("/"))
if not os.path.exists(source_path):
try:
input_file = open(source_path, "rb")
except FileNotFoundError:
raise FileNotFoundFromStorageError(key)

input_size = os.stat(source_path).st_size
bytes_written = 0
with open(source_path, "rb") as fp:
if byte_range:
fp.seek(byte_range[0])
input_size = byte_range[1] - byte_range[0] + 1
while bytes_written <= input_size:
left = min(input_size - bytes_written, CHUNK_SIZE)
buf = fp.read(left)
if not buf:
break
fileobj_to_store_to.write(buf)
bytes_written += len(buf)
if progress_callback:
progress_callback(bytes_written, input_size)
def iter_chunks() -> Iterator[bytes]:
with input_file as fp:
input_size = os.fstat(input_file.fileno()).st_size
bytes_written = 0
if byte_range:
fp.seek(byte_range[0])
input_size = byte_range[1] - byte_range[0] + 1
while bytes_written <= input_size:
left = min(input_size - bytes_written, CHUNK_SIZE)
buf = fp.read(left)
if not buf:
break
yield buf
bytes_written += len(buf)
if progress_callback:
progress_callback(bytes_written, input_size)

return self.get_metadata_for_key(key)
return metadata, iter_chunks()

def get_file_size(self, key: str) -> int:
source_path = self.format_key_for_backend(key.strip("/"))
Expand Down
25 changes: 18 additions & 7 deletions rohmu/object_storage/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _get_object_stream(self, key: str, byte_range: Optional[tuple[int, int]]) ->
kwargs["Range"] = f"bytes {byte_range[0]}-{byte_range[1]}"
try:
# Actual usage is accounted for in
# _read_object_to_fileobj, although that omits the initial
# _iter_object, although that omits the initial
# get_object call if it fails.
response = self.s3_client.get_object(Bucket=self.bucket_name, Key=path, **kwargs)
except botocore.exceptions.ClientError as ex:
Expand All @@ -336,17 +336,17 @@ def _get_object_stream(self, key: str, byte_range: Optional[tuple[int, int]]) ->
raise StorageError("Fetching the remote object {} failed".format(path)) from ex
return response["Body"], response["ContentLength"], response["Metadata"]

def _read_object_to_fileobj(
self, fileobj: BinaryIO, streaming_body: StreamingBody, body_length: int, cb: ProgressProportionCallbackType = None
) -> None:
def _iter_object(
self, streaming_body: StreamingBody, body_length: int, cb: ProgressProportionCallbackType = None
) -> Iterator[bytes]:
data_read = 0
while data_read < body_length:
read_amount = body_length - data_read
if read_amount > READ_BLOCK_SIZE:
read_amount = READ_BLOCK_SIZE
data = streaming_body.read(amt=read_amount)
fileobj.write(data)
data_read += len(data)
yield data
if cb:
cb(data_read, body_length)
self.stats.operation(operation=StorageOperation.get_file, size=len(data))
Expand All @@ -361,10 +361,21 @@ def get_contents_to_fileobj(
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> Metadata:
metadata, chunks = self.get_contents_iterator(key, byte_range=byte_range, progress_callback=progress_callback)
for chunk in chunks:
fileobj_to_store_to.write(chunk)
return metadata

def get_contents_iterator(
self,
key: str,
*,
byte_range: Optional[Tuple[int, int]] = None,
progress_callback: ProgressProportionCallbackType = None,
) -> tuple[Metadata, Iterator[bytes]]:
self._validate_byte_range(byte_range)
stream, length, metadata = self._get_object_stream(key, byte_range)
self._read_object_to_fileobj(fileobj_to_store_to, stream, length, cb=progress_callback)
return metadata
return metadata, self._iter_object(stream, length, cb=progress_callback)

def get_file_size(self, key: str) -> int:
path = self.format_key_for_backend(key, remove_slash_prefix=True)
Expand Down
Loading

0 comments on commit 4f10eb7

Please sign in to comment.