Skip to content

Commit

Permalink
updates formatting for alignment with repo standards.
Browse files Browse the repository at this point in the history
  • Loading branch information
evanhanders committed Apr 18, 2024
1 parent 0f85ded commit 5e1f342
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 24 deletions.
12 changes: 6 additions & 6 deletions sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,10 @@ def convert_connor_rob_sae_to_our_saelens_format(
ae_alt.load_state_dict(state_dict)
return ae_alt


def convert_old_to_modern_saelens_format(
pytorch_file: str,
out_folder: str = None,
force: bool = False
):
pytorch_file: str, out_folder: str = None, force: bool = False
):
"""
Reads a pretrained SAE from the old pickle-style SAELens .pt format, then saves a modern-format SAELens SAE.
Expand All @@ -154,17 +153,18 @@ def convert_old_to_modern_saelens_format(
"""
file_path = pathlib.Path(pytorch_file)
if out_folder is None:
out_folder = file_path.parent/file_path.stem
out_folder = file_path.parent / file_path.stem
else:
out_folder = pathlib.Path(out_folder)
if (not force) and out_folder.exists():
raise FileExistsError(f"{out_folder} already exists and force=False")
out_folder.mkdir(exist_ok=True, parents=True)

#Load model & save in new format.
# Load model & save in new format.
autoencoder = SparseAutoencoder.load_from_pretrained_legacy(str(file_path))
autoencoder.save_model(out_folder)


def get_gpt2_small_ckrk_attn_out_saes() -> dict[str, SparseAutoencoder]:

REPO_ID = "ckkissane/attn-saes-gpt2-small-all-layers"
Expand Down
3 changes: 1 addition & 2 deletions sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,7 @@ def load_from_pretrained_legacy(cls, path: str):
state_dict["cfg"].device = "mps"
else:
state_dict = torch.load(
path,
pickle_module=BackwardsCompatiblePickleClass
path, pickle_module=BackwardsCompatiblePickleClass
)
except Exception as e:
raise IOError(f"Error loading the state dictionary from .pt file: {e}")
Expand Down
30 changes: 14 additions & 16 deletions tests/unit/toolkit/test_pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,34 @@
import pytest
import torch

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.toolkit import pretrained_saes
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


def test_convert_old_to_modern_saelens_format():
out_dir = pathlib.Path('unit_test_tmp')
out_dir = pathlib.Path("unit_test_tmp")
out_dir.mkdir(exist_ok=True)
legacy_out_file = str(out_dir/'test.pt')
new_out_folder = str(out_dir/'test')
legacy_out_file = str(out_dir / "test.pt")
new_out_folder = str(out_dir / "test")

#Make an SAE, save old version
# Make an SAE, save old version
cfg = LanguageModelSAERunnerConfig(
dtype=torch.float32,
hook_point = 'blocks.0.hook_mlp_out',
hook_point="blocks.0.hook_mlp_out",
)
old_sae = SparseAutoencoder(cfg)
old_sae.save_model_legacy(legacy_out_file)

#convert file format
# convert file format
pretrained_saes.convert_old_to_modern_saelens_format(
legacy_out_file,
new_out_folder
legacy_out_file, new_out_folder
)

#Load from new converted file
new_sae = SparseAutoencoder.load_from_pretrained(
new_out_folder
)
shutil.rmtree(out_dir) #cleanup
# Load from new converted file
new_sae = SparseAutoencoder.load_from_pretrained(new_out_folder)
shutil.rmtree(out_dir) # cleanup

#Test similarity
# Test similarity
assert torch.allclose(new_sae.W_enc, old_sae.W_enc)
assert torch.allclose(new_sae.W_dec, old_sae.W_dec)
assert torch.allclose(new_sae.W_dec, old_sae.W_dec)

0 comments on commit 5e1f342

Please sign in to comment.