Skip to content

Commit

Permalink
chore: adding test that all config params pass to sae
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 14, 2024
1 parent c07a3c7 commit 4b37fa2
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dataclasses import fields
from typing import Optional

import pytest

from sae_lens import __version__
from sae_lens.config import CacheActivationsRunnerConfig, LanguageModelSAERunnerConfig
from sae_lens.sae import SAEConfig
from sae_lens.training.training_sae import TrainingSAEConfig

TINYSTORIES_MODEL = "tiny-stories-1M"
TINYSTORIES_DATASET = "roneneldan/TinyStories"
Expand All @@ -20,6 +23,26 @@ def test_get_training_sae_cfg_dict_passes_scale_sparsity_penalty_by_decoder_norm
assert not cfg.get_training_sae_cfg_dict()["scale_sparsity_penalty_by_decoder_norm"]


def test_get_training_sae_cfg_dict_has_all_relevant_options():
cfg = LanguageModelSAERunnerConfig()
cfg_dict = cfg.get_training_sae_cfg_dict()
training_sae_opts = fields(TrainingSAEConfig)
allowed_missing_fields = {"neuronpedia_id"}
training_sae_field_names = {opt.name for opt in training_sae_opts}
missing_fields = training_sae_field_names - allowed_missing_fields - cfg_dict.keys()
assert missing_fields == set()


def test_get_base_sae_cfg_dict_has_all_relevant_options():
cfg = LanguageModelSAERunnerConfig()
cfg_dict = cfg.get_base_sae_cfg_dict()
sae_opts = fields(SAEConfig)
allowed_missing_fields = {"neuronpedia_id"}
sae_field_names = {opt.name for opt in sae_opts}
missing_fields = sae_field_names - allowed_missing_fields - cfg_dict.keys()
assert missing_fields == set()


def test_sae_training_runner_config_runs_with_defaults():
"""
Helper to create a mock instance of LanguageModelSAERunnerConfig.
Expand Down

0 comments on commit 4b37fa2

Please sign in to comment.