diff --git a/sae_lens/config.py b/sae_lens/config.py index f3478533..5f01aa8c 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -420,6 +420,7 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]: "normalize_activations": self.normalize_activations, "jumprelu_init_threshold": self.jumprelu_init_threshold, "jumprelu_bandwidth": self.jumprelu_bandwidth, + "scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm, } def to_dict(self) -> dict[str, Any]: diff --git a/tests/unit/training/test_config.py b/tests/unit/training/test_config.py index 6b1a93d9..88b42222 100644 --- a/tests/unit/training/test_config.py +++ b/tests/unit/training/test_config.py @@ -9,6 +9,17 @@ TINYSTORIES_DATASET = "roneneldan/TinyStories" +def test_get_training_sae_cfg_dict_passes_scale_sparsity_penalty_by_decoder_norm(): + cfg = LanguageModelSAERunnerConfig( + scale_sparsity_penalty_by_decoder_norm=True, normalize_sae_decoder=False + ) + assert cfg.get_training_sae_cfg_dict()["scale_sparsity_penalty_by_decoder_norm"] + cfg = LanguageModelSAERunnerConfig( + scale_sparsity_penalty_by_decoder_norm=False, normalize_sae_decoder=False + ) + assert not cfg.get_training_sae_cfg_dict()["scale_sparsity_penalty_by_decoder_norm"] + + def test_sae_training_runner_config_runs_with_defaults(): """ Helper to create a mock instance of LanguageModelSAERunnerConfig. @@ -138,3 +149,26 @@ def test_cache_activations_runner_config_seqpos( seqpos_slice=seqpos_slice, context_size=context_size, ) + + +def test_topk_architecture_requires_topk_activation(): + with pytest.raises( + ValueError, match="If using topk architecture, activation_fn must be topk." + ): + LanguageModelSAERunnerConfig(architecture="topk", activation_fn="relu") + + +def test_topk_architecture_requires_k_parameter(): + with pytest.raises( + ValueError, + match="activation_fn_kwargs.k must be provided for topk architecture.", + ): + LanguageModelSAERunnerConfig( + architecture="topk", activation_fn="topk", activation_fn_kwargs={} + ) + + +def test_topk_architecture_sets_topk_defaults(): + cfg = LanguageModelSAERunnerConfig(architecture="topk") + assert cfg.activation_fn == "topk" + assert cfg.activation_fn_kwargs == {"k": 100}