Skip to content

Commit

Permalink
feat: adding a CLI training runner (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind authored Nov 9, 2024
1 parent 1866aa7 commit 998c277
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/training_saes.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ sparse_autoencoder = SAETrainingRunner(cfg).run()

As you can see, the training setup provides a large number of options to explore. The full list of options can be found in the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] class.

## CLI Runner

The SAE training runner can also be run from the command line via the `sae_lens.sae_training_runner` module. This can be useful for quickly testing different hyperparameters or running training on a remote server. The command line interface is shown below. All options to the CLI are the same as the[LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] with a `--` prefix. E.g., `--model_name` is the same as `model_name` in the config.

```bash
python -m sae_lens.sae_training_runner --help
```

## Logging to Weights and Biases

For any real training run, you should be logging to Weights and Biases (WandB). This will allow you to track your training progress and compare different runs. To enable WandB, set `log_to_wandb=True`. The `wandb_project` parameter in the config controls the project name in WandB. You can also control the logging frequency with `wandb_log_frequency` and `eval_every_n_wandb_logs`.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ pyyaml = "^6.0.1"
pytest-profiling = "^1.7.0"
zstandard = "^0.22.0"
typing-extensions = "^4.10.0"
simple-parsing = "^0.1.6"


[tool.poetry.group.dev.dependencies]
Expand Down
22 changes: 21 additions & 1 deletion sae_lens/sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import logging
import os
import signal
from typing import Any, cast
import sys
from typing import Any, Sequence, cast

import torch
import wandb
from safetensors.torch import save_file
from simple_parsing import ArgumentParser
from transformer_lens.hook_points import HookedRootModule

from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig
Expand Down Expand Up @@ -235,3 +237,21 @@ def save_checkpoint(
wandb.log_artifact(sparsity_artifact) # type: ignore

return checkpoint_path


def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig:
if len(args) == 0:
args = ["--help"]
parser = ArgumentParser()
parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg")
return parser.parse_args(args).cfg


# moved into its own function to make it easier to test
def _run_cli(args: Sequence[str]):
cfg = _parse_cfg_args(args)
SAETrainingRunner(cfg=cfg).run()


if __name__ == "__main__":
_run_cli(args=sys.argv[1:])
2 changes: 2 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from tests.unit.helpers import TINYSTORIES_MODEL, load_model_cached

torch.set_grad_enabled(True)


@pytest.fixture(autouse=True)
def reproducibility():
Expand Down
117 changes: 115 additions & 2 deletions tests/unit/training/test_sae_training_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
from pathlib import Path

Expand All @@ -8,11 +9,16 @@

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae import SAE
from sae_lens.sae_training_runner import SAETrainingRunner
from sae_lens.sae_training_runner import SAETrainingRunner, _parse_cfg_args, _run_cli
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sae_trainer import SAETrainer
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached
from tests.unit.helpers import (
TINYSTORIES_DATASET,
TINYSTORIES_MODEL,
build_sae_cfg,
load_model_cached,
)


@pytest.fixture
Expand Down Expand Up @@ -105,3 +111,110 @@ def test_training_runner_works_with_from_pretrained_path(
assert torch.allclose(orig_sae.W_enc, new_sae.W_enc)
assert torch.allclose(orig_sae.b_enc, new_sae.b_enc)
assert torch.allclose(orig_sae.b_dec, new_sae.b_dec)


def test_parse_cfg_args_prints_help_if_no_args():
args = []
with pytest.raises(SystemExit):
_parse_cfg_args(args)


def test_parse_cfg_args_override():
args = [
"--model_name",
"test-model",
"--d_in",
"1024",
"--d_sae",
"4096",
"--activation_fn",
"tanh-relu",
"--normalize_sae_decoder",
"False",
"--dataset_path",
"my/dataset",
]
cfg = _parse_cfg_args(args)

assert cfg.model_name == "test-model"
assert cfg.d_in == 1024
assert cfg.d_sae == 4096
assert cfg.activation_fn == "tanh-relu"
assert cfg.normalize_sae_decoder is False
assert cfg.dataset_path == "my/dataset"


def test_parse_cfg_args_expansion_factor():
# Test that we can't set both d_sae and expansion_factor
args = ["--d_sae", "1024", "--expansion_factor", "8"]
with pytest.raises(ValueError):
_parse_cfg_args(args)


def test_parse_cfg_args_b_dec_init_method():
# Test validation of b_dec_init_method
args = ["--b_dec_init_method", "invalid"]
with pytest.raises(ValueError):
cfg = _parse_cfg_args(args)

valid_methods = ["geometric_median", "mean", "zeros"]
for method in valid_methods:
args = ["--b_dec_init_method", method]
cfg = _parse_cfg_args(args)
assert cfg.b_dec_init_method == method


def test_run_cli_saves_config(tmp_path: Path):
# Set up args for a minimal training run
args = [
"--model_name",
TINYSTORIES_MODEL,
"--dataset_path",
TINYSTORIES_DATASET,
"--checkpoint_path",
str(tmp_path),
"--n_checkpoints",
"1", # Save one checkpoint
"--training_tokens",
"128",
"--train_batch_size_tokens",
"4",
"--store_batch_size_prompts",
"4",
"--log_to_wandb",
"False", # Don't log to wandb in test
"--d_in",
"64", # Match gelu-1l hidden size
"--d_sae",
"128", # Small SAE for test
"--activation_fn",
"relu",
"--normalize_sae_decoder",
"False",
]

# Run training
_run_cli(args)

# Check that checkpoint was saved
run_dirs = list(tmp_path.glob("*")) # run dirs
assert len(run_dirs) == 1
checkpoint_dirs = list(run_dirs[0].glob("*"))
assert len(checkpoint_dirs) == 1

# Load and verify saved config
with open(checkpoint_dirs[0] / "cfg.json") as f:
saved_cfg = json.load(f)

# Verify key config values were saved correctly
assert saved_cfg["model_name"] == TINYSTORIES_MODEL
assert saved_cfg["d_in"] == 64
assert saved_cfg["d_sae"] == 128
assert saved_cfg["activation_fn"] == "relu"
assert saved_cfg["normalize_sae_decoder"] is False
assert saved_cfg["dataset_path"] == TINYSTORIES_DATASET
assert saved_cfg["n_checkpoints"] == 1
assert saved_cfg["training_tokens"] == 128
assert saved_cfg["train_batch_size_tokens"] == 4
assert saved_cfg["store_batch_size_prompts"] == 4
assert saved_cfg["model_name"] == TINYSTORIES_MODEL

0 comments on commit 998c277

Please sign in to comment.