Skip to content

Commit

Permalink
Merge pull request #324 from jbloomAus/improving-evals
Browse files Browse the repository at this point in the history
chore: Misc basic evals improvements (eg: consistent activation heuristic, cli args)
  • Loading branch information
curt-tigges authored Oct 18, 2024
2 parents 36e1d86 + d1b4f5d commit 10f4773
Show file tree
Hide file tree
Showing 8 changed files with 1,092 additions and 178 deletions.
377 changes: 330 additions & 47 deletions sae_lens/evals.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def get_dictionary_learning_config_1(
"sae_lens_training_version": None,
"prepend_bos": True,
"dataset_path": "monology/pile-uncopyrighted",
"dataset_trust_remote_code": False,
"context_size": buffer["ctx_len"],
"normalize_activations": "none",
"neuronpedia_id": None,
"dataset_trust_remote_code": True,
}


Expand Down
10 changes: 7 additions & 3 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,13 @@ def __init__(
self.trainer_eval_config = EvalConfig(
batch_size_prompts=self.cfg.eval_batch_size_prompts,
n_eval_reconstruction_batches=self.cfg.n_eval_batches,
n_eval_sparsity_variance_batches=self.cfg.n_eval_batches,
compute_ce_loss=True,
n_eval_sparsity_variance_batches=1,
compute_l2_norms=True,
compute_sparsity_metrics=True,
compute_variance_metrics=True,
compute_kl=False,
compute_featurewise_weight_based_metrics=False,
)

@property
Expand Down Expand Up @@ -330,13 +334,13 @@ def _run_and_log_evals(self):
self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs
) == 0:
self.sae.eval()
eval_metrics = run_evals(
eval_metrics, _ = run_evals(
sae=self.sae,
activation_store=self.activation_store,
model=self.model,
eval_config=self.trainer_eval_config,
model_kwargs=self.cfg.model_kwargs,
)
) # not calculating featurwise metrics here.

# Remove eval metrics that are already logged during training
eval_metrics.pop("metrics/explained_variance", None)
Expand Down
142 changes: 80 additions & 62 deletions tests/benchmark/test_eval_all_loadable_saes.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
# import pandas as pd
# import plotly.express as px
# import numpy as np
import argparse
import json
from pathlib import Path

import pytest
import torch

from sae_lens import SAE, ActivationsStore
from sae_lens.analysis.neuronpedia_integration import open_neuronpedia_feature_dashboard
from sae_lens.evals import all_loadable_saes
from sae_lens.sae import SAE
from sae_lens.evals import (
all_loadable_saes,
get_eval_everything_config,
process_results,
run_evals,
run_evaluations,
)
from sae_lens.toolkit.pretrained_sae_loaders import (
SAEConfigLoadOptions,
get_sae_config_from_hf,
)

# from sae_lens.training.activations_store import ActivationsStore
from tests.unit.helpers import load_model_cached

