Skip to content

Commit

Permalink
chore: Add tests for evals (#346)
Browse files Browse the repository at this point in the history
* add unit tests for untested functions

* adds test to increase coverage

* fixes typo
  • Loading branch information
anthonyduong9 authored Oct 22, 2024
1 parent d9db786 commit 06594f9
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/unit/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.evals import (
EvalConfig,
all_loadable_saes,
get_eval_everything_config,
get_saes_from_regex,
process_results,
run_evals,
run_evaluations,
)
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached
Expand Down Expand Up @@ -67,13 +70,22 @@
"hook_layer": 1,
"d_in": 16 * 4,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"hook_name": "blocks.1.attn.hook_q",
"hook_layer": 1,
"d_in": 4,
"hook_head_index": 2,
},
],
ids=[
"tiny-stories-1M-resid-pre",
"tiny-stories-1M-resid-pre-L1-W-dec-Norm",
"tiny-stories-1M-resid-pre-pretokenized",
"tiny-stories-1M-hook-z",
"tiny-stories-1M-hook-q",
"tiny-stories-1M-hook-q-head-index-2",
],
)
def cfg(request: pytest.FixtureRequest):
Expand Down Expand Up @@ -327,3 +339,81 @@ def test_process_results(tmp_path: Path):
# Check if CSV file is created
csv_path = output_dir / "all_eval_results.csv"
assert csv_path.exists()


@patch("sae_lens.evals.get_pretrained_saes_directory")
def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock):
mock_get_pretrained_saes_directory.return_value = {
"release1": PretrainedSAELookup(
release="release1",
repo_id="repo1",
model="model1",
conversion_func=None,
saes_map={"sae1": "path1", "sae2": "path2"},
expected_var_explained={"sae1": 0.9, "sae2": 0.85},
expected_l0={"sae1": 0.1, "sae2": 0.15},
neuronpedia_id={},
config_overrides=None,
),
"release2": PretrainedSAELookup(
release="release2",
repo_id="repo2",
model="model2",
conversion_func=None,
saes_map={"sae3": "path3"},
expected_var_explained={"sae3": 0.8},
expected_l0={"sae3": 0.2},
neuronpedia_id={},
config_overrides=None,
),
}

result = all_loadable_saes()

expected = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
]
assert result == expected


mock_all_saes = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
("release2", "block1", 0.95, 0.05),
]


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_no_match(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release1", "sae3")

assert not result


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_single_match(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release1", "sae1")

expected = [("release1", "sae1", 0.9, 0.1)]
assert result == expected


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_multiple_matches(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release.*", "sae.*")

expected = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
]
assert result == expected

0 comments on commit 06594f9

Please sign in to comment.