Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for evals #346

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions tests/unit/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,15 @@
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.evals import (
EvalConfig,
all_loadable_saes,
get_eval_everything_config,
get_saes_from_regex,
process_results,
run_evals,
run_evaluations,
)
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached
Expand Down Expand Up @@ -67,13 +70,22 @@
"hook_layer": 1,
"d_in": 16 * 4,
},
{
"model_name": "tiny-stories-1M",
"dataset_path": "roneneldan/TinyStories",
"hook_name": "blocks.1.attn.hook_q",
"hook_layer": 1,
"d_in": 4,
"hook_head_index": 2,
},
],
ids=[
"tiny-stories-1M-resid-pre",
"tiny-stories-1M-resid-pre-L1-W-dec-Norm",
"tiny-stories-1M-resid-pre-pretokenized",
"tiny-stories-1M-hook-z",
"tiny-stories-1M-hook-q",
"tiny-stories-1M-hook-q-head-index-2",
],
)
def cfg(request: pytest.FixtureRequest):
Expand Down Expand Up @@ -327,3 +339,81 @@ def test_process_results(tmp_path: Path):
# Check if CSV file is created
csv_path = output_dir / "all_eval_results.csv"
assert csv_path.exists()


@patch("sae_lens.evals.get_pretrained_saes_directory")
def test_all_loadable_saes(mock_get_pretrained_saes_directory: MagicMock):
mock_get_pretrained_saes_directory.return_value = {
"release1": PretrainedSAELookup(
release="release1",
repo_id="repo1",
model="model1",
conversion_func=None,
saes_map={"sae1": "path1", "sae2": "path2"},
expected_var_explained={"sae1": 0.9, "sae2": 0.85},
expected_l0={"sae1": 0.1, "sae2": 0.15},
neuronpedia_id={},
config_overrides=None,
),
"release2": PretrainedSAELookup(
release="release2",
repo_id="repo2",
model="model2",
conversion_func=None,
saes_map={"sae3": "path3"},
expected_var_explained={"sae3": 0.8},
expected_l0={"sae3": 0.2},
neuronpedia_id={},
config_overrides=None,
),
}

result = all_loadable_saes()

expected = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
]
assert result == expected


mock_all_saes = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
("release2", "block1", 0.95, 0.05),
]


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_no_match(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release1", "sae3")

assert not result


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_single_match(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release1", "sae1")

expected = [("release1", "sae1", 0.9, 0.1)]
assert result == expected


@patch("sae_lens.evals.all_loadable_saes")
def test_get_saes_from_regex_multiple_matches(mock_all_loadable_saes: MagicMock):
mock_all_loadable_saes.return_value = mock_all_saes

result = get_saes_from_regex("release.*", "sae.*")

expected = [
("release1", "sae1", 0.9, 0.1),
("release1", "sae2", 0.85, 0.15),
("release2", "sae3", 0.8, 0.2),
]
assert result == expected
Loading