diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index 535ddb17..435defdf 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -10,12 +10,15 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.evals import ( EvalConfig, + all_loadable_saes, get_eval_everything_config, + get_saes_from_regex, process_results, run_evals, run_evaluations, ) from sae_lens.sae import SAE +from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.training_sae import TrainingSAE from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached @@ -67,6 +70,14 @@ "hook_layer": 1, "d_in": 16 * 4, }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks.1.attn.hook_q", + "hook_layer": 1, + "d_in": 4, + "hook_head_index": 2, + }, ], ids=[ "tiny-stories-1M-resid-pre", @@ -74,6 +85,7 @@ "tiny-stories-1M-resid-pre-pretokenized", "tiny-stories-1M-hook-z", "tiny-stories-1M-hook-q", + "tiny-stories-1M-hook-q-head-index-2", ], ) def cfg(request: pytest.FixtureRequest): @@ -327,3 +339,81 @@ def test_process_results(tmp_path: Path): # Check if CSV file is created csv_path = output_dir / "all_eval_results.csv" assert csv_path.exists() + + +@patch("sae_lens.evals.get_pretrained_saes_directory") +def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock): + mock_get_pretrained_saes_directory.return_value = { + "release1": PretrainedSAELookup( + release="release1", + repo_id="repo1", + model="model1", + conversion_func=None, + saes_map={"sae1": "path1", "sae2": "path2"}, + expected_var_explained={"sae1": 0.9, "sae2": 0.85}, + expected_l0={"sae1": 0.1, "sae2": 0.15}, + neuronpedia_id={}, + config_overrides=None, + ), + "release2": PretrainedSAELookup( + release="release2", + repo_id="repo2", + model="model2", + conversion_func=None, + saes_map={"sae3": "path3"}, + expected_var_explained={"sae3": 0.8}, + expected_l0={"sae3": 0.2}, + neuronpedia_id={}, + config_overrides=None, + ), + } + + result = all_loadable_saes() + + expected = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ] + assert result == expected + + +mock_all_saes = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ("release2", "block1", 0.95, 0.05), +] + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_no_match(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release1", "sae3") + + assert not result + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_single_match(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release1", "sae1") + + expected = [("release1", "sae1", 0.9, 0.1)] + assert result == expected + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_multiple_matches(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release.*", "sae.*") + + expected = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ] + assert result == expected