Skip to content

Commit

Permalink
feat: Add model_from_pretrained_kwargs as config parameter (#122)
Browse files Browse the repository at this point in the history
* add model_from_pretrained_kwargs config parameter to allow full control over model used to extract activations from. Update tests to cover new cases

* tweaking test style

---------

Co-authored-by: David Chanin <[email protected]>
  • Loading branch information
RoganInglis and chanind authored May 9, 2024
1 parent 5154d29 commit 094b1e8
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 4 deletions.
1 change: 1 addition & 0 deletions sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, cfg: CacheActivationsRunnerConfig):
model_class_name=cfg.model_class_name,
model_name=cfg.model_name,
device=cfg.device,
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
)
self.activations_store = ActivationsStore.from_config(
self.model,
Expand Down
2 changes: 2 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class LanguageModelSAERunnerConfig:
checkpoint_path: str = "checkpoints"
verbose: bool = True
model_kwargs: dict[str, Any] = field(default_factory=dict)
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
sae_lens_version: str = field(default_factory=lambda: __version__)
sae_lens_training_version: str = field(default_factory=lambda: __version__)

Expand Down Expand Up @@ -328,6 +329,7 @@ class CacheActivationsRunnerConfig:
n_shuffles_in_entire_dir: int = 10
n_shuffles_final: int = 100
model_kwargs: dict[str, Any] = field(default_factory=dict)
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)

def __post_init__(self):
# Autofill cached_activations_path unless the user overrode it
Expand Down
15 changes: 12 additions & 3 deletions sae_lens/training/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@


def load_model(
model_class_name: str, model_name: str, device: str | torch.device | None = None
model_class_name: str,
model_name: str,
device: str | torch.device | None = None,
model_from_pretrained_kwargs: dict[str, Any] | None = None,
) -> HookedRootModule:
model_from_pretrained_kwargs = model_from_pretrained_kwargs or {}

if model_class_name == "HookedTransformer":
return HookedTransformer.from_pretrained(model_name=model_name, device=device)
return HookedTransformer.from_pretrained(
model_name=model_name, device=device, **model_from_pretrained_kwargs
)
elif model_class_name == "HookedMamba":
try:
from mamba_lens import HookedMamba
Expand All @@ -20,7 +27,9 @@ def load_model(
# HookedMamba has incorrect typing information, so we need to cast the type here
return cast(
HookedRootModule,
HookedMamba.from_pretrained(model_name, device=cast(Any, device)),
HookedMamba.from_pretrained(
model_name, device=cast(Any, device), **model_from_pretrained_kwargs
),
)
else:
raise ValueError(f"Unknown model class: {model_class_name}")
2 changes: 2 additions & 0 deletions sae_lens/training/sae_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary"
cfg.sae_lens_version = "0.0.0"
if not hasattr(cfg, "sae_lens_training_version"):
cfg.sae_lens_training_version = "0.0.0"
if not hasattr(cfg, "model_from_pretrained_kwargs"):
cfg.model_from_pretrained_kwargs = {}
sparse_autoencoder = SparseAutoencoder(cfg=cfg)
# add dummy scaling factor to the state dict
group["state_dict"]["scaling_factor"] = torch.ones(
Expand Down
5 changes: 4 additions & 1 deletion sae_lens/training/session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ def get_model(self, model_name: str) -> HookedRootModule:
# Todo: add check that model_name is valid

model = load_model(
self.cfg.model_class_name, model_name, device=self.cfg.device
self.cfg.model_class_name,
model_name,
device=self.cfg.device,
model_from_pretrained_kwargs=self.cfg.model_from_pretrained_kwargs,
)
return model
24 changes: 24 additions & 0 deletions tests/unit/training/test_load_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mamba_lens import HookedMamba
from transformer_lens import HookedTransformer

from sae_lens.training.load_model import load_model

Expand All @@ -11,3 +12,26 @@ def test_load_model_works_with_mamba():
)
assert model is not None
assert isinstance(model, HookedMamba)


def test_load_model_works_without_model_kwargs():
model = load_model(
model_class_name="HookedTransformer",
model_name="pythia-14m",
device="cpu",
)
assert model is not None
assert isinstance(model, HookedTransformer)
assert model.cfg.checkpoint_index is None


def test_load_model_works_with_model_kwargs():
model = load_model(
model_class_name="HookedTransformer",
model_name="pythia-14m",
device="cpu",
model_from_pretrained_kwargs={"checkpoint_index": 0},
)
assert model is not None
assert isinstance(model, HookedTransformer)
assert model.cfg.checkpoint_index == 0
22 changes: 22 additions & 0 deletions tests/unit/training/test_session_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,28 @@ def test_LMSparseAutoencoderSessionloader_load_session(
assert isinstance(model, HookedTransformer)
assert isinstance(next(iter(sae_group))[1], SparseAutoencoder)
assert isinstance(activations_loader, ActivationsStore)
assert model.cfg.checkpoint_index is None


def test_LMSparseAutoencoderSessionloader_load_session_can_load_model_with_kwargs():
cfg = build_sae_cfg(
model_name="pythia-14m",
hook_point="blocks.0.hook_mlp_out",
hook_point_layer=0,
dataset_path="roneneldan/TinyStories",
is_dataset_tokenized=False,
model_from_pretrained_kwargs={"checkpoint_index": 0},
)
loader = LMSparseAutoencoderSessionloader(cfg)
model, sae_group, activations_loader = loader.load_sae_training_group_session()

assert isinstance(model, HookedTransformer)
assert isinstance(next(iter(sae_group))[1], SparseAutoencoder)
assert isinstance(activations_loader, ActivationsStore)
assert (
model.cfg.checkpoint_index
== cfg.model_from_pretrained_kwargs["checkpoint_index"]
)


def test_LMSparseAutoencoderSessionloader_load_sae_session_from_pretrained(
Expand Down

0 comments on commit 094b1e8

Please sign in to comment.