diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4a473451..ba0a3d64 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -64,8 +64,8 @@ jobs: run: poetry run flake8 . - name: black code formatting run: poetry run black . --check - # - name: isort linting - # run: poetry run isort . --check-only --diff + - name: isort linting + run: poetry run isort . --check-only --diff - name: type checking run: poetry run pyright - name: Run Unit Tests diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5ff3ed..c1e6e8c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,128 @@ +## v0.5.1 (2024-04-19) + +### Chore + +* chore: re-enabling isort in CI (#86) ([`9c44731`](https://github.com/jbloomAus/SAELens/commit/9c44731a9b7718c9f0913136ed9df42dac87c390)) + +### Fix + +* fix: pin pyzmq==26.0.1 temporarily ([`0094021`](https://github.com/jbloomAus/SAELens/commit/00940219754ddb1be6708e54cdd0ac6ed5dc3948)) + +* fix: typing issue, temporary ([`25cebf1`](https://github.com/jbloomAus/SAELens/commit/25cebf1e5e0630a377a5045c1b3571a5f181853f)) + +### Unknown + +* v0.5.1 ([`0ac218b`](https://github.com/jbloomAus/SAELens/commit/0ac218bf8068b8568310b40a1399f9eb3c8d992e)) + +* Merge pull request #91 from jbloomAus/decoder-fine-tuning + +Decoder fine tuning ([`1fc652c`](https://github.com/jbloomAus/SAELens/commit/1fc652c19e2a34172c1fd520565e9620366f565c)) + +* par update ([`2bb5975`](https://github.com/jbloomAus/SAELens/commit/2bb5975226807d352f2d3cf6b6dad7aefaf1b662)) + +* Merge pull request #89 from jbloomAus/fix_np + +Enhance + Fix Neuronpedia generation / upload ([`38d507c`](https://github.com/jbloomAus/SAELens/commit/38d507c052875cbd78f8fd9dae45a658e47c2b9d)) + +* minor changes ([`bc766e4`](https://github.com/jbloomAus/SAELens/commit/bc766e4f7a8d472f647408b8a5cd3c6140d856b7)) + +* reformat run.ipynb ([`822882c`](https://github.com/jbloomAus/SAELens/commit/822882cac9c05449b7237b7d42ce17297903da2f)) + +* get decoder fine tuning working ([`11a71e1`](https://github.com/jbloomAus/SAELens/commit/11a71e1b95576ef6dc3dbec7eb1c76ce7ca44dfd)) + +* format ([`040676d`](https://github.com/jbloomAus/SAELens/commit/040676db6814c1f64171a32344e0bed40528c8f9)) + +* Merge pull request #88 from jbloomAus/get_feature_from_neuronpedia + +FEAT: Add API for getting Neuronpedia feature ([`1666a68`](https://github.com/jbloomAus/SAELens/commit/1666a68bb7d7ee4837e03d95b203ee371ca9ea9e)) + +* Fix resuming from batch ([`145a407`](https://github.com/jbloomAus/SAELens/commit/145a407f8d57755301bf56c87efd4e775c59b980)) + +* Use original repo for sae_vis ([`1a7d636`](https://github.com/jbloomAus/SAELens/commit/1a7d636c95ef508dde8bd100ab6d9f241b0be977)) + +* Use correct model name for np runner ([`138d5d4`](https://github.com/jbloomAus/SAELens/commit/138d5d445878c0830c6c96a5fbe6b10a1d9644b0)) + +* Merge main, remove eindex ([`6578436`](https://github.com/jbloomAus/SAELens/commit/6578436891e71a0ef60fb2ed6d6a6b6279d71cbc)) + +* Add API for getting Neuronpedia feature ([`e78207d`](https://github.com/jbloomAus/SAELens/commit/e78207d086cb3372dc805cbb4c87b694749cd905)) + + +## v0.5.0 (2024-04-17) + +### Feature + +* feat: Mamba support vs mamba-lens (#79) + +* mamba support + +* added init + +* added optional model kwargs + +* Support transformers and mamba + +* forgot one model kwargs + +* failed opts + +* tokens input + +* hack to fix tokens, will look into fixing mambalens + +* fixed checkpoint + +* added sae group + +* removed some comments and fixed merge error + +* removed unneeded params since that issue is fixed in mambalens now + +* Unneded input param + +* removed debug checkpoing and eval + +* added refs to hookedrootmodule + +* feed linter + +* added example and fixed loading + +* made layer for eval change + +* fix linter issues + +* adding mamba-lens as optional dep, and fixing typing/linting + +* adding a test for loading mamba model + +* adding mamba-lens to dev for CI + +* updating min mamba-lens version + +* updating mamba-lens version + +--------- + +Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://github.com/jbloomAus/SAELens/commit/eea7db4b99098c33cd862e7e2280a32b630826bd)) + +### Unknown + +* update readme ([`440df7b`](https://github.com/jbloomAus/SAELens/commit/440df7b6c0ef55ba3d116054f81e1ee4a58f9089)) + +* update readme ([`3694fd2`](https://github.com/jbloomAus/SAELens/commit/3694fd2c4cc7438121e4549636508c45835a5d38)) + +* Fix upload skipped/dead features ([`932f380`](https://github.com/jbloomAus/SAELens/commit/932f380971ce3d431e6592c804d12f6df2b4ec78)) + +* Use python typer instead of shell script for neuronpedia jobs ([`b611e72`](https://github.com/jbloomAus/SAELens/commit/b611e721dd2620ab5a030cc0f6e37029c30711ca)) + +* Merge branch 'main' into fix_np ([`cc6cb6a`](https://github.com/jbloomAus/SAELens/commit/cc6cb6a96b793e41fb91f4ebbaf3bfa5e7c11b4e)) + +* convert sparsity to log sparsity if needed ([`8d7d404`](https://github.com/jbloomAus/SAELens/commit/8d7d4040033fb80c5b994cdc662b0f90b8fcc7aa)) + + ## v0.4.0 (2024-04-16) ### Feature @@ -22,8 +144,26 @@ * default orthogonal init false ([`a8b0113`](https://github.com/jbloomAus/SAELens/commit/a8b0113140bd2f9b97befccc8f158dace02a4810)) +* Formatting ([`1e3d53e`](https://github.com/jbloomAus/SAELens/commit/1e3d53ec2b72897bfebb6065f3b530fe65d3a97c)) + +* Eindex required by sae_vis ([`f769e7a`](https://github.com/jbloomAus/SAELens/commit/f769e7a65ab84d4073852931a86ff3b5076eea3c)) + +* Upload dead feature stubs ([`9067380`](https://github.com/jbloomAus/SAELens/commit/9067380bf67b89d8b2d235944f696016286f683e)) + +* Make feature sparsity an argument ([`8230570`](https://github.com/jbloomAus/SAELens/commit/8230570297d68e35cb614a63abf442e4a01174d2)) + +* Fix buffer" ([`dde2481`](https://github.com/jbloomAus/SAELens/commit/dde248162b70ff4311d4182333b7cce43aed78df)) + +* Merge branch 'main' into fix_np ([`6658392`](https://github.com/jbloomAus/SAELens/commit/66583923cd625bfc1c1ef152bc5f5beaa764b2d6)) + * notebook update ([`feca408`](https://github.com/jbloomAus/SAELens/commit/feca408cf003737cd4eb529ca7fea2f77984f5c6)) +* Merge branch 'main' into fix_np ([`f8fb3ef`](https://github.com/jbloomAus/SAELens/commit/f8fb3efbde7fc79e6fafe2d9b3324c9f0b2a337d)) + +* Final fixes ([`e87788d`](https://github.com/jbloomAus/SAELens/commit/e87788d63a9b767e34e497c85a318337ab8aabb8)) + +* Don't use buffer, fix anomalies ([`2c9ca64`](https://github.com/jbloomAus/SAELens/commit/2c9ca642b334b7a444544a4640c483229dc04c62)) + ## v0.3.0 (2024-04-15) @@ -44,6 +184,10 @@ * make dense_batch_mse_normalization optional ([`c41774e`](https://github.com/jbloomAus/SAELens/commit/c41774e5cfaeb195e3320e9e3fc93d60d921337d)) +* Runner is fixed, faster, cleaned up, and now gives whole sequences instead of buffer. ([`3837884`](https://github.com/jbloomAus/SAELens/commit/383788485917cee114fba24e8ded944aefcfb568)) + +* Merge branch 'main' into fix_np ([`3ed30cf`](https://github.com/jbloomAus/SAELens/commit/3ed30cf2b84a2444c8ed030641214f0dbb65898a)) + * add warning in run script ([`9a772ca`](https://github.com/jbloomAus/SAELens/commit/9a772ca6da155b5e97bc3109da74457f5addfbfd)) * update sae loading code ([`356a8ef`](https://github.com/jbloomAus/SAELens/commit/356a8efba06e4f453d2f15afe9171b71d780819a)) @@ -74,6 +218,10 @@ ### Unknown +* Use legacy loader, add back histograms, logits. Fix anomaly characters. ([`ebbb622`](https://github.com/jbloomAus/SAELens/commit/ebbb622353bef21c953f844a108ea8d9fe31e9f9)) + +* Merge branch 'main' into fix_np ([`586e088`](https://github.com/jbloomAus/SAELens/commit/586e0881e08a9b013e2d4101878ef054c1f3dd8b)) + * Merge pull request #80 from wllgrnt/will-update-tutorial bugfix - minimum viable updates to tutorial notebook ([`e51016b`](https://github.com/jbloomAus/SAELens/commit/e51016b01f3b0f30c83365c54430908779671d87)) @@ -118,6 +266,8 @@ Fix artifact saving loading ([`8784c74`](https://github.com/jbloomAus/SAELens/co * add safetensors to project ([`0da48b0`](https://github.com/jbloomAus/SAELens/commit/0da48b044357eed17e5afffd3ce541e064185043)) +* Don't precompute background colors and tick values ([`271dbf0`](https://github.com/jbloomAus/SAELens/commit/271dbf05567b6e6ae4cfc1dab138132872038381)) + * Merge pull request #71 from weissercn/main Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/commit/84175055ba5876b335cbc0de38bf709d0b11cec1)) @@ -126,6 +276,8 @@ Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/com chore: updating README.md with pip install instructions and PyPI badge ([`4d7d1e7`](https://github.com/jbloomAus/SAELens/commit/4d7d1e7db5e952c7e9accf19c0ccce466cdcf6cf)) +* FIX: Add back correlated neurons, frac_nonzero ([`d532b82`](https://github.com/jbloomAus/SAELens/commit/d532b828bd77c18b73f495d6b42ca53b5148fd2f)) + * linting ([`1db0b5a`](https://github.com/jbloomAus/SAELens/commit/1db0b5ae7e091822c72bba0488d30fc16bc9a1c6)) * fixed graph name ([`ace4813`](https://github.com/jbloomAus/SAELens/commit/ace481322103737de2e80d688683d0c937ac5558)) diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docs/training_saes.md b/docs/training_saes.md index 632eefbd..3d6a3ca8 100644 --- a/docs/training_saes.md +++ b/docs/training_saes.md @@ -49,7 +49,7 @@ cfg = LanguageModelSAERunnerConfig( # Activation Store Parameters n_batches_in_buffer = 128, - total_training_tokens = 1_000_000 * 300, + training_tokens = 1_000_000 * 300, store_batch_size = 32, # Dead Neurons and Sparsity @@ -60,7 +60,7 @@ cfg = LanguageModelSAERunnerConfig( # WANDB log_to_wandb = True, - wandb_project= "mats_sae_training_gpt2", + wandb_project= "gpt2", wandb_entity = None, wandb_log_frequency=100, diff --git a/pyproject.toml b/pyproject.toml index 373513f0..d7c283e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sae-lens" -version = "0.4.0" +version = "0.5.1" description = "Training and Analyzing Sparse Autoencoders (SAEs)" authors = ["Joseph Bloom"] readme = "README.md" @@ -20,7 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" -sae-vis = "0.2.6" +sae-vis = "^0.2.15" mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" mkdocs-autorefs = "^1.0.1" @@ -28,6 +28,9 @@ mkdocs-section-index = "^0.3.8" mkdocstrings = "^0.24.1" mkdocstrings-python = "^1.9.0" safetensors = "^0.4.2" +typer = "^0.12.3" +mamba-lens = { version = "^0.0.4", optional = true } +pyzmq = "26.0.0" [tool.poetry.group.dev.dependencies] @@ -36,12 +39,17 @@ pytest = "^8.0.2" pytest-cov = "^4.1.0" pre-commit = "^3.6.2" flake8 = "^7.0.0" -isort = "^5.13.2" +isort = "5.13.2" pyright = "^1.1.351" +mamba-lens = "^0.0.4" + +[tool.poetry.extras] +mamba = ["mamba-lens"] [tool.isort] profile = "black" +src_paths = ["sae_lens", "tests"] [tool.pyright] typeCheckingMode = "strict" diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py index 86d4e514..e5dff1f7 100644 --- a/sae_lens/__init__.py +++ b/sae_lens/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.5.1" from .training.activations_store import ActivationsStore from .training.cache_activations_runner import cache_activations_runner diff --git a/sae_lens/analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py index ebc25fc7..150b0a52 100644 --- a/sae_lens/analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -11,6 +11,7 @@ import plotly import plotly.express as px import torch +import wandb from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, @@ -22,13 +23,15 @@ from sae_vis.data_fetching_fns import get_feature_data from torch.nn.functional import cosine_similarity from tqdm import tqdm +from transformer_lens import HookedTransformer -import wandb from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader class DashboardRunner: + model: HookedTransformer | None = None + def __init__( self, sae_path: Optional[str] = None, @@ -131,10 +134,14 @@ def get_dashboard_folder_name(self): def init_sae_session(self): ( - self.model, + model, sae_group, self.activation_store, ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) + assert isinstance( + model, HookedTransformer + ) # only HookedTransformer is allowed to be used in the dashboard + self.model = model # TODO: handle multiple autoencoders self.sparse_autoencoder = next(iter(sae_group))[1] @@ -316,6 +323,7 @@ def run(self): if self.use_wandb: wandb.log({"time/time_to_get_tokens": end - start}) + assert self.model is not None vocab_dict = cast(Any, self.model.tokenizer).vocab vocab_dict = { v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items() diff --git a/sae_lens/analysis/neuronpedia_integration.py b/sae_lens/analysis/neuronpedia_integration.py index 84f628ad..512bac32 100644 --- a/sae_lens/analysis/neuronpedia_integration.py +++ b/sae_lens/analysis/neuronpedia_integration.py @@ -2,6 +2,8 @@ import urllib.parse import webbrowser +import requests + def get_neuronpedia_quick_list( features: list[int], @@ -14,10 +16,29 @@ def get_neuronpedia_quick_list( name = urllib.parse.quote(name) url = url + "?name=" + name list_feature = [ - {"modelId": model, "layer": f"{layer}-{dataset}", "index": str(feature)} + { + "modelId": model, + "layer": f"{layer}-{dataset}", + "index": str(feature), + } for feature in features ] url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature)) webbrowser.open(url) return url + + +def get_neuronpedia_feature( + feature: int, + layer: int, + model: str = "gpt2-small", + dataset: str = "res-jb", +): + url = "https://neuronpedia.org/api/feature/" + url = url + f"{model}/{layer}-{dataset}/{feature}" + + result = requests.get(url).json() + result["index"] = int(result["index"]) + + return result diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 96475d64..a34e970a 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -4,7 +4,6 @@ # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json -import time import numpy as np import torch @@ -13,14 +12,19 @@ ActsHistogramConfig, Column, FeatureTablesConfig, + LogitsHistogramConfig, + LogitsTableConfig, SaeVisConfig, SaeVisLayoutConfig, SequencesConfig, ) -from sae_vis.data_fetching_fns import get_feature_data +from sae_vis.data_storing_fns import SaeVisData from tqdm import tqdm +from transformer_lens import HookedTransformer +from sae_lens.toolkit.pretrained_saes import load_sparsity from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader +from sae_lens.training.sparse_autoencoder import SparseAutoencoder OUT_OF_RANGE_TOKEN = "<|outofrange|>" @@ -28,6 +32,21 @@ "bg_color_map", ["white", "darkorange"] ) +DEFAULT_SPARSITY_THRESHOLD = -5 + +HTML_ANOMALIES = { + "âĢĶ": "—", + "âĢĵ": "–", + "âĢľ": "“", + "âĢĿ": "”", + "âĢĺ": "‘", + "âĢĻ": "’", + "âĢĭ": " ", # todo: this is actually zero width space + "Ġ": " ", + "Ċ": "\n", + "ĉ": "\t", +} + class NpEncoder(json.JSONEncoder): def default(self, o: Any): @@ -44,62 +63,49 @@ class NeuronpediaRunner: def __init__( self, + sae_id: str, sae_path: str, - feature_sparsity_path: Optional[str] = None, - neuronpedia_parent_folder: str = "./neuronpedia_outputs", - init_session: bool = True, + outputs_dir: str, + sparsity_threshold: int = DEFAULT_SPARSITY_THRESHOLD, # token pars n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6, - # sampling pars - n_features_at_a_time: int = 1024, - buffer_tokens_left: int = 8, - buffer_tokens_right: int = 8, - # start and end batch + # batching + n_features_at_a_time: int = 128, start_batch_inclusive: int = 0, end_batch_inclusive: Optional[int] = None, ): - self.sae_path = sae_path - if init_session: - self.init_sae_session() - self.feature_sparsity_path = feature_sparsity_path + self.device = "cpu" + if torch.backends.mps.is_available(): + self.device = "mps" + elif torch.cuda.is_available(): + self.device = "cuda" + + self.sae_path = sae_path + self.sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + self.sae_path, device=self.device + ) + loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) + self.model, _, self.activation_store = loader.load_sae_training_group_session() + self.model_id = self.sparse_autoencoder.cfg.model_name + self.layer = self.sparse_autoencoder.cfg.hook_point_layer + self.sae_id = sae_id + self.sparsity_threshold = sparsity_threshold self.n_features_at_a_time = n_features_at_a_time - self.buffer_tokens_left = buffer_tokens_left - self.buffer_tokens_right = buffer_tokens_right self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select self.start_batch = start_batch_inclusive self.end_batch = end_batch_inclusive - # Deal with file structure - if not os.path.exists(neuronpedia_parent_folder): - os.makedirs(neuronpedia_parent_folder) - self.neuronpedia_folder = ( - f"{neuronpedia_parent_folder}/{self.get_folder_name()}" - ) - if not os.path.exists(self.neuronpedia_folder): - os.makedirs(self.neuronpedia_folder) - - def get_folder_name(self): - model = self.sparse_autoencoder.cfg.model_name - hook_point = self.sparse_autoencoder.cfg.hook_point - d_sae = self.sparse_autoencoder.cfg.d_sae - dashboard_folder_name = f"{model}_{hook_point}_{d_sae}" - - return dashboard_folder_name - - def init_sae_session(self): - ( - self.model, - sae_group, - self.activation_store, - ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) - # TODO: handle multiple autoencoders - self.sparse_autoencoder = next(iter(sae_group))[1] + if not os.path.exists(outputs_dir): + os.makedirs(outputs_dir) + self.outputs_dir = outputs_dir def get_tokens( - self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6 + self, + n_batches_to_sample_from: int = 2**12, + n_prompts_to_select: int = 4096 * 6, ): all_tokens_list = [] pbar = tqdm(range(n_batches_to_sample_from)) @@ -118,11 +124,14 @@ def round_list(self, to_round: list[float]): return list(np.round(to_round, 3)) def to_str_tokens_safe( - self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor] + self, + vocab_dict: Dict[int, str], + tokens: Union[int, List[int], torch.Tensor], ): """ does to_str_tokens, except handles out of range """ + assert self.model is not None vocab_max_index = self.model.cfg.d_vocab - 1 # Deal with the int case separately if isinstance(tokens, int): @@ -144,28 +153,20 @@ def to_str_tokens_safe( return np.reshape(str_tokens, tokens.shape).tolist() def run(self): - """ - Generate the Neuronpedia outputs. - """ - - if self.model is None: - self.init_sae_session() - self.n_features = self.sparse_autoencoder.cfg.d_sae assert self.n_features is not None # if we have feature sparsity, then use it to only generate outputs for non-dead features self.target_feature_indexes: list[int] = [] - if self.feature_sparsity_path: - loaded = torch.load( - self.feature_sparsity_path, map_location=self.sparse_autoencoder.device - ) - self.target_feature_indexes = ( - (loaded > -5).nonzero(as_tuple=True)[0].tolist() - ) - else: - self.target_feature_indexes = list(range(self.n_features)) - print("No feat sparsity path specified - doing all indexes.") + sparsity = load_sparsity(self.sae_path) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) + sparsity = sparsity.to(self.device) + self.target_feature_indexes = ( + (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() + ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) @@ -173,9 +174,9 @@ def run(self): feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] - print(f"==== Starting at batch: {self.start_batch}") - if self.end_batch is not None: - print(f"==== Ending at batch: {self.end_batch}") + # print(f"==== Starting Batch: {self.start_batch}") + # if self.end_batch is not None and self.end_batch != self.start_batch: + # print(f"==== Ending at Batch: {self.end_batch}") if self.start_batch > len(feature_idx) + 1: print( @@ -185,31 +186,41 @@ def run(self): # write dead into file so we can create them as dead in Neuronpedia skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) - skipped_indexes_json = json.dumps({"skipped_indexes": list(skipped_indexes)}) - with open(f"{self.neuronpedia_folder}/skipped_indexes.json", "w") as f: + skipped_indexes_json = json.dumps( + { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "skipped_indexes": list(skipped_indexes), + } + ) + with open(f"{self.outputs_dir}/skipped_indexes.json", "w") as f: f.write(skipped_indexes_json) - print(f"Total features to run: {len(self.target_feature_indexes)}") - print(f"Total skipped: {len(skipped_indexes)}") - print(f"Total batches: {len(feature_idx)}") + tokens_file = f"{self.outputs_dir}/tokens_{self.n_batches_to_sample_from}_{self.n_prompts_to_select}.pt" + if os.path.isfile(tokens_file): + print("Tokens exist, loading them.") + tokens = torch.load(tokens_file) + else: + print("Tokens don't exist, making them.") + tokens = self.get_tokens( + self.n_batches_to_sample_from, self.n_prompts_to_select + ) + torch.save( + tokens, + tokens_file, + ) - print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") - print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") - print(f"Writing files to: {self.neuronpedia_folder}") + vocab_dict = self.model.tokenizer.vocab + new_vocab_dict = {} + # Replace substrings in the keys of vocab_dict using HTML_ANOMALIES + for k, v in vocab_dict.items(): + modified_key = k + for anomaly in HTML_ANOMALIES: + modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) + new_vocab_dict[v] = modified_key + vocab_dict = new_vocab_dict - # get tokens: - start = time.time() - tokens = self.get_tokens( - self.n_batches_to_sample_from, self.n_prompts_to_select - ) - end = time.time() - print(f"Time to get tokens: {end - start}") - - vocab_dict = cast(Any, self.model.tokenizer).vocab - vocab_dict = { - v: k.replace("Ġ", " ").replace("\n", "\\n").replace("Ċ", "\n") - for k, v in vocab_dict.items() - } # pad with blank tokens to the actual vocab size for i in range(len(vocab_dict), self.model.cfg.d_vocab): vocab_dict[i] = OUT_OF_RANGE_TOKEN @@ -226,44 +237,37 @@ def run(self): # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue - print(f"Doing batch: {feature_batch_count}") + print(f"========== Running Batch #{feature_batch_count} ==========") layout = SaeVisLayoutConfig( columns=[ Column( SequencesConfig( stack_mode="stack-all", - buffer=( - self.buffer_tokens_left, - self.buffer_tokens_right, - ), - compute_buffer=False, - n_quantiles=10, + buffer=None, # type: ignore + compute_buffer=True, + n_quantiles=5, top_acts_group_size=20, quantile_group_size=5, ), - width=650, - ), - Column( ActsHistogramConfig(), - FeatureTablesConfig(n_rows=5), - width=500, - ), - ], - height=1000, + LogitsHistogramConfig(), + LogitsTableConfig(), + FeatureTablesConfig(n_rows=3), + ) + ] ) feature_vis_params = SaeVisConfig( hook_point=self.sparse_autoencoder.cfg.hook_point, - minibatch_size_features=256, + minibatch_size_features=128, minibatch_size_tokens=64, features=features_to_process, - verbose=False, + verbose=True, feature_centric_layout=layout, ) - - feature_data = get_feature_data( + feature_data = SaeVisData.create( encoder=self.sparse_autoencoder, # type: ignore - model=self.model, + model=cast(HookedTransformer, self.model), tokens=tokens, cfg=feature_vis_params, ) @@ -280,20 +284,6 @@ def run(self): feature.logits_table_data.bottom_logits ) - # TODO: don't precompute/store these. should do it on the frontend - max_value = max( - np.absolute(bottom10_logits).max(), - np.absolute(top10_logits).max(), - ) - neg_bg_values = self.round_list( - np.absolute(bottom10_logits) / max_value - ) - pos_bg_values = self.round_list( - np.absolute(top10_logits) / max_value - ) - feature_output["neg_bg_values"] = neg_bg_values - feature_output["pos_bg_values"] = pos_bg_values - if feature.feature_tables_data: feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices @@ -307,23 +297,21 @@ def run(self): feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - # TODO: this value doesn't exist in the new output type, commenting out for now - # there is a cossim value though - is that what's needed? - # feature_output["correlated_neurons_l1"] = self.round_list( - # feature.feature_tables_data.correlated_neurons_l1 - # ) + feature_output["correlated_neurons_l1"] = self.round_list( + feature.feature_tables_data.correlated_neurons_cossim + ) feature_output["correlated_neurons_pearson"] = self.round_list( feature.feature_tables_data.correlated_neurons_pearson ) - # feature_output["correlated_features_indices"] = ( - # feature.feature_tables_data.correlated_features_indices - # ) - # feature_output["correlated_features_l1"] = self.round_list( - # feature.feature_tables_data.correlated_features_l1 - # ) - # feature_output["correlated_features_pearson"] = self.round_list( - # feature.feature_tables_data.correlated_features_pearson - # ) + feature_output["correlated_features_indices"] = ( + feature.feature_tables_data.correlated_features_indices + ) + feature_output["correlated_features_l1"] = self.round_list( + feature.feature_tables_data.correlated_features_cossim + ) + feature_output["correlated_features_pearson"] = self.round_list( + feature.feature_tables_data.correlated_features_pearson + ) feature_output["neg_str"] = self.to_str_tokens_safe( vocab_dict, feature.logits_table_data.bottom_token_ids @@ -334,30 +322,23 @@ def run(self): ) feature_output["pos_values"] = top10_logits - # TODO: don't know what this should be in the new version - # feature_output["frac_nonzero"] = ( - # feature.middle_plots_data.frac_nonzero - # ) + feature_output["frac_nonzero"] = ( + float( + feature.acts_histogram_data.title.split(" = ")[1].split( + "%" + )[0] + ) + / 100 + if feature.acts_histogram_data.title is not None + else 0 + ) freq_hist_data = feature.acts_histogram_data freq_bar_values = self.round_list(freq_hist_data.bar_values) feature_output["freq_hist_data_bar_values"] = freq_bar_values - feature_output["freq_hist_data_tick_vals"] = self.round_list( - freq_hist_data.tick_vals - ) - - # TODO: don't precompute/store these. should do it on the frontend - freq_bar_values_clipped = [ - (0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values) - for v in freq_bar_values - ] - freq_bar_colors = [ - colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped - ] feature_output["freq_hist_data_bar_heights"] = self.round_list( freq_hist_data.bar_heights ) - feature_output["freq_bar_colors"] = freq_bar_colors logits_hist_data = feature.logits_histogram_data feature_output["logits_hist_data_bar_heights"] = self.round_list( @@ -366,11 +347,7 @@ def run(self): feature_output["logits_hist_data_bar_values"] = self.round_list( logits_hist_data.bar_values ) - feature_output["logits_hist_data_tick_vals"] = self.round_list( - logits_hist_data.tick_vals - ) - # TODO: check this feature_output["num_tokens_for_dashboard"] = ( self.n_prompts_to_select ) @@ -433,10 +410,19 @@ def run(self): features_outputs.append(feature_output) - json_object = json.dumps(features_outputs, cls=NpEncoder) + to_write = { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "features": features_outputs, + "n_batches_to_sample_from": self.n_batches_to_sample_from, + "n_prompts_to_select": self.n_prompts_to_select, + } + json_object = json.dumps(to_write, cls=NpEncoder) with open( - f"{self.neuronpedia_folder}/batch-{feature_batch_count}.json", "w" + f"{self.outputs_dir}/batch-{feature_batch_count}.json", + "w", ) as f: f.write(json_object) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index f5c33ebb..f0efa3ff 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -12,7 +12,7 @@ load_dataset, ) from torch.utils.data import DataLoader -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule from sae_lens.training.config import ( CacheActivationsRunnerConfig, @@ -28,7 +28,7 @@ class ActivationsStore: while training SAEs. """ - model: HookedTransformer + model: HookedRootModule dataset: HfDataset cached_activations_path: str | None tokens_column: Literal["tokens", "input_ids", "text"] @@ -39,7 +39,7 @@ class ActivationsStore: @classmethod def from_config( cls, - model: HookedTransformer, + model: HookedRootModule, cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig, dataset: HfDataset | None = None, ) -> "ActivationsStore": @@ -59,18 +59,19 @@ def from_config( context_size=cfg.context_size, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, - total_training_tokens=cfg.total_training_tokens, + total_training_tokens=cfg.training_tokens, store_batch_size=cfg.store_batch_size, train_batch_size=cfg.train_batch_size, prepend_bos=cfg.prepend_bos, device=cfg.device, dtype=cfg.dtype, cached_activations_path=cached_activations_path, + model_kwargs=cfg.model_kwargs, ) def __init__( self, - model: HookedTransformer, + model: HookedRootModule, dataset: HfDataset | str, hook_point: str, hook_point_layers: list[int], @@ -85,8 +86,12 @@ def __init__( device: str | torch.device, dtype: str | torch.dtype, cached_activations_path: str | None = None, + model_kwargs: dict[str, Any] | None = None, ): self.model = model + if model_kwargs is None: + model_kwargs = {} + self.model_kwargs = model_kwargs self.dataset = ( load_dataset(dataset, split="train", streaming=True) if isinstance(dataset, str) @@ -248,6 +253,7 @@ def get_activations(self, batch_tokens: torch.Tensor): names_filter=act_names, stop_at_layer=hook_point_max_layer + 1, prepend_bos=self.prepend_bos, + **self.model_kwargs, )[1] activations_list = [layerwise_activations[act_name] for act_name in act_names] if self.hook_point_head_index is not None: diff --git a/sae_lens/training/cache_activations_runner.py b/sae_lens/training/cache_activations_runner.py index 2c939317..db9b470f 100644 --- a/sae_lens/training/cache_activations_runner.py +++ b/sae_lens/training/cache_activations_runner.py @@ -3,16 +3,19 @@ import torch from tqdm import tqdm -from transformer_lens import HookedTransformer from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.config import CacheActivationsRunnerConfig +from sae_lens.training.load_model import load_model from sae_lens.training.utils import shuffle_activations_pairwise def cache_activations_runner(cfg: CacheActivationsRunnerConfig): - model = HookedTransformer.from_pretrained(cfg.model_name) - model.to(cfg.device) + model = load_model( + model_class_name=cfg.model_class_name, + model_name=cfg.model_name, + device=cfg.device, + ) activations_store = ActivationsStore.from_config( model, cfg, @@ -28,11 +31,11 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): else: os.makedirs(activations_store.cached_activations_path) - print(f"Started caching {cfg.total_training_tokens} activations") + print(f"Started caching {cfg.training_tokens} activations") tokens_per_buffer = ( cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer ) - n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer) + n_buffers = math.ceil(cfg.training_tokens / tokens_per_buffer) # for i in tqdm(range(n_buffers), desc="Caching activations"): for i in range(n_buffers): buffer = activations_store.get_buffer(cfg.n_batches_in_buffer) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index ae9fa00a..276a9c75 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -1,8 +1,7 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional, cast import torch - import wandb DTYPE_MAP = { @@ -21,7 +20,9 @@ class LanguageModelSAERunnerConfig: # Data Generating Function (Model + Training Distibuion) model_name: str = "gelu-2l" + model_class_name: str = "HookedTransformer" hook_point: str = "blocks.{layer}.hook_mlp_out" + hook_point_eval: str = "blocks.{layer}.attn.pattern" hook_point_layer: int | list[int] = 0 hook_point_head_index: Optional[int] = None dataset_path: str = "NeelNanda/c4-tokenized-2b" @@ -45,7 +46,8 @@ class LanguageModelSAERunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 + finetuning_tokens: int = 0 store_batch_size: int = 32 train_batch_size: int = 4096 @@ -56,11 +58,20 @@ class LanguageModelSAERunnerConfig: prepend_bos: bool = True # Training Parameters + + ## Batch size + train_batch_size: int = 4096 + + ## Adam adam_beta1: float | list[float] = 0 adam_beta2: float | list[float] = 0.999 + + ## Loss Function mse_loss_normalization: Optional[str] = None l1_coefficient: float | list[float] = 1e-3 lp_norm: float | list[float] = 1 + + ## Learning Rate Schedule lr: float | list[float] = 3e-4 lr_scheduler_name: str | list[str] = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts @@ -71,7 +82,9 @@ class LanguageModelSAERunnerConfig: ) lr_decay_steps: int | list[int] = 0 n_restart_cycles: int | list[int] = 1 # used only for cosineannealingwarmrestarts - train_batch_size: int = 4096 + + ## FineTuning + finetuning_method: Optional[str] = None # scale, decoder or unrotated_decoder # Resampling protocol args use_ghost_grads: bool | list[bool] = ( @@ -93,6 +106,7 @@ class LanguageModelSAERunnerConfig: n_checkpoints: int = 0 checkpoint_path: str = "checkpoints" verbose: bool = True + model_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): if self.use_cached_activations and self.cached_activations_path is None: @@ -110,7 +124,7 @@ def __post_init__(self): ) if self.run_name is None: - self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: raise ValueError( @@ -128,6 +142,12 @@ def __post_init__(self): elif isinstance(self.dtype, str): self.dtype: torch.dtype = DTYPE_MAP[self.dtype] + # if we use decoder fine tuning, we can't be applying b_dec to the input + if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input): + raise ValueError( + "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False." + ) + self.device: str | torch.device = torch.device(self.device) if self.lr_end is None: @@ -143,7 +163,7 @@ def __post_init__(self): if self.verbose: print( - f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" ) # Print out some useful info: n_tokens_per_buffer = ( @@ -155,7 +175,9 @@ def __post_init__(self): f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" ) - total_training_steps = self.total_training_tokens // self.train_batch_size + total_training_steps = ( + self.training_tokens + self.finetuning_tokens + ) // self.train_batch_size print(f"Total training steps: {total_training_steps}") total_wandb_updates = total_training_steps // self.wandb_log_frequency @@ -192,6 +214,7 @@ class CacheActivationsRunnerConfig: # Data Generating Function (Model + Training Distibuion) model_name: str = "gelu-2l" + model_class_name: str = "HookedTransformer" hook_point: str = "blocks.{layer}.hook_mlp_out" hook_point_layer: int | list[int] = 0 hook_point_head_index: Optional[int] = None @@ -207,7 +230,7 @@ class CacheActivationsRunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 store_batch_size: int = 32 train_batch_size: int = 4096 @@ -222,6 +245,7 @@ class CacheActivationsRunnerConfig: n_shuffles_with_last_section: int = 10 n_shuffles_in_entire_dir: int = 10 n_shuffles_final: int = 100 + model_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): # Autofill cached_activations_path unless the user overrode it diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 25a59e3f..9742512d 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -3,10 +3,9 @@ import pandas as pd import torch -from transformer_lens import HookedTransformer -from transformer_lens.utils import get_act_name - import wandb +from transformer_lens.hook_points import HookedRootModule + from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sparse_autoencoder import SparseAutoencoder @@ -15,14 +14,16 @@ def run_evals( sparse_autoencoder: SparseAutoencoder, activation_store: ActivationsStore, - model: HookedTransformer, + model: HookedRootModule, n_training_steps: int, suffix: str = "", ) -> Mapping[str, Any]: hook_point = sparse_autoencoder.cfg.hook_point hook_point_layer = sparse_autoencoder.hook_point_layer hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index - + hook_point_eval = sparse_autoencoder.cfg.hook_point_eval.format( + layer=hook_point_layer + ) ### Evals eval_tokens = activation_store.get_batch_tokens() @@ -43,7 +44,8 @@ def run_evals( _, cache = model.run_with_cache( eval_tokens, prepend_bos=False, - names_filter=[get_act_name("pattern", hook_point_layer), hook_point], + names_filter=[hook_point_eval, hook_point], + **sparse_autoencoder.cfg.model_kwargs, ) has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] @@ -62,7 +64,9 @@ def run_evals( l2_norm_in = torch.norm(original_act, dim=-1) l2_norm_out = torch.norm(sae_out, dim=-1) - l2_norm_ratio = l2_norm_out / l2_norm_in + l2_norm_in_for_div = l2_norm_in.clone() + l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 + l2_norm_ratio = l2_norm_out / l2_norm_in_for_div metrics = { # l2 norms @@ -86,7 +90,7 @@ def run_evals( def recons_loss_batched( sparse_autoencoder: SparseAutoencoder, - model: HookedTransformer, + model: HookedRootModule, activation_store: ActivationsStore, n_batches: int = 100, ): @@ -115,11 +119,13 @@ def recons_loss_batched( @torch.no_grad() def get_recons_loss( sparse_autoencoder: SparseAutoencoder, - model: HookedTransformer, + model: HookedRootModule, batch_tokens: torch.Tensor, ): hook_point = sparse_autoencoder.cfg.hook_point - loss = model(batch_tokens, return_type="loss") + loss = model( + batch_tokens, return_type="loss", **sparse_autoencoder.cfg.model_kwargs + ) head_index = sparse_autoencoder.cfg.hook_point_head_index def standard_replacement_hook(activations: torch.Tensor, hook: Any): @@ -157,13 +163,20 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any): batch_tokens, return_type="loss", fwd_hooks=[(hook_point, partial(replacement_hook))], + **sparse_autoencoder.cfg.model_kwargs, ) zero_abl_loss = model.run_with_hooks( - batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)] + batch_tokens, + return_type="loss", + fwd_hooks=[(hook_point, zero_ablate_hook)], + **sparse_autoencoder.cfg.model_kwargs, ) - score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) + div_val = zero_abl_loss - loss + div_val[torch.abs(div_val) < 0.0001] = 1.0 + + score = (zero_abl_loss - recons_loss) / div_val return score, loss, recons_loss, zero_abl_loss diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index c6a4ce79..4d7ea440 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -1,6 +1,7 @@ from typing import Any, cast import wandb + from sae_lens.training.config import LanguageModelSAERunnerConfig from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/training/load_model.py b/sae_lens/training/load_model.py new file mode 100644 index 00000000..6a41c9df --- /dev/null +++ b/sae_lens/training/load_model.py @@ -0,0 +1,26 @@ +from typing import Any, cast + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule + + +def load_model( + model_class_name: str, model_name: str, device: str | torch.device | None = None +) -> HookedRootModule: + if model_class_name == "HookedTransformer": + return HookedTransformer.from_pretrained(model_name=model_name, device=device) + elif model_class_name == "HookedMamba": + try: + from mamba_lens import HookedMamba + except ImportError: + raise ValueError( + "mamba-lens must be installed to work with mamba models. This can be added with `pip install sae-lens[mamba]`" + ) + # HookedMamba has incorrect typing information, so we need to cast the type here + return cast( + HookedRootModule, + HookedMamba.from_pretrained(model_name, device=cast(Any, device)), + ) + else: + raise ValueError(f"Unknown model class: {model_class_name}") diff --git a/sae_lens/training/sae_group.py b/sae_lens/training/sae_group.py index fb31cd1e..15b651ba 100644 --- a/sae_lens/training/sae_group.py +++ b/sae_lens/training/sae_group.py @@ -129,7 +129,14 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary" # handle loading old autoencoders where before SAEGroup existed, where we just save a dict if isinstance(group, dict): cfg = group["cfg"] + # need to add this field to old configs + if not hasattr(cfg, "model_kwargs"): + cfg.model_kwargs = {} sparse_autoencoder = SparseAutoencoder(cfg=cfg) + # add dummy scaling factor to the state dict + group["state_dict"]["scaling_factor"] = torch.ones( + cfg.d_sae, dtype=cfg.dtype, device=cfg.device + ) sparse_autoencoder.load_state_dict(group["state_dict"]) group = cls(cfg) for key in group.autoencoders: @@ -194,10 +201,7 @@ def save_saes(self, path: str): autoencoder.save_model(f"{path}/{i}") def get_name(self): - - sae_name = ( - f"sae_group_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}" - ) + sae_name = f"sae_group_{self.cfg.model_name.replace('/', '_')}_{self.cfg.hook_point}_{self.cfg.d_sae}" return sae_name def eval(self): diff --git a/sae_lens/training/session_loader.py b/sae_lens/training/session_loader.py index 2d632551..8bbfb67c 100644 --- a/sae_lens/training/session_loader.py +++ b/sae_lens/training/session_loader.py @@ -1,9 +1,10 @@ from typing import Tuple -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.load_model import load_model from sae_lens.training.sae_group import SparseAutoencoderDictionary @@ -20,7 +21,7 @@ def __init__(self, cfg: LanguageModelSAERunnerConfig): def load_sae_training_group_session( self, - ) -> Tuple[HookedTransformer, SparseAutoencoderDictionary, ActivationsStore]: + ) -> Tuple[HookedRootModule, SparseAutoencoderDictionary, ActivationsStore]: """ Loads a session for training a sparse autoencoder on a language model. """ @@ -39,7 +40,7 @@ def load_sae_training_group_session( @classmethod def load_pretrained_sae( cls, path: str, device: str = "cpu" - ) -> Tuple[HookedTransformer, SparseAutoencoderDictionary, ActivationsStore]: + ) -> Tuple[HookedRootModule, SparseAutoencoderDictionary, ActivationsStore]: """ Loads a session for analysing a pretrained sparse autoencoder. """ @@ -56,7 +57,7 @@ def load_pretrained_sae( return model, sparse_autoencoders, activations_loader - def get_model(self, model_name: str) -> HookedTransformer: + def get_model(self, model_name: str) -> HookedRootModule: """ Loads a model from transformer lens. @@ -65,9 +66,7 @@ def get_model(self, model_name: str) -> HookedTransformer: # Todo: add check that model_name is valid - model = HookedTransformer.from_pretrained( - model_name, - device=self.cfg.device, + model = load_model( + self.cfg.model_class_name, model_name, device=self.cfg.device ) - return model diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index 43900133..c292e09b 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -122,6 +122,11 @@ def tanh_relu(input: torch.Tensor) -> torch.Tensor: torch.zeros(self.d_in, dtype=self.dtype, device=self.device) ) + # scaling factor for fine-tuning (not to be used in initial training) + self.scaling_factor = nn.Parameter( + torch.ones(self.d_sae, dtype=self.dtype, device=self.device) + ) + self.hook_sae_in = HookPoint() self.hook_hidden_pre = HookPoint() self.hook_hidden_post = HookPoint() @@ -150,7 +155,8 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) sae_out = self.hook_sae_out( einops.einsum( - feature_acts, + feature_acts + * self.scaling_factor, # need to make sure this handled when loading old models. self.W_dec, "... d_sae, d_sae d_in -> ... d_in", ) @@ -356,6 +362,14 @@ def load_from_pretrained(cls, path: str, device: str = "cpu"): with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore for k in f.keys(): tensors[k] = f.get_tensor(k) + + # old saves may not have scaling factors. + if "scaling_factor" not in tensors: + assert isinstance(config.d_sae, int) + tensors["scaling_factor"] = torch.ones( + config.d_sae, dtype=config.dtype, device=config.device + ) + sae.load_state_dict(tensors) return sae diff --git a/sae_lens/training/toy_model_runner.py b/sae_lens/training/toy_model_runner.py index 29b436a6..83044c3c 100644 --- a/sae_lens/training/toy_model_runner.py +++ b/sae_lens/training/toy_model_runner.py @@ -3,8 +3,8 @@ import einops import torch - import wandb + from sae_lens.training.sparse_autoencoder import SparseAutoencoder from sae_lens.training.toy_models import Config as ToyConfig from sae_lens.training.toy_models import Model as ToyModel diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 54e49c0a..726491b5 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -3,13 +3,13 @@ from typing import Any, cast import torch +import wandb from safetensors.torch import save_file from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.evals import run_evals from sae_lens.training.geometric_median import compute_geometric_median @@ -17,6 +17,13 @@ from sae_lens.training.sae_group import SparseAutoencoderDictionary from sae_lens.training.sparse_autoencoder import SparseAutoencoder +# used to map between parameters which are updated during finetuning and the config str. +FINETUNING_PARAMETERS = { + "scale": ["scaling_factor"], + "decoder": ["scaling_factor", "W_dec", "b_dec"], + "unrotated_decoder": ["scaling_factor", "b_dec"], +} + def _log_feature_sparsity( feature_sparsity: torch.Tensor, eps: float = 1e-10 @@ -35,6 +42,7 @@ class SAETrainContext: n_frac_active_tokens: int optimizer: Optimizer scheduler: LRScheduler + finetuning: bool = False @property def feature_sparsity(self) -> torch.Tensor: @@ -44,6 +52,21 @@ def feature_sparsity(self) -> torch.Tensor: def log_feature_sparsity(self) -> torch.Tensor: return _log_feature_sparsity(self.feature_sparsity) + def begin_finetuning(self, sae: SparseAutoencoder): + + # finetuning method should be set in the config + # if not, then we don't finetune + if not isinstance(sae.cfg.finetuning_method, str): + return + + for name, param in sae.named_parameters(): + if name in FINETUNING_PARAMETERS[sae.cfg.finetuning_method]: + param.requires_grad = True + else: + param.requires_grad = False + + self.finetuning = True + @dataclass class TrainSAEGroupOutput: @@ -53,7 +76,7 @@ class TrainSAEGroupOutput: def train_sae_on_language_model( - model: HookedTransformer, + model: HookedRootModule, sae_group: SparseAutoencoderDictionary, activation_store: ActivationsStore, batch_size: int = 1024, @@ -79,7 +102,7 @@ def train_sae_on_language_model( def train_sae_group_on_language_model( - model: HookedTransformer, + model: HookedRootModule, sae_group: SparseAutoencoderDictionary, activation_store: ActivationsStore, batch_size: int = 1024, @@ -88,10 +111,13 @@ def train_sae_group_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, ) -> TrainSAEGroupOutput: - total_training_tokens = sae_group.cfg.total_training_tokens + total_training_tokens = ( + sae_group.cfg.training_tokens + sae_group.cfg.finetuning_tokens + ) total_training_steps = total_training_tokens // batch_size n_training_steps = 0 n_training_tokens = 0 + started_fine_tuning = False checkpoint_thresholds = [] if n_checkpoints > 0: @@ -180,6 +206,16 @@ def train_sae_group_on_language_model( ) pbar.update(batch_size) + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + if (not started_fine_tuning) and ( + n_training_tokens > sae_group.cfg.training_tokens + ): + started_fine_tuning = True + for name, sparse_autoencoder in sae_group.autoencoders.items(): + ctx = train_contexts[name] + # this should turn grads on for the scaling factor and other parameters. + ctx.begin_finetuning(sae_group.autoencoders[name]) + # save final sae group to checkpoints folder final_checkpoint = _save_checkpoint( sae_group, @@ -248,6 +284,12 @@ def _build_train_context( ) n_frac_active_tokens = 0 + # we don't train the scaling factor (initially) + # set requires grad to false for the scaling factor + for name, param in sae.named_parameters(): + if "scaling_factor" in name: + param.requires_grad = False + optimizer = Adam( sae.parameters(), lr=sae.cfg.lr, diff --git a/sae_lens/training/train_sae_on_toy_model.py b/sae_lens/training/train_sae_on_toy_model.py index 029e9a70..58bb7cda 100644 --- a/sae_lens/training/train_sae_on_toy_model.py +++ b/sae_lens/training/train_sae_on_toy_model.py @@ -1,10 +1,10 @@ from typing import Any, cast import torch +import wandb from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 8e6cc3a4..f13aac59 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -24,14 +24,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using device: mps\n" + "Using device: cuda\n" ] } ], @@ -60,273 +60,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Gelu-2L\n", - "\n", - "An example of a toy language model we're able to train on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### MLP Out" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gelu-2l\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=[16, 32, 64],\n", - " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " l1_coefficient=0.00016,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 100,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=5000,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-4,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GPT2 - Small" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Residual Stream" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "layer = 3\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", - " hook_point_layer=layer,\n", - " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=32, # determines the dimension of the SAE.\n", - " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0004,\n", - " l1_coefficient=0.00008,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " lr_warm_up_steps=5000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=2500,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=10,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70-M" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "import cProfile\n", - "\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=64,\n", - " # Training Parameters\n", - " lr=3e-4,\n", - " l1_coefficient=4e-5,\n", - " train_batch_size=8192,\n", - " context_size=128,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=10_000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=1_000_000 * 800,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_window=2000, # Doesn't currently matter.\n", - " dead_feature_window=40000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=20,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70M Hook Q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"../\")\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.2.attn.hook_q\",\n", - " hook_point_layer=2,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " l1_coefficient=0.003,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=1000, # about 4 million tokens.\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 1500,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_method=\"anthropic\",\n", - " feature_sampling_window=1000, # doesn't do anything currently.\n", - " feature_reinit_scale=0.2,\n", - " resample_batches=8,\n", - " dead_feature_window=60000,\n", - " dead_feature_threshold=1e-5,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " n_checkpoints=15,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tiny Stories" + "# Tiny Stories - 1L" ] }, { @@ -338,49 +72,218 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n", + "Moving model to device: cuda\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_135218-opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Objective value: 1781464.6250: 4%|▍ | 4/100 [00:00<00:00, 206.25it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "135| MSE Loss 0.257 | L1 1.354: 1%| | 552960/50000000 [00:13<19:08, 43042.90it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (v0kr8hz9) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "6104| MSE Loss 0.072 | L1 0.024: : 25001984it [18:07, 22981.57it/s]\n", + "12208| MSE Loss 0.070 | L1 0.024: 100%|█████████▉| 49999872/50000000 [20:15<00:00, 30551.50it/s]" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb94759b99e14133aece0058a423e305", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='128.448 MB of 128.448 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▄▅▆▆▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▅▂▅▅█▆▆▇▅▅▆▅▄▅▅▄▅▃▄▄▄▄▃▆▆▄▁▄▆▃▆▃▅▆▂▃▆▄▃▅
metrics/ce_loss_with_sae█▅▅▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▆▂▂▁▇▄▇▆▃▃▅▅▆▇▃▆▃▄▄▄█▂▂▄▄▃▁▅▄▄▂▅▄█▃▄▄▄▅▆
metrics/explained_variance▁▄▄▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/explained_variance_std▆▁▄▇███▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▄▅▄▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/l0█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm█▁▂▃▄▄▃▅▅▄▅▅▄▅▄▅▅▅▅▄▆▅▆▅▆▆▅▆▆▇█▇▇▆▆▇▆▆▇▇
metrics/l2_ratio█▁▂▃▃▃▃▄▄▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆▅▆▆▆▆▆▆▅▆▆▅▆▆▆
metrics/mean_log10_feature_sparsity█▅▃▃▂▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▁▂▆███▅██
sparsity/below_1e-6▁▁▁▁▁▁█▇█▂▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▅▅▁▁▁▁▅▁▁▁▁▁▅▅▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▂▂▂▃▂▄▄▅▇▇▇▇██▆▇▆███▆▅▅█▇▇▇▇██

