diff --git a/docs/training_saes.md b/docs/training_saes.md index 2b061445..2fa60086 100644 --- a/docs/training_saes.md +++ b/docs/training_saes.md @@ -92,6 +92,14 @@ sparse_autoencoder = SAETrainingRunner(cfg).run() As you can see, the training setup provides a large number of options to explore. The full list of options can be found in the [LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] class. +## CLI Runner + +The SAE training runner can also be run from the command line via the `sae_lens.sae_training_runner` module. This can be useful for quickly testing different hyperparameters or running training on a remote server. The command line interface is shown below. All options to the CLI are the same as the[LanguageModelSAERunnerConfig][sae_lens.LanguageModelSAERunnerConfig] with a `--` prefix. E.g., `--model_name` is the same as `model_name` in the config. + +```bash +python -m sae_lens.sae_training_runner --help +``` + ## Logging to Weights and Biases For any real training run, you should be logging to Weights and Biases (WandB). This will allow you to track your training progress and compare different runs. To enable WandB, set `log_to_wandb=True`. The `wandb_project` parameter in the config controls the project name in WandB. You can also control the logging frequency with `wandb_log_frequency` and `eval_every_n_wandb_logs`. diff --git a/pyproject.toml b/pyproject.toml index 34c0f0ab..f1a46da0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ pyyaml = "^6.0.1" pytest-profiling = "^1.7.0" zstandard = "^0.22.0" typing-extensions = "^4.10.0" +simple-parsing = "^0.1.6" [tool.poetry.group.dev.dependencies] diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 26c90db6..f951a79e 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -2,11 +2,13 @@ import logging import os import signal -from typing import Any, cast +import sys +from typing import Any, Sequence, cast import torch import wandb from safetensors.torch import save_file +from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig @@ -235,3 +237,21 @@ def save_checkpoint( wandb.log_artifact(sparsity_artifact) # type: ignore return checkpoint_path + + +def _parse_cfg_args(args: Sequence[str]) -> LanguageModelSAERunnerConfig: + if len(args) == 0: + args = ["--help"] + parser = ArgumentParser() + parser.add_arguments(LanguageModelSAERunnerConfig, dest="cfg") + return parser.parse_args(args).cfg + + +# moved into its own function to make it easier to test +def _run_cli(args: Sequence[str]): + cfg = _parse_cfg_args(args) + SAETrainingRunner(cfg=cfg).run() + + +if __name__ == "__main__": + _run_cli(args=sys.argv[1:]) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 5082ecc3..12ec89ef 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -8,6 +8,8 @@ from tests.unit.helpers import TINYSTORIES_MODEL, load_model_cached +torch.set_grad_enabled(True) + @pytest.fixture(autouse=True) def reproducibility(): diff --git a/tests/unit/training/test_sae_training_runner.py b/tests/unit/training/test_sae_training_runner.py index 3a6a4cdb..4987c456 100644 --- a/tests/unit/training/test_sae_training_runner.py +++ b/tests/unit/training/test_sae_training_runner.py @@ -1,3 +1,4 @@ +import json import os from pathlib import Path @@ -8,11 +9,16 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.sae import SAE -from sae_lens.sae_training_runner import SAETrainingRunner +from sae_lens.sae_training_runner import SAETrainingRunner, _parse_cfg_args, _run_cli from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.training_sae import TrainingSAE -from tests.unit.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached +from tests.unit.helpers import ( + TINYSTORIES_DATASET, + TINYSTORIES_MODEL, + build_sae_cfg, + load_model_cached, +) @pytest.fixture @@ -105,3 +111,110 @@ def test_training_runner_works_with_from_pretrained_path( assert torch.allclose(orig_sae.W_enc, new_sae.W_enc) assert torch.allclose(orig_sae.b_enc, new_sae.b_enc) assert torch.allclose(orig_sae.b_dec, new_sae.b_dec) + + +def test_parse_cfg_args_prints_help_if_no_args(): + args = [] + with pytest.raises(SystemExit): + _parse_cfg_args(args) + + +def test_parse_cfg_args_override(): + args = [ + "--model_name", + "test-model", + "--d_in", + "1024", + "--d_sae", + "4096", + "--activation_fn", + "tanh-relu", + "--normalize_sae_decoder", + "False", + "--dataset_path", + "my/dataset", + ] + cfg = _parse_cfg_args(args) + + assert cfg.model_name == "test-model" + assert cfg.d_in == 1024 + assert cfg.d_sae == 4096 + assert cfg.activation_fn == "tanh-relu" + assert cfg.normalize_sae_decoder is False + assert cfg.dataset_path == "my/dataset" + + +def test_parse_cfg_args_expansion_factor(): + # Test that we can't set both d_sae and expansion_factor + args = ["--d_sae", "1024", "--expansion_factor", "8"] + with pytest.raises(ValueError): + _parse_cfg_args(args) + + +def test_parse_cfg_args_b_dec_init_method(): + # Test validation of b_dec_init_method + args = ["--b_dec_init_method", "invalid"] + with pytest.raises(ValueError): + cfg = _parse_cfg_args(args) + + valid_methods = ["geometric_median", "mean", "zeros"] + for method in valid_methods: + args = ["--b_dec_init_method", method] + cfg = _parse_cfg_args(args) + assert cfg.b_dec_init_method == method + + +def test_run_cli_saves_config(tmp_path: Path): + # Set up args for a minimal training run + args = [ + "--model_name", + TINYSTORIES_MODEL, + "--dataset_path", + TINYSTORIES_DATASET, + "--checkpoint_path", + str(tmp_path), + "--n_checkpoints", + "1", # Save one checkpoint + "--training_tokens", + "128", + "--train_batch_size_tokens", + "4", + "--store_batch_size_prompts", + "4", + "--log_to_wandb", + "False", # Don't log to wandb in test + "--d_in", + "64", # Match gelu-1l hidden size + "--d_sae", + "128", # Small SAE for test + "--activation_fn", + "relu", + "--normalize_sae_decoder", + "False", + ] + + # Run training + _run_cli(args) + + # Check that checkpoint was saved + run_dirs = list(tmp_path.glob("*")) # run dirs + assert len(run_dirs) == 1 + checkpoint_dirs = list(run_dirs[0].glob("*")) + assert len(checkpoint_dirs) == 1 + + # Load and verify saved config + with open(checkpoint_dirs[0] / "cfg.json") as f: + saved_cfg = json.load(f) + + # Verify key config values were saved correctly + assert saved_cfg["model_name"] == TINYSTORIES_MODEL + assert saved_cfg["d_in"] == 64 + assert saved_cfg["d_sae"] == 128 + assert saved_cfg["activation_fn"] == "relu" + assert saved_cfg["normalize_sae_decoder"] is False + assert saved_cfg["dataset_path"] == TINYSTORIES_DATASET + assert saved_cfg["n_checkpoints"] == 1 + assert saved_cfg["training_tokens"] == 128 + assert saved_cfg["train_batch_size_tokens"] == 4 + assert saved_cfg["store_batch_size_prompts"] == 4 + assert saved_cfg["model_name"] == TINYSTORIES_MODEL