From 1f92656f6024396f4b70561e193c448738d35a86 Mon Sep 17 00:00:00 2001 From: rahul-tuli Date: Thu, 28 Mar 2024 14:40:43 +0000 Subject: [PATCH] Make reloading compatible with safetensors --- src/sparseml/pytorch/model_load/helpers.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/sparseml/pytorch/model_load/helpers.py b/src/sparseml/pytorch/model_load/helpers.py index 9016583ddf3..2c8f4de9fed 100644 --- a/src/sparseml/pytorch/model_load/helpers.py +++ b/src/sparseml/pytorch/model_load/helpers.py @@ -22,6 +22,7 @@ from torch.nn import Module import sparseml.core.session as session_manager +from safetensors import safe_open from sparseml.core.framework import Framework from sparseml.pytorch.sparsification.quantization.helpers import ( initialize_channel_wise_scale_zp, @@ -143,7 +144,8 @@ def reload_model_state( weight_files = [ os.path.join(load_path, os.path.basename(f)) for f in files - if f.startswith("pytorch_model") and f.endswith("bin") + if (f.startswith("pytorch_model") and f.endswith("bin")) + or (f.endswith("safetensors")) ] if not weight_files: _LOGGER.warning( @@ -168,7 +170,10 @@ def reload_model_state( # change in keys due to architecture changes, reload statedict loaded_state_dict = {} for f in weight_files: - dd = torch.load(f, map_location="cpu") + if f.endswith("safetensors"): + dd = load_safetensors_state_dict(file_path=f) + else: + dd = torch.load(f, map_location="cpu") loaded_state_dict.update(dd) _, missing, unexpected, mismatched, _, _ = model._load_pretrained_model( @@ -334,3 +339,14 @@ def save_completed_stages(checkpoint_dir: str, completed_stages: List[str]): stage_path = os.path.join(checkpoint_dir, COMPLETED_STAGES_FILENAME) with open(stage_path, "w") as out_file: json.dump({"completed": completed_stages}, out_file) + + +def load_safetensors_state_dict(file_path: str) -> Dict[str, torch.Tensor]: + """ + Load a safetensors file from disk + + :param file_path: path to the safetensors file + :return: dictionary of safetensors data + """ + with safe_open(file_path, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()}