Skip to content

Commit

Permalink
Make reloading compatible with safetensors (#2201)
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Mar 28, 2024
1 parent 85b0e72 commit d3498c2
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/sparseml/pytorch/model_load/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from torch.nn import Module

import sparseml.core.session as session_manager
from safetensors import safe_open
from sparseml.core.framework import Framework
from sparseml.pytorch.sparsification.quantization.helpers import (
initialize_channel_wise_scale_zp,
Expand Down Expand Up @@ -143,7 +144,8 @@ def reload_model_state(
weight_files = [
os.path.join(load_path, os.path.basename(f))
for f in files
if f.startswith("pytorch_model") and f.endswith("bin")
if (f.startswith("pytorch_model") and f.endswith("bin"))
or (f.endswith("safetensors"))
]
if not weight_files:
_LOGGER.warning(
Expand All @@ -168,7 +170,10 @@ def reload_model_state(
# change in keys due to architecture changes, reload statedict
loaded_state_dict = {}
for f in weight_files:
dd = torch.load(f, map_location="cpu")
if f.endswith("safetensors"):
dd = load_safetensors_state_dict(file_path=f)
else:
dd = torch.load(f, map_location="cpu")
loaded_state_dict.update(dd)

_, missing, unexpected, mismatched, _, _ = model._load_pretrained_model(
Expand Down Expand Up @@ -334,3 +339,14 @@ def save_completed_stages(checkpoint_dir: str, completed_stages: List[str]):
stage_path = os.path.join(checkpoint_dir, COMPLETED_STAGES_FILENAME)
with open(stage_path, "w") as out_file:
json.dump({"completed": completed_stages}, out_file)


def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]:
"""
Load a safetensors file from disk
:param file_path: path to the safetensors file
:return: dictionary of safetensors data
"""
with safe_open(file_path, framework="pt", device="cpu") as f:
return {key: f.get_tensor(key) for key in f.keys()}

0 comments on commit d3498c2

Please sign in to comment.