Skip to content

Commit

Permalink
s3: tested backwards compat loading w/ broken cache
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Oct 8, 2024
1 parent f7f514d commit bf34644
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions helpers/data_backend/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ def _detect_file_format(self, fileobj):
fileobj.seek(0)
magic_number = fileobj.read(4)
fileobj.seek(0)
if magic_number[:2] == b"\x80\x04":
logger.debug(f"Magic number: {magic_number}")
if magic_number[:2] == b"\x80\x04" or b"PK" in magic_number:
# This is likely a torch-saved object (Pickle protocol 4)
# Need to check whether it's the incorrectly saved compressed data
try:
Expand Down Expand Up @@ -324,23 +325,24 @@ def torch_load(self, s3_key):

# Determine if the file was saved incorrectly
file_format = self._detect_file_format(stored_data)

logger.debug(f"File format: {file_format}")
if file_format == "incorrect":
# Load the compressed bytes object serialized by torch.save
stored_data.seek(0)
compressed_data = torch.load(stored_data, map_location="cpu")
compressed_data = BytesIO(
torch.load(stored_data, map_location="cpu")
)
# Decompress the data
decompressed_data = self._decompress_torch(compressed_data)
stored_tensor = BytesIO(decompressed_data)
stored_tensor = self._decompress_torch(compressed_data)
elif file_format == "correct_compressed":
# Data is compressed but saved correctly
decompressed_data = self._decompress_torch(data)
stored_tensor = BytesIO(decompressed_data)
else:
# Data is uncompressed and saved correctly
stored_tensor = stored_data

stored_tensor.seek(0)
if hasattr(stored_tensor, "seek"):
stored_tensor.seek(0)
obj = torch.load(stored_tensor, map_location="cpu")

if isinstance(obj, tuple):
Expand Down

0 comments on commit bf34644

Please sign in to comment.