From 445e7458fff6174bfdc186205d654e8f5ee55097 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Sat, 19 Oct 2024 23:26:07 -0700 Subject: [PATCH 1/3] add unit tests for untested functions --- tests/unit/test_evals.py | 81 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index 535ddb17..727d4426 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -10,12 +10,15 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.evals import ( EvalConfig, + all_loadable_saes, get_eval_everything_config, + get_saes_from_regex, process_results, run_evals, run_evaluations, ) from sae_lens.sae import SAE +from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.training_sae import TrainingSAE from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached @@ -327,3 +330,81 @@ def test_process_results(tmp_path: Path): # Check if CSV file is created csv_path = output_dir / "all_eval_results.csv" assert csv_path.exists() + + +@patch("sae_lens.evals.get_pretrained_saes_directory") +def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock): + mock_get_pretrained_saes_directory.return_value = { + "release1": PretrainedSAELookup( + release="release1", + repo_id="repo1", + model="model1", + conversion_func=None, + saes_map={"sae1": "path1", "sae2": "path2"}, + expected_var_explained={"sae1": 0.9, "sae2": 0.85}, + expected_l0={"sae1": 0.1, "sae2": 0.15}, + neuronpedia_id={}, + config_overrides=None, + ), + "release2": PretrainedSAELookup( + release="release2", + repo_id="repo2", + model="model2", + conversion_func=None, + saes_map={"sae3": "path3"}, + expected_var_explained={"sae3": 0.8}, + expected_l0={"sae3": 0.2}, + neuronpedia_id={}, + config_overrides=None, + ), + } + + result = all_loadable_saes() + + expected = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ] + assert result == expected + + +mock_all_saes = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ("release2", "block1", 0.95, 0.05), +] + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_no_match(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release1", "sae3") + + assert not result + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_single_match(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release1", "sae1") + + expected = [("release1", "sae1", 0.9, 0.1)] + assert result == expected + + +@patch("sae_lens.evals.all_loadable_saes") +def test_get_saes_from_regex_multiple_matches(mock_all_loadable_saes: MagicMock): + mock_all_loadable_saes.return_value = mock_all_saes + + result = get_saes_from_regex("release.*", "sae.*") + + expected = [ + ("release1", "sae1", 0.9, 0.1), + ("release1", "sae2", 0.85, 0.15), + ("release2", "sae3", 0.8, 0.2), + ] + assert result == expected From c214a44c59da4e1a6efe404383e7420abea1f599 Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Sun, 20 Oct 2024 13:22:26 -0700 Subject: [PATCH 2/3] adds test to increase coverage --- tests/unit/test_evals.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index 727d4426..3e7fe0ea 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -70,6 +70,14 @@ "hook_layer": 1, "d_in": 16 * 4, }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks..attn.hook_q", + "hook_layer": 1, + "d_in": 4, + "hook_head_index": 2, + }, ], ids=[ "tiny-stories-1M-resid-pre", @@ -77,6 +85,7 @@ "tiny-stories-1M-resid-pre-pretokenized", "tiny-stories-1M-hook-z", "tiny-stories-1M-hook-q", + "tiny-stories-1M-hook-q-head-index-2", ], ) def cfg(request: pytest.FixtureRequest): From 335badb2186e4e0ca7f4563db58ca92d70858c8b Mon Sep 17 00:00:00 2001 From: anthonyduong Date: Mon, 21 Oct 2024 23:32:19 -0700 Subject: [PATCH 3/3] fixes typo --- tests/unit/test_evals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_evals.py b/tests/unit/test_evals.py index 3e7fe0ea..435defdf 100644 --- a/tests/unit/test_evals.py +++ b/tests/unit/test_evals.py @@ -73,7 +73,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks..attn.hook_q", + "hook_name": "blocks.1.attn.hook_q", "hook_layer": 1, "d_in": 4, "hook_head_index": 2,