Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add model_from_pretrained_kwargs as config parameter #122

Merged
merged 4 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
model_class_name=cfg.model_class_name,
model_name=cfg.model_name,
device=cfg.device,
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
)
self.activations_store = ActivationsStore.from_config(
self.model,
Expand Down
2 changes: 2 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class LanguageModelSAERunnerConfig:
checkpoint_path: str = "checkpoints"
verbose: bool = True
model_kwargs: dict[str, Any] = field(default_factory=dict)
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
sae_lens_version: str = field(default_factory=lambda: __version__)
sae_lens_training_version: str = field(default_factory=lambda: __version__)

Expand Down Expand Up @@ -328,6 +329,7 @@ class CacheActivationsRunnerConfig:
n_shuffles_in_entire_dir: int = 10
n_shuffles_final: int = 100
model_kwargs: dict[str, Any] = field(default_factory=dict)
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
# Autofill cached_activations_path unless the user overrode it
Expand Down
15 changes: 12 additions & 3 deletions sae_lens/training/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@


def load_model(
model_class_name: str, model_name: str, device: str | torch.device | None = None
model_class_name: str,
model_name: str,
device: str | torch.device | None = None,
model_from_pretrained_kwargs: dict[str, Any] | None = None,
) -> HookedRootModule:
model_from_pretrained_kwargs = model_from_pretrained_kwargs or {}

if model_class_name == "HookedTransformer":
return HookedTransformer.from_pretrained(model_name=model_name, device=device)
return HookedTransformer.from_pretrained(
model_name=model_name, device=device, **model_from_pretrained_kwargs
)
elif model_class_name == "HookedMamba":
try:
from mamba_lens import HookedMamba
Expand All @@ -20,7 +27,9 @@ def load_model(
# HookedMamba has incorrect typing information, so we need to cast the type here
return cast(
HookedRootModule,
HookedMamba.from_pretrained(model_name, device=cast(Any, device)),
HookedMamba.from_pretrained(
model_name, device=cast(Any, device), **model_from_pretrained_kwargs
),
)
else:
raise ValueError(f"Unknown model class: {model_class_name}")
2 changes: 2 additions & 0 deletions sae_lens/training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary"
cfg.sae_lens_version = "0.0.0"
if not hasattr(cfg, "sae_lens_training_version"):
cfg.sae_lens_training_version = "0.0.0"
if not hasattr(cfg, "model_from_pretrained_kwargs"):
cfg.model_from_pretrained_kwargs = {}
sparse_autoencoder = SparseAutoencoder(cfg=cfg)
# add dummy scaling factor to the state dict
group["state_dict"]["scaling_factor"] = torch.ones(
Expand Down
5 changes: 4 additions & 1 deletion sae_lens/training/session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def get_model(self, model_name: str) -> HookedRootModule:
# Todo: add check that model_name is valid

model = load_model(
self.cfg.model_class_name, model_name, device=self.cfg.device
self.cfg.model_class_name,
model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)
return model
24 changes: 24 additions & 0 deletions tests/unit/training/test_load_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mamba_lens import HookedMamba
from transformer_lens import HookedTransformer

from sae_lens.training.load_model import load_model

Expand All @@ -11,3 +12,26 @@ def test_load_model_works_with_mamba():
)
assert model is not None
assert isinstance(model, HookedMamba)


def test_load_model_works_without_model_kwargs():
model = load_model(
model_class_name="HookedTransformer",
model_name="pythia-14m",
device="cpu",
)
assert model is not None
assert isinstance(model, HookedTransformer)
assert model.cfg.checkpoint_index is None


def test_load_model_works_with_model_kwargs():
model = load_model(
model_class_name="HookedTransformer",
model_name="pythia-14m",
device="cpu",
model_from_pretrained_kwargs={"checkpoint_index": 0},
)
assert model is not None
assert isinstance(model, HookedTransformer)
assert model.cfg.checkpoint_index == 0
22 changes: 22 additions & 0 deletions tests/unit/training/test_session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def test_LMSparseAutoencoderSessionloader_load_session(
assert isinstance(model, HookedTransformer)
assert isinstance(next(iter(sae_group))[1], SparseAutoencoder)
assert isinstance(activations_loader, ActivationsStore)
assert model.cfg.checkpoint_index is None


def test_LMSparseAutoencoderSessionloader_load_session_can_load_model_with_kwargs():
cfg = build_sae_cfg(
model_name="pythia-14m",
hook_point="blocks.0.hook_mlp_out",
hook_point_layer=0,
dataset_path="roneneldan/TinyStories",
is_dataset_tokenized=False,
model_from_pretrained_kwargs={"checkpoint_index": 0},
)
loader = LMSparseAutoencoderSessionloader(cfg)
model, sae_group, activations_loader = loader.load_sae_training_group_session()

assert isinstance(model, HookedTransformer)
assert isinstance(next(iter(sae_group))[1], SparseAutoencoder)
assert isinstance(activations_loader, ActivationsStore)
assert (
model.cfg.checkpoint_index
== cfg.model_from_pretrained_kwargs["checkpoint_index"]
)


def test_LMSparseAutoencoderSessionloader_load_sae_session_from_pretrained(
Expand Down