Skip to content

Commit

Permalink
Merge pull request jbloomAus#95 from jbloomAus/load-state-dict-not-st…
Browse files Browse the repository at this point in the history
…rict

Make load_state_dict use strict=False
  • Loading branch information
jbloomAus authored Apr 21, 2024
2 parents c020263 + c9da015 commit 2784e34
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
7 changes: 6 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,12 @@ def load_from_pretrained_legacy(cls, path: str):

# Create an instance of the class using the loaded configuration
instance = cls(cfg=state_dict["cfg"])
instance.load_state_dict(state_dict["state_dict"])
if "scaling_factor" not in state_dict["state_dict"]:
assert isinstance(instance.cfg.d_sae, int)
state_dict["state_dict"]["scaling_factor"] = torch.ones(
instance.cfg.d_sae, dtype=instance.cfg.dtype, device=instance.cfg.device
)
instance.load_state_dict(state_dict["state_dict"], strict=True)

return instance

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/toolkit/test_pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_convert_old_to_modern_saelens_format():

# convert file format
pretrained_saes.convert_old_to_modern_saelens_format(
legacy_out_file, new_out_folder
legacy_out_file, new_out_folder, force=True
)

# Load from new converted file
Expand Down
37 changes: 37 additions & 0 deletions tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,43 @@ def test_SparseAutoencoder_save_and_load_from_pretrained(tmp_path: Path) -> None
)


def test_SparseAutoencoder_save_and_load_from_pretrained_lacks_scaling_factor(
tmp_path: Path,
) -> None:
cfg = build_sae_cfg(device="cpu")
model_path = str(tmp_path)
sparse_autoencoder = SparseAutoencoder(cfg)
sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
# sometimes old state dicts will be missing the scaling factor
del sparse_autoencoder_state_dict["scaling_factor"] # = torch.tensor(0.0)
sparse_autoencoder.save_model(model_path)

assert os.path.exists(model_path)

sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(model_path)
sparse_autoencoder_loaded.cfg.verbose = True
sparse_autoencoder_loaded.cfg.checkpoint_path = cfg.checkpoint_path
sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps
sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu")
sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict()
# check cfg matches the original
assert sparse_autoencoder_loaded.cfg == cfg

# check state_dict matches the original
for key in sparse_autoencoder.state_dict().keys():
if key == "scaling_factor":
assert isinstance(cfg.d_sae, int)
assert torch.allclose(
torch.ones(cfg.d_sae, dtype=cfg.dtype, device=cfg.device),
sparse_autoencoder_loaded_state_dict[key],
)
else:
assert torch.allclose(
sparse_autoencoder_state_dict[key],
sparse_autoencoder_loaded_state_dict[key],
)


def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder):
batch_size = 32
d_in = sparse_autoencoder.d_in
Expand Down

0 comments on commit 2784e34

Please sign in to comment.