diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index 88d0b984..774875e9 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -312,7 +312,12 @@ def load_from_pretrained_legacy(cls, path: str): # Create an instance of the class using the loaded configuration instance = cls(cfg=state_dict["cfg"]) - instance.load_state_dict(state_dict["state_dict"]) + if "scaling_factor" not in state_dict["state_dict"]: + assert isinstance(instance.cfg.d_sae, int) + state_dict["state_dict"]["scaling_factor"] = torch.ones( + instance.cfg.d_sae, dtype=instance.cfg.dtype, device=instance.cfg.device + ) + instance.load_state_dict(state_dict["state_dict"], strict=True) return instance diff --git a/tests/unit/toolkit/test_pretrained_saes.py b/tests/unit/toolkit/test_pretrained_saes.py index 05a0d8a7..19ee5b3b 100644 --- a/tests/unit/toolkit/test_pretrained_saes.py +++ b/tests/unit/toolkit/test_pretrained_saes.py @@ -24,7 +24,7 @@ def test_convert_old_to_modern_saelens_format(): # convert file format pretrained_saes.convert_old_to_modern_saelens_format( - legacy_out_file, new_out_folder + legacy_out_file, new_out_folder, force=True ) # Load from new converted file diff --git a/tests/unit/training/test_sparse_autoencoder.py b/tests/unit/training/test_sparse_autoencoder.py index 40febe74..9ca6c6f2 100644 --- a/tests/unit/training/test_sparse_autoencoder.py +++ b/tests/unit/training/test_sparse_autoencoder.py @@ -112,6 +112,43 @@ def test_SparseAutoencoder_save_and_load_from_pretrained(tmp_path: Path) -> None ) +def test_SparseAutoencoder_save_and_load_from_pretrained_lacks_scaling_factor( + tmp_path: Path, +) -> None: + cfg = build_sae_cfg(device="cpu") + model_path = str(tmp_path) + sparse_autoencoder = SparseAutoencoder(cfg) + sparse_autoencoder_state_dict = sparse_autoencoder.state_dict() + # sometimes old state dicts will be missing the scaling factor + del sparse_autoencoder_state_dict["scaling_factor"] # = torch.tensor(0.0) + sparse_autoencoder.save_model(model_path) + + assert os.path.exists(model_path) + + sparse_autoencoder_loaded = SparseAutoencoder.load_from_pretrained(model_path) + sparse_autoencoder_loaded.cfg.verbose = True + sparse_autoencoder_loaded.cfg.checkpoint_path = cfg.checkpoint_path + sparse_autoencoder_loaded.cfg.device = "cpu" # might autoload onto mps + sparse_autoencoder_loaded = sparse_autoencoder_loaded.to("cpu") + sparse_autoencoder_loaded_state_dict = sparse_autoencoder_loaded.state_dict() + # check cfg matches the original + assert sparse_autoencoder_loaded.cfg == cfg + + # check state_dict matches the original + for key in sparse_autoencoder.state_dict().keys(): + if key == "scaling_factor": + assert isinstance(cfg.d_sae, int) + assert torch.allclose( + torch.ones(cfg.d_sae, dtype=cfg.dtype, device=cfg.device), + sparse_autoencoder_loaded_state_dict[key], + ) + else: + assert torch.allclose( + sparse_autoencoder_state_dict[key], + sparse_autoencoder_loaded_state_dict[key], + ) + + def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder): batch_size = 32 d_in = sparse_autoencoder.d_in