Skip to content

Commit

Permalink
util: make ProgressStream seekable
Browse files Browse the repository at this point in the history
botocore wants a seekable stream apparently.
  • Loading branch information
giacomo-alzetta-aiven committed Aug 2, 2023
1 parent ef9ecd2 commit f2c6500
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
13 changes: 11 additions & 2 deletions rohmu/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def __init__(self, raw_stream: BinaryIO) -> None:
self.bytes_read = 0

def seekable(self) -> bool:
return False
"""A progress stream is seekable if the underlying stream is."""
return self.raw_stream.seekable()

def writable(self) -> bool:
return False
Expand Down Expand Up @@ -191,7 +192,15 @@ def tell(self) -> int:
return self.raw_stream.tell()

def seek(self, offset: int, whence: int = 0) -> int:
raise UnsupportedOperation("seek")
"""Seek the underlying file if this operation is supported.
NOTE: Calling this method will reset the bytes_read field!
"""
result = self.raw_stream.seek(offset, whence)

self.bytes_read = 0
return result

def truncate(self, size: Optional[int] = None) -> int:
raise UnsupportedOperation("truncate")
Expand Down
12 changes: 9 additions & 3 deletions test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ def test_progress_stream() -> None:
progress_stream = ProgressStream(stream)
assert progress_stream.readable()
assert not progress_stream.writable()
assert not progress_stream.seekable()
# stream is seekable if underlying stream is
assert progress_stream.seekable()

assert progress_stream.read(14) == b"Hello, World!\n"
assert progress_stream.bytes_read == 14
assert progress_stream.readlines() == [b"Second line\n", b"This is a longer third line\n"]
assert progress_stream.bytes_read == 54

with pytest.raises(UnsupportedOperation):
progress_stream.seek(0)
with pytest.raises(UnsupportedOperation):
progress_stream.truncate(0)
with pytest.raises(UnsupportedOperation):
Expand All @@ -66,6 +65,13 @@ def test_progress_stream() -> None:
with pytest.raises(UnsupportedOperation):
progress_stream.fileno()

# seeking the stream, in any position, resets the bytes_read counter
progress_stream.seek(10)
assert progress_stream.bytes_read == 0
# the seek works as expected on the stream
assert progress_stream.read(10) == b"ld!\nSecond"
assert progress_stream.bytes_read == 10

assert not progress_stream.closed
with progress_stream:
# check that __exit__ closes the file
Expand Down

0 comments on commit f2c6500

Please sign in to comment.