From c6c21518cc7369c3a0cf740c03a97fa2f984e78c Mon Sep 17 00:00:00 2001 From: Mateusz Date: Wed, 13 Nov 2024 04:39:35 +0000 Subject: [PATCH] Add model upload and load --- transformer_lens/HookedTransformer.py | 46 +++++++++++++-------- transformer_lens/HookedTransformerConfig.py | 17 ++++++++ transformer_lens/loading_from_pretrained.py | 17 ++++++++ transformer_lens/utils.py | 25 ++++++++++- 4 files changed, 86 insertions(+), 19 deletions(-) diff --git a/transformer_lens/HookedTransformer.py b/transformer_lens/HookedTransformer.py index 56096484c..275e1406e 100644 --- a/transformer_lens/HookedTransformer.py +++ b/transformer_lens/HookedTransformer.py @@ -1264,25 +1264,32 @@ def from_pretrained( ) and device in ["cpu", None]: logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.") - # Get the model name used in HuggingFace, rather than the alias. - official_model_name = loading.get_official_model_name(model_name) + try: + # Get the model name used in HuggingFace, rather than the alias. + resolved_model_name = loading.get_official_model_name(model_name) + model_type = "hf" + except ValueError: + resolved_model_name = model_name + model_type = "tl" + cfg = loading.load_tl_model_config(model_name) # Load the config into an HookedTransformerConfig object. If loading from a # checkpoint, the config object will contain the information about the # checkpoint - cfg = loading.get_pretrained_model_config( - official_model_name, - hf_cfg=hf_cfg, - checkpoint_index=checkpoint_index, - checkpoint_value=checkpoint_value, - fold_ln=fold_ln, - device=device, - n_devices=n_devices, - default_prepend_bos=default_prepend_bos, - dtype=dtype, - first_n_layers=first_n_layers, - **from_pretrained_kwargs, - ) + if model_type == "hf": + cfg = loading.get_pretrained_model_config( + resolved_model_name, + hf_cfg=hf_cfg, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=fold_ln, + device=device, + n_devices=n_devices, + default_prepend_bos=default_prepend_bos, + dtype=dtype, + first_n_layers=first_n_layers, + **from_pretrained_kwargs, + ) if cfg.positional_embedding_type == "shortformer": if fold_ln: @@ -1312,9 +1319,12 @@ def from_pretrained( # Get the state dict of the model (ie a mapping of parameter names to tensors), processed to # match the HookedTransformer parameter names. - state_dict = loading.get_pretrained_state_dict( - official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs - ) + if model_type == "hf": + state_dict = loading.get_pretrained_state_dict( + resolved_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + else: + state_dict = loading.load_tl_state_dict(model_name, dtype=dtype) # Create the HookedTransformer object model = cls( diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 4458705de..225c314a3 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -9,6 +9,7 @@ import logging import pprint import random +import json from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union @@ -355,10 +356,26 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig: Instantiates a `HookedTransformerConfig` from a Python dictionary of parameters. """ + if isinstance(config_dict.get("dtype"), str): + config_dict = config_dict.copy() + config_dict["dtype"] = getattr(torch, config_dict["dtype"]) return cls(**config_dict) def to_dict(self): return self.__dict__ + + def to_json(self, indent=None): + def _serialize(obj): + if isinstance(obj, torch.dtype): + return str(obj).split(".")[1] + if hasattr(obj, 'dtype'): + if 'int' in str(obj.dtype): + return int(obj) + if 'float' in str(obj.dtype): + return float(obj) + return obj + + return json.dumps(self.to_dict(), default=_serialize, indent=indent) def __repr__(self): return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict()) diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 49dffbf04..2adee798a 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -711,6 +711,14 @@ def get_official_model_name(model_name: str): return official_model_name +def load_tl_model_config(model_name: str): + """ + Loads the model config for a TransformerLens model. + """ + config = utils.download_file_from_hf(model_name, "tl_config.json") + return HookedTransformerConfig.from_dict(config) + + def convert_hf_model_config(model_name: str, **kwargs): """ Returns the model config for a HuggingFace model, converted to a dictionary @@ -1836,6 +1844,15 @@ def get_pretrained_state_dict( return state_dict +def load_tl_state_dict(model_name: str, dtype: torch.dtype = torch.float32): + """ + Loads the state dict for a TransformerLens model. + """ + state_dict = utils.download_file_from_hf(model_name, "state_dict.pth") + state_dict = {k: v.to(dtype) for k, v in state_dict.items()} + return state_dict + + def fill_missing_keys(model, state_dict): """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization. diff --git a/transformer_lens/utils.py b/transformer_lens/utils.py index ae4fec5cf..78faf27a1 100644 --- a/transformer_lens/utils.py +++ b/transformer_lens/utils.py @@ -6,10 +6,12 @@ from __future__ import annotations import inspect +import io import json import os import re import shutil +import tempfile from copy import deepcopy from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast @@ -21,7 +23,7 @@ import transformers from datasets.arrow_dataset import Dataset from datasets.load import load_dataset -from huggingface_hub import hf_hub_download +from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download from jaxtyping import Float, Int from rich import print as rprint from transformers import AutoTokenizer @@ -67,6 +69,27 @@ def download_file_from_hf( return file_path +def upload_model_to_hf(model: "HookedTransformer", repo_name: str, commit_message: str = None): + """ + Upload a model to the Hugging Face Hub. + """ + api = HfApi() + config_buffer = io.BytesIO() + config_buffer.write(model.cfg.to_json(indent=2).encode("utf-8")) + config_buffer.seek(0) + add_config = CommitOperationAdd(path_or_fileobj=config_buffer, path_in_repo="tl_config.json") + + with tempfile.TemporaryFile() as f: + torch.save(model.state_dict(), f) + f.seek(0) + add_model = CommitOperationAdd(path_or_fileobj=f, path_in_repo="state_dict.pth") + api.create_commit( + repo_id=repo_name, + operations=[add_config, add_model], + commit_message=commit_message, + ) + + def clear_huggingface_cache(): """ Deletes the Hugging Face cache directory and all its contents.