Skip to content

Commit

Permalink
enforce irpa extension on save
Browse files Browse the repository at this point in the history
Signed-off-by: dan <[email protected]>
  • Loading branch information
dan-garvey committed Oct 4, 2024
1 parent 64b7d27 commit 9cd848b
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion shark_turbine/aot/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@ def index(self) -> ParameterIndex:

def save(self, file_path: Union[str, Path]):
"""Saves the archive."""
self._index.create_archive_file(str(file_path))
str_file_path = str(file_path)
if not str_file_path.endswith(".irpa"):
file_path = str_file_path + ".irpa"
self._index.create_archive_file(str_file_path)

def add_tensor(self, name: str, tensor: torch.Tensor):
"""Adds an named tensor to the archive."""
Expand Down

0 comments on commit 9cd848b

Please sign in to comment.