Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model upload and load #779

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 28 additions & 18 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,25 +1264,32 @@ def from_pretrained(
) and device in ["cpu", None]:
logging.warning("float16 models may not work on CPU. Consider using a GPU or bfloat16.")

# Get the model name used in HuggingFace, rather than the alias.
official_model_name = loading.get_official_model_name(model_name)
try:
# Get the model name used in HuggingFace, rather than the alias.
resolved_model_name = loading.get_official_model_name(model_name)
model_type = "hf"
except ValueError:
resolved_model_name = model_name
model_type = "tl"
cfg = loading.load_tl_model_config(model_name)

# Load the config into an HookedTransformerConfig object. If loading from a
# checkpoint, the config object will contain the information about the
# checkpoint
cfg = loading.get_pretrained_model_config(
official_model_name,
hf_cfg=hf_cfg,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
dtype=dtype,
first_n_layers=first_n_layers,
**from_pretrained_kwargs,
)
if model_type == "hf":
cfg = loading.get_pretrained_model_config(
resolved_model_name,
hf_cfg=hf_cfg,
checkpoint_index=checkpoint_index,
checkpoint_value=checkpoint_value,
fold_ln=fold_ln,
device=device,
n_devices=n_devices,
default_prepend_bos=default_prepend_bos,
dtype=dtype,
first_n_layers=first_n_layers,
**from_pretrained_kwargs,
)

if cfg.positional_embedding_type == "shortformer":
if fold_ln:
Expand Down Expand Up @@ -1312,9 +1319,12 @@ def from_pretrained(

# Get the state dict of the model (ie a mapping of parameter names to tensors), processed to
# match the HookedTransformer parameter names.
state_dict = loading.get_pretrained_state_dict(
official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
)
if model_type == "hf":
state_dict = loading.get_pretrained_state_dict(
resolved_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs
)
else:
state_dict = loading.load_tl_state_dict(model_name, dtype=dtype)

# Create the HookedTransformer object
model = cls(
Expand Down
17 changes: 17 additions & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import pprint
import random
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

Expand Down Expand Up @@ -355,10 +356,26 @@ def from_dict(cls, config_dict: Dict[str, Any]) -> HookedTransformerConfig:
Instantiates a `HookedTransformerConfig` from a Python dictionary of
parameters.
"""
if isinstance(config_dict.get("dtype"), str):
config_dict = config_dict.copy()
config_dict["dtype"] = getattr(torch, config_dict["dtype"])
return cls(**config_dict)

def to_dict(self):
return self.__dict__

def to_json(self, indent=None):
def _serialize(obj):
if isinstance(obj, torch.dtype):
return str(obj).split(".")[1]
if hasattr(obj, 'dtype'):
if 'int' in str(obj.dtype):
return int(obj)
if 'float' in str(obj.dtype):
return float(obj)
return obj

return json.dumps(self.to_dict(), default=_serialize, indent=indent)

def __repr__(self):
return "HookedTransformerConfig:\n" + pprint.pformat(self.to_dict())
Expand Down
17 changes: 17 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,14 @@ def get_official_model_name(model_name: str):
return official_model_name


def load_tl_model_config(model_name: str):
"""
Loads the model config for a TransformerLens model.
"""
config = utils.download_file_from_hf(model_name, "tl_config.json")
return HookedTransformerConfig.from_dict(config)


def convert_hf_model_config(model_name: str, **kwargs):
"""
Returns the model config for a HuggingFace model, converted to a dictionary
Expand Down Expand Up @@ -1836,6 +1844,15 @@ def get_pretrained_state_dict(
return state_dict


def load_tl_state_dict(model_name: str, dtype: torch.dtype = torch.float32):
"""
Loads the state dict for a TransformerLens model.
"""
state_dict = utils.download_file_from_hf(model_name, "state_dict.pth")
state_dict = {k: v.to(dtype) for k, v in state_dict.items()}
return state_dict


def fill_missing_keys(model, state_dict):
"""Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.

Expand Down
25 changes: 24 additions & 1 deletion transformer_lens/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from __future__ import annotations

import inspect
import io
import json
import os
import re
import shutil
import tempfile
from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast

Expand All @@ -21,7 +23,7 @@
import transformers
from datasets.arrow_dataset import Dataset
from datasets.load import load_dataset
from huggingface_hub import hf_hub_download
from huggingface_hub import CommitOperationAdd, HfApi, hf_hub_download
from jaxtyping import Float, Int
from rich import print as rprint
from transformers import AutoTokenizer
Expand Down Expand Up @@ -67,6 +69,27 @@ def download_file_from_hf(
return file_path


def upload_model_to_hf(model: "HookedTransformer", repo_name: str, commit_message: str = None):
"""
Upload a model to the Hugging Face Hub.
"""
api = HfApi()
config_buffer = io.BytesIO()
config_buffer.write(model.cfg.to_json(indent=2).encode("utf-8"))
config_buffer.seek(0)
add_config = CommitOperationAdd(path_or_fileobj=config_buffer, path_in_repo="tl_config.json")

with tempfile.TemporaryFile() as f:
torch.save(model.state_dict(), f)
f.seek(0)
add_model = CommitOperationAdd(path_or_fileobj=f, path_in_repo="state_dict.pth")
api.create_commit(
repo_id=repo_name,
operations=[add_config, add_model],
commit_message=commit_message,
)


def clear_huggingface_cache():
"""
Deletes the Hugging Face cache directory and all its contents.
Expand Down
Loading