# from sae_lens.evals import run_evals
Expand Down Expand Up @@ -136,67 +144,77 @@ def test_eval_all_loadable_saes(
model = load_model_cached(sae.cfg.model_name)
model.to(device)

# activation_store = ActivationsStore.from_sae(
# model=model,
# sae=sae,
# streaming=True,
# # fairly conservative parameters here so can use same for larger
# # models without running out of memory.
# store_batch_size_prompts=8,
# train_batch_size_tokens=4096,
# n_batches_in_buffer=4,
# device=device,
# )

# if sae.cfg.normalize_activations == "expected_average_only_in":
# norm_scaling_factor = activation_store.estimate_norm_scaling_factor(
# n_batches_for_norm_estimate=100
# )
# sae.fold_activation_norm_scaling_factor(norm_scaling_factor)
# activation_store.normalize_activations = "none"

metrics = {}
# eval_metrics = run_evals(
# sae=sae,
# activation_store=activation_store,
# model=model,
# n_eval_batches=10,
# eval_batch_size_prompts=8,
# ) #
# eval_metrics = dict(eval_metrics)
# eval_metrics["ce_loss_diff"] = (
# eval_metrics["metrics/ce_loss_with_sae"].item()
# - eval_metrics["metrics/ce_loss_without_sae"].item()
# )
# assert eval_metrics["ce_loss_diff"] < 0.1, "CE Loss Difference is too high"

# CE Loss Difference
_, cache = model.run_with_cache(
example_text, names_filter=[sae.cfg.hook_name], prepend_bos=sae.cfg.prepend_bos
activation_store = ActivationsStore.from_sae(
model=model,
sae=sae,
streaming=True,
# fairly conservative parameters here so can use same for larger
# models without running out of memory.
store_batch_size_prompts=8,
train_batch_size_tokens=4096,
n_batches_in_buffer=4,
device=device,
)

# Use the SAE
sae_in = cache[sae.cfg.hook_name].squeeze()[1:]

feature_acts = sae.encode(sae_in)
sae_out = sae.decode(feature_acts)

mean_l0 = (feature_acts[1:] > 0).float().sum(-1).detach().cpu().numpy().mean()
eval_config = get_eval_everything_config(
batch_size_prompts=8,
n_eval_reconstruction_batches=3,
n_eval_sparsity_variance_batches=100,
)

# # get the FVE of teh SAE
per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze()
total_variance = (sae_in - sae_in.mean(0)).pow(2).sum(-1)
explained_variance = 1 - per_token_l2_loss / total_variance
metrics, _ = run_evals(
sae=sae,
activation_store=activation_store,
model=model,
eval_config=eval_config,
ignore_tokens={
model.tokenizer.pad_token_id, # type: ignore
model.tokenizer.eos_token_id, # type: ignore
model.tokenizer.bos_token_id, # type: ignore
},
)

metrics["l0"] = mean_l0
metrics["var_explained"] = explained_variance.mean().cpu().item()
assert pytest.approx(metrics["l0"], abs=5) == expected_l0
assert (
pytest.approx(metrics["explained_variance"], abs=0.1) == expected_var_explained
)

assert metrics == {
"l0": pytest.approx(expected_l0, abs=5),
"var_explained": pytest.approx(expected_var_explained, abs=0.1),
}

# assert mean_l0 == pytest.approx(expected_l0, abs=5)
# assert explained_variance.mean().cpu() == pytest.approx(
# expected_var_explained, abs=0.1
# )
@pytest.fixture
def mock_evals_simple_args(tmp_path: Path):
class Args:
sae_regex_pattern = "gpt2-small-res-jb"
sae_block_pattern = "blocks.0.hook_resid_pre"
num_eval_batches = 1
n_eval_reconstruction_batches = 1
n_eval_sparsity_variance_batches = 1

eval_batch_size_prompts = 2
datasets = ["Skylion007/openwebtext"]
ctx_lens = [128]
output_dir = str(tmp_path)
verbose = False

return Args()


def test_run_evaluations_process_results(mock_evals_simple_args: argparse.Namespace):
"""
This test is more like an acceptance test for the evals code than a benchmark.
"""
eval_results = run_evaluations(mock_evals_simple_args)
output_files = process_results(eval_results, mock_evals_simple_args.output_dir)

print("Evaluation complete. Output files:")
print(f"Individual JSONs: {len(output_files['individual_jsons'])}") # type: ignore
print(f"Combined JSON: {output_files['combined_json']}")
print(f"CSV: {output_files['csv']}")

# open and validate the files
combined_json_path = output_files["combined_json"]
assert isinstance(combined_json_path, Path)
assert combined_json_path.exists()
with open(combined_json_path, "r") as f:
data = json.load(f)[0]
assert "metrics" in data
assert "feature_metrics" in data
Loading

0 comments on commit 10f4773

Please sign in to comment.