Run summary:


details/current_learning_rate0.0008
details/n_training_tokens49971200
losses/ghost_grad_loss0.0
losses/l1_loss15.59199
losses/mse_loss0.07019
losses/overall_loss0.09358
metrics/CE_loss_score0.86351
metrics/ce_loss_with_ablation8.5168
metrics/ce_loss_with_sae3.00156
metrics/ce_loss_without_sae2.12988
metrics/explained_variance0.56934
metrics/explained_variance_std0.14386
metrics/l019.32129
metrics/l2_norm15.93428
metrics/l2_ratio0.86545
metrics/mean_log10_feature_sparsity-4.81775
sparsity/below_1e-56329
sparsity/below_1e-681
sparsity/dead_features0
sparsity/mean_passes_since_fired29.02307

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 at: https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl
View project at: https://wandb.ai/jbloom/sae_lens_tutorial
Synced 7 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240416_135218-opqs9dgl/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12208| MSE Loss 0.070 | L1 0.024: : 50003968it [20:27, 30551.50it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (opqs9dgl) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n" + ] + } + ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.mlp.hook_post\",\n", - " hook_point_layer=1,\n", - " d_in=256,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_point=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_point_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", " is_dataset_tokenized=True,\n", " # SAE Parameters\n", - " expansion_factor=16,\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"geometric_median\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder to the input.\n", " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", + " lr=0.0008, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=10000, # this can help avoid too many dead features initially.\n", + " l1_coefficient=0.0015, # will control how sparse the feature activations are\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=128, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=1_000_000\n", + " * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 25,\n", " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", + " # Resampling protocol\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", " wandb_log_frequency=10,\n", " # Misc\n", " device=device,\n", @@ -390,82 +293,145 @@ " dtype=torch.float32,\n", ")\n", "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder_dictionary = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT2 - Small" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Hook Z\n", - "\n" + "### Residual Stream" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", + "Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08\n", + "n_tokens_per_buffer (millions): 1.048576\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", + "Total training steps: 48828\n", "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", - "Moving model to device: mps\n" + "n_tokens_per_feature_sampling_window (millions): 2621.44\n", + "n_tokens_per_dead_feature_window (millions): 5242.88\n", + "We will reset the sparsity calculation 19 times.\n", + "Number tokens in sparsity calculation window: 1.02e+07\n", + "Loaded pretrained model gpt2-small into HookedTransformer\n", + "Moving model to device: cuda\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fee8922d83f04003a2f1441eeb30200d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/73 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea686292dff7449a9846fcfa29d6ff74", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.064 MB of 0.064 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▂▂▃▄▄▅▅▆▆▇▇███████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇▇█████████
metrics/ce_loss_with_ablation▄▂▁▄▄▂▄▄▁▄▁▄▂█
metrics/ce_loss_with_sae█▄▂▂▂▂▁▂▁▁▁▁▁▁
metrics/ce_loss_without_sae▂▇▄█▅▃▄▇▆▅▄▆▁▁
metrics/explained_variance▁▃▄▆▆▇▇▇▇▇▇▇████████████████████████████
metrics/explained_variance_std█▄█▇▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▅▆▇▇▇█▇██████
metrics/l2_ratio▁▅▆▇▇▇▇▇██████
metrics/mean_log10_feature_sparsity█▃▂▁▁
sparsity/below_1e-5▁▁▅██
sparsity/below_1e-6▁▁▁▄█
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▅▆█
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▇▇█

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens59801600
losses/ghost_grad_loss0.0
losses/l1_loss160.66861
losses/mse_loss1.68098
losses/overall_loss2.96633
metrics/CE_loss_score0.96258
metrics/ce_loss_with_ablation11.49633
metrics/ce_loss_with_sae3.62324
metrics/ce_loss_without_sae3.3166
metrics/explained_variance0.78709
metrics/explained_variance_std0.05978
metrics/l050.03076
metrics/l2_norm102.32782
metrics/l2_ratio0.8864
metrics/mean_log10_feature_sparsity-5.31744
sparsity/below_1e-519194
sparsity/below_1e-611736
sparsity/dead_features60
sparsity/mean_passes_since_fired640.44727

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "wandb version 0.16.5 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/pq5q3x9s
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)" ], "text/plain": [ "" @@ -477,7 +443,7 @@ { "data": { "text/html": [ - "Tracking run with wandb version 0.16.3" + "Find logs at: ./wandb/run-20240416_155117-pq5q3x9s/logs" ], "text/plain": [ "" @@ -489,7 +455,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" + "Successfully finished last run (ID:pq5q3x9s). Initializing new run:
" ], "text/plain": [ "" @@ -498,10 +464,24 @@ "metadata": {}, "output_type": "display_data" }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fd02fd0295cc4afda9bb0e1367c87f84", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112805799995032, max=1.0…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ - "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" + "Tracking run with wandb version 0.16.6" ], "text/plain": [ "" @@ -513,7 +493,7 @@ { "data": { "text/html": [ - " View project at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests" + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_165827-vbwoyzi8" ], "text/plain": [ "" @@ -525,7 +505,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" + "Syncing run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -535,91 +515,56 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", - "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", - "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", - "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", - "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" - ] + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/gpt2_small_experiments_april" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "name": "stdout", + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", "output_type": "stream", "text": [ - "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" + "Objective value: 46608928.0000: 2%|▏ | 2/100 [00:00<00:01, 55.75it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "120| MSE Loss 31.151 | L1 65.750: 0%| | 487424/300000000 [00:15<1:28:10, 56617.16it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (4elmsny3) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "2407| MSE Loss 0.070 | L1 0.027: 20%|█▉ | 9859072/50000000 [3:33:05<14:27:36, 771.10it/s]\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:02, 54947.70it/s] " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1dffd84a387d4cf48100fbe143287481", + "model_id": "6111ba99afb144ae82bab7723efb2c86", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" + "VBox(children=(Label(value='721.959 MB of 721.959 MB uploaded (0.005 MB deduped)\\r'), FloatProgress(value=1.0,…" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" - ] - }, { "data": { "text/html": [ @@ -628,7 +573,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" + "

Run history:


details/current_learning_rate▁▄▆█████████████████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇████████████████████████████████████
metrics/ce_loss_with_ablation▅▃▄▃▅▄▄▃▄▄▂▂▂▃▄▃▂▆▃▆▃▆▃▆▅▁▁▅▂▃▃▄▃▅▄▃█▄▃▆
metrics/ce_loss_with_sae█▄▂▂▂▂▂▁▂▁▁▂▂▁▁▁▂▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▁▅▆▂▆▄▅▁▄▄▄▇▆▄▃▄▇▄▂▆▆█▄▅▅▄▄▄▅▄▅▅▃▅▄▅▂▃▂▄
metrics/explained_variance▁▆▇▇▇▇▇█████████████████████████████████
metrics/explained_variance_std█▄▃▂▂▂▂▁▂▁▂▁▁▂▂▁▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▆▆▇▆▆▆▇▆█████████████
metrics/l2_ratio▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆█████████████
metrics/mean_log10_feature_sparsity█▅▄▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▅██████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
sparsity/below_1e-6▁▁▁▃▅▆▇██████████████████████
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▂▂▃▄▅▆▆▇▇▇▇███████████████████
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens299827200
losses/ghost_grad_loss0.0
losses/l1_loss162.07342
losses/mse_loss1.42934
losses/overall_loss2.72593
metrics/CE_loss_score0.97257
metrics/ce_loss_with_ablation11.42603
metrics/ce_loss_with_sae3.61949
metrics/ce_loss_without_sae3.39944
metrics/explained_variance0.82112
metrics/explained_variance_std0.0526
metrics/l050.53198
metrics/l2_norm108.35806
metrics/l2_ratio0.94604
metrics/mean_log10_feature_sparsity-7.89094
sparsity/below_1e-518079
sparsity/below_1e-618075
sparsity/dead_features16912
sparsity/mean_passes_since_fired27024.85938

" ], "text/plain": [ "" @@ -640,7 +585,7 @@ { "data": { "text/html": [ - " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 15 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" @@ -652,7 +597,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" + "Find logs at: ./wandb/run-20240416_165827-vbwoyzi8/logs" ], "text/plain": [ "" @@ -665,223 +610,52 @@ "name": "stderr", "output_type": "stream", "text": [ - "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:12, 54947.70it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (vbwoyzi8) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", " lambda data: self._console_raw_callback(\"stderr\", data),\n" ] } ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.attn.hook_z\",\n", - " hook_point_layer=1,\n", - " d_in=64,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", - " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Toy Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.toy_model_runner import (\n", - " SAEToyModelRunnerConfig,\n", - " toy_model_sae_runner,\n", - ")\n", - "\n", - "\n", - "cfg = SAEToyModelRunnerConfig(\n", - " # Model Details\n", - " n_features=200,\n", - " n_hidden=5,\n", - " n_correlated_pairs=0,\n", - " n_anticorrelated_pairs=0,\n", - " feature_probability=0.025,\n", - " model_training_steps=10_000,\n", - " # SAE Parameters\n", - " d_sae=240,\n", - " l1_coefficient=0.001,\n", - " # SAE Train Config\n", - " train_batch_size=1028,\n", - " feature_sampling_window=3_000,\n", - " dead_feature_window=1_000,\n", - " feature_reinit_scale=0.5,\n", - " total_training_tokens=4096 * 300,\n", - " # Other parameters\n", - " log_to_wandb=True,\n", - " wandb_project=\"sae-training-test\",\n", - " wandb_log_frequency=5,\n", - " device=\"mps\",\n", - ")\n", - "\n", - "trained_sae = toy_model_sae_runner(cfg)\n", - "\n", - "assert trained_sae is not None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run caching of activations to disk" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "\n", - "from sae_lens.training.config import CacheActivationsRunnerConfig\n", - "from sae_lens.training.cache_activations_runner import cache_activations_runner\n", - "\n", - "cfg = CacheActivationsRunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.attn.hook_q\",\n", - " hook_point_layer=10,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " cached_activations_path=\"../activations/\",\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=16,\n", - " total_training_tokens=500_000_000,\n", - " store_batch_size=32,\n", - " # Activation caching shuffle parameters\n", - " n_shuffles_final=16,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "cache_activations_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train an SAE using the cached activations stored on disk\n", - "Pass `use_cached_activations=True` into the config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.hook_resid_pre\",\n", - " hook_point_layer=11,\n", + " hook_point=\"blocks.8.hook_resid_pre\",\n", + " hook_point_layer=8,\n", " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " use_cached_activations=True,\n", + " dataset_path=\"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " prepend_bos=True, # should experiment with turning this off.\n", " # SAE Parameters\n", - " expansion_factor=64, # determines the dimension of the SAE.\n", + " expansion_factor=32, # determines the dimension of the SAE.\n", + " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", + " apply_b_dec_to_input=False,\n", " # Training Parameters\n", - " lr=1e-5,\n", - " l1_coefficient=5e-4,\n", - " lr_scheduler_name=None,\n", + " adam_beta1=0,\n", + " adam_beta2=0.999,\n", + " lr=0.0004,\n", + " l1_coefficient=0.008,\n", + " lr_scheduler_name=\"constant\",\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=256,\n", + " lr_warm_up_steps=5000,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=200_000,\n", + " n_batches_in_buffer=128,\n", + " training_tokens=1_000_000 * 200, # 200M tokens seems doable overnight.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 100,\n", " store_batch_size=32,\n", " # Resampling protocol\n", - " feature_sampling_method=\"l2\",\n", - " feature_sampling_window=1000,\n", - " feature_reinit_scale=0.2,\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=2500,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-7,\n", + " dead_feature_threshold=1e-8,\n", " # WANDB\n", " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_gpt2_small\",\n", + " wandb_project=\"gpt2_small_experiments_april\",\n", " wandb_entity=None,\n", - " wandb_log_frequency=50,\n", + " wandb_log_frequency=100,\n", " # Misc\n", - " device=\"mps\",\n", + " device=device,\n", " seed=42,\n", " n_checkpoints=5,\n", " checkpoint_path=\"checkpoints\",\n", @@ -890,13 +664,6 @@ "\n", "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -915,7 +682,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index e0030ce5..e1c878c0 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -30,7 +30,7 @@ def test_language_model_sae_runner(): context_size=128, # Activation Store Parameters n_batches_in_buffer=24, - total_training_tokens=1_000_000 * 10, + training_tokens=1_000_000 * 10, store_batch_size=32, # Resampling protocol use_ghost_grads=True, diff --git a/tests/unit/analysis/test_neuronpedia_integration.py b/tests/unit/analysis/test_neuronpedia_integration.py new file mode 100644 index 00000000..6b6c9c28 --- /dev/null +++ b/tests/unit/analysis/test_neuronpedia_integration.py @@ -0,0 +1,11 @@ +from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_feature + + +def test_get_neuronpedia_feature(): + result = get_neuronpedia_feature( + feature=0, layer=0, model="gpt2-small", dataset="res-jb" + ) + + assert result["modelId"] == "gpt2-small" + assert result["layer"] == "0-res-jb" + assert result["index"] == 0 diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 4d18b6ab..0f172ab5 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -32,7 +32,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: feature_sampling_window=50, dead_feature_threshold=1e-7, n_batches_in_buffer=2, - total_training_tokens=1_000_000, + training_tokens=1_000_000, store_batch_size=4, log_to_wandb=False, wandb_project="test_project", diff --git a/tests/unit/training/test_load_model.py b/tests/unit/training/test_load_model.py new file mode 100644 index 00000000..4a42f3eb --- /dev/null +++ b/tests/unit/training/test_load_model.py @@ -0,0 +1,13 @@ +from mamba_lens import HookedMamba + +from sae_lens.training.load_model import load_model + + +def test_load_model_works_with_mamba(): + model = load_model( + model_class_name="HookedMamba", + model_name="state-spaces/mamba-370m", + device="cpu", + ) + assert model is not None + assert isinstance(model, HookedMamba) diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..73d69430 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -5,11 +5,11 @@ import pytest import torch +import wandb from datasets import Dataset from torch import Tensor from transformer_lens import HookedTransformer -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import get_scheduler from sae_lens.training.sae_group import SparseAutoencoderDictionary @@ -26,6 +26,7 @@ from tests.unit.helpers import build_sae_cfg +# TODO: Address why we have this code here rather than importing it. def build_train_ctx( sae: SparseAutoencoder, act_freq_scores: Tensor | None = None, @@ -310,11 +311,11 @@ def test_train_sae_group_on_language_model__runs( cfg = build_sae_cfg( checkpoint_path=checkpoint_dir, train_batch_size=32, - total_training_tokens=100, + training_tokens=100, context_size=8, ) # just a tiny datast which will run quickly - dataset = Dataset.from_list([{"text": "hello world"}] * 1000) + dataset = Dataset.from_list([{"text": "hello world"}] * 2000) activation_store = ActivationsStore.from_config(ts_model, cfg, dataset=dataset) sae_group = SparseAutoencoderDictionary(cfg) res = train_sae_group_on_language_model( diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py new file mode 100644 index 00000000..19ff5074 --- /dev/null +++ b/tutorials/mamba_train_example.py @@ -0,0 +1,60 @@ +# install from https://github.com/Phylliida/MambaLens +import os +import sys + +import torch + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +# run this as python3 tutorials/mamba_train_example.py +# i.e. from the root directory +from sae_lens.training.config import LanguageModelSAERunnerConfig + +cfg = LanguageModelSAERunnerConfig( + # Data Generating Function (Model + Training Distibuion) + model_name="state-spaces/mamba-370m", + model_class_name="HookedMamba", + hook_point="blocks.39.hook_ssm_input", + hook_point_layer=39, + hook_point_eval="blocks.39.hook_ssm_output", # we compare this when replace hook_point activations with autoencode.decode(autoencoder.encode( hook_point activations)) + d_in=2048, + dataset_path="NeelNanda/openwebtext-tokenized-9b", + is_dataset_tokenized=True, + # SAE Parameters + expansion_factor=64, + b_dec_init_method="geometric_median", + # Training Parameters + lr=0.0004, + l1_coefficient=0.00006 * 0.2, + lr_scheduler_name="cosineannealingwarmrestarts", + train_batch_size=4096, + context_size=128, + lr_warm_up_steps=5000, + # Activation Store Parameters + n_batches_in_buffer=128, + training_tokens=1_000_000 * 300, + store_batch_size=32, + # Dead Neurons and Sparsity + use_ghost_grads=True, + feature_sampling_window=1000, + dead_feature_window=5000, + dead_feature_threshold=1e-6, + # WANDB + log_to_wandb=True, + wandb_project="sae_training_mamba", + wandb_entity=None, + wandb_log_frequency=100, + # Misc + device="cuda", + seed=42, + checkpoint_path="checkpoints", + dtype=torch.float32, + model_kwargs={ + "fast_ssm": True, + "fast_conv": True, + }, +) + +from sae_lens.training.lm_runner import language_model_sae_runner + +language_model_sae_runner(cfg) diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index 79474377..a208f77c 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -22,27 +22,17 @@ "metadata": {}, "outputs": [], "source": [ - "# from huggingface_hub import hf_hub_download\n", - "\n", - "# MODEL = \"gpt2-small\"\n", - "# LAYER = 0\n", - "# SOURCE = \"res-jb\"\n", - "# REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "# FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{LAYER}.hook_resid_pre_24576.pt\"\n", - "# SAE_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", - "\n", - "# Change these\n", - "MODEL = \"pythia-70m-deduped\"\n", - "LAYER = 0\n", - "TYPE = \"resid\"\n", - "SOURCE_AUTHOR_SUFFIX = \"sm\"\n", - "SOURCE = \"res-sm\"\n", - "\n", - "# Change these depending on how your files are named\n", - "SAE_PATH = f\"../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt\"\n", - "FEATURE_SPARSITY_PATH = (\n", - " f\"../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt\"\n", - ")" + "from sae_lens.toolkit.pretrained_saes import download_sae_from_hf\n", + "import os\n", + "\n", + "MODEL_ID = \"gpt2-small\"\n", + "SAE_ID = \"res-jb\"\n", + "\n", + "(_, SAE_WEIGHTS_PATH, _) = download_sae_from_hf(\n", + " \"jbloom/GPT2-Small-SAEs-Reformatted\", \"blocks.0.hook_resid_pre\"\n", + ")\n", + "\n", + "SAE_PATH = os.path.dirname(SAE_WEIGHTS_PATH)" ] }, { @@ -60,21 +50,21 @@ "source": [ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", - "NP_OUTPUT_FOLDER = \"../neuronpedia_outputs\"\n", + "print(SAE_PATH)\n", + "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs/my_outputs\"\n", "\n", "runner = NeuronpediaRunner(\n", + " sae_id=SAE_ID,\n", " sae_path=SAE_PATH,\n", - " feature_sparsity_path=FEATURE_SPARSITY_PATH,\n", - " neuronpedia_parent_folder=NP_OUTPUT_FOLDER,\n", - " init_session=True,\n", + " outputs_dir=NP_OUTPUT_FOLDER,\n", + " sparsity_threshold=-5,\n", " n_batches_to_sample_from=2**12,\n", - " n_prompts_to_select=4096 * 6,\n", - " n_features_at_a_time=512,\n", - " buffer_tokens_left=64,\n", - " buffer_tokens_right=63,\n", - " start_batch_inclusive=22,\n", - " end_batch_inclusive=23,\n", + " n_prompts_to_select=4096*6,\n", + " n_features_at_a_time=24,\n", + " start_batch_inclusive=1,\n", + " end_batch_inclusive=1,\n", ")\n", + "\n", "runner.run()" ] }, @@ -100,8 +90,7 @@ "import os\n", "import requests\n", "\n", - "folder_path = runner.neuronpedia_folder\n", - "\n", + "FEATURE_OUTPUTS_FOLDER = runner.outputs_dir\n", "\n", "def nanToNeg999(obj: Any) -> Any:\n", " if isinstance(obj, dict):\n", @@ -120,13 +109,12 @@ "\n", "# Server info\n", "host = \"http://localhost:3000\"\n", - "sourceName = str(LAYER) + \"-\" + SOURCE\n", "\n", "# Upload alive features\n", - "for file_name in os.listdir(folder_path):\n", + "for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):\n", " if file_name.startswith(\"batch-\") and file_name.endswith(\".json\"):\n", " print(\"Uploading file: \" + file_name)\n", - " file_path = os.path.join(folder_path, file_name)\n", + " file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)\n", " f = open(file_path, \"r\")\n", " data = json.load(f)\n", "\n", @@ -134,31 +122,21 @@ " data_fixed = json.dumps(data, cls=NanConverter)\n", " data = json.loads(data_fixed)\n", "\n", - " url = host + \"/api/internal/upload-features\"\n", + " url = host + \"/api/local/upload-features\"\n", " resp = requests.post(\n", " url,\n", - " json={\n", - " \"modelId\": MODEL,\n", - " \"layer\": sourceName,\n", - " \"features\": data,\n", - " },\n", + " json=data,\n", " )\n", "\n", - "# Upload dead features (just makes blanks features)\n", - "# We want this for completeness\n", - "# skipped_path = os.path.join(folder_path, \"skipped_indexes.json\")\n", - "# f = open(skipped_path, \"r\")\n", - "# data = json.load(f)\n", - "# skipped_indexes = data[\"skipped_indexes\"]\n", - "# url = host + \"/api/internal/upload-dead-features\"\n", - "# resp = requests.post(\n", - "# url,\n", - "# json={\n", - "# \"modelId\": MODEL,\n", - "# \"layer\": sourceName,\n", - "# \"deadIndexes\": skipped_indexes,\n", - "# },\n", - "# )" + "# Upload dead feature stubs\n", + "skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, \"skipped_indexes.json\")\n", + "f = open(skipped_path, \"r\")\n", + "data = json.load(f)\n", + "url = host + \"/api/local/upload-dead-features\"\n", + "resp = requests.post(\n", + " url,\n", + " json=data,\n", + ")" ] }, { @@ -185,7 +163,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py new file mode 100644 index 00000000..945a2939 --- /dev/null +++ b/tutorials/neuronpedia/make_batch.py @@ -0,0 +1,29 @@ +import sys + +from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner + +# we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +SAE_ID = sys.argv[1] +SAE_PATH = sys.argv[2] +OUTPUTS_DIR = sys.argv[3] +SPARSITY_THRESHOLD = int(sys.argv[4]) +N_BATCHES_SAMPLE = int(sys.argv[5]) +N_PROMPTS_SELECT = int(sys.argv[6]) +FEATURES_AT_A_TIME = int(sys.argv[7]) +START_BATCH_INCLUSIVE = int(sys.argv[8]) +END_BATCH_INCLUSIVE = int(sys.argv[9]) + +runner = NeuronpediaRunner( + sae_id=SAE_ID, + sae_path=SAE_PATH, + outputs_dir=OUTPUTS_DIR, + sparsity_threshold=SPARSITY_THRESHOLD, + n_batches_to_sample_from=N_BATCHES_SAMPLE, + n_prompts_to_select=N_PROMPTS_SELECT, + n_features_at_a_time=FEATURES_AT_A_TIME, + start_batch_inclusive=START_BATCH_INCLUSIVE, + end_batch_inclusive=END_BATCH_INCLUSIVE, +) +runner.run() diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py new file mode 100755 index 00000000..aa36a6b6 --- /dev/null +++ b/tutorials/neuronpedia/neuronpedia.py @@ -0,0 +1,404 @@ +# we use a script that launches separate python processes to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +import json +import math +import os +import subprocess +from decimal import Decimal +from pathlib import Path +from typing import Any + +import requests +import torch +import typer +from rich import print +from rich.align import Align +from rich.panel import Panel +from typing_extensions import Annotated + +from sae_lens.toolkit.pretrained_saes import load_sparsity +from sae_lens.training.sparse_autoencoder import SparseAutoencoder + +OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") + +app = typer.Typer( + add_completion=False, + no_args_is_help=True, + help="Tool that generates features (generate) and uploads features (upload) to Neuronpedia.", +) + + +@app.command() +def generate( + sae_id: Annotated[ + str, + typer.Option( + help="SAE ID to generate features for (must exactly match the one used on Neuronpedia). Example: res-jb", + prompt=""" +What is the SAE ID you want to generate features for? +This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). +It's in the format [abbrev hook name]-[abbrev author name], like res-jb. +Enter SAE ID""", + ), + ], + sae_path: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + help="Absolute local path to the SAE directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors).", + prompt=""" +What is the absolute local path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)? +Enter path""", + ), + ], + log_sparsity: Annotated[ + int, + typer.Option( + min=-10, + max=0, + help="Desired feature log sparsity threshold. Range -10 to 0.", + prompt=""" +What is your desired feature log sparsity threshold? +Enter value from -10 to 0""", + ), + ] = -5, + feat_per_batch: Annotated[ + int, + typer.Option( + min=1, + max=2048, + help="Features to generate per batch. More requires more memory.", + prompt=""" +How many features do you want to generate per batch? More requires more memory. +Enter value""", + ), + ] = 128, + resume_from_batch: Annotated[ + int, + typer.Option( + min=1, + help="Batch number to resume from.", + prompt=""" +Do you want to resume from a specific batch number? +Enter 1 to start from the beginning""", + ), + ] = 1, + n_batches_to_sample: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of batches to sample from.", + prompt=""" +[Activation Text Generation] How many batches do you want to sample from? +Enter value""", + ), + ] = 2 + ** 12, + n_prompts_to_select: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of prompts to select from.", + prompt=""" +[Activation Text Generation] How many prompts do you want to select from? +Enter value""", + ), + ] = 4096 + * 6, +): + """ + This will start a batch job that generates features for Neuronpedia for a specific SAE. To upload those features, use the 'upload' command afterwards. + """ + + # Check arguments + if sae_path.is_dir() is not True: + print("Error: SAE path must be a directory.") + raise typer.Abort() + if sae_path.joinpath("cfg.json").is_file() is not True: + print("Error: cfg.json file not found in SAE directory.") + raise typer.Abort() + if sae_path.joinpath("sae_weights.safetensors").is_file() is not True: + print("Error: sae_weights.safetensors file not found in SAE directory.") + raise typer.Abort() + if sae_path.joinpath("sparsity.safetensors").is_file() is not True: + print("Error: sparsity.safetensors file not found in SAE directory.") + raise typer.Abort() + + sae_path_string = sae_path.as_posix() + + # Load SAE + device = "cpu" + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + sae_path_string, device=device + ) + model_id = sparse_autoencoder.cfg.model_name + + outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" + outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) + if outputs_dir.exists() and outputs_dir.is_file(): + print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.") + raise typer.Abort() + outputs_dir.mkdir(parents=True, exist_ok=True) + # Check if output_dir has any files starting with "batch_" + batch_files = list(outputs_dir.glob("batch-*.json")) + if len(batch_files) > 0 and resume_from_batch == 1: + print( + f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files." + ) + raise typer.Abort() + + sparsity = load_sparsity(sae_path_string) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) + sparsity = sparsity.to(device) + alive_indexes = (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() + num_alive = len(alive_indexes) + num_dead = sparse_autoencoder.d_sae - num_alive + + print("\n") + print( + Align.center( + Panel.fit( + f""" +[white]SAE Path: [green]{sae_path.as_posix()} +[white]Model ID: [green]{model_id} +[white]Hook Point: [green]{sparse_autoencoder.cfg.hook_point} +[white]Using Device: [green]{device} +""", + title="SAE Info", + ) + ) + ) + num_batches = math.ceil(num_alive / feat_per_batch) + print( + Align.center( + Panel.fit( + f""" +[white]Total Features: [green]{sparse_autoencoder.d_sae} +[white]Log Sparsity Threshold: [green]{log_sparsity} +[white]Alive Features: [green]{num_alive} +[white]Dead Features: [red]{num_dead} +[white]Features per Batch: [green]{feat_per_batch} +[white]Number of Batches: [green]{num_batches} +{resume_from_batch != 1 and f"[white]Resuming from Batch: [green]{resume_from_batch}" or ""} +""", + title="Number of Features", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[white]Dataset: [green]{sparse_autoencoder.cfg.dataset_path} +[white]Batches to Sample From: [green]{n_batches_to_sample} +[white]Prompts to Select From: [green]{n_prompts_to_select} +""", + title="Activation Text Settings", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[green]{outputs_dir.absolute().as_posix()} +""", + title="Output Directory", + ) + ) + ) + + print( + Align.center( + "\n========== [yellow]Starting batch feature generations...[/yellow] ==========" + ) + ) + + # iterate from 1 to num_batches + for i in range(resume_from_batch, num_batches + 1): + command = [ + "python", + "make_batch.py", + sae_id, + sae_path.absolute().as_posix(), + outputs_dir.absolute().as_posix(), + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + print("\n") + print( + Align.center( + Panel.fit( + f""" +[yellow]{" ".join(command)} +""", + title="Running Command for Batch #" + str(i), + ) + ) + ) + # make a subprocess call to python make_batch.py + subprocess.run( + [ + "python", + "make_batch.py", + sae_id, + sae_path, + outputs_dir, + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + ) + + print( + Align.center( + Panel( + f""" +Your Features Are In: [green]{outputs_dir.absolute().as_posix()} +Use [yellow]'neuronpedia.py upload'[/yellow] to upload your features to Neuronpedia. +""", + title="Generation Complete", + ) + ) + ) + + +@app.command() +def upload( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will upload features that were generated to Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + files_to_upload = list(outputs_dir.glob("batch-*.json")) + + # sort files by batch number + files_to_upload.sort(key=lambda x: int(x.stem.split("-")[1])) + + print("\n") + # Upload alive features + for file_path in files_to_upload: + print("===== Uploading file: " + os.path.basename(file_path)) + f = open(file_path, "r") + data = json.load(f) + + # Replace NaNs + data_fixed = json.dumps(data, cls=NanConverter) + data = json.loads(data_fixed) + + url = host + "/api/local/upload-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + f""" +{len(files_to_upload)} batch files uploaded to Neuronpedia. +""", + title="Uploads Complete", + ) + ) + ) + + +@app.command() +def upload_dead_stubs( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will create "There are no activations for this feature" stubs for dead features on Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + skipped_path = os.path.join(outputs_dir, "skipped_indexes.json") + f = open(skipped_path, "r") + data = json.load(f) + url = host + "/api/local/upload-skipped-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + """ +Dead feature stubs created. +""", + title="Complete", + ) + ) + ) + + +# Helper utilities that help fix weird NaNs in the feature outputs + + +def nanToNeg999(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: nanToNeg999(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [nanToNeg999(v) for v in obj] + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): + return -999 + return obj + + +class NanConverter(json.JSONEncoder): + def encode(self, o: Any, *args: Any, **kwargs: Any): + return super().encode(nanToNeg999(o), *args, **kwargs) + + +if __name__ == "__main__": + app() diff --git a/tutorials/neuronpedia/np_runner.sh b/tutorials/neuronpedia/np_runner.sh deleted file mode 100755 index c005d48c..00000000 --- a/tutorials/neuronpedia/np_runner.sh +++ /dev/null @@ -1,15 +0,0 @@ -# script for working around memory issues - this ensures every batch gets the whole system available memory -# better fix is to investigate and fix the memory issues - -#!/bin/bash -LAYER=$1 -TYPE=$2 -SOURCE_AUTHOR_SUFFIX=$3 -FEATURES_AT_A_TIME=$4 -START_BATCH_INCLUSIVE=$5 -END_BATCH_INCLUSIVE=$6 -for j in $(seq $5 $6) - do - echo "Iteration: $j" - python np_runner_batch.py $1 $2 $3 $4 $j $j -done \ No newline at end of file diff --git a/tutorials/neuronpedia/np_runner_batch.py b/tutorials/neuronpedia/np_runner_batch.py deleted file mode 100644 index 23a72ec3..00000000 --- a/tutorials/neuronpedia/np_runner_batch.py +++ /dev/null @@ -1,35 +0,0 @@ -import sys - -LAYER = int(sys.argv[1]) # 0 -TYPE = sys.argv[2] # "resid" -SOURCE_AUTHOR_SUFFIX = sys.argv[3] # "sm" -FEATURES_AT_A_TIME = int( - sys.argv[4] -) # this must stay the same or your batching will be off -START_BATCH_INCLUSIVE = int(sys.argv[5]) -END_BATCH_INCLUSIVE = int(sys.argv[6]) if len(sys.argv) > 6 else None - -# Change these depending on how your files are named -SAE_PATH = f"../../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt" -FEATURE_SPARSITY_PATH = ( - f"../../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt" -) - -from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner - -NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" - -runner = NeuronpediaRunner( - sae_path=SAE_PATH, - feature_sparsity_path=FEATURE_SPARSITY_PATH, - neuronpedia_parent_folder=NP_OUTPUT_FOLDER, - init_session=True, - n_batches_to_sample_from=2**12, - n_prompts_to_select=4096 * 6, - n_features_at_a_time=FEATURES_AT_A_TIME, - buffer_tokens_left=64, - buffer_tokens_right=63, - start_batch_inclusive=START_BATCH_INCLUSIVE, - end_batch_inclusive=END_BATCH_INCLUSIVE, -) -runner.run() diff --git a/tutorials/training_a_sparse_autoencoder.ipynb b/tutorials/training_a_sparse_autoencoder.ipynb index 73f4ccdd..17eed89c 100644 --- a/tutorials/training_a_sparse_autoencoder.ipynb +++ b/tutorials/training_a_sparse_autoencoder.ipynb @@ -335,7 +335,7 @@ " context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", " # Activation Store Parameters\n", " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " total_training_tokens=1_000_000\n", + " training_tokens=1_000_000\n", " * 50, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", " store_batch_size=16,\n", " # Resampling protocol\n",