Skip to content

Commit

Permalink
fix: hotfix scale decoder norm is not passed to training sae
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 14, 2024
1 parent ad740d4 commit 240a78d
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
1 change: 1 addition & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"jumprelu_init_threshold": self.jumprelu_init_threshold,
"jumprelu_bandwidth": self.jumprelu_bandwidth,
"scale_sparsity_penalty_by_decoder_norm": self.scale_sparsity_penalty_by_decoder_norm,
}

def to_dict(self) -> dict[str, Any]:
Expand Down
34 changes: 34 additions & 0 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@
TINYSTORIES_DATASET = "roneneldan/TinyStories"


def test_get_training_sae_cfg_dict_passes_scale_sparsity_penalty_by_decoder_norm():
cfg = LanguageModelSAERunnerConfig(
scale_sparsity_penalty_by_decoder_norm=True, normalize_sae_decoder=False
)
assert cfg.get_training_sae_cfg_dict()["scale_sparsity_penalty_by_decoder_norm"]
cfg = LanguageModelSAERunnerConfig(
scale_sparsity_penalty_by_decoder_norm=False, normalize_sae_decoder=False
)
assert not cfg.get_training_sae_cfg_dict()["scale_sparsity_penalty_by_decoder_norm"]


def test_sae_training_runner_config_runs_with_defaults():
"""
Helper to create a mock instance of LanguageModelSAERunnerConfig.
Expand Down Expand Up @@ -138,3 +149,26 @@ def test_cache_activations_runner_config_seqpos(
seqpos_slice=seqpos_slice,
context_size=context_size,
)


def test_topk_architecture_requires_topk_activation():
with pytest.raises(
ValueError, match="If using topk architecture, activation_fn must be topk."
):
LanguageModelSAERunnerConfig(architecture="topk", activation_fn="relu")


def test_topk_architecture_requires_k_parameter():
with pytest.raises(
ValueError,
match="activation_fn_kwargs.k must be provided for topk architecture.",
):
LanguageModelSAERunnerConfig(
architecture="topk", activation_fn="topk", activation_fn_kwargs={}
)


def test_topk_architecture_sets_topk_defaults():
cfg = LanguageModelSAERunnerConfig(architecture="topk")
assert cfg.activation_fn == "topk"
assert cfg.activation_fn_kwargs == {"k": 100}

0 comments on commit 240a78d

Please sign in to comment.