From d532b828bd77c18b73f495d6b42ca53b5148fd2f Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Sun, 7 Apr 2024 16:57:17 -0700 Subject: [PATCH 01/30] FIX: Add back correlated neurons, frac_nonzero --- sae_lens/analysis/neuronpedia_runner.py | 34 ++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index ecfc4cd1..2e9e85f7 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -307,23 +307,21 @@ def run(self): feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - # TODO: this value doesn't exist in the new output type, commenting out for now - # there is a cossim value though - is that what's needed? - # feature_output["correlated_neurons_l1"] = self.round_list( - # feature.feature_tables_data.correlated_neurons_l1 - # ) + feature_output["correlated_neurons_l1"] = self.round_list( + feature.feature_tables_data.correlated_neurons_cossim + ) feature_output["correlated_neurons_pearson"] = self.round_list( feature.feature_tables_data.correlated_neurons_pearson ) - # feature_output["correlated_features_indices"] = ( - # feature.feature_tables_data.correlated_features_indices - # ) - # feature_output["correlated_features_l1"] = self.round_list( - # feature.feature_tables_data.correlated_features_l1 - # ) - # feature_output["correlated_features_pearson"] = self.round_list( - # feature.feature_tables_data.correlated_features_pearson - # ) + feature_output["correlated_features_indices"] = ( + feature.feature_tables_data.correlated_features_indices + ) + feature_output["correlated_features_l1"] = self.round_list( + feature.feature_tables_data.correlated_features_cossim + ) + feature_output["correlated_features_pearson"] = self.round_list( + feature.feature_tables_data.correlated_features_pearson + ) feature_output["neg_str"] = self.to_str_tokens_safe( vocab_dict, feature.logits_table_data.bottom_token_ids @@ -335,9 +333,11 @@ def run(self): feature_output["pos_values"] = top10_logits # TODO: don't know what this should be in the new version - # feature_output["frac_nonzero"] = ( - # feature.middle_plots_data.frac_nonzero - # ) + feature_output["frac_nonzero"] = ( + feature.acts_histogram_data.title.split(" = ")[1] + if feature.acts_histogram_data.title is not None + else 0 + ) freq_hist_data = feature.acts_histogram_data freq_bar_values = self.round_list(freq_hist_data.bar_values) From 271dbf05567b6e6ae4cfc1dab138132872038381 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 9 Apr 2024 23:23:55 -0700 Subject: [PATCH 02/30] Don't precompute background colors and tick values --- sae_lens/analysis/neuronpedia_runner.py | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 2e9e85f7..2387bf2b 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -280,20 +280,6 @@ def run(self): feature.logits_table_data.bottom_logits ) - # TODO: don't precompute/store these. should do it on the frontend - max_value = max( - np.absolute(bottom10_logits).max(), - np.absolute(top10_logits).max(), - ) - neg_bg_values = self.round_list( - np.absolute(bottom10_logits) / max_value - ) - pos_bg_values = self.round_list( - np.absolute(top10_logits) / max_value - ) - feature_output["neg_bg_values"] = neg_bg_values - feature_output["pos_bg_values"] = pos_bg_values - if feature.feature_tables_data: feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices @@ -332,7 +318,6 @@ def run(self): ) feature_output["pos_values"] = top10_logits - # TODO: don't know what this should be in the new version feature_output["frac_nonzero"] = ( feature.acts_histogram_data.title.split(" = ")[1] if feature.acts_histogram_data.title is not None @@ -342,22 +327,9 @@ def run(self): freq_hist_data = feature.acts_histogram_data freq_bar_values = self.round_list(freq_hist_data.bar_values) feature_output["freq_hist_data_bar_values"] = freq_bar_values - feature_output["freq_hist_data_tick_vals"] = self.round_list( - freq_hist_data.tick_vals - ) - - # TODO: don't precompute/store these. should do it on the frontend - freq_bar_values_clipped = [ - (0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values) - for v in freq_bar_values - ] - freq_bar_colors = [ - colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped - ] feature_output["freq_hist_data_bar_heights"] = self.round_list( freq_hist_data.bar_heights ) - feature_output["freq_bar_colors"] = freq_bar_colors logits_hist_data = feature.logits_histogram_data feature_output["logits_hist_data_bar_heights"] = self.round_list( @@ -366,11 +338,7 @@ def run(self): feature_output["logits_hist_data_bar_values"] = self.round_list( logits_hist_data.bar_values ) - feature_output["logits_hist_data_tick_vals"] = self.round_list( - logits_hist_data.tick_vals - ) - # TODO: check this feature_output["num_tokens_for_dashboard"] = ( self.n_prompts_to_select ) From ebbb622353bef21c953f844a108ea8d9fe31e9f9 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Fri, 12 Apr 2024 23:39:51 -0700 Subject: [PATCH 03/30] Use legacy loader, add back histograms, logits. Fix anomaly characters. --- sae_lens/analysis/neuronpedia_runner.py | 67 +++++++++++++------ .../generating_neuronpedia_outputs.ipynb | 56 +++++++++------- tutorials/neuronpedia/np_runner_batch.py | 4 +- 3 files changed, 79 insertions(+), 48 deletions(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index c2893249..e7346ca3 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -1,6 +1,8 @@ import os from typing import Any, Dict, List, Optional, Union, cast +from sae_lens.training.sparse_autoencoder import SparseAutoencoder + # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json @@ -12,11 +14,14 @@ from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, + LogitsHistogramConfig, + LogitsTableConfig, FeatureTablesConfig, SaeVisConfig, SaeVisLayoutConfig, SequencesConfig, ) +from sae_vis.utils_fns import HTML_ANOMALIES from sae_vis.data_fetching_fns import get_feature_data from tqdm import tqdm @@ -45,6 +50,7 @@ class NeuronpediaRunner: def __init__( self, sae_path: str, + use_legacy: bool = False, feature_sparsity_path: Optional[str] = None, neuronpedia_parent_folder: str = "./neuronpedia_outputs", init_session: bool = True, @@ -60,6 +66,7 @@ def __init__( end_batch_inclusive: Optional[int] = None, ): self.sae_path = sae_path + self.use_legacy = use_legacy if init_session: self.init_sae_session() @@ -90,11 +97,22 @@ def get_folder_name(self): return dashboard_folder_name def init_sae_session(self): - ( - self.model, - sae_group, - self.activation_store, - ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) + + if self.use_legacy: + # load the SAE + sparse_autoencoder = SparseAutoencoder.load_from_pretrained_legacy( + self.sae_path + ) + # load the model, SAE and activations loader with it. + session_loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg) + (self.model, sae_group, self.activation_store) = ( + session_loader.load_sae_training_group_session() + ) + else: + (self.model, sae_group, self.activation_store) = ( + LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) + ) + # TODO: handle multiple autoencoders self.sparse_autoencoder = next(iter(sae_group))[1] @@ -206,10 +224,15 @@ def run(self): print(f"Time to get tokens: {end - start}") vocab_dict = cast(Any, self.model.tokenizer).vocab - vocab_dict = { - v: k.replace("Ġ", " ").replace("\n", "\\n").replace("Ċ", "\n") - for k, v in vocab_dict.items() - } + new_vocab_dict = {} + # Replace substrings in the keys of vocab_dict using HTML_ANOMALIES + for k, v in vocab_dict.items(): + modified_key = k + for anomaly in HTML_ANOMALIES: + modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) + new_vocab_dict[modified_key] = v + vocab_dict = new_vocab_dict + # pad with blank tokens to the actual vocab size for i in range(len(vocab_dict), self.model.cfg.d_vocab): vocab_dict[i] = OUT_OF_RANGE_TOKEN @@ -227,6 +250,7 @@ def run(self): continue print(f"Doing batch: {feature_batch_count}") + print(f"{features_to_process}") layout = SaeVisLayoutConfig( columns=[ @@ -237,27 +261,24 @@ def run(self): self.buffer_tokens_left, self.buffer_tokens_right, ), - compute_buffer=False, + compute_buffer=True, n_quantiles=10, top_acts_group_size=20, quantile_group_size=5, ), - width=650, - ), - Column( ActsHistogramConfig(), - FeatureTablesConfig(n_rows=5), - width=500, - ), - ], - height=1000, + LogitsHistogramConfig(), + LogitsTableConfig(), + FeatureTablesConfig(n_rows=3), + ) + ] ) feature_vis_params = SaeVisConfig( hook_point=self.sparse_autoencoder.cfg.hook_point, - minibatch_size_features=256, + minibatch_size_features=128, minibatch_size_tokens=64, features=features_to_process, - verbose=False, + verbose=True, feature_centric_layout=layout, ) @@ -319,7 +340,11 @@ def run(self): feature_output["pos_values"] = top10_logits feature_output["frac_nonzero"] = ( - feature.acts_histogram_data.title.split(" = ")[1] + float( + feature.acts_histogram_data.title.split(" = ")[1].split( + "%" + )[0] + ) if feature.acts_histogram_data.title is not None else 0 ) diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index 79474377..adc66c3e 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -22,27 +22,30 @@ "metadata": {}, "outputs": [], "source": [ - "# from huggingface_hub import hf_hub_download\n", + "from huggingface_hub import hf_hub_download\n", "\n", - "# MODEL = \"gpt2-small\"\n", - "# LAYER = 0\n", - "# SOURCE = \"res-jb\"\n", - "# REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "# FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{LAYER}.hook_resid_pre_24576.pt\"\n", - "# SAE_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", - "\n", - "# Change these\n", - "MODEL = \"pythia-70m-deduped\"\n", + "MODEL = \"gpt2-small\"\n", "LAYER = 0\n", - "TYPE = \"resid\"\n", - "SOURCE_AUTHOR_SUFFIX = \"sm\"\n", - "SOURCE = \"res-sm\"\n", - "\n", - "# Change these depending on how your files are named\n", - "SAE_PATH = f\"../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt\"\n", - "FEATURE_SPARSITY_PATH = (\n", - " f\"../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt\"\n", - ")" + "SOURCE = \"res-jb\"\n", + "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", + "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{LAYER}.hook_resid_pre_24576.pt\"\n", + "USE_LEGACY = True\n", + "\n", + "SAE_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", + "FEATURE_SPARSITY_PATH = None\n", + "\n", + "# # Change these\n", + "# MODEL = \"pythia-70m-deduped\"\n", + "# LAYER = 0\n", + "# TYPE = \"resid\"\n", + "# SOURCE_AUTHOR_SUFFIX = \"sm\"\n", + "# SOURCE = \"res-sm\"\n", + "\n", + "# # Change these depending on how your files are named\n", + "# SAE_PATH = f\"../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt\"\n", + "# FEATURE_SPARSITY_PATH = (\n", + "# f\"../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt\"\n", + "# )" ] }, { @@ -60,20 +63,21 @@ "source": [ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", - "NP_OUTPUT_FOLDER = \"../neuronpedia_outputs\"\n", + "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs\"\n", "\n", "runner = NeuronpediaRunner(\n", " sae_path=SAE_PATH,\n", + " use_legacy=USE_LEGACY,\n", " feature_sparsity_path=FEATURE_SPARSITY_PATH,\n", " neuronpedia_parent_folder=NP_OUTPUT_FOLDER,\n", " init_session=True,\n", " n_batches_to_sample_from=2**12,\n", " n_prompts_to_select=4096 * 6,\n", - " n_features_at_a_time=512,\n", + " n_features_at_a_time=128,\n", " buffer_tokens_left=64,\n", - " buffer_tokens_right=63,\n", - " start_batch_inclusive=22,\n", - " end_batch_inclusive=23,\n", + " buffer_tokens_right=62,\n", + " start_batch_inclusive=0,\n", + " end_batch_inclusive=1,\n", ")\n", "runner.run()" ] @@ -100,7 +104,7 @@ "import os\n", "import requests\n", "\n", - "folder_path = runner.neuronpedia_folder\n", + "folder_path = \"../../neuronpedia_outputs/gpt2-small_blocks.0.hook_resid_pre_24576\" #runner.neuronpedia_folder\n", "\n", "\n", "def nanToNeg999(obj: Any) -> Any:\n", @@ -134,7 +138,7 @@ " data_fixed = json.dumps(data, cls=NanConverter)\n", " data = json.loads(data_fixed)\n", "\n", - " url = host + \"/api/internal/upload-features\"\n", + " url = host + \"/api/local/upload-features\"\n", " resp = requests.post(\n", " url,\n", " json={\n", diff --git a/tutorials/neuronpedia/np_runner_batch.py b/tutorials/neuronpedia/np_runner_batch.py index 23a72ec3..99c72588 100644 --- a/tutorials/neuronpedia/np_runner_batch.py +++ b/tutorials/neuronpedia/np_runner_batch.py @@ -8,6 +8,7 @@ ) # this must stay the same or your batching will be off START_BATCH_INCLUSIVE = int(sys.argv[5]) END_BATCH_INCLUSIVE = int(sys.argv[6]) if len(sys.argv) > 6 else None +USE_LEGACY = True # Change these depending on how your files are named SAE_PATH = f"../../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt" @@ -21,6 +22,7 @@ runner = NeuronpediaRunner( sae_path=SAE_PATH, + use_legacy=USE_LEGACY, feature_sparsity_path=FEATURE_SPARSITY_PATH, neuronpedia_parent_folder=NP_OUTPUT_FOLDER, init_session=True, @@ -28,7 +30,7 @@ n_prompts_to_select=4096 * 6, n_features_at_a_time=FEATURES_AT_A_TIME, buffer_tokens_left=64, - buffer_tokens_right=63, + buffer_tokens_right=62, start_batch_inclusive=START_BATCH_INCLUSIVE, end_batch_inclusive=END_BATCH_INCLUSIVE, ) From 383788485917cee114fba24e8ded944aefcfb568 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 00:43:34 -0700 Subject: [PATCH 04/30] Runner is fixed, faster, cleaned up, and now gives whole sequences instead of buffer. --- .../generating_neuronpedia_outputs.ipynb | 108 ++++++++++++------ tutorials/neuronpedia/make_batch.py | 31 +++++ tutorials/neuronpedia/make_features.sh | 90 +++++++++++++++ tutorials/neuronpedia/np_runner.sh | 15 --- tutorials/neuronpedia/np_runner_batch.py | 37 ------ tutorials/neuronpedia/upload_batch.py | 63 ++++++++++ tutorials/neuronpedia/upload_features.sh | 22 ++++ 7 files changed, 281 insertions(+), 85 deletions(-) create mode 100644 tutorials/neuronpedia/make_batch.py create mode 100755 tutorials/neuronpedia/make_features.sh delete mode 100755 tutorials/neuronpedia/np_runner.sh delete mode 100644 tutorials/neuronpedia/np_runner_batch.py create mode 100644 tutorials/neuronpedia/upload_batch.py create mode 100755 tutorials/neuronpedia/upload_features.sh diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index adc66c3e..b82ace04 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -18,34 +18,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ - "from huggingface_hub import hf_hub_download\n", - "\n", - "MODEL = \"gpt2-small\"\n", - "LAYER = 0\n", - "SOURCE = \"res-jb\"\n", - "REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{LAYER}.hook_resid_pre_24576.pt\"\n", - "USE_LEGACY = True\n", - "\n", - "SAE_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", - "FEATURE_SPARSITY_PATH = None\n", - "\n", - "# # Change these\n", - "# MODEL = \"pythia-70m-deduped\"\n", - "# LAYER = 0\n", - "# TYPE = \"resid\"\n", - "# SOURCE_AUTHOR_SUFFIX = \"sm\"\n", - "# SOURCE = \"res-sm\"\n", - "\n", - "# # Change these depending on how your files are named\n", - "# SAE_PATH = f\"../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt\"\n", - "# FEATURE_SPARSITY_PATH = (\n", - "# f\"../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt\"\n", - "# )" + "from sae_lens.toolkit.pretrained_saes import download_sae_from_hf\n", + "import os\n", + "\n", + "MODEL_ID = \"gpt2-small\"\n", + "SAE_ID = \"res-jb\"\n", + "\n", + "(_, SAE_WEIGHTS_PATH, _) = download_sae_from_hf(\n", + " \"jbloom/GPT2-Small-SAEs-Reformatted\", \"blocks.0.hook_resid_pre\"\n", + ")\n", + "\n", + "SAE_PATH = os.path.dirname(SAE_WEIGHTS_PATH)" ] }, { @@ -57,26 +44,81 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/Users/johnnylin/.cache/huggingface/hub/models--jbloom--GPT2-Small-SAEs-Reformatted/snapshots/5bd69d8ccac6b19d91934c5aeed4866f8b6e50c7/blocks.0.hook_resid_pre\n", + "Loaded pretrained model gpt2-small into HookedTransformer\n", + "Moving model to device: mps\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/johnnylin/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/datasets/load.py:1461: FutureWarning: The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext\n", + "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", + "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==== Starting at batch: 1\n", + "==== Ending at batch: 1\n", + "Total features to run: 19321\n", + "Total skipped: 5255\n", + "Total batches: 806\n", + "Hook Point Layer: 0\n", + "Hook Point: blocks.0.hook_resid_pre\n", + "Writing files to: ../../neuronpedia_outputs/gpt2-small_res-jb_blocks.0.hook_resid_pre\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 84%|████████▍ | 3435/4096 [02:41<00:31, 21.29it/s]\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[2], line 19\u001b[0m\n\u001b[1;32m 4\u001b[0m NP_OUTPUT_FOLDER \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../../neuronpedia_outputs\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 5\u001b[0m runner \u001b[38;5;241m=\u001b[39m NeuronpediaRunner(\n\u001b[1;32m 6\u001b[0m sae_path\u001b[38;5;241m=\u001b[39mSAE_PATH,\n\u001b[1;32m 7\u001b[0m model_id\u001b[38;5;241m=\u001b[39mMODEL_ID,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m end_batch_inclusive\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0m \u001b[43mrunner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:219\u001b[0m, in \u001b[0;36mNeuronpediaRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# get tokens:\u001b[39;00m\n\u001b[1;32m 218\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 219\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tokens\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_batches_to_sample_from\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_prompts_to_select\u001b[49m\n\u001b[1;32m 221\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTime to get tokens: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mend\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39mstart\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:123\u001b[0m, in \u001b[0;36mNeuronpediaRunner.get_tokens\u001b[0;34m(self, n_batches_to_sample_from, n_prompts_to_select)\u001b[0m\n\u001b[1;32m 121\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28mrange\u001b[39m(n_batches_to_sample_from))\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m pbar:\n\u001b[0;32m--> 123\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m batch_tokens[torch\u001b[38;5;241m.\u001b[39mrandperm(batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])][\n\u001b[1;32m 125\u001b[0m : batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 126\u001b[0m ]\n\u001b[1;32m 127\u001b[0m all_tokens_list\u001b[38;5;241m.\u001b[39mappend(batch_tokens)\n", + "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/training/activations_store.py:227\u001b[0m, in \u001b[0;36mActivationsStore.get_batch_tokens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m current_length \u001b[38;5;241m==\u001b[39m context_size:\n\u001b[1;32m 226\u001b[0m full_batch \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(current_batch, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 227\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfull_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 230\u001b[0m current_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 231\u001b[0m current_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", + "print(SAE_PATH)\n", "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs\"\n", - "\n", "runner = NeuronpediaRunner(\n", " sae_path=SAE_PATH,\n", - " use_legacy=USE_LEGACY,\n", - " feature_sparsity_path=FEATURE_SPARSITY_PATH,\n", - " neuronpedia_parent_folder=NP_OUTPUT_FOLDER,\n", + " model_id=MODEL_ID,\n", + " sae_id=SAE_ID,\n", + " neuronpedia_outputs_folder=NP_OUTPUT_FOLDER,\n", " init_session=True,\n", " n_batches_to_sample_from=2**12,\n", " n_prompts_to_select=4096 * 6,\n", - " n_features_at_a_time=128,\n", + " n_features_at_a_time=24,\n", " buffer_tokens_left=64,\n", " buffer_tokens_right=62,\n", - " start_batch_inclusive=0,\n", + " start_batch_inclusive=1,\n", " end_batch_inclusive=1,\n", ")\n", "runner.run()" diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py new file mode 100644 index 00000000..d347ebca --- /dev/null +++ b/tutorials/neuronpedia/make_batch.py @@ -0,0 +1,31 @@ +import sys +from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner + +SAE_PATH = sys.argv[1] +MODEL_ID = sys.argv[2] +SAE_ID = sys.argv[3] +LEFT_BUFFER = int(sys.argv[4]) +RIGHT_BUFFER = int(sys.argv[5]) +N_BATCHES_SAMPLE = int(sys.argv[6]) +N_PROMPTS_SELECT = int(sys.argv[7]) +FEATURES_AT_A_TIME = int(sys.argv[8]) +START_BATCH_INCLUSIVE = int(sys.argv[9]) +END_BATCH_INCLUSIVE = int(sys.argv[10]) + +NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" + +runner = NeuronpediaRunner( + sae_path=SAE_PATH, + model_id=MODEL_ID, + sae_id=SAE_ID, + neuronpedia_outputs_folder=NP_OUTPUT_FOLDER, + init_session=True, + n_batches_to_sample_from=N_BATCHES_SAMPLE, + n_prompts_to_select=N_PROMPTS_SELECT, + n_features_at_a_time=FEATURES_AT_A_TIME, + buffer_tokens_left=LEFT_BUFFER, + buffer_tokens_right=RIGHT_BUFFER, + start_batch_inclusive=START_BATCH_INCLUSIVE, + end_batch_inclusive=END_BATCH_INCLUSIVE, +) +runner.run() diff --git a/tutorials/neuronpedia/make_features.sh b/tutorials/neuronpedia/make_features.sh new file mode 100755 index 00000000..a1e23f58 --- /dev/null +++ b/tutorials/neuronpedia/make_features.sh @@ -0,0 +1,90 @@ +#!/bin/bash + +# we use a script around python to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +echo "===== This will start a batch job that generates features to upload to Neuronpedia." +echo "===== This takes input of one SAE directory at a time." +echo "===== Features will be output into ./neuronpedia_outputs/{model}_{hook_point}_{d_sae}/batch-{batch_num}.json" + +echo "" +echo "(Step 1 of 10)" +echo "What is the absolute, full local file path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)?" +read saepath +# TODO: support huggingface directories + +echo "" +echo "(Step 2 of 10)" +echo "What's the model ID? This must exactly match (including casing) the model ID you created on Neuronpedia." +read modelid + +echo "" +echo "(Step 3 of 10)" +echo "What's the SAE ID?" +echo "This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). It's in the format [abbrev hook name]-[abbrev author name], like res-jb." +read saeid + +echo "" +echo "(Step 4 of 10)" +echo "How many features are in this SAE?" +read numfeatures + +echo "" +echo "(Step 5 of 10)" +read -p "How many features do you want generate per batch file? More requires more RAM. (default: 128): " perbatch +[ -z "${perbatch}" ] && perbatch='128' + +echo "" +echo "(Step 6 of 10)" +echo "For each activating text sequence, how many tokens to the LEFT of the top activating token do you want?" +echo "If your text sequences are 128 tokens long, then you might put 64. (default: 64)" +read leftbuffer +[ -z "${leftbuffer}" ] && leftbuffer='64' + +echo "" +echo "(Step 7 of 10)" +echo "For each activating text sequence, how many tokens to the RIGHT of the top activating token do you want?" +echo "Left Buffer + Right Buffer must be < Total Text Length - 1" +echo "For example, text sequences of 128 can have at most buffers of 64 + 62 = 126" +echo "If your text sequences are 128 tokens long, then you might put 62. (default: 62)" +read rightbuffer +[ -z "${rightbuffer}" ] && rightbuffer='62' + +echo "" +echo "(Step 8 of 10)" +read -p "Enter number of batches to sample from (default: 4096): " batches +[ -z "${batches}" ] && batches='4096' + +echo "" +echo "(Step 9 of 10)" +read -p "Enter number of prompts to select from (default: 24576): " prompts +[ -z "${prompts}" ] && prompts='24576' + +echo "" +numbatches=$(expr $numfeatures / $perbatch) +echo "===== INFO: We'll generate $numbatches batches of $perbatch features per batch = $numfeatures total features" + +echo "" +echo "(Step 10 of 10)" +read -p "Do you want to resume from a specific batch number? Enter 1 to start from the beginning (default: 1): " startbatch +[ -z "${startbatch}" ] && startbatch='1' + +endbatch=$(expr $numbatches) + + +echo "" +echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/batch-{batch_num}.json" +read -p "===== Hit ENTER to start!" start + +for j in $(seq $startbatch $endbatch) + do + echo "" + echo "===== BATCH: $j" + echo "RUNNING: python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j" + python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j +done + +echo "" +echo "===== ALL DONE." +echo "===== Your features are under: [repo_dir]/neuronpedia_outputs/{model}_{hook_point}_{d_sae}" +echo "===== Use upload_features.sh to upload your features. Be sure to have the localhost server running first." \ No newline at end of file diff --git a/tutorials/neuronpedia/np_runner.sh b/tutorials/neuronpedia/np_runner.sh deleted file mode 100755 index c005d48c..00000000 --- a/tutorials/neuronpedia/np_runner.sh +++ /dev/null @@ -1,15 +0,0 @@ -# script for working around memory issues - this ensures every batch gets the whole system available memory -# better fix is to investigate and fix the memory issues - -#!/bin/bash -LAYER=$1 -TYPE=$2 -SOURCE_AUTHOR_SUFFIX=$3 -FEATURES_AT_A_TIME=$4 -START_BATCH_INCLUSIVE=$5 -END_BATCH_INCLUSIVE=$6 -for j in $(seq $5 $6) - do - echo "Iteration: $j" - python np_runner_batch.py $1 $2 $3 $4 $j $j -done \ No newline at end of file diff --git a/tutorials/neuronpedia/np_runner_batch.py b/tutorials/neuronpedia/np_runner_batch.py deleted file mode 100644 index 99c72588..00000000 --- a/tutorials/neuronpedia/np_runner_batch.py +++ /dev/null @@ -1,37 +0,0 @@ -import sys - -LAYER = int(sys.argv[1]) # 0 -TYPE = sys.argv[2] # "resid" -SOURCE_AUTHOR_SUFFIX = sys.argv[3] # "sm" -FEATURES_AT_A_TIME = int( - sys.argv[4] -) # this must stay the same or your batching will be off -START_BATCH_INCLUSIVE = int(sys.argv[5]) -END_BATCH_INCLUSIVE = int(sys.argv[6]) if len(sys.argv) > 6 else None -USE_LEGACY = True - -# Change these depending on how your files are named -SAE_PATH = f"../../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt" -FEATURE_SPARSITY_PATH = ( - f"../../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt" -) - -from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner - -NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" - -runner = NeuronpediaRunner( - sae_path=SAE_PATH, - use_legacy=USE_LEGACY, - feature_sparsity_path=FEATURE_SPARSITY_PATH, - neuronpedia_parent_folder=NP_OUTPUT_FOLDER, - init_session=True, - n_batches_to_sample_from=2**12, - n_prompts_to_select=4096 * 6, - n_features_at_a_time=FEATURES_AT_A_TIME, - buffer_tokens_left=64, - buffer_tokens_right=62, - start_batch_inclusive=START_BATCH_INCLUSIVE, - end_batch_inclusive=END_BATCH_INCLUSIVE, -) -runner.run() diff --git a/tutorials/neuronpedia/upload_batch.py b/tutorials/neuronpedia/upload_batch.py new file mode 100644 index 00000000..f599799e --- /dev/null +++ b/tutorials/neuronpedia/upload_batch.py @@ -0,0 +1,63 @@ +# Helpers that fix weird NaN stuff +from decimal import Decimal +from typing import Any +import math +import json +import os +import requests +import sys + +FEATURE_OUTPUTS_FOLDER = sys.argv[1] + + +def nanToNeg999(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: nanToNeg999(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [nanToNeg999(v) for v in obj] + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): + return -999 + return obj + + +class NanConverter(json.JSONEncoder): + def encode(self, o: Any, *args: Any, **kwargs: Any): + return super().encode(nanToNeg999(o), *args, **kwargs) + + +# Server info +host = "http://localhost:3000" + +# Upload alive features +for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER): + if file_name.startswith("batch-") and file_name.endswith(".json"): + print("Uploading file: " + file_name) + file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name) + f = open(file_path, "r") + data = json.load(f) + + # Replace NaNs + data_fixed = json.dumps(data, cls=NanConverter) + data = json.loads(data_fixed) + + url = host + "/api/local/upload-features" + resp = requests.post( + url, + json=data, + ) + +# Upload dead features (just makes blanks features) +# We want this for completeness +# skipped_path = os.path.join(folder_path, "skipped_indexes.json") +# f = open(skipped_path, "r") +# data = json.load(f) +# skipped_indexes = data["skipped_indexes"] +# url = host + "/api/internal/upload-dead-features" +# resp = requests.post( +# url, +# json={ +# "modelId": MODEL, +# "layer": sourceName, +# "deadIndexes": skipped_indexes, +# }, +# ) diff --git a/tutorials/neuronpedia/upload_features.sh b/tutorials/neuronpedia/upload_features.sh new file mode 100755 index 00000000..92cc08a7 --- /dev/null +++ b/tutorials/neuronpedia/upload_features.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# we use a script around python to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +echo "===== This will start upload the feature batch files to Neuronpedia." +echo "===== You'll need Neuronpedia running at localhost:3000 for this to work." + +echo "" +echo "(Step 1 of 1)" +echo "What is the absolute, full local DIRECTORY PATH to your Neuronpedia batch outputs?" +read outputfilesdir + +echo "" +read -p "===== Hit ENTER to start uploading!" start + +echo "RUNNING: python upload_batch.py $outputfilesdir" +python upload_batch.py $outputfilesdir + +echo "" +echo "===== ALL DONE." +echo "===== Go to http://localhost:3000 to browse your features" \ No newline at end of file From 2c9ca642b334b7a444544a4640c483229dc04c62 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 09:40:24 -0700 Subject: [PATCH 05/30] Don't use buffer, fix anomalies --- tutorials/neuronpedia/make_batch.py | 14 ++++------- tutorials/neuronpedia/make_features.sh | 32 +++++++------------------- 2 files changed, 13 insertions(+), 33 deletions(-) diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py index d347ebca..79dbe5d9 100644 --- a/tutorials/neuronpedia/make_batch.py +++ b/tutorials/neuronpedia/make_batch.py @@ -4,13 +4,11 @@ SAE_PATH = sys.argv[1] MODEL_ID = sys.argv[2] SAE_ID = sys.argv[3] -LEFT_BUFFER = int(sys.argv[4]) -RIGHT_BUFFER = int(sys.argv[5]) -N_BATCHES_SAMPLE = int(sys.argv[6]) -N_PROMPTS_SELECT = int(sys.argv[7]) -FEATURES_AT_A_TIME = int(sys.argv[8]) -START_BATCH_INCLUSIVE = int(sys.argv[9]) -END_BATCH_INCLUSIVE = int(sys.argv[10]) +N_BATCHES_SAMPLE = int(sys.argv[4]) +N_PROMPTS_SELECT = int(sys.argv[5]) +FEATURES_AT_A_TIME = int(sys.argv[6]) +START_BATCH_INCLUSIVE = int(sys.argv[7]) +END_BATCH_INCLUSIVE = int(sys.argv[8]) NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" @@ -23,8 +21,6 @@ n_batches_to_sample_from=N_BATCHES_SAMPLE, n_prompts_to_select=N_PROMPTS_SELECT, n_features_at_a_time=FEATURES_AT_A_TIME, - buffer_tokens_left=LEFT_BUFFER, - buffer_tokens_right=RIGHT_BUFFER, start_batch_inclusive=START_BATCH_INCLUSIVE, end_batch_inclusive=END_BATCH_INCLUSIVE, ) diff --git a/tutorials/neuronpedia/make_features.sh b/tutorials/neuronpedia/make_features.sh index a1e23f58..149147b7 100755 --- a/tutorials/neuronpedia/make_features.sh +++ b/tutorials/neuronpedia/make_features.sh @@ -8,55 +8,39 @@ echo "===== This takes input of one SAE directory at a time." echo "===== Features will be output into ./neuronpedia_outputs/{model}_{hook_point}_{d_sae}/batch-{batch_num}.json" echo "" -echo "(Step 1 of 10)" +echo "(Step 1 of 8)" echo "What is the absolute, full local file path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)?" read saepath # TODO: support huggingface directories echo "" -echo "(Step 2 of 10)" +echo "(Step 2 of 8)" echo "What's the model ID? This must exactly match (including casing) the model ID you created on Neuronpedia." read modelid echo "" -echo "(Step 3 of 10)" +echo "(Step 3 of 8)" echo "What's the SAE ID?" echo "This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). It's in the format [abbrev hook name]-[abbrev author name], like res-jb." read saeid echo "" -echo "(Step 4 of 10)" +echo "(Step 4 of 8)" echo "How many features are in this SAE?" read numfeatures echo "" -echo "(Step 5 of 10)" +echo "(Step 5 of 8)" read -p "How many features do you want generate per batch file? More requires more RAM. (default: 128): " perbatch [ -z "${perbatch}" ] && perbatch='128' echo "" -echo "(Step 6 of 10)" -echo "For each activating text sequence, how many tokens to the LEFT of the top activating token do you want?" -echo "If your text sequences are 128 tokens long, then you might put 64. (default: 64)" -read leftbuffer -[ -z "${leftbuffer}" ] && leftbuffer='64' - -echo "" -echo "(Step 7 of 10)" -echo "For each activating text sequence, how many tokens to the RIGHT of the top activating token do you want?" -echo "Left Buffer + Right Buffer must be < Total Text Length - 1" -echo "For example, text sequences of 128 can have at most buffers of 64 + 62 = 126" -echo "If your text sequences are 128 tokens long, then you might put 62. (default: 62)" -read rightbuffer -[ -z "${rightbuffer}" ] && rightbuffer='62' - -echo "" -echo "(Step 8 of 10)" +echo "(Step 6 of 8)" read -p "Enter number of batches to sample from (default: 4096): " batches [ -z "${batches}" ] && batches='4096' echo "" -echo "(Step 9 of 10)" +echo "(Step 7 of 8)" read -p "Enter number of prompts to select from (default: 24576): " prompts [ -z "${prompts}" ] && prompts='24576' @@ -65,7 +49,7 @@ numbatches=$(expr $numfeatures / $perbatch) echo "===== INFO: We'll generate $numbatches batches of $perbatch features per batch = $numfeatures total features" echo "" -echo "(Step 10 of 10)" +echo "(Step 8 of 8)" read -p "Do you want to resume from a specific batch number? Enter 1 to start from the beginning (default: 1): " startbatch [ -z "${startbatch}" ] && startbatch='1' From e87788d63a9b767e34e497c85a318337ab8aabb8 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 09:41:56 -0700 Subject: [PATCH 06/30] Final fixes --- pyproject.toml | 2 +- sae_lens/analysis/neuronpedia_runner.py | 166 +++++++++++++----------- 2 files changed, 89 insertions(+), 79 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 517e53ec..66e53ff8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" -sae-vis = "0.2.6" +sae-vis = { git = "https://github.com/hijohnnylin/sae_vis.git", branch = "allow_disable_buffer" } mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" mkdocs-autorefs = "^1.0.1" diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index e7346ca3..5b14b4de 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -1,13 +1,10 @@ import os from typing import Any, Dict, List, Optional, Union, cast -from sae_lens.training.sparse_autoencoder import SparseAutoencoder - # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json import time - import numpy as np import torch from matplotlib import colors @@ -21,11 +18,11 @@ SaeVisLayoutConfig, SequencesConfig, ) -from sae_vis.utils_fns import HTML_ANOMALIES -from sae_vis.data_fetching_fns import get_feature_data from tqdm import tqdm - +from sae_vis.data_storing_fns import SaeVisData from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader +from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from sae_lens.toolkit.pretrained_saes import load_sparsity OUT_OF_RANGE_TOKEN = "<|outofrange|>" @@ -33,6 +30,21 @@ "bg_color_map", ["white", "darkorange"] ) +SPARSITY_THRESHOLD = -5 + +HTML_ANOMALIES = { + "âĢĶ": "—", + "âĢĵ": "–", + "âĢľ": "“", + "âĢĿ": "”", + "âĢĺ": "‘", + "âĢĻ": "’", + "âĢĭ": " ", # todo: this is actually zero width space + "Ġ": " ", + "Ċ": "\n", + "ĉ": "\t", +} + class NpEncoder(json.JSONEncoder): def default(self, o: Any): @@ -50,71 +62,53 @@ class NeuronpediaRunner: def __init__( self, sae_path: str, - use_legacy: bool = False, - feature_sparsity_path: Optional[str] = None, - neuronpedia_parent_folder: str = "./neuronpedia_outputs", + model_id: str, + sae_id: str, + neuronpedia_outputs_folder: str = "../../neuronpedia_outputs", init_session: bool = True, # token pars n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6, # sampling pars n_features_at_a_time: int = 1024, - buffer_tokens_left: int = 8, - buffer_tokens_right: int = 8, # start and end batch start_batch_inclusive: int = 0, end_batch_inclusive: Optional[int] = None, ): + + self.device = "cpu" + if torch.backends.mps.is_available(): + self.device = "mps" + elif torch.cuda.is_available(): + self.device = "cuda" + self.sae_path = sae_path - self.use_legacy = use_legacy if init_session: self.init_sae_session() - self.feature_sparsity_path = feature_sparsity_path + self.model_id = model_id + self.layer = self.sparse_autoencoder.cfg.hook_point_layer + self.sae_id = sae_id self.n_features_at_a_time = n_features_at_a_time - self.buffer_tokens_left = buffer_tokens_left - self.buffer_tokens_right = buffer_tokens_right self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select self.start_batch = start_batch_inclusive self.end_batch = end_batch_inclusive - # Deal with file structure - if not os.path.exists(neuronpedia_parent_folder): - os.makedirs(neuronpedia_parent_folder) - self.neuronpedia_folder = ( - f"{neuronpedia_parent_folder}/{self.get_folder_name()}" - ) - if not os.path.exists(self.neuronpedia_folder): - os.makedirs(self.neuronpedia_folder) - - def get_folder_name(self): - model = self.sparse_autoencoder.cfg.model_name - hook_point = self.sparse_autoencoder.cfg.hook_point - d_sae = self.sparse_autoencoder.cfg.d_sae - dashboard_folder_name = f"{model}_{hook_point}_{d_sae}" + if not os.path.exists(neuronpedia_outputs_folder): + os.makedirs(neuronpedia_outputs_folder) + self.neuronpedia_outputs_folder = neuronpedia_outputs_folder - return dashboard_folder_name + self.outputs_folder = f"{neuronpedia_outputs_folder}/{self.sparse_autoencoder.cfg.model_name}_{self.sae_id}_{self.sparse_autoencoder.cfg.hook_point}" + if not os.path.exists(self.outputs_folder): + os.makedirs(self.outputs_folder) def init_sae_session(self): - - if self.use_legacy: - # load the SAE - sparse_autoencoder = SparseAutoencoder.load_from_pretrained_legacy( - self.sae_path - ) - # load the model, SAE and activations loader with it. - session_loader = LMSparseAutoencoderSessionloader(sparse_autoencoder.cfg) - (self.model, sae_group, self.activation_store) = ( - session_loader.load_sae_training_group_session() - ) - else: - (self.model, sae_group, self.activation_store) = ( - LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) - ) - - # TODO: handle multiple autoencoders - self.sparse_autoencoder = next(iter(sae_group))[1] + self.sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + self.sae_path, device=self.device + ) + loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) + self.model, _, self.activation_store = loader.load_sae_training_group_session() def get_tokens( self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6 @@ -174,16 +168,11 @@ def run(self): # if we have feature sparsity, then use it to only generate outputs for non-dead features self.target_feature_indexes: list[int] = [] - if self.feature_sparsity_path: - loaded = torch.load( - self.feature_sparsity_path, map_location=self.sparse_autoencoder.device - ) - self.target_feature_indexes = ( - (loaded > -5).nonzero(as_tuple=True)[0].tolist() - ) - else: - self.target_feature_indexes = list(range(self.n_features)) - print("No feat sparsity path specified - doing all indexes.") + sparsity = load_sparsity(self.sae_path) + sparsity = sparsity.to(self.device) + self.target_feature_indexes = ( + (sparsity > SPARSITY_THRESHOLD).nonzero(as_tuple=True)[0].tolist() + ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) @@ -203,8 +192,15 @@ def run(self): # write dead into file so we can create them as dead in Neuronpedia skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) - skipped_indexes_json = json.dumps({"skipped_indexes": list(skipped_indexes)}) - with open(f"{self.neuronpedia_folder}/skipped_indexes.json", "w") as f: + skipped_indexes_json = json.dumps( + { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "skipped_indexes": list(skipped_indexes), + } + ) + with open(f"{self.outputs_folder}/skipped_indexes.json", "w") as f: f.write(skipped_indexes_json) print(f"Total features to run: {len(self.target_feature_indexes)}") @@ -213,15 +209,25 @@ def run(self): print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") - print(f"Writing files to: {self.neuronpedia_folder}") + print(f"Writing files to: {self.outputs_folder}") - # get tokens: - start = time.time() - tokens = self.get_tokens( - self.n_batches_to_sample_from, self.n_prompts_to_select - ) - end = time.time() - print(f"Time to get tokens: {end - start}") + tokens_file = f"{self.outputs_folder}/tokens_{self.n_batches_to_sample_from}_{self.n_prompts_to_select}.pt" + + if os.path.isfile(tokens_file): + print("Tokens exist, loading it") + tokens = torch.load(tokens_file) + else: + start = time.time() + tokens = self.get_tokens( + self.n_batches_to_sample_from, self.n_prompts_to_select + ) + end = time.time() + print(f"Time to get tokens: {end - start}") + print("Saved tokens to: " + tokens_file) + torch.save( + tokens, + tokens_file, + ) vocab_dict = cast(Any, self.model.tokenizer).vocab new_vocab_dict = {} @@ -230,7 +236,7 @@ def run(self): modified_key = k for anomaly in HTML_ANOMALIES: modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) - new_vocab_dict[modified_key] = v + new_vocab_dict[v] = modified_key vocab_dict = new_vocab_dict # pad with blank tokens to the actual vocab size @@ -257,12 +263,9 @@ def run(self): Column( SequencesConfig( stack_mode="stack-all", - buffer=( - self.buffer_tokens_left, - self.buffer_tokens_right, - ), + buffer=None, compute_buffer=True, - n_quantiles=10, + n_quantiles=5, top_acts_group_size=20, quantile_group_size=5, ), @@ -281,8 +284,7 @@ def run(self): verbose=True, feature_centric_layout=layout, ) - - feature_data = get_feature_data( + feature_data = SaeVisData.create( encoder=self.sparse_autoencoder, # type: ignore model=self.model, tokens=tokens, @@ -345,6 +347,7 @@ def run(self): "%" )[0] ) + / 100 if feature.acts_histogram_data.title is not None else 0 ) @@ -426,10 +429,17 @@ def run(self): features_outputs.append(feature_output) - json_object = json.dumps(features_outputs, cls=NpEncoder) + to_write = { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "features": features_outputs, + } + json_object = json.dumps(to_write, cls=NpEncoder) with open( - f"{self.neuronpedia_folder}/batch-{feature_batch_count}.json", "w" + f"{self.outputs_folder}/batch-{feature_batch_count}.json", + "w", ) as f: f.write(json_object) From dde248162b70ff4311d4182333b7cce43aed78df Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 09:59:48 -0700 Subject: [PATCH 07/30] Fix buffer" --- tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb | 2 -- 1 file changed, 2 deletions(-) diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index b82ace04..c04ca0e8 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -116,8 +116,6 @@ " n_batches_to_sample_from=2**12,\n", " n_prompts_to_select=4096 * 6,\n", " n_features_at_a_time=24,\n", - " buffer_tokens_left=64,\n", - " buffer_tokens_right=62,\n", " start_batch_inclusive=1,\n", " end_batch_inclusive=1,\n", ")\n", From 8230570297d68e35cb614a63abf442e4a01174d2 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 14:51:07 -0700 Subject: [PATCH 08/30] Make feature sparsity an argument --- sae_lens/analysis/neuronpedia_runner.py | 6 +- .../generating_neuronpedia_outputs.ipynb | 62 ++----------------- tutorials/neuronpedia/make_batch.py | 12 ++-- tutorials/neuronpedia/make_features.sh | 29 +++++---- 4 files changed, 32 insertions(+), 77 deletions(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 5b14b4de..0cd8ef16 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -30,7 +30,7 @@ "bg_color_map", ["white", "darkorange"] ) -SPARSITY_THRESHOLD = -5 +DEFAULT_SPARSITY_THRESHOLD = -5 HTML_ANOMALIES = { "âĢĶ": "—", @@ -64,6 +64,7 @@ def __init__( sae_path: str, model_id: str, sae_id: str, + sparsity_threshold: int = DEFAULT_SPARSITY_THRESHOLD, neuronpedia_outputs_folder: str = "../../neuronpedia_outputs", init_session: bool = True, # token pars @@ -89,6 +90,7 @@ def __init__( self.model_id = model_id self.layer = self.sparse_autoencoder.cfg.hook_point_layer self.sae_id = sae_id + self.sparsity_threshold = sparsity_threshold self.n_features_at_a_time = n_features_at_a_time self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select @@ -171,7 +173,7 @@ def run(self): sparsity = load_sparsity(self.sae_path) sparsity = sparsity.to(self.device) self.target_feature_indexes = ( - (sparsity > SPARSITY_THRESHOLD).nonzero(as_tuple=True)[0].tolist() + (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() ) # divide into batches diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index c04ca0e8..2346ad9a 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -44,64 +44,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/Users/johnnylin/.cache/huggingface/hub/models--jbloom--GPT2-Small-SAEs-Reformatted/snapshots/5bd69d8ccac6b19d91934c5aeed4866f8b6e50c7/blocks.0.hook_resid_pre\n", - "Loaded pretrained model gpt2-small into HookedTransformer\n", - "Moving model to device: mps\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/johnnylin/Documents/Projects/SAELens/.venv/lib/python3.12/site-packages/datasets/load.py:1461: FutureWarning: The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext\n", - "You can avoid this message in future by passing the argument `trust_remote_code=True`.\n", - "Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== Starting at batch: 1\n", - "==== Ending at batch: 1\n", - "Total features to run: 19321\n", - "Total skipped: 5255\n", - "Total batches: 806\n", - "Hook Point Layer: 0\n", - "Hook Point: blocks.0.hook_resid_pre\n", - "Writing files to: ../../neuronpedia_outputs/gpt2-small_res-jb_blocks.0.hook_resid_pre\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 84%|████████▍ | 3435/4096 [02:41<00:31, 21.29it/s]\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[2], line 19\u001b[0m\n\u001b[1;32m 4\u001b[0m NP_OUTPUT_FOLDER \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m../../neuronpedia_outputs\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 5\u001b[0m runner \u001b[38;5;241m=\u001b[39m NeuronpediaRunner(\n\u001b[1;32m 6\u001b[0m sae_path\u001b[38;5;241m=\u001b[39mSAE_PATH,\n\u001b[1;32m 7\u001b[0m model_id\u001b[38;5;241m=\u001b[39mMODEL_ID,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 17\u001b[0m end_batch_inclusive\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m,\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0m \u001b[43mrunner\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:219\u001b[0m, in \u001b[0;36mNeuronpediaRunner.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 216\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 217\u001b[0m \u001b[38;5;66;03m# get tokens:\u001b[39;00m\n\u001b[1;32m 218\u001b[0m start \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[0;32m--> 219\u001b[0m tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_tokens\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 220\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_batches_to_sample_from\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_prompts_to_select\u001b[49m\n\u001b[1;32m 221\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 222\u001b[0m end \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m 223\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTime to get tokens: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mend\u001b[38;5;250m \u001b[39m\u001b[38;5;241m-\u001b[39m\u001b[38;5;250m \u001b[39mstart\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/analysis/neuronpedia_runner.py:123\u001b[0m, in \u001b[0;36mNeuronpediaRunner.get_tokens\u001b[0;34m(self, n_batches_to_sample_from, n_prompts_to_select)\u001b[0m\n\u001b[1;32m 121\u001b[0m pbar \u001b[38;5;241m=\u001b[39m tqdm(\u001b[38;5;28mrange\u001b[39m(n_batches_to_sample_from))\n\u001b[1;32m 122\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m pbar:\n\u001b[0;32m--> 123\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactivation_store\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_batch_tokens\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 124\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m batch_tokens[torch\u001b[38;5;241m.\u001b[39mrandperm(batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m])][\n\u001b[1;32m 125\u001b[0m : batch_tokens\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]\n\u001b[1;32m 126\u001b[0m ]\n\u001b[1;32m 127\u001b[0m all_tokens_list\u001b[38;5;241m.\u001b[39mappend(batch_tokens)\n", - "File \u001b[0;32m~/Documents/Projects/SAELens/sae_lens/training/activations_store.py:227\u001b[0m, in \u001b[0;36mActivationsStore.get_batch_tokens\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m current_length \u001b[38;5;241m==\u001b[39m context_size:\n\u001b[1;32m 226\u001b[0m full_batch \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat(current_batch, dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m)\n\u001b[0;32m--> 227\u001b[0m batch_tokens \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcat\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_tokens\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfull_batch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 230\u001b[0m current_batch \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 231\u001b[0m current_length \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", @@ -111,6 +56,7 @@ " sae_path=SAE_PATH,\n", " model_id=MODEL_ID,\n", " sae_id=SAE_ID,\n", + " sparsity_threshold=-5,\n", " neuronpedia_outputs_folder=NP_OUTPUT_FOLDER,\n", " init_session=True,\n", " n_batches_to_sample_from=2**12,\n", diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py index 79dbe5d9..6003742b 100644 --- a/tutorials/neuronpedia/make_batch.py +++ b/tutorials/neuronpedia/make_batch.py @@ -4,11 +4,12 @@ SAE_PATH = sys.argv[1] MODEL_ID = sys.argv[2] SAE_ID = sys.argv[3] -N_BATCHES_SAMPLE = int(sys.argv[4]) -N_PROMPTS_SELECT = int(sys.argv[5]) -FEATURES_AT_A_TIME = int(sys.argv[6]) -START_BATCH_INCLUSIVE = int(sys.argv[7]) -END_BATCH_INCLUSIVE = int(sys.argv[8]) +SPARSITY_THRESHOLD = int(sys.argv[4]) +N_BATCHES_SAMPLE = int(sys.argv[5]) +N_PROMPTS_SELECT = int(sys.argv[6]) +FEATURES_AT_A_TIME = int(sys.argv[7]) +START_BATCH_INCLUSIVE = int(sys.argv[8]) +END_BATCH_INCLUSIVE = int(sys.argv[9]) NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" @@ -16,6 +17,7 @@ sae_path=SAE_PATH, model_id=MODEL_ID, sae_id=SAE_ID, + sparsity_threshold=SPARSITY_THRESHOLD, neuronpedia_outputs_folder=NP_OUTPUT_FOLDER, init_session=True, n_batches_to_sample_from=N_BATCHES_SAMPLE, diff --git a/tutorials/neuronpedia/make_features.sh b/tutorials/neuronpedia/make_features.sh index 149147b7..acb79148 100755 --- a/tutorials/neuronpedia/make_features.sh +++ b/tutorials/neuronpedia/make_features.sh @@ -5,42 +5,47 @@ echo "===== This will start a batch job that generates features to upload to Neuronpedia." echo "===== This takes input of one SAE directory at a time." -echo "===== Features will be output into ./neuronpedia_outputs/{model}_{hook_point}_{d_sae}/batch-{batch_num}.json" +echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/" echo "" -echo "(Step 1 of 8)" +echo "(Step 1 of 9)" echo "What is the absolute, full local file path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)?" read saepath # TODO: support huggingface directories echo "" -echo "(Step 2 of 8)" +echo "(Step 2 of 9)" echo "What's the model ID? This must exactly match (including casing) the model ID you created on Neuronpedia." read modelid echo "" -echo "(Step 3 of 8)" +echo "(Step 3 of 9)" echo "What's the SAE ID?" echo "This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). It's in the format [abbrev hook name]-[abbrev author name], like res-jb." read saeid echo "" -echo "(Step 4 of 8)" +echo "(Step 4 of 9)" echo "How many features are in this SAE?" read numfeatures echo "" -echo "(Step 5 of 8)" +echo "(Step 5 of 9)" +read -p "What's your feature sparsity threshold? (default: -5): " sparsity +[ -z "${sparsity}" ] && sparsity='-5' + +echo "" +echo "(Step 6 of 9)" read -p "How many features do you want generate per batch file? More requires more RAM. (default: 128): " perbatch [ -z "${perbatch}" ] && perbatch='128' echo "" -echo "(Step 6 of 8)" +echo "(Step 7 of 9)" read -p "Enter number of batches to sample from (default: 4096): " batches [ -z "${batches}" ] && batches='4096' echo "" -echo "(Step 7 of 8)" +echo "(Step 8 of 9)" read -p "Enter number of prompts to select from (default: 24576): " prompts [ -z "${prompts}" ] && prompts='24576' @@ -49,7 +54,7 @@ numbatches=$(expr $numfeatures / $perbatch) echo "===== INFO: We'll generate $numbatches batches of $perbatch features per batch = $numfeatures total features" echo "" -echo "(Step 8 of 8)" +echo "(Step 9 of 9)" read -p "Do you want to resume from a specific batch number? Enter 1 to start from the beginning (default: 1): " startbatch [ -z "${startbatch}" ] && startbatch='1' @@ -57,15 +62,15 @@ endbatch=$(expr $numbatches) echo "" -echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/batch-{batch_num}.json" +echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/" read -p "===== Hit ENTER to start!" start for j in $(seq $startbatch $endbatch) do echo "" echo "===== BATCH: $j" - echo "RUNNING: python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j" - python make_batch.py $saepath $modelid $saeid $leftbuffer $rightbuffer $batches $prompts $perbatch $j $j + echo "RUNNING: python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j" + python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j done echo "" From 9067380bf67b89d8b2d235944f696016286f683e Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 15:08:01 -0700 Subject: [PATCH 09/30] Upload dead feature stubs --- .../generating_neuronpedia_outputs.ipynb | 38 +++++++------------ tutorials/neuronpedia/upload_batch.py | 16 -------- .../neuronpedia/upload_dead_feature_stubs.py | 18 +++++++++ .../neuronpedia/upload_dead_feature_stubs.sh | 19 ++++++++++ tutorials/neuronpedia/upload_features.sh | 3 -- 5 files changed, 50 insertions(+), 44 deletions(-) create mode 100644 tutorials/neuronpedia/upload_dead_feature_stubs.py create mode 100755 tutorials/neuronpedia/upload_dead_feature_stubs.sh diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index 2346ad9a..5e65b4ff 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -90,8 +90,7 @@ "import os\n", "import requests\n", "\n", - "folder_path = \"../../neuronpedia_outputs/gpt2-small_blocks.0.hook_resid_pre_24576\" #runner.neuronpedia_folder\n", - "\n", + "FEATURE_OUTPUTS_FOLDER = runner.outputs_folder\n", "\n", "def nanToNeg999(obj: Any) -> Any:\n", " if isinstance(obj, dict):\n", @@ -110,13 +109,12 @@ "\n", "# Server info\n", "host = \"http://localhost:3000\"\n", - "sourceName = str(LAYER) + \"-\" + SOURCE\n", "\n", "# Upload alive features\n", - "for file_name in os.listdir(folder_path):\n", + "for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):\n", " if file_name.startswith(\"batch-\") and file_name.endswith(\".json\"):\n", " print(\"Uploading file: \" + file_name)\n", - " file_path = os.path.join(folder_path, file_name)\n", + " file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)\n", " f = open(file_path, \"r\")\n", " data = json.load(f)\n", "\n", @@ -127,28 +125,18 @@ " url = host + \"/api/local/upload-features\"\n", " resp = requests.post(\n", " url,\n", - " json={\n", - " \"modelId\": MODEL,\n", - " \"layer\": sourceName,\n", - " \"features\": data,\n", - " },\n", + " json=data,\n", " )\n", "\n", - "# Upload dead features (just makes blanks features)\n", - "# We want this for completeness\n", - "# skipped_path = os.path.join(folder_path, \"skipped_indexes.json\")\n", - "# f = open(skipped_path, \"r\")\n", - "# data = json.load(f)\n", - "# skipped_indexes = data[\"skipped_indexes\"]\n", - "# url = host + \"/api/internal/upload-dead-features\"\n", - "# resp = requests.post(\n", - "# url,\n", - "# json={\n", - "# \"modelId\": MODEL,\n", - "# \"layer\": sourceName,\n", - "# \"deadIndexes\": skipped_indexes,\n", - "# },\n", - "# )" + "# Upload dead feature stubs\n", + "skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, \"skipped_indexes.json\")\n", + "f = open(skipped_path, \"r\")\n", + "data = json.load(f)\n", + "url = host + \"/api/local/upload-dead-features\"\n", + "resp = requests.post(\n", + " url,\n", + " json=data,\n", + ")" ] }, { diff --git a/tutorials/neuronpedia/upload_batch.py b/tutorials/neuronpedia/upload_batch.py index f599799e..b5b0919c 100644 --- a/tutorials/neuronpedia/upload_batch.py +++ b/tutorials/neuronpedia/upload_batch.py @@ -45,19 +45,3 @@ def encode(self, o: Any, *args: Any, **kwargs: Any): url, json=data, ) - -# Upload dead features (just makes blanks features) -# We want this for completeness -# skipped_path = os.path.join(folder_path, "skipped_indexes.json") -# f = open(skipped_path, "r") -# data = json.load(f) -# skipped_indexes = data["skipped_indexes"] -# url = host + "/api/internal/upload-dead-features" -# resp = requests.post( -# url, -# json={ -# "modelId": MODEL, -# "layer": sourceName, -# "deadIndexes": skipped_indexes, -# }, -# ) diff --git a/tutorials/neuronpedia/upload_dead_feature_stubs.py b/tutorials/neuronpedia/upload_dead_feature_stubs.py new file mode 100644 index 00000000..e2f4f543 --- /dev/null +++ b/tutorials/neuronpedia/upload_dead_feature_stubs.py @@ -0,0 +1,18 @@ +import json +import os +import requests +import sys + +FEATURE_OUTPUTS_FOLDER = sys.argv[1] + +# Server info +host = "http://localhost:3000" + +skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, "skipped_indexes.json") +f = open(skipped_path, "r") +data = json.load(f) +url = host + "/api/local/upload-dead-features" +resp = requests.post( + url, + json=data, +) diff --git a/tutorials/neuronpedia/upload_dead_feature_stubs.sh b/tutorials/neuronpedia/upload_dead_feature_stubs.sh new file mode 100755 index 00000000..1b082ba6 --- /dev/null +++ b/tutorials/neuronpedia/upload_dead_feature_stubs.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +echo "===== This will create stubs for dead features using skipped_indexes.json." +echo "===== You'll need Neuronpedia running at localhost:3000 for this to work." + +echo "" +echo "(Step 1 of 1)" +echo "What is the absolute, full local DIRECTORY PATH to your Neuronpedia batch outputs?" +read outputfilesdir + +echo "" +read -p "===== Hit ENTER to start uploading!" start + +echo "RUNNING: python upload_dead_feature_stubs.py $outputfilesdir" +python upload_dead_feature_stubs.py $outputfilesdir + +echo "" +echo "===== ALL DONE." +echo "===== Go to http://localhost:3000 to browse your features" \ No newline at end of file diff --git a/tutorials/neuronpedia/upload_features.sh b/tutorials/neuronpedia/upload_features.sh index 92cc08a7..719eb681 100755 --- a/tutorials/neuronpedia/upload_features.sh +++ b/tutorials/neuronpedia/upload_features.sh @@ -1,8 +1,5 @@ #!/bin/bash -# we use a script around python to work around OOM issues - this ensures every batch gets the whole system available memory -# better fix is to investigate and fix the memory issues - echo "===== This will start upload the feature batch files to Neuronpedia." echo "===== You'll need Neuronpedia running at localhost:3000 for this to work." From f769e7a65ab84d4073852931a86ff3b5076eea3c Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 16:54:39 -0700 Subject: [PATCH 10/30] Eindex required by sae_vis --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0545dc62..8fac2c16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" +eindex-callum = {git = "https://github.com/callummcdougall/eindex.git"} sae-vis = { git = "https://github.com/hijohnnylin/sae_vis.git", branch = "allow_disable_buffer" } mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" From 1e3d53ec2b72897bfebb6065f3b530fe65d3a97c Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Mon, 15 Apr 2024 22:30:30 -0700 Subject: [PATCH 11/30] Formatting --- sae_lens/analysis/neuronpedia_runner.py | 115 ++++++++++++++++-------- tutorials/neuronpedia/upload_batch.py | 6 +- 2 files changed, 83 insertions(+), 38 deletions(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 0cd8ef16..5bb44dde 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -110,10 +110,14 @@ def init_sae_session(self): self.sae_path, device=self.device ) loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) - self.model, _, self.activation_store = loader.load_sae_training_group_session() + self.model, _, self.activation_store = ( + loader.load_sae_training_group_session() + ) def get_tokens( - self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6 + self, + n_batches_to_sample_from: int = 2**12, + n_prompts_to_select: int = 4096 * 6, ): all_tokens_list = [] pbar = tqdm(range(n_batches_to_sample_from)) @@ -132,7 +136,9 @@ def round_list(self, to_round: list[float]): return list(np.round(to_round, 3)) def to_str_tokens_safe( - self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor] + self, + vocab_dict: Dict[int, str], + tokens: Union[int, List[int], torch.Tensor], ): """ does to_str_tokens, except handles out of range @@ -173,12 +179,16 @@ def run(self): sparsity = load_sparsity(self.sae_path) sparsity = sparsity.to(self.device) self.target_feature_indexes = ( - (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() + (sparsity > self.sparsity_threshold) + .nonzero(as_tuple=True)[0] + .tolist() ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) - n_subarrays = np.ceil(len(feature_idx) / self.n_features_at_a_time).astype(int) + n_subarrays = np.ceil( + len(feature_idx) / self.n_features_at_a_time + ).astype(int) feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] @@ -193,7 +203,9 @@ def run(self): exit() # write dead into file so we can create them as dead in Neuronpedia - skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) + skipped_indexes = set(range(self.n_features)) - set( + self.target_feature_indexes + ) skipped_indexes_json = json.dumps( { "model_id": self.model_id, @@ -209,7 +221,9 @@ def run(self): print(f"Total skipped: {len(skipped_indexes)}") print(f"Total batches: {len(feature_idx)}") - print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") + print( + f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}" + ) print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") print(f"Writing files to: {self.outputs_folder}") @@ -237,7 +251,9 @@ def run(self): for k, v in vocab_dict.items(): modified_key = k for anomaly in HTML_ANOMALIES: - modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) + modified_key = modified_key.replace( + anomaly, HTML_ANOMALIES[anomaly] + ) new_vocab_dict[v] = modified_key vocab_dict = new_vocab_dict @@ -253,7 +269,10 @@ def run(self): if feature_batch_count < self.start_batch: # print(f"Skipping batch - it's after start_batch: {feature_batch_count}") continue - if self.end_batch is not None and feature_batch_count > self.end_batch: + if ( + self.end_batch is not None + and feature_batch_count > self.end_batch + ): # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue @@ -294,13 +313,17 @@ def run(self): ) features_outputs = [] - for _, feat_index in enumerate(feature_data.feature_data_dict.keys()): + for _, feat_index in enumerate( + feature_data.feature_data_dict.keys() + ): feature = feature_data.feature_data_dict[feat_index] feature_output = {} feature_output["featureIndex"] = feat_index - top10_logits = self.round_list(feature.logits_table_data.top_logits) + top10_logits = self.round_list( + feature.logits_table_data.top_logits + ) bottom10_logits = self.round_list( feature.logits_table_data.bottom_logits ) @@ -309,29 +332,41 @@ def run(self): feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices ) - feature_output["neuron_alignment_values"] = self.round_list( - feature.feature_tables_data.neuron_alignment_values + feature_output["neuron_alignment_values"] = ( + self.round_list( + feature.feature_tables_data.neuron_alignment_values + ) ) - feature_output["neuron_alignment_l1"] = self.round_list( - feature.feature_tables_data.neuron_alignment_l1 + feature_output["neuron_alignment_l1"] = ( + self.round_list( + feature.feature_tables_data.neuron_alignment_l1 + ) ) feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - feature_output["correlated_neurons_l1"] = self.round_list( - feature.feature_tables_data.correlated_neurons_cossim + feature_output["correlated_neurons_l1"] = ( + self.round_list( + feature.feature_tables_data.correlated_neurons_cossim + ) ) - feature_output["correlated_neurons_pearson"] = self.round_list( - feature.feature_tables_data.correlated_neurons_pearson + feature_output["correlated_neurons_pearson"] = ( + self.round_list( + feature.feature_tables_data.correlated_neurons_pearson + ) ) feature_output["correlated_features_indices"] = ( feature.feature_tables_data.correlated_features_indices ) - feature_output["correlated_features_l1"] = self.round_list( - feature.feature_tables_data.correlated_features_cossim + feature_output["correlated_features_l1"] = ( + self.round_list( + feature.feature_tables_data.correlated_features_cossim + ) ) - feature_output["correlated_features_pearson"] = self.round_list( - feature.feature_tables_data.correlated_features_pearson + feature_output["correlated_features_pearson"] = ( + self.round_list( + feature.feature_tables_data.correlated_features_pearson + ) ) feature_output["neg_str"] = self.to_str_tokens_safe( @@ -345,9 +380,9 @@ def run(self): feature_output["frac_nonzero"] = ( float( - feature.acts_histogram_data.title.split(" = ")[1].split( - "%" - )[0] + feature.acts_histogram_data.title.split(" = ")[ + 1 + ].split("%")[0] ) / 100 if feature.acts_histogram_data.title is not None @@ -355,18 +390,22 @@ def run(self): ) freq_hist_data = feature.acts_histogram_data - freq_bar_values = self.round_list(freq_hist_data.bar_values) - feature_output["freq_hist_data_bar_values"] = freq_bar_values - feature_output["freq_hist_data_bar_heights"] = self.round_list( - freq_hist_data.bar_heights + freq_bar_values = self.round_list( + freq_hist_data.bar_values + ) + feature_output["freq_hist_data_bar_values"] = ( + freq_bar_values + ) + feature_output["freq_hist_data_bar_heights"] = ( + self.round_list(freq_hist_data.bar_heights) ) logits_hist_data = feature.logits_histogram_data - feature_output["logits_hist_data_bar_heights"] = self.round_list( - logits_hist_data.bar_heights + feature_output["logits_hist_data_bar_heights"] = ( + self.round_list(logits_hist_data.bar_heights) ) - feature_output["logits_hist_data_bar_values"] = self.round_list( - logits_hist_data.bar_values + feature_output["logits_hist_data_bar_values"] = ( + self.round_list(logits_hist_data.bar_values) ) feature_output["num_tokens_for_dashboard"] = ( @@ -420,8 +459,12 @@ def run(self): {"pos": posContribs, "neg": negContribs} ) activation["tokens"] = strs - activation["values"] = self.round_list(sd.feat_acts) - activation["maxValue"] = max(activation["values"]) + activation["values"] = self.round_list( + sd.feat_acts + ) + activation["maxValue"] = max( + activation["values"] + ) activation["lossValues"] = self.round_list( sd.loss_contribution ) diff --git a/tutorials/neuronpedia/upload_batch.py b/tutorials/neuronpedia/upload_batch.py index b5b0919c..56f583e6 100644 --- a/tutorials/neuronpedia/upload_batch.py +++ b/tutorials/neuronpedia/upload_batch.py @@ -15,7 +15,9 @@ def nanToNeg999(obj: Any) -> Any: return {k: nanToNeg999(v) for k, v in obj.items()} elif isinstance(obj, list): return [nanToNeg999(v) for v in obj] - elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan( + obj + ): return -999 return obj @@ -29,7 +31,7 @@ def encode(self, o: Any, *args: Any, **kwargs: Any): host = "http://localhost:3000" # Upload alive features -for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER): +for file_name in sorted(os.listdir(FEATURE_OUTPUTS_FOLDER)): if file_name.startswith("batch-") and file_name.endswith(".json"): print("Uploading file: " + file_name) file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name) From 8d7d4040033fb80c5b994cdc662b0f90b8fcc7aa Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Tue, 16 Apr 2024 16:19:35 -0700 Subject: [PATCH 12/30] convert sparsity to log sparsity if needed --- sae_lens/analysis/neuronpedia_runner.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 5bb44dde..03bf1a71 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -177,6 +177,10 @@ def run(self): # if we have feature sparsity, then use it to only generate outputs for non-dead features self.target_feature_indexes: list[int] = [] sparsity = load_sparsity(self.sae_path) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) sparsity = sparsity.to(self.device) self.target_feature_indexes = ( (sparsity > self.sparsity_threshold) From b611e721dd2620ab5a030cc0f6e37029c30711ca Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 00:30:46 -0700 Subject: [PATCH 13/30] Use python typer instead of shell script for neuronpedia jobs --- pyproject.toml | 1 + sae_lens/analysis/neuronpedia_runner.py | 87 ++-- .../generating_neuronpedia_outputs.ipynb | 16 +- tutorials/neuronpedia/make_batch.py | 17 +- tutorials/neuronpedia/make_features.sh | 79 ---- tutorials/neuronpedia/neuronpedia.py | 410 ++++++++++++++++++ tutorials/neuronpedia/upload_batch.py | 49 --- .../neuronpedia/upload_dead_feature_stubs.py | 18 - .../neuronpedia/upload_dead_feature_stubs.sh | 19 - tutorials/neuronpedia/upload_features.sh | 19 - 10 files changed, 455 insertions(+), 260 deletions(-) delete mode 100755 tutorials/neuronpedia/make_features.sh create mode 100755 tutorials/neuronpedia/neuronpedia.py delete mode 100644 tutorials/neuronpedia/upload_batch.py delete mode 100644 tutorials/neuronpedia/upload_dead_feature_stubs.py delete mode 100755 tutorials/neuronpedia/upload_dead_feature_stubs.sh delete mode 100755 tutorials/neuronpedia/upload_features.sh diff --git a/pyproject.toml b/pyproject.toml index 1fd371f6..42c6cb78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ mkdocs-section-index = "^0.3.8" mkdocstrings = "^0.24.1" mkdocstrings-python = "^1.9.0" safetensors = "^0.4.2" +typer = "^0.12.3" [tool.poetry.group.dev.dependencies] diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 03bf1a71..2ed9bf7f 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -4,7 +4,6 @@ # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json -import time import numpy as np import torch from matplotlib import colors @@ -61,18 +60,15 @@ class NeuronpediaRunner: def __init__( self, - sae_path: str, - model_id: str, sae_id: str, + sae_path: str, + outputs_dir: str, sparsity_threshold: int = DEFAULT_SPARSITY_THRESHOLD, - neuronpedia_outputs_folder: str = "../../neuronpedia_outputs", - init_session: bool = True, # token pars n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6, - # sampling pars - n_features_at_a_time: int = 1024, - # start and end batch + # batching + n_features_at_a_time: int = 128, start_batch_inclusive: int = 0, end_batch_inclusive: Optional[int] = None, ): @@ -84,10 +80,14 @@ def __init__( self.device = "cuda" self.sae_path = sae_path - if init_session: - self.init_sae_session() - - self.model_id = model_id + self.sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + self.sae_path, device=self.device + ) + loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) + self.model, _, self.activation_store = ( + loader.load_sae_training_group_session() + ) + self.model_id = self.model.cfg.model_name self.layer = self.sparse_autoencoder.cfg.hook_point_layer self.sae_id = sae_id self.sparsity_threshold = sparsity_threshold @@ -97,22 +97,9 @@ def __init__( self.start_batch = start_batch_inclusive self.end_batch = end_batch_inclusive - if not os.path.exists(neuronpedia_outputs_folder): - os.makedirs(neuronpedia_outputs_folder) - self.neuronpedia_outputs_folder = neuronpedia_outputs_folder - - self.outputs_folder = f"{neuronpedia_outputs_folder}/{self.sparse_autoencoder.cfg.model_name}_{self.sae_id}_{self.sparse_autoencoder.cfg.hook_point}" - if not os.path.exists(self.outputs_folder): - os.makedirs(self.outputs_folder) - - def init_sae_session(self): - self.sparse_autoencoder = SparseAutoencoder.load_from_pretrained( - self.sae_path, device=self.device - ) - loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) - self.model, _, self.activation_store = ( - loader.load_sae_training_group_session() - ) + if not os.path.exists(outputs_dir): + os.makedirs(outputs_dir) + self.outputs_dir = outputs_dir def get_tokens( self, @@ -164,13 +151,6 @@ def to_str_tokens_safe( return np.reshape(str_tokens, tokens.shape).tolist() def run(self): - """ - Generate the Neuronpedia outputs. - """ - - if self.model is None: - self.init_sae_session() - self.n_features = self.sparse_autoencoder.cfg.d_sae assert self.n_features is not None @@ -196,9 +176,9 @@ def run(self): feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] - print(f"==== Starting at batch: {self.start_batch}") - if self.end_batch is not None: - print(f"==== Ending at batch: {self.end_batch}") + # print(f"==== Starting Batch: {self.start_batch}") + # if self.end_batch is not None and self.end_batch != self.start_batch: + # print(f"==== Ending at Batch: {self.end_batch}") if self.start_batch > len(feature_idx) + 1: print( @@ -218,32 +198,18 @@ def run(self): "skipped_indexes": list(skipped_indexes), } ) - with open(f"{self.outputs_folder}/skipped_indexes.json", "w") as f: + with open(f"{self.outputs_dir}/skipped_indexes.json", "w") as f: f.write(skipped_indexes_json) - print(f"Total features to run: {len(self.target_feature_indexes)}") - print(f"Total skipped: {len(skipped_indexes)}") - print(f"Total batches: {len(feature_idx)}") - - print( - f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}" - ) - print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") - print(f"Writing files to: {self.outputs_folder}") - - tokens_file = f"{self.outputs_folder}/tokens_{self.n_batches_to_sample_from}_{self.n_prompts_to_select}.pt" - + tokens_file = f"{self.outputs_dir}/tokens_{self.n_batches_to_sample_from}_{self.n_prompts_to_select}.pt" if os.path.isfile(tokens_file): - print("Tokens exist, loading it") + print("Tokens exist, loading them.") tokens = torch.load(tokens_file) else: - start = time.time() + print("Tokens don't exist, making them.") tokens = self.get_tokens( self.n_batches_to_sample_from, self.n_prompts_to_select ) - end = time.time() - print(f"Time to get tokens: {end - start}") - print("Saved tokens to: " + tokens_file) torch.save( tokens, tokens_file, @@ -280,8 +246,9 @@ def run(self): # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue - print(f"Doing batch: {feature_batch_count}") - print(f"{features_to_process}") + print( + f"========== Running Batch #{feature_batch_count} ==========" + ) layout = SaeVisLayoutConfig( columns=[ @@ -483,11 +450,13 @@ def run(self): "layer": str(self.layer), "sae_id": self.sae_id, "features": features_outputs, + "n_batches_to_sample_from": self.n_batches_to_sample_from, + "n_prompts_to_select": self.n_prompts_to_select, } json_object = json.dumps(to_write, cls=NpEncoder) with open( - f"{self.outputs_folder}/batch-{feature_batch_count}.json", + f"{self.outputs_dir}/batch-{feature_batch_count}.json", "w", ) as f: f.write(json_object) diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index 5e65b4ff..a208f77c 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -51,20 +51,20 @@ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", "print(SAE_PATH)\n", - "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs\"\n", + "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs/my_outputs\"\n", + "\n", "runner = NeuronpediaRunner(\n", - " sae_path=SAE_PATH,\n", - " model_id=MODEL_ID,\n", " sae_id=SAE_ID,\n", + " sae_path=SAE_PATH,\n", + " outputs_dir=NP_OUTPUT_FOLDER,\n", " sparsity_threshold=-5,\n", - " neuronpedia_outputs_folder=NP_OUTPUT_FOLDER,\n", - " init_session=True,\n", " n_batches_to_sample_from=2**12,\n", - " n_prompts_to_select=4096 * 6,\n", + " n_prompts_to_select=4096*6,\n", " n_features_at_a_time=24,\n", " start_batch_inclusive=1,\n", " end_batch_inclusive=1,\n", ")\n", + "\n", "runner.run()" ] }, @@ -90,7 +90,7 @@ "import os\n", "import requests\n", "\n", - "FEATURE_OUTPUTS_FOLDER = runner.outputs_folder\n", + "FEATURE_OUTPUTS_FOLDER = runner.outputs_dir\n", "\n", "def nanToNeg999(obj: Any) -> Any:\n", " if isinstance(obj, dict):\n", @@ -163,7 +163,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py index 6003742b..b5f28e82 100644 --- a/tutorials/neuronpedia/make_batch.py +++ b/tutorials/neuronpedia/make_batch.py @@ -1,9 +1,12 @@ import sys from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner -SAE_PATH = sys.argv[1] -MODEL_ID = sys.argv[2] -SAE_ID = sys.argv[3] +# we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +SAE_ID = sys.argv[1] +SAE_PATH = sys.argv[2] +OUTPUTS_DIR = sys.argv[3] SPARSITY_THRESHOLD = int(sys.argv[4]) N_BATCHES_SAMPLE = int(sys.argv[5]) N_PROMPTS_SELECT = int(sys.argv[6]) @@ -11,15 +14,11 @@ START_BATCH_INCLUSIVE = int(sys.argv[8]) END_BATCH_INCLUSIVE = int(sys.argv[9]) -NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" - runner = NeuronpediaRunner( - sae_path=SAE_PATH, - model_id=MODEL_ID, sae_id=SAE_ID, + sae_path=SAE_PATH, + outputs_dir=OUTPUTS_DIR, sparsity_threshold=SPARSITY_THRESHOLD, - neuronpedia_outputs_folder=NP_OUTPUT_FOLDER, - init_session=True, n_batches_to_sample_from=N_BATCHES_SAMPLE, n_prompts_to_select=N_PROMPTS_SELECT, n_features_at_a_time=FEATURES_AT_A_TIME, diff --git a/tutorials/neuronpedia/make_features.sh b/tutorials/neuronpedia/make_features.sh deleted file mode 100755 index acb79148..00000000 --- a/tutorials/neuronpedia/make_features.sh +++ /dev/null @@ -1,79 +0,0 @@ -#!/bin/bash - -# we use a script around python to work around OOM issues - this ensures every batch gets the whole system available memory -# better fix is to investigate and fix the memory issues - -echo "===== This will start a batch job that generates features to upload to Neuronpedia." -echo "===== This takes input of one SAE directory at a time." -echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/" - -echo "" -echo "(Step 1 of 9)" -echo "What is the absolute, full local file path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)?" -read saepath -# TODO: support huggingface directories - -echo "" -echo "(Step 2 of 9)" -echo "What's the model ID? This must exactly match (including casing) the model ID you created on Neuronpedia." -read modelid - -echo "" -echo "(Step 3 of 9)" -echo "What's the SAE ID?" -echo "This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). It's in the format [abbrev hook name]-[abbrev author name], like res-jb." -read saeid - -echo "" -echo "(Step 4 of 9)" -echo "How many features are in this SAE?" -read numfeatures - -echo "" -echo "(Step 5 of 9)" -read -p "What's your feature sparsity threshold? (default: -5): " sparsity -[ -z "${sparsity}" ] && sparsity='-5' - -echo "" -echo "(Step 6 of 9)" -read -p "How many features do you want generate per batch file? More requires more RAM. (default: 128): " perbatch -[ -z "${perbatch}" ] && perbatch='128' - -echo "" -echo "(Step 7 of 9)" -read -p "Enter number of batches to sample from (default: 4096): " batches -[ -z "${batches}" ] && batches='4096' - -echo "" -echo "(Step 8 of 9)" -read -p "Enter number of prompts to select from (default: 24576): " prompts -[ -z "${prompts}" ] && prompts='24576' - -echo "" -numbatches=$(expr $numfeatures / $perbatch) -echo "===== INFO: We'll generate $numbatches batches of $perbatch features per batch = $numfeatures total features" - -echo "" -echo "(Step 9 of 9)" -read -p "Do you want to resume from a specific batch number? Enter 1 to start from the beginning (default: 1): " startbatch -[ -z "${startbatch}" ] && startbatch='1' - -endbatch=$(expr $numbatches) - - -echo "" -echo "===== Features will be output into [repo_dir]/neuronpedia_outputs/{modelId}_{saeId}_{hook_point}/" -read -p "===== Hit ENTER to start!" start - -for j in $(seq $startbatch $endbatch) - do - echo "" - echo "===== BATCH: $j" - echo "RUNNING: python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j" - python make_batch.py $saepath $modelid $saeid $sparsity $batches $prompts $perbatch $j $j -done - -echo "" -echo "===== ALL DONE." -echo "===== Your features are under: [repo_dir]/neuronpedia_outputs/{model}_{hook_point}_{d_sae}" -echo "===== Use upload_features.sh to upload your features. Be sure to have the localhost server running first." \ No newline at end of file diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py new file mode 100755 index 00000000..fbf32496 --- /dev/null +++ b/tutorials/neuronpedia/neuronpedia.py @@ -0,0 +1,410 @@ +# we use a script that launches separate python processes to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +import json +import os +import requests +import typer +import torch +import math +import subprocess +from typing import Any +from decimal import Decimal +from pathlib import Path +from typing_extensions import Annotated +from rich import print +from rich.align import Align +from rich.panel import Panel +from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from sae_lens.toolkit.pretrained_saes import load_sparsity + +OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") + +app = typer.Typer( + add_completion=False, + no_args_is_help=True, + help="Tool that generates features (generate) and uploads features (upload) to Neuronpedia.", +) + + +@app.command() +def generate( + sae_id: Annotated[ + str, + typer.Option( + help="SAE ID to generate features for (must exactly match the one used on Neuronpedia). Example: res-jb", + prompt=""" +What is the SAE ID you want to generate features for? +This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). +It's in the format [abbrev hook name]-[abbrev author name], like res-jb. +Enter SAE ID""", + ), + ], + sae_path: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + help="Absolute local path to the SAE directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors).", + prompt=""" +What is the absolute local path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)? +Enter path""", + ), + ], + log_sparsity: Annotated[ + int, + typer.Option( + min=-10, + max=0, + help="Desired feature log sparsity threshold. Range -10 to 0.", + prompt=""" +What is your desired feature log sparsity threshold? +Enter value from -10 to 0""", + ), + ] = -5, + feat_per_batch: Annotated[ + int, + typer.Option( + min=1, + max=2048, + help="Features to generate per batch. More requires more memory.", + prompt=""" +How many features do you want to generate per batch? More requires more memory. +Enter value""", + ), + ] = 128, + resume_from_batch: Annotated[ + int, + typer.Option( + min=1, + help="Batch number to resume from.", + prompt=""" +Do you want to resume from a specific batch number? +Enter 1 to start from the beginning""", + ), + ] = 1, + n_batches_to_sample: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of batches to sample from.", + prompt=""" +[Activation Text Generation] How many batches do you want to sample from? +Enter value""", + ), + ] = 2 + ** 12, + n_prompts_to_select: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of prompts to select from.", + prompt=""" +[Activation Text Generation] How many prompts do you want to select from? +Enter value""", + ), + ] = 4096 + * 6, +): + """ + This will start a batch job that generates features for Neuronpedia for a specific SAE. To upload those features, use the 'upload' command afterwards. + """ + + # Check arguments + if sae_path.is_dir() is not True: + print("Error: SAE path must be a directory.") + raise typer.Abort() + if sae_path.joinpath("cfg.json").is_file() is not True: + print("Error: cfg.json file not found in SAE directory.") + raise typer.Abort() + if sae_path.joinpath("sae_weights.safetensors").is_file() is not True: + print( + "Error: sae_weights.safetensors file not found in SAE directory." + ) + raise typer.Abort() + if sae_path.joinpath("sparsity.safetensors").is_file() is not True: + print("Error: sparsity.safetensors file not found in SAE directory.") + raise typer.Abort() + + sae_path_string = sae_path.as_posix() + + # Load SAE + device = "cpu" + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + sae_path_string, device=device + ) + model_id = sparse_autoencoder.cfg.model_name + + outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" + outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) + if outputs_dir.exists() and outputs_dir.is_file(): + print( + f"Error: Output directory {outputs_dir.as_posix()} exists and is a file." + ) + raise typer.Abort() + outputs_dir.mkdir(parents=True, exist_ok=True) + # Check if output_dir has any files starting with "batch_" + batch_files = list(outputs_dir.glob("batch-*.json")) + if len(batch_files) > 0 and resume_from_batch != 1: + print( + f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files." + ) + raise typer.Abort() + + sparsity = load_sparsity(sae_path_string) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) + sparsity = sparsity.to(device) + alive_indexes = ( + (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() + ) + num_alive = len(alive_indexes) + num_dead = sparse_autoencoder.d_sae - num_alive + + print("\n") + print( + Align.center( + Panel.fit( + f""" +[white]SAE Path: [green]{sae_path.as_posix()} +[white]Model ID: [green]{model_id} +[white]Hook Point: [green]{sparse_autoencoder.cfg.hook_point} +[white]Using Device: [green]{device} +""", + title="SAE Info", + ) + ) + ) + num_batches = math.ceil(num_alive / feat_per_batch) + print( + Align.center( + Panel.fit( + f""" +[white]Total Features: [green]{sparse_autoencoder.d_sae} +[white]Log Sparsity Threshold: [green]{log_sparsity} +[white]Alive Features: [green]{num_alive} +[white]Dead Features: [red]{num_dead} +[white]Features per Batch: [green]{feat_per_batch} +[white]Number of Batches: [green]{num_batches} +{resume_from_batch != 1 and f"[white]Resuming from Batch: [green]{resume_from_batch}" or ""} +""", + title="Number of Features", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[white]Dataset: [green]{sparse_autoencoder.cfg.dataset_path} +[white]Batches to Sample From: [green]{n_batches_to_sample} +[white]Prompts to Select From: [green]{n_prompts_to_select} +""", + title="Activation Text Settings", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[green]{outputs_dir.absolute().as_posix()} +""", + title="Output Directory", + ) + ) + ) + + print( + Align.center( + "\n========== [yellow]Starting batch feature generations...[/yellow] ==========" + ) + ) + + # iterate from 1 to num_batches + for i in range(1, num_batches + 1): + command = [ + "python", + "make_batch.py", + sae_id, + sae_path.absolute().as_posix(), + outputs_dir.absolute().as_posix(), + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + print("\n") + print( + Align.center( + Panel.fit( + f""" +[yellow]{" ".join(command)} +""", + title="Running Command for Batch #" + str(i), + ) + ) + ) + # make a subprocess call to python make_batch.py + subprocess.run( + [ + "python", + "make_batch.py", + sae_id, + sae_path, + outputs_dir, + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + ) + + print( + Align.center( + Panel( + f""" +Your Features Are In: [green]{outputs_dir.absolute().as_posix()} +Use [yellow]'neuronpedia.py upload'[/yellow] to upload your features to Neuronpedia. +""", + title="Generation Complete", + ) + ) + ) + + +@app.command() +def upload( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will upload features that were generated to Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + files_to_upload = list(outputs_dir.glob("batch-*.json")) + + # sort files by batch number + files_to_upload.sort(key=lambda x: int(x.stem.split("-")[1])) + + print("\n") + # Upload alive features + for file_path in files_to_upload: + print("===== Uploading file: " + os.path.basename(file_path)) + f = open(file_path, "r") + data = json.load(f) + + # Replace NaNs + data_fixed = json.dumps(data, cls=NanConverter) + data = json.loads(data_fixed) + + url = host + "/api/local/upload-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + f""" +{len(files_to_upload)} batch files uploaded to Neuronpedia. +""", + title="Uploads Complete", + ) + ) + ) + + +@app.command() +def upload_dead_stubs( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will create "There are no activations for this feature" stubs for dead features on Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + skipped_path = os.path.join(outputs_dir, "skipped_indexes.json") + f = open(skipped_path, "r") + data = json.load(f) + url = host + "/api/local/upload-dead-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + """ +Dead feature stubs created. +""", + title="Complete", + ) + ) + ) + + +# Helper utilities that help fix weird NaNs in the feature outputs + + +def nanToNeg999(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: nanToNeg999(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [nanToNeg999(v) for v in obj] + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan( + obj + ): + return -999 + return obj + + +class NanConverter(json.JSONEncoder): + def encode(self, o: Any, *args: Any, **kwargs: Any): + return super().encode(nanToNeg999(o), *args, **kwargs) + + +if __name__ == "__main__": + app() diff --git a/tutorials/neuronpedia/upload_batch.py b/tutorials/neuronpedia/upload_batch.py deleted file mode 100644 index 56f583e6..00000000 --- a/tutorials/neuronpedia/upload_batch.py +++ /dev/null @@ -1,49 +0,0 @@ -# Helpers that fix weird NaN stuff -from decimal import Decimal -from typing import Any -import math -import json -import os -import requests -import sys - -FEATURE_OUTPUTS_FOLDER = sys.argv[1] - - -def nanToNeg999(obj: Any) -> Any: - if isinstance(obj, dict): - return {k: nanToNeg999(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [nanToNeg999(v) for v in obj] - elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan( - obj - ): - return -999 - return obj - - -class NanConverter(json.JSONEncoder): - def encode(self, o: Any, *args: Any, **kwargs: Any): - return super().encode(nanToNeg999(o), *args, **kwargs) - - -# Server info -host = "http://localhost:3000" - -# Upload alive features -for file_name in sorted(os.listdir(FEATURE_OUTPUTS_FOLDER)): - if file_name.startswith("batch-") and file_name.endswith(".json"): - print("Uploading file: " + file_name) - file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name) - f = open(file_path, "r") - data = json.load(f) - - # Replace NaNs - data_fixed = json.dumps(data, cls=NanConverter) - data = json.loads(data_fixed) - - url = host + "/api/local/upload-features" - resp = requests.post( - url, - json=data, - ) diff --git a/tutorials/neuronpedia/upload_dead_feature_stubs.py b/tutorials/neuronpedia/upload_dead_feature_stubs.py deleted file mode 100644 index e2f4f543..00000000 --- a/tutorials/neuronpedia/upload_dead_feature_stubs.py +++ /dev/null @@ -1,18 +0,0 @@ -import json -import os -import requests -import sys - -FEATURE_OUTPUTS_FOLDER = sys.argv[1] - -# Server info -host = "http://localhost:3000" - -skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, "skipped_indexes.json") -f = open(skipped_path, "r") -data = json.load(f) -url = host + "/api/local/upload-dead-features" -resp = requests.post( - url, - json=data, -) diff --git a/tutorials/neuronpedia/upload_dead_feature_stubs.sh b/tutorials/neuronpedia/upload_dead_feature_stubs.sh deleted file mode 100755 index 1b082ba6..00000000 --- a/tutorials/neuronpedia/upload_dead_feature_stubs.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -echo "===== This will create stubs for dead features using skipped_indexes.json." -echo "===== You'll need Neuronpedia running at localhost:3000 for this to work." - -echo "" -echo "(Step 1 of 1)" -echo "What is the absolute, full local DIRECTORY PATH to your Neuronpedia batch outputs?" -read outputfilesdir - -echo "" -read -p "===== Hit ENTER to start uploading!" start - -echo "RUNNING: python upload_dead_feature_stubs.py $outputfilesdir" -python upload_dead_feature_stubs.py $outputfilesdir - -echo "" -echo "===== ALL DONE." -echo "===== Go to http://localhost:3000 to browse your features" \ No newline at end of file diff --git a/tutorials/neuronpedia/upload_features.sh b/tutorials/neuronpedia/upload_features.sh deleted file mode 100755 index 719eb681..00000000 --- a/tutorials/neuronpedia/upload_features.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -echo "===== This will start upload the feature batch files to Neuronpedia." -echo "===== You'll need Neuronpedia running at localhost:3000 for this to work." - -echo "" -echo "(Step 1 of 1)" -echo "What is the absolute, full local DIRECTORY PATH to your Neuronpedia batch outputs?" -read outputfilesdir - -echo "" -read -p "===== Hit ENTER to start uploading!" start - -echo "RUNNING: python upload_batch.py $outputfilesdir" -python upload_batch.py $outputfilesdir - -echo "" -echo "===== ALL DONE." -echo "===== Go to http://localhost:3000 to browse your features" \ No newline at end of file From 932f380971ce3d431e6592c804d12f6df2b4ec78 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 00:34:06 -0700 Subject: [PATCH 14/30] Fix upload skipped/dead features --- tutorials/neuronpedia/neuronpedia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py index fbf32496..380a39fc 100755 --- a/tutorials/neuronpedia/neuronpedia.py +++ b/tutorials/neuronpedia/neuronpedia.py @@ -368,7 +368,7 @@ def upload_dead_stubs( skipped_path = os.path.join(outputs_dir, "skipped_indexes.json") f = open(skipped_path, "r") data = json.load(f) - url = host + "/api/local/upload-dead-features" + url = host + "/api/local/upload-skipped-features" requests.post( url, json=data, From eea7db4b99098c33cd862e7e2280a32b630826bd Mon Sep 17 00:00:00 2001 From: Phylliida Dev Date: Wed, 17 Apr 2024 09:39:06 -0700 Subject: [PATCH 15/30] feat: Mamba support vs mamba-lens (#79) * mamba support * added init * added optional model kwargs * Support transformers and mamba * forgot one model kwargs * failed opts * tokens input * hack to fix tokens, will look into fixing mambalens * fixed checkpoint * added sae group * removed some comments and fixed merge error * removed unneeded params since that issue is fixed in mambalens now * Unneded input param * removed debug checkpoing and eval * added refs to hookedrootmodule * feed linter * added example and fixed loading * made layer for eval change * fix linter issues * adding mamba-lens as optional dep, and fixing typing/linting * adding a test for loading mamba model * adding mamba-lens to dev for CI * updating min mamba-lens version * updating mamba-lens version --------- Co-authored-by: David Chanin --- __init__.py | 0 pyproject.toml | 5 ++ sae_lens/analysis/dashboard_runner.py | 10 +++- sae_lens/analysis/neuronpedia_runner.py | 10 +++- sae_lens/training/activations_store.py | 14 +++-- sae_lens/training/cache_activations_runner.py | 9 ++- sae_lens/training/config.py | 7 ++- sae_lens/training/evals.py | 35 +++++++---- sae_lens/training/load_model.py | 26 ++++++++ sae_lens/training/sae_group.py | 8 +-- sae_lens/training/session_loader.py | 15 +++-- .../training/train_sae_on_language_model.py | 6 +- tests/unit/training/test_load_model.py | 13 ++++ tutorials/mamba_train_example.py | 60 +++++++++++++++++++ 14 files changed, 182 insertions(+), 36 deletions(-) create mode 100644 __init__.py create mode 100644 sae_lens/training/load_model.py create mode 100644 tests/unit/training/test_load_model.py create mode 100644 tutorials/mamba_train_example.py diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index d07f13f0..4f62f578 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ mkdocs-section-index = "^0.3.8" mkdocstrings = "^0.24.1" mkdocstrings-python = "^1.9.0" safetensors = "^0.4.2" +mamba-lens = { version = "^0.0.4", optional = true } [tool.poetry.group.dev.dependencies] @@ -38,6 +39,10 @@ pre-commit = "^3.6.2" flake8 = "^7.0.0" isort = "^5.13.2" pyright = "^1.1.351" +mamba-lens = "^0.0.4" + +[tool.poetry.extras] +mamba = ["mamba-lens"] [tool.isort] diff --git a/sae_lens/analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py index ebc25fc7..d77be35d 100644 --- a/sae_lens/analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -22,6 +22,7 @@ from sae_vis.data_fetching_fns import get_feature_data from torch.nn.functional import cosine_similarity from tqdm import tqdm +from transformer_lens import HookedTransformer import wandb from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader @@ -29,6 +30,8 @@ class DashboardRunner: + model: HookedTransformer | None = None + def __init__( self, sae_path: Optional[str] = None, @@ -131,10 +134,14 @@ def get_dashboard_folder_name(self): def init_sae_session(self): ( - self.model, + model, sae_group, self.activation_store, ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) + assert isinstance( + model, HookedTransformer + ) # only HookedTransformer is allowed to be used in the dashboard + self.model = model # TODO: handle multiple autoencoders self.sparse_autoencoder = next(iter(sae_group))[1] @@ -316,6 +323,7 @@ def run(self): if self.use_wandb: wandb.log({"time/time_to_get_tokens": end - start}) + assert self.model is not None vocab_dict = cast(Any, self.model.tokenizer).vocab vocab_dict = { v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items() diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 96475d64..14c39556 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -19,6 +19,7 @@ ) from sae_vis.data_fetching_fns import get_feature_data from tqdm import tqdm +from transformer_lens import HookedTransformer from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader @@ -42,6 +43,8 @@ def default(self, o: Any): class NeuronpediaRunner: + model: HookedTransformer | None = None + def __init__( self, sae_path: str, @@ -91,10 +94,13 @@ def get_folder_name(self): def init_sae_session(self): ( - self.model, + model, sae_group, self.activation_store, ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) + # only HookedTransformer works with this runner + assert isinstance(model, HookedTransformer) + self.model = model # TODO: handle multiple autoencoders self.sparse_autoencoder = next(iter(sae_group))[1] @@ -123,6 +129,7 @@ def to_str_tokens_safe( """ does to_str_tokens, except handles out of range """ + assert self.model is not None vocab_max_index = self.model.cfg.d_vocab - 1 # Deal with the int case separately if isinstance(tokens, int): @@ -205,6 +212,7 @@ def run(self): end = time.time() print(f"Time to get tokens: {end - start}") + assert self.model is not None vocab_dict = cast(Any, self.model.tokenizer).vocab vocab_dict = { v: k.replace("Ġ", " ").replace("\n", "\\n").replace("Ċ", "\n") diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index f5c33ebb..604f7ba4 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -12,7 +12,7 @@ load_dataset, ) from torch.utils.data import DataLoader -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule from sae_lens.training.config import ( CacheActivationsRunnerConfig, @@ -28,7 +28,7 @@ class ActivationsStore: while training SAEs. """ - model: HookedTransformer + model: HookedRootModule dataset: HfDataset cached_activations_path: str | None tokens_column: Literal["tokens", "input_ids", "text"] @@ -39,7 +39,7 @@ class ActivationsStore: @classmethod def from_config( cls, - model: HookedTransformer, + model: HookedRootModule, cfg: LanguageModelSAERunnerConfig | CacheActivationsRunnerConfig, dataset: HfDataset | None = None, ) -> "ActivationsStore": @@ -66,11 +66,12 @@ def from_config( device=cfg.device, dtype=cfg.dtype, cached_activations_path=cached_activations_path, + model_kwargs=cfg.model_kwargs, ) def __init__( self, - model: HookedTransformer, + model: HookedRootModule, dataset: HfDataset | str, hook_point: str, hook_point_layers: list[int], @@ -85,8 +86,12 @@ def __init__( device: str | torch.device, dtype: str | torch.dtype, cached_activations_path: str | None = None, + model_kwargs: dict[str, Any] | None = None, ): self.model = model + if model_kwargs is None: + model_kwargs = {} + self.model_kwargs = model_kwargs self.dataset = ( load_dataset(dataset, split="train", streaming=True) if isinstance(dataset, str) @@ -248,6 +253,7 @@ def get_activations(self, batch_tokens: torch.Tensor): names_filter=act_names, stop_at_layer=hook_point_max_layer + 1, prepend_bos=self.prepend_bos, + **self.model_kwargs, )[1] activations_list = [layerwise_activations[act_name] for act_name in act_names] if self.hook_point_head_index is not None: diff --git a/sae_lens/training/cache_activations_runner.py b/sae_lens/training/cache_activations_runner.py index 2c939317..443b399e 100644 --- a/sae_lens/training/cache_activations_runner.py +++ b/sae_lens/training/cache_activations_runner.py @@ -3,16 +3,19 @@ import torch from tqdm import tqdm -from transformer_lens import HookedTransformer from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.config import CacheActivationsRunnerConfig +from sae_lens.training.load_model import load_model from sae_lens.training.utils import shuffle_activations_pairwise def cache_activations_runner(cfg: CacheActivationsRunnerConfig): - model = HookedTransformer.from_pretrained(cfg.model_name) - model.to(cfg.device) + model = load_model( + model_class_name=cfg.model_class_name, + model_name=cfg.model_name, + device=cfg.device, + ) activations_store = ActivationsStore.from_config( model, cfg, diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 1d531f51..2a91b86f 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Optional, cast import torch @@ -21,7 +21,9 @@ class LanguageModelSAERunnerConfig: # Data Generating Function (Model + Training Distibuion) model_name: str = "gelu-2l" + model_class_name: str = "HookedTransformer" hook_point: str = "blocks.{layer}.hook_mlp_out" + hook_point_eval: str = "blocks.{layer}.attn.pattern" hook_point_layer: int | list[int] = 0 hook_point_head_index: Optional[int] = None dataset_path: str = "NeelNanda/c4-tokenized-2b" @@ -91,6 +93,7 @@ class LanguageModelSAERunnerConfig: n_checkpoints: int = 0 checkpoint_path: str = "checkpoints" verbose: bool = True + model_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): if self.use_cached_activations and self.cached_activations_path is None: @@ -190,6 +193,7 @@ class CacheActivationsRunnerConfig: # Data Generating Function (Model + Training Distibuion) model_name: str = "gelu-2l" + model_class_name: str = "HookedTransformer" hook_point: str = "blocks.{layer}.hook_mlp_out" hook_point_layer: int | list[int] = 0 hook_point_head_index: Optional[int] = None @@ -220,6 +224,7 @@ class CacheActivationsRunnerConfig: n_shuffles_with_last_section: int = 10 n_shuffles_in_entire_dir: int = 10 n_shuffles_final: int = 100 + model_kwargs: dict[str, Any] = field(default_factory=dict) def __post_init__(self): # Autofill cached_activations_path unless the user overrode it diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 25a59e3f..80aa91dd 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -3,8 +3,7 @@ import pandas as pd import torch -from transformer_lens import HookedTransformer -from transformer_lens.utils import get_act_name +from transformer_lens.hook_points import HookedRootModule import wandb from sae_lens.training.activations_store import ActivationsStore @@ -15,14 +14,16 @@ def run_evals( sparse_autoencoder: SparseAutoencoder, activation_store: ActivationsStore, - model: HookedTransformer, + model: HookedRootModule, n_training_steps: int, suffix: str = "", ) -> Mapping[str, Any]: hook_point = sparse_autoencoder.cfg.hook_point hook_point_layer = sparse_autoencoder.hook_point_layer hook_point_head_index = sparse_autoencoder.cfg.hook_point_head_index - + hook_point_eval = sparse_autoencoder.cfg.hook_point_eval.format( + layer=hook_point_layer + ) ### Evals eval_tokens = activation_store.get_batch_tokens() @@ -43,7 +44,8 @@ def run_evals( _, cache = model.run_with_cache( eval_tokens, prepend_bos=False, - names_filter=[get_act_name("pattern", hook_point_layer), hook_point], + names_filter=[hook_point_eval, hook_point], + **sparse_autoencoder.cfg.model_kwargs, ) has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] @@ -62,7 +64,9 @@ def run_evals( l2_norm_in = torch.norm(original_act, dim=-1) l2_norm_out = torch.norm(sae_out, dim=-1) - l2_norm_ratio = l2_norm_out / l2_norm_in + l2_norm_in_for_div = l2_norm_in.clone() + l2_norm_in_for_div[torch.abs(l2_norm_in_for_div) < 0.0001] = 1 + l2_norm_ratio = l2_norm_out / l2_norm_in_for_div metrics = { # l2 norms @@ -86,7 +90,7 @@ def run_evals( def recons_loss_batched( sparse_autoencoder: SparseAutoencoder, - model: HookedTransformer, + model: HookedRootModule, activation_store: ActivationsStore, n_batches: int = 100, ): @@ -115,11 +119,13 @@ def recons_loss_batched( @torch.no_grad() def get_recons_loss( sparse_autoencoder: SparseAutoencoder, - model: HookedTransformer, + model: HookedRootModule, batch_tokens: torch.Tensor, ): hook_point = sparse_autoencoder.cfg.hook_point - loss = model(batch_tokens, return_type="loss") + loss = model( + batch_tokens, return_type="loss", **sparse_autoencoder.cfg.model_kwargs + ) head_index = sparse_autoencoder.cfg.hook_point_head_index def standard_replacement_hook(activations: torch.Tensor, hook: Any): @@ -157,13 +163,20 @@ def single_head_replacement_hook(activations: torch.Tensor, hook: Any): batch_tokens, return_type="loss", fwd_hooks=[(hook_point, partial(replacement_hook))], + **sparse_autoencoder.cfg.model_kwargs, ) zero_abl_loss = model.run_with_hooks( - batch_tokens, return_type="loss", fwd_hooks=[(hook_point, zero_ablate_hook)] + batch_tokens, + return_type="loss", + fwd_hooks=[(hook_point, zero_ablate_hook)], + **sparse_autoencoder.cfg.model_kwargs, ) - score = (zero_abl_loss - recons_loss) / (zero_abl_loss - loss) + div_val = zero_abl_loss - loss + div_val[torch.abs(div_val) < 0.0001] = 1.0 + + score = (zero_abl_loss - recons_loss) / div_val return score, loss, recons_loss, zero_abl_loss diff --git a/sae_lens/training/load_model.py b/sae_lens/training/load_model.py new file mode 100644 index 00000000..6a41c9df --- /dev/null +++ b/sae_lens/training/load_model.py @@ -0,0 +1,26 @@ +from typing import Any, cast + +import torch +from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule + + +def load_model( + model_class_name: str, model_name: str, device: str | torch.device | None = None +) -> HookedRootModule: + if model_class_name == "HookedTransformer": + return HookedTransformer.from_pretrained(model_name=model_name, device=device) + elif model_class_name == "HookedMamba": + try: + from mamba_lens import HookedMamba + except ImportError: + raise ValueError( + "mamba-lens must be installed to work with mamba models. This can be added with `pip install sae-lens[mamba]`" + ) + # HookedMamba has incorrect typing information, so we need to cast the type here + return cast( + HookedRootModule, + HookedMamba.from_pretrained(model_name, device=cast(Any, device)), + ) + else: + raise ValueError(f"Unknown model class: {model_class_name}") diff --git a/sae_lens/training/sae_group.py b/sae_lens/training/sae_group.py index fb31cd1e..b7ae75f9 100644 --- a/sae_lens/training/sae_group.py +++ b/sae_lens/training/sae_group.py @@ -129,6 +129,9 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary" # handle loading old autoencoders where before SAEGroup existed, where we just save a dict if isinstance(group, dict): cfg = group["cfg"] + # need to add this field to old configs + if not hasattr(cfg, "model_kwargs"): + cfg.model_kwargs = {} sparse_autoencoder = SparseAutoencoder(cfg=cfg) sparse_autoencoder.load_state_dict(group["state_dict"]) group = cls(cfg) @@ -194,10 +197,7 @@ def save_saes(self, path: str): autoencoder.save_model(f"{path}/{i}") def get_name(self): - - sae_name = ( - f"sae_group_{self.cfg.model_name}_{self.cfg.hook_point}_{self.cfg.d_sae}" - ) + sae_name = f"sae_group_{self.cfg.model_name.replace('/', '_')}_{self.cfg.hook_point}_{self.cfg.d_sae}" return sae_name def eval(self): diff --git a/sae_lens/training/session_loader.py b/sae_lens/training/session_loader.py index 2d632551..8bbfb67c 100644 --- a/sae_lens/training/session_loader.py +++ b/sae_lens/training/session_loader.py @@ -1,9 +1,10 @@ from typing import Tuple -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.config import LanguageModelSAERunnerConfig +from sae_lens.training.load_model import load_model from sae_lens.training.sae_group import SparseAutoencoderDictionary @@ -20,7 +21,7 @@ def __init__(self, cfg: LanguageModelSAERunnerConfig): def load_sae_training_group_session( self, - ) -> Tuple[HookedTransformer, SparseAutoencoderDictionary, ActivationsStore]: + ) -> Tuple[HookedRootModule, SparseAutoencoderDictionary, ActivationsStore]: """ Loads a session for training a sparse autoencoder on a language model. """ @@ -39,7 +40,7 @@ def load_sae_training_group_session( @classmethod def load_pretrained_sae( cls, path: str, device: str = "cpu" - ) -> Tuple[HookedTransformer, SparseAutoencoderDictionary, ActivationsStore]: + ) -> Tuple[HookedRootModule, SparseAutoencoderDictionary, ActivationsStore]: """ Loads a session for analysing a pretrained sparse autoencoder. """ @@ -56,7 +57,7 @@ def load_pretrained_sae( return model, sparse_autoencoders, activations_loader - def get_model(self, model_name: str) -> HookedTransformer: + def get_model(self, model_name: str) -> HookedRootModule: """ Loads a model from transformer lens. @@ -65,9 +66,7 @@ def get_model(self, model_name: str) -> HookedTransformer: # Todo: add check that model_name is valid - model = HookedTransformer.from_pretrained( - model_name, - device=self.cfg.device, + model = load_model( + self.cfg.model_class_name, model_name, device=self.cfg.device ) - return model diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 901db9eb..efcd3473 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -7,7 +7,7 @@ from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm -from transformer_lens import HookedTransformer +from transformer_lens.hook_points import HookedRootModule import wandb from sae_lens.training.activations_store import ActivationsStore @@ -53,7 +53,7 @@ class TrainSAEGroupOutput: def train_sae_on_language_model( - model: HookedTransformer, + model: HookedRootModule, sae_group: SparseAutoencoderDictionary, activation_store: ActivationsStore, batch_size: int = 1024, @@ -79,7 +79,7 @@ def train_sae_on_language_model( def train_sae_group_on_language_model( - model: HookedTransformer, + model: HookedRootModule, sae_group: SparseAutoencoderDictionary, activation_store: ActivationsStore, batch_size: int = 1024, diff --git a/tests/unit/training/test_load_model.py b/tests/unit/training/test_load_model.py new file mode 100644 index 00000000..4a42f3eb --- /dev/null +++ b/tests/unit/training/test_load_model.py @@ -0,0 +1,13 @@ +from mamba_lens import HookedMamba + +from sae_lens.training.load_model import load_model + + +def test_load_model_works_with_mamba(): + model = load_model( + model_class_name="HookedMamba", + model_name="state-spaces/mamba-370m", + device="cpu", + ) + assert model is not None + assert isinstance(model, HookedMamba) diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py new file mode 100644 index 00000000..12866bd2 --- /dev/null +++ b/tutorials/mamba_train_example.py @@ -0,0 +1,60 @@ +# install from https://github.com/Phylliida/MambaLens +import os +import sys + +import torch + +sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +# run this as python3 tutorials/mamba_train_example.py +# i.e. from the root directory +from sae_lens.training.config import LanguageModelSAERunnerConfig + +cfg = LanguageModelSAERunnerConfig( + # Data Generating Function (Model + Training Distibuion) + model_name="state-spaces/mamba-370m", + model_class_name="HookedMamba", + hook_point="blocks.39.hook_ssm_input", + hook_point_layer=39, + hook_point_eval="blocks.39.hook_ssm_output", # we compare this when replace hook_point activations with autoencode.decode(autoencoder.encode( hook_point activations)) + d_in=2048, + dataset_path="NeelNanda/openwebtext-tokenized-9b", + is_dataset_tokenized=True, + # SAE Parameters + expansion_factor=64, + b_dec_init_method="geometric_median", + # Training Parameters + lr=0.0004, + l1_coefficient=0.00006 * 0.2, + lr_scheduler_name="cosineannealingwarmrestarts", + train_batch_size=4096, + context_size=128, + lr_warm_up_steps=5000, + # Activation Store Parameters + n_batches_in_buffer=128, + total_training_tokens=1_000_000 * 300, + store_batch_size=32, + # Dead Neurons and Sparsity + use_ghost_grads=True, + feature_sampling_window=1000, + dead_feature_window=5000, + dead_feature_threshold=1e-6, + # WANDB + log_to_wandb=True, + wandb_project="sae_training_mamba", + wandb_entity=None, + wandb_log_frequency=100, + # Misc + device="cuda", + seed=42, + checkpoint_path="checkpoints", + dtype=torch.float32, + model_kwargs={ + "fast_ssm": True, + "fast_conv": True, + }, +) + +from sae_lens.training.lm_runner import language_model_sae_runner + +language_model_sae_runner(cfg) From 8f88d15cd8aa91de6027f8053c04cf875efc65fb Mon Sep 17 00:00:00 2001 From: github-actions Date: Wed, 17 Apr 2024 16:45:43 +0000 Subject: [PATCH 16/30] 0.5.0 Automatically generated by python-semantic-release --- CHANGELOG.md | 65 ++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 2 +- sae_lens/__init__.py | 2 +- 3 files changed, 67 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a5ff3ed..29501cd2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,71 @@ +## v0.5.0 (2024-04-17) + +### Feature + +* feat: Mamba support vs mamba-lens (#79) + +* mamba support + +* added init + +* added optional model kwargs + +* Support transformers and mamba + +* forgot one model kwargs + +* failed opts + +* tokens input + +* hack to fix tokens, will look into fixing mambalens + +* fixed checkpoint + +* added sae group + +* removed some comments and fixed merge error + +* removed unneeded params since that issue is fixed in mambalens now + +* Unneded input param + +* removed debug checkpoing and eval + +* added refs to hookedrootmodule + +* feed linter + +* added example and fixed loading + +* made layer for eval change + +* fix linter issues + +* adding mamba-lens as optional dep, and fixing typing/linting + +* adding a test for loading mamba model + +* adding mamba-lens to dev for CI + +* updating min mamba-lens version + +* updating mamba-lens version + +--------- + +Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://github.com/jbloomAus/SAELens/commit/eea7db4b99098c33cd862e7e2280a32b630826bd)) + +### Unknown + +* update readme ([`440df7b`](https://github.com/jbloomAus/SAELens/commit/440df7b6c0ef55ba3d116054f81e1ee4a58f9089)) + +* update readme ([`3694fd2`](https://github.com/jbloomAus/SAELens/commit/3694fd2c4cc7438121e4549636508c45835a5d38)) + + ## v0.4.0 (2024-04-16) ### Feature diff --git a/pyproject.toml b/pyproject.toml index 4f62f578..1475de8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sae-lens" -version = "0.4.0" +version = "0.5.0" description = "Training and Analyzing Sparse Autoencoders (SAEs)" authors = ["Joseph Bloom"] readme = "README.md" diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py index 86d4e514..e13470b4 100644 --- a/sae_lens/__init__.py +++ b/sae_lens/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.0" +__version__ = "0.5.0" from .training.activations_store import ActivationsStore from .training.cache_activations_runner import cache_activations_runner From e78207d086cb3372dc805cbb4c87b694749cd905 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 16:43:20 -0700 Subject: [PATCH 17/30] Add API for getting Neuronpedia feature --- sae_lens/analysis/neuronpedia_integration.py | 22 ++++++++++++++++++- .../analysis/test_neuronpedia_integration.py | 11 ++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/unit/analysis/test_neuronpedia_integration.py diff --git a/sae_lens/analysis/neuronpedia_integration.py b/sae_lens/analysis/neuronpedia_integration.py index 84f628ad..37d80e95 100644 --- a/sae_lens/analysis/neuronpedia_integration.py +++ b/sae_lens/analysis/neuronpedia_integration.py @@ -1,6 +1,7 @@ import json import urllib.parse import webbrowser +import requests def get_neuronpedia_quick_list( @@ -14,10 +15,29 @@ def get_neuronpedia_quick_list( name = urllib.parse.quote(name) url = url + "?name=" + name list_feature = [ - {"modelId": model, "layer": f"{layer}-{dataset}", "index": str(feature)} + { + "modelId": model, + "layer": f"{layer}-{dataset}", + "index": str(feature), + } for feature in features ] url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature)) webbrowser.open(url) return url + + +def get_neuronpedia_feature( + feature: int, + layer: int, + model: str = "gpt2-small", + dataset: str = "res-jb", +): + url = "https://neuronpedia.org/api/feature/" + url = url + f"{model}/{layer}-{dataset}/{feature}" + + result = requests.get(url).json() + result["index"] = int(result["index"]) + + return result diff --git a/tests/unit/analysis/test_neuronpedia_integration.py b/tests/unit/analysis/test_neuronpedia_integration.py new file mode 100644 index 00000000..6b6c9c28 --- /dev/null +++ b/tests/unit/analysis/test_neuronpedia_integration.py @@ -0,0 +1,11 @@ +from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_feature + + +def test_get_neuronpedia_feature(): + result = get_neuronpedia_feature( + feature=0, layer=0, model="gpt2-small", dataset="res-jb" + ) + + assert result["modelId"] == "gpt2-small" + assert result["layer"] == "0-res-jb" + assert result["index"] == 0 From 138d5d445878c0830c6c96a5fbe6b10a1d9644b0 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 19:11:06 -0700 Subject: [PATCH 18/30] Use correct model name for np runner --- sae_lens/analysis/neuronpedia_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 2033a7aa..99228c4a 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -88,7 +88,7 @@ def __init__( self.model, _, self.activation_store = ( loader.load_sae_training_group_session() ) - self.model_id = self.model.cfg.model_name + self.model_id = self.sparse_autoencoder.cfg.model_name self.layer = self.sparse_autoencoder.cfg.hook_point_layer self.sae_id = sae_id self.sparsity_threshold = sparsity_threshold From 1a7d636c95ef508dde8bd100ab6d9f241b0be977 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 19:55:03 -0700 Subject: [PATCH 19/30] Use original repo for sae_vis --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c4ce144b..38d2e911 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" -sae-vis = { git = "https://github.com/hijohnnylin/sae_vis.git", branch = "allow_disable_buffer" } +sae-vis = { git = "https://github.com/callummcdougall/sae_vis.git", branch = "allow_disable_buffer" } mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" mkdocs-autorefs = "^1.0.1" From 145a407f8d57755301bf56c87efd4e775c59b980 Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Wed, 17 Apr 2024 23:42:16 -0700 Subject: [PATCH 20/30] Fix resuming from batch --- tutorials/neuronpedia/neuronpedia.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py index 380a39fc..6e3e9af9 100755 --- a/tutorials/neuronpedia/neuronpedia.py +++ b/tutorials/neuronpedia/neuronpedia.py @@ -151,7 +151,7 @@ def generate( outputs_dir.mkdir(parents=True, exist_ok=True) # Check if output_dir has any files starting with "batch_" batch_files = list(outputs_dir.glob("batch-*.json")) - if len(batch_files) > 0 and resume_from_batch != 1: + if len(batch_files) > 0 and resume_from_batch == 1: print( f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files." ) @@ -230,7 +230,7 @@ def generate( ) # iterate from 1 to num_batches - for i in range(1, num_batches + 1): + for i in range(resume_from_batch, num_batches + 1): command = [ "python", "make_batch.py", From 040676db6814c1f64171a32344e0bed40528c8f9 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 18 Apr 2024 10:55:49 +0100 Subject: [PATCH 21/30] format --- sae_lens/analysis/dashboard_runner.py | 2 +- sae_lens/analysis/neuronpedia_runner.py | 115 ++++++------------ sae_lens/training/config.py | 1 - sae_lens/training/evals.py | 2 +- sae_lens/training/lm_runner.py | 1 + sae_lens/training/toy_model_runner.py | 2 +- .../training/train_sae_on_language_model.py | 2 +- sae_lens/training/train_sae_on_toy_model.py | 2 +- .../test_train_sae_on_language_model.py | 2 +- tutorials/neuronpedia/make_batch.py | 1 + tutorials/neuronpedia/neuronpedia.py | 32 ++--- 11 files changed, 60 insertions(+), 102 deletions(-) diff --git a/sae_lens/analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py index d77be35d..150b0a52 100644 --- a/sae_lens/analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -11,6 +11,7 @@ import plotly import plotly.express as px import torch +import wandb from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, @@ -24,7 +25,6 @@ from tqdm import tqdm from transformer_lens import HookedTransformer -import wandb from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 99228c4a..401941c4 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -4,25 +4,27 @@ # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json + import numpy as np import torch from matplotlib import colors from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, + FeatureTablesConfig, LogitsHistogramConfig, LogitsTableConfig, - FeatureTablesConfig, SaeVisConfig, SaeVisLayoutConfig, SequencesConfig, ) -from tqdm import tqdm from sae_vis.data_storing_fns import SaeVisData +from tqdm import tqdm from transformer_lens import HookedTransformer + +from sae_lens.toolkit.pretrained_saes import load_sparsity from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader from sae_lens.training.sparse_autoencoder import SparseAutoencoder -from sae_lens.toolkit.pretrained_saes import load_sparsity OUT_OF_RANGE_TOKEN = "<|outofrange|>" @@ -85,9 +87,7 @@ def __init__( self.sae_path, device=self.device ) loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) - self.model, _, self.activation_store = ( - loader.load_sae_training_group_session() - ) + self.model, _, self.activation_store = loader.load_sae_training_group_session() self.model_id = self.sparse_autoencoder.cfg.model_name self.layer = self.sparse_autoencoder.cfg.hook_point_layer self.sae_id = sae_id @@ -165,16 +165,12 @@ def run(self): sparsity = torch.log10(sparsity + 1e-10) sparsity = sparsity.to(self.device) self.target_feature_indexes = ( - (sparsity > self.sparsity_threshold) - .nonzero(as_tuple=True)[0] - .tolist() + (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) - n_subarrays = np.ceil( - len(feature_idx) / self.n_features_at_a_time - ).astype(int) + n_subarrays = np.ceil(len(feature_idx) / self.n_features_at_a_time).astype(int) feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] @@ -189,9 +185,7 @@ def run(self): exit() # write dead into file so we can create them as dead in Neuronpedia - skipped_indexes = set(range(self.n_features)) - set( - self.target_feature_indexes - ) + skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) skipped_indexes_json = json.dumps( { "model_id": self.model_id, @@ -223,9 +217,7 @@ def run(self): for k, v in vocab_dict.items(): modified_key = k for anomaly in HTML_ANOMALIES: - modified_key = modified_key.replace( - anomaly, HTML_ANOMALIES[anomaly] - ) + modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) new_vocab_dict[v] = modified_key vocab_dict = new_vocab_dict @@ -241,16 +233,11 @@ def run(self): if feature_batch_count < self.start_batch: # print(f"Skipping batch - it's after start_batch: {feature_batch_count}") continue - if ( - self.end_batch is not None - and feature_batch_count > self.end_batch - ): + if self.end_batch is not None and feature_batch_count > self.end_batch: # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue - print( - f"========== Running Batch #{feature_batch_count} ==========" - ) + print(f"========== Running Batch #{feature_batch_count} ==========") layout = SaeVisLayoutConfig( columns=[ @@ -286,17 +273,13 @@ def run(self): ) features_outputs = [] - for _, feat_index in enumerate( - feature_data.feature_data_dict.keys() - ): + for _, feat_index in enumerate(feature_data.feature_data_dict.keys()): feature = feature_data.feature_data_dict[feat_index] feature_output = {} feature_output["featureIndex"] = feat_index - top10_logits = self.round_list( - feature.logits_table_data.top_logits - ) + top10_logits = self.round_list(feature.logits_table_data.top_logits) bottom10_logits = self.round_list( feature.logits_table_data.bottom_logits ) @@ -305,41 +288,29 @@ def run(self): feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices ) - feature_output["neuron_alignment_values"] = ( - self.round_list( - feature.feature_tables_data.neuron_alignment_values - ) + feature_output["neuron_alignment_values"] = self.round_list( + feature.feature_tables_data.neuron_alignment_values ) - feature_output["neuron_alignment_l1"] = ( - self.round_list( - feature.feature_tables_data.neuron_alignment_l1 - ) + feature_output["neuron_alignment_l1"] = self.round_list( + feature.feature_tables_data.neuron_alignment_l1 ) feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - feature_output["correlated_neurons_l1"] = ( - self.round_list( - feature.feature_tables_data.correlated_neurons_cossim - ) + feature_output["correlated_neurons_l1"] = self.round_list( + feature.feature_tables_data.correlated_neurons_cossim ) - feature_output["correlated_neurons_pearson"] = ( - self.round_list( - feature.feature_tables_data.correlated_neurons_pearson - ) + feature_output["correlated_neurons_pearson"] = self.round_list( + feature.feature_tables_data.correlated_neurons_pearson ) feature_output["correlated_features_indices"] = ( feature.feature_tables_data.correlated_features_indices ) - feature_output["correlated_features_l1"] = ( - self.round_list( - feature.feature_tables_data.correlated_features_cossim - ) + feature_output["correlated_features_l1"] = self.round_list( + feature.feature_tables_data.correlated_features_cossim ) - feature_output["correlated_features_pearson"] = ( - self.round_list( - feature.feature_tables_data.correlated_features_pearson - ) + feature_output["correlated_features_pearson"] = self.round_list( + feature.feature_tables_data.correlated_features_pearson ) feature_output["neg_str"] = self.to_str_tokens_safe( @@ -353,9 +324,9 @@ def run(self): feature_output["frac_nonzero"] = ( float( - feature.acts_histogram_data.title.split(" = ")[ - 1 - ].split("%")[0] + feature.acts_histogram_data.title.split(" = ")[1].split( + "%" + )[0] ) / 100 if feature.acts_histogram_data.title is not None @@ -363,22 +334,18 @@ def run(self): ) freq_hist_data = feature.acts_histogram_data - freq_bar_values = self.round_list( - freq_hist_data.bar_values - ) - feature_output["freq_hist_data_bar_values"] = ( - freq_bar_values - ) - feature_output["freq_hist_data_bar_heights"] = ( - self.round_list(freq_hist_data.bar_heights) + freq_bar_values = self.round_list(freq_hist_data.bar_values) + feature_output["freq_hist_data_bar_values"] = freq_bar_values + feature_output["freq_hist_data_bar_heights"] = self.round_list( + freq_hist_data.bar_heights ) logits_hist_data = feature.logits_histogram_data - feature_output["logits_hist_data_bar_heights"] = ( - self.round_list(logits_hist_data.bar_heights) + feature_output["logits_hist_data_bar_heights"] = self.round_list( + logits_hist_data.bar_heights ) - feature_output["logits_hist_data_bar_values"] = ( - self.round_list(logits_hist_data.bar_values) + feature_output["logits_hist_data_bar_values"] = self.round_list( + logits_hist_data.bar_values ) feature_output["num_tokens_for_dashboard"] = ( @@ -432,12 +399,8 @@ def run(self): {"pos": posContribs, "neg": negContribs} ) activation["tokens"] = strs - activation["values"] = self.round_list( - sd.feat_acts - ) - activation["maxValue"] = max( - activation["values"] - ) + activation["values"] = self.round_list(sd.feat_acts) + activation["maxValue"] = max(activation["values"]) activation["lossValues"] = self.round_list( sd.loss_contribution ) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 2a91b86f..c28de151 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -2,7 +2,6 @@ from typing import Any, Optional, cast import torch - import wandb DTYPE_MAP = { diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 80aa91dd..9742512d 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -3,9 +3,9 @@ import pandas as pd import torch +import wandb from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index c6a4ce79..4d7ea440 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -1,6 +1,7 @@ from typing import Any, cast import wandb + from sae_lens.training.config import LanguageModelSAERunnerConfig from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/training/toy_model_runner.py b/sae_lens/training/toy_model_runner.py index 29b436a6..83044c3c 100644 --- a/sae_lens/training/toy_model_runner.py +++ b/sae_lens/training/toy_model_runner.py @@ -3,8 +3,8 @@ import einops import torch - import wandb + from sae_lens.training.sparse_autoencoder import SparseAutoencoder from sae_lens.training.toy_models import Config as ToyConfig from sae_lens.training.toy_models import Model as ToyModel diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index efcd3473..9cf9e9a6 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -3,13 +3,13 @@ from typing import Any, cast import torch +import wandb from safetensors.torch import save_file from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.evals import run_evals from sae_lens.training.geometric_median import compute_geometric_median diff --git a/sae_lens/training/train_sae_on_toy_model.py b/sae_lens/training/train_sae_on_toy_model.py index e5227da9..61236b7a 100644 --- a/sae_lens/training/train_sae_on_toy_model.py +++ b/sae_lens/training/train_sae_on_toy_model.py @@ -1,10 +1,10 @@ from typing import Any, cast import torch +import wandb from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..e0c8b9e6 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -5,11 +5,11 @@ import pytest import torch +import wandb from datasets import Dataset from torch import Tensor from transformer_lens import HookedTransformer -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import get_scheduler from sae_lens.training.sae_group import SparseAutoencoderDictionary diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py index b5f28e82..945a2939 100644 --- a/tutorials/neuronpedia/make_batch.py +++ b/tutorials/neuronpedia/make_batch.py @@ -1,4 +1,5 @@ import sys + from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner # we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py index 6e3e9af9..aa36a6b6 100755 --- a/tutorials/neuronpedia/neuronpedia.py +++ b/tutorials/neuronpedia/neuronpedia.py @@ -2,21 +2,23 @@ # better fix is to investigate and fix the memory issues import json -import os -import requests -import typer -import torch import math +import os import subprocess -from typing import Any from decimal import Decimal from pathlib import Path -from typing_extensions import Annotated +from typing import Any + +import requests +import torch +import typer from rich import print from rich.align import Align from rich.panel import Panel -from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from typing_extensions import Annotated + from sae_lens.toolkit.pretrained_saes import load_sparsity +from sae_lens.training.sparse_autoencoder import SparseAutoencoder OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") @@ -120,9 +122,7 @@ def generate( print("Error: cfg.json file not found in SAE directory.") raise typer.Abort() if sae_path.joinpath("sae_weights.safetensors").is_file() is not True: - print( - "Error: sae_weights.safetensors file not found in SAE directory." - ) + print("Error: sae_weights.safetensors file not found in SAE directory.") raise typer.Abort() if sae_path.joinpath("sparsity.safetensors").is_file() is not True: print("Error: sparsity.safetensors file not found in SAE directory.") @@ -144,9 +144,7 @@ def generate( outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) if outputs_dir.exists() and outputs_dir.is_file(): - print( - f"Error: Output directory {outputs_dir.as_posix()} exists and is a file." - ) + print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.") raise typer.Abort() outputs_dir.mkdir(parents=True, exist_ok=True) # Check if output_dir has any files starting with "batch_" @@ -163,9 +161,7 @@ def generate( if len(sparsity) > 0 and sparsity[0] >= 0: sparsity = torch.log10(sparsity + 1e-10) sparsity = sparsity.to(device) - alive_indexes = ( - (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() - ) + alive_indexes = (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() num_alive = len(alive_indexes) num_dead = sparse_autoencoder.d_sae - num_alive @@ -394,9 +390,7 @@ def nanToNeg999(obj: Any) -> Any: return {k: nanToNeg999(v) for k, v in obj.items()} elif isinstance(obj, list): return [nanToNeg999(v) for v in obj] - elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan( - obj - ): + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): return -999 return obj From 11a71e1b95576ef6dc3dbec7eb1c76ce7ca44dfd Mon Sep 17 00:00:00 2001 From: Joseph Bloom Date: Tue, 16 Apr 2024 20:18:32 +0000 Subject: [PATCH 22/30] get decoder fine tuning working --- sae_lens/training/activations_store.py | 2 +- sae_lens/training/cache_activations_runner.py | 4 +- sae_lens/training/config.py | 30 +- sae_lens/training/sae_group.py | 4 + sae_lens/training/sparse_autoencoder.py | 16 +- .../training/train_sae_on_language_model.py | 44 +- scripts/run.ipynb | 1011 +++++++---------- .../test_language_model_sae_runner.py | 2 +- tests/unit/helpers.py | 2 +- .../test_train_sae_on_language_model.py | 5 +- tutorials/training_a_sparse_autoencoder.ipynb | 2 +- 11 files changed, 488 insertions(+), 634 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 604f7ba4..f0efa3ff 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -59,7 +59,7 @@ def from_config( context_size=cfg.context_size, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, - total_training_tokens=cfg.total_training_tokens, + total_training_tokens=cfg.training_tokens, store_batch_size=cfg.store_batch_size, train_batch_size=cfg.train_batch_size, prepend_bos=cfg.prepend_bos, diff --git a/sae_lens/training/cache_activations_runner.py b/sae_lens/training/cache_activations_runner.py index 443b399e..db9b470f 100644 --- a/sae_lens/training/cache_activations_runner.py +++ b/sae_lens/training/cache_activations_runner.py @@ -31,11 +31,11 @@ def cache_activations_runner(cfg: CacheActivationsRunnerConfig): else: os.makedirs(activations_store.cached_activations_path) - print(f"Started caching {cfg.total_training_tokens} activations") + print(f"Started caching {cfg.training_tokens} activations") tokens_per_buffer = ( cfg.store_batch_size * cfg.context_size * cfg.n_batches_in_buffer ) - n_buffers = math.ceil(cfg.total_training_tokens / tokens_per_buffer) + n_buffers = math.ceil(cfg.training_tokens / tokens_per_buffer) # for i in tqdm(range(n_buffers), desc="Caching activations"): for i in range(n_buffers): buffer = activations_store.get_buffer(cfg.n_batches_in_buffer) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 2a91b86f..22857e90 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -45,7 +45,8 @@ class LanguageModelSAERunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 + finetuning_tokens: int = 0 store_batch_size: int = 32 train_batch_size: int = 4096 @@ -56,11 +57,20 @@ class LanguageModelSAERunnerConfig: prepend_bos: bool = True # Training Parameters + + ## Batch size + train_batch_size: int = 4096 + + ## Adam adam_beta1: float | list[float] = 0 adam_beta2: float | list[float] = 0.999 + + ## Loss Function mse_loss_normalization: Optional[str] = None l1_coefficient: float | list[float] = 1e-3 lp_norm: float | list[float] = 1 + + ## Learning Rate Schedule lr: float | list[float] = 3e-4 lr_scheduler_name: str | list[str] = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts @@ -71,7 +81,9 @@ class LanguageModelSAERunnerConfig: ) lr_decay_steps: int | list[int] = 0 n_restart_cycles: int | list[int] = 1 # used only for cosineannealingwarmrestarts - train_batch_size: int = 4096 + + ## FineTuning + finetuning_method: Optional[str] = None # scale, decoder or unrotated_decoder # Resampling protocol args use_ghost_grads: bool | list[bool] = ( @@ -111,7 +123,7 @@ def __post_init__(self): ) if self.run_name is None: - self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: raise ValueError( @@ -129,6 +141,12 @@ def __post_init__(self): elif isinstance(self.dtype, str): self.dtype: torch.dtype = DTYPE_MAP[self.dtype] + # if we use decoder fine tuning, we can't be applying b_dec to the input + if (self.finetuning_method == "decoder") and (self.apply_b_dec_to_input): + raise ValueError( + "If we are fine tuning the decoder, we can't be applying b_dec to the input.\nSet apply_b_dec_to_input to False." + ) + self.device: str | torch.device = torch.device(self.device) if self.lr_end is None: @@ -144,7 +162,7 @@ def __post_init__(self): if self.verbose: print( - f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + f"Run name: {self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.training_tokens:3.3e}" ) # Print out some useful info: n_tokens_per_buffer = ( @@ -156,7 +174,7 @@ def __post_init__(self): f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" ) - total_training_steps = self.total_training_tokens // self.train_batch_size + total_training_steps = self.training_tokens // self.train_batch_size print(f"Total training steps: {total_training_steps}") total_wandb_updates = total_training_steps // self.wandb_log_frequency @@ -209,7 +227,7 @@ class CacheActivationsRunnerConfig: # Activation Store Parameters n_batches_in_buffer: int = 20 - total_training_tokens: int = 2_000_000 + training_tokens: int = 2_000_000 store_batch_size: int = 32 train_batch_size: int = 4096 diff --git a/sae_lens/training/sae_group.py b/sae_lens/training/sae_group.py index b7ae75f9..15b651ba 100644 --- a/sae_lens/training/sae_group.py +++ b/sae_lens/training/sae_group.py @@ -133,6 +133,10 @@ def load_from_pretrained_legacy(cls, path: str) -> "SparseAutoencoderDictionary" if not hasattr(cfg, "model_kwargs"): cfg.model_kwargs = {} sparse_autoencoder = SparseAutoencoder(cfg=cfg) + # add dummy scaling factor to the state dict + group["state_dict"]["scaling_factor"] = torch.ones( + cfg.d_sae, dtype=cfg.dtype, device=cfg.device + ) sparse_autoencoder.load_state_dict(group["state_dict"]) group = cls(cfg) for key in group.autoencoders: diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index 33d89d2a..9272a4d9 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -98,6 +98,11 @@ def __init__( torch.zeros(self.d_in, dtype=self.dtype, device=self.device) ) + # scaling factor for fine-tuning (not to be used in initial training) + self.scaling_factor = nn.Parameter( + torch.ones(self.d_sae, dtype=self.dtype, device=self.device) + ) + self.hook_sae_in = HookPoint() self.hook_hidden_pre = HookPoint() self.hook_hidden_post = HookPoint() @@ -124,7 +129,8 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) sae_out = self.hook_sae_out( einops.einsum( - feature_acts, + feature_acts + * self.scaling_factor, # need to make sure this handled when loading old models. self.W_dec, "... d_sae, d_sae d_in -> ... d_in", ) @@ -330,6 +336,14 @@ def load_from_pretrained(cls, path: str, device: str = "cpu"): with safe_open(weight_path, framework="pt", device=device) as f: # type: ignore for k in f.keys(): tensors[k] = f.get_tensor(k) + + # old saves may not have scaling factors. + if "scaling_factor" not in tensors: + assert isinstance(config.d_sae, int) + tensors["scaling_factor"] = torch.ones( + config.d_sae, dtype=config.dtype, device=config.device + ) + sae.load_state_dict(tensors) return sae diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index efcd3473..59285289 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -17,6 +17,13 @@ from sae_lens.training.sae_group import SparseAutoencoderDictionary from sae_lens.training.sparse_autoencoder import SparseAutoencoder +# used to map between parameters which are updated during finetuning and the config str. +FINETUNING_PARAMETERS = { + "scale": ["scaling_factor"], + "decoder": ["scaling_factor", "W_dec", "b_dec"], + "unrotated_decoder": ["scaling_factor", "b_dec"], +} + def _log_feature_sparsity( feature_sparsity: torch.Tensor, eps: float = 1e-10 @@ -35,6 +42,7 @@ class SAETrainContext: n_frac_active_tokens: int optimizer: Optimizer scheduler: LRScheduler + finetuning: bool = False @property def feature_sparsity(self) -> torch.Tensor: @@ -44,6 +52,21 @@ def feature_sparsity(self) -> torch.Tensor: def log_feature_sparsity(self) -> torch.Tensor: return _log_feature_sparsity(self.feature_sparsity) + def begin_finetuning(self, sae: SparseAutoencoder): + + # finetuning method should be set in the config + # if not, then we don't finetune + if not isinstance(sae.cfg.finetuning_method, str): + return + + for name, param in sae.named_parameters(): + if name in FINETUNING_PARAMETERS[sae.cfg.finetuning_method]: + param.requires_grad = True + else: + param.requires_grad = False + + self.finetuning = True + @dataclass class TrainSAEGroupOutput: @@ -88,10 +111,13 @@ def train_sae_group_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, ) -> TrainSAEGroupOutput: - total_training_tokens = sae_group.cfg.total_training_tokens + total_training_tokens = ( + sae_group.cfg.training_tokens + sae_group.cfg.finetuning_tokens + ) total_training_steps = total_training_tokens // batch_size n_training_steps = 0 n_training_tokens = 0 + started_fine_tuning = False checkpoint_thresholds = [] if n_checkpoints > 0: @@ -180,6 +206,16 @@ def train_sae_group_on_language_model( ) pbar.update(batch_size) + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + if (not started_fine_tuning) and ( + n_training_tokens > sae_group.cfg.training_tokens + ): + started_fine_tuning = True + for name, sparse_autoencoder in sae_group.autoencoders.items(): + ctx = train_contexts[name] + # this should turn grads on for the scaling factor and other parameters. + ctx.begin_finetuning(sae_group.autoencoders[name]) + # save final sae group to checkpoints folder final_checkpoint = _save_checkpoint( sae_group, @@ -248,6 +284,12 @@ def _build_train_context( ) n_frac_active_tokens = 0 + # we don't train the scaling factor (initially) + # set requires grad to false for the scaling factor + for name, param in sae.named_parameters(): + if "scaling_factor" in name: + param.requires_grad = False + optimizer = Adam( sae.parameters(), lr=sae.cfg.lr, diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 8e6cc3a4..824b431c 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -24,14 +24,14 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Using device: mps\n" + "Using device: cuda\n" ] } ], @@ -60,273 +60,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Gelu-2L\n", - "\n", - "An example of a toy language model we're able to train on." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### MLP Out" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gelu-2l\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"NeelNanda/c4-tokenized-2b\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=[16, 32, 64],\n", - " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " l1_coefficient=0.00016,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 100,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=5000,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-4,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_gelu_2l_test\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## GPT2 - Small" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Residual Stream" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "layer = 3\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=f\"blocks.{layer}.hook_resid_pre\",\n", - " hook_point_layer=layer,\n", - " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=32, # determines the dimension of the SAE.\n", - " b_dec_init_method=\"mean\", # geometric median is better but slower to get started\n", - " # Training Parameters\n", - " lr=0.0004,\n", - " l1_coefficient=0.00008,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " lr_warm_up_steps=5000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 300, # 200M tokens seems doable overnight.\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " use_ghost_grads=True,\n", - " feature_sampling_window=2500,\n", - " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_models_resid_pre_test\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=10,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70-M" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "import cProfile\n", - "\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.0.hook_mlp_out\",\n", - " hook_point_layer=0,\n", - " d_in=512,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=64,\n", - " # Training Parameters\n", - " lr=3e-4,\n", - " l1_coefficient=4e-5,\n", - " train_batch_size=8192,\n", - " context_size=128,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=10_000,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=1_000_000 * 800,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_window=2000, # Doesn't currently matter.\n", - " dead_feature_window=40000,\n", - " dead_feature_threshold=1e-8,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=20,\n", - " # Misc\n", - " device=\"cuda\",\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Pythia 70M Hook Q" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"../\")\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"pythia-70m-deduped\",\n", - " hook_point=\"blocks.2.attn.hook_q\",\n", - " hook_point_layer=2,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"EleutherAI/the_pile_deduplicated\",\n", - " is_dataset_tokenized=False,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=0.0012,\n", - " l1_coefficient=0.003,\n", - " lr_scheduler_name=\"constantwithwarmup\",\n", - " lr_warm_up_steps=1000, # about 4 million tokens.\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 1500,\n", - " store_batch_size=32,\n", - " # Resampling protocol\n", - " feature_sampling_method=\"anthropic\",\n", - " feature_sampling_window=1000, # doesn't do anything currently.\n", - " feature_reinit_scale=0.2,\n", - " resample_batches=8,\n", - " dead_feature_window=60000,\n", - " dead_feature_threshold=1e-5,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_pythia_70M_hook_q_L2H7\",\n", - " wandb_entity=None,\n", - " wandb_log_frequency=100,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " n_checkpoints=15,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Tiny Stories" + "# Tiny Stories - 1L" ] }, { @@ -338,49 +72,222 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Loaded pretrained model tiny-stories-1L-21M into HookedTransformer\n", + "Moving model to device: cuda\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n", + "Run name: 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07\n", + "n_tokens_per_buffer (millions): 0.262144\n", + "Lower bound: n_contexts_per_buffer (millions): 0.002048\n", + "Total training steps: 6103\n", + "Total wandb updates: 610\n", + "n_tokens_per_feature_sampling_window (millions): 524.288\n", + "n_tokens_per_dead_feature_window (millions): 524.288\n", + "We will reset the sparsity calculation 6 times.\n", + "Number tokens in sparsity calculation window: 4.10e+06\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_135218-opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/sae_lens_tutorial" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Objective value: 1781464.6250: 4%|▍ | 4/100 [00:00<00:00, 206.25it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "135| MSE Loss 0.257 | L1 1.354: 1%| | 552960/50000000 [00:13<19:08, 43042.90it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (v0kr8hz9) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "6104| MSE Loss 0.072 | L1 0.024: : 25001984it [18:07, 22981.57it/s]\n", + "12208| MSE Loss 0.070 | L1 0.024: 100%|█████████▉| 49999872/50000000 [20:15<00:00, 30551.50it/s]" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bb94759b99e14133aece0058a423e305", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='128.448 MB of 128.448 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▄▅▆▆▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▅▂▅▅█▆▆▇▅▅▆▅▄▅▅▄▅▃▄▄▄▄▃▆▆▄▁▄▆▃▆▃▅▆▂▃▆▄▃▅
metrics/ce_loss_with_sae█▅▅▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▆▂▂▁▇▄▇▆▃▃▅▅▆▇▃▆▃▄▄▄█▂▂▄▄▃▁▅▄▄▂▅▄█▃▄▄▄▅▆
metrics/explained_variance▁▄▄▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/explained_variance_std▆▁▄▇███▇▇▆▆▆▆▆▆▆▆▅▅▅▅▅▄▅▄▄▅▄▄▄▄▄▄▄▄▄▄▄▄▄
metrics/l0█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm█▁▂▃▄▄▃▅▅▄▅▅▄▅▄▅▅▅▅▄▆▅▆▅▆▆▅▆▆▇█▇▇▆▆▇▆▆▇▇
metrics/l2_ratio█▁▂▃▃▃▃▄▄▃▄▄▄▄▄▄▄▄▄▄▅▅▅▅▆▆▅▆▆▆▆▆▆▅▆▆▅▆▆▆
metrics/mean_log10_feature_sparsity█▅▃▃▂▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▁▁▂▆███▅██
sparsity/below_1e-6▁▁▁▁▁▁█▇█▂▆▆
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██▅▅▁▁▁▁▅▁▁▁▁▁▅▅▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▂▂▂▃▂▄▄▅▇▇▇▇██▆▇▆███▆▅▅█▇▇▇▇██

Run summary:


details/current_learning_rate0.0008
details/n_training_tokens49971200
losses/ghost_grad_loss0.0
losses/l1_loss15.59199
losses/mse_loss0.07019
losses/overall_loss0.09358
metrics/CE_loss_score0.86351
metrics/ce_loss_with_ablation8.5168
metrics/ce_loss_with_sae3.00156
metrics/ce_loss_without_sae2.12988
metrics/explained_variance0.56934
metrics/explained_variance_std0.14386
metrics/l019.32129
metrics/l2_norm15.93428
metrics/l2_ratio0.86545
metrics/mean_log10_feature_sparsity-4.81775
sparsity/below_1e-56329
sparsity/below_1e-681
sparsity/dead_features0
sparsity/mean_passes_since_fired29.02307

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run 16384-L1-0.0015-LR-0.0008-Tokens-2.500e+07 at: https://wandb.ai/jbloom/sae_lens_tutorial/runs/opqs9dgl
View project at: https://wandb.ai/jbloom/sae_lens_tutorial
Synced 7 W&B file(s), 0 media file(s), 3 artifact file(s) and 1 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: ./wandb/run-20240416_135218-opqs9dgl/logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "12208| MSE Loss 0.070 | L1 0.024: : 50003968it [20:27, 30551.50it/s] /home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (opqs9dgl) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n" + ] + } + ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.mlp.hook_post\",\n", - " hook_point_layer=1,\n", - " d_in=256,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", + " model_name=\"tiny-stories-1L-21M\", # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)\n", + " hook_point=\"blocks.0.hook_mlp_out\", # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)\n", + " hook_point_layer=0, # Only one layer in the model.\n", + " d_in=1024, # the width of the mlp output.\n", + " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", " is_dataset_tokenized=True,\n", + " \n", " # SAE Parameters\n", - " expansion_factor=16,\n", + " mse_loss_normalization=None, # We won't normalize the mse loss,\n", + " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", + " b_dec_init_method=\"geometric_median\", # The geometric median can be used to initialize the decoder weights.\n", + " apply_b_dec_to_input=False, # We won't apply the decoder to the input.\n", + " \n", " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", + " lr=0.0008, # lower the better, we'll go fairly high to speed up the tutorial.\n", + " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", + " lr_warm_up_steps=10000, # this can help avoid too many dead features initially.\n", + " l1_coefficient=0.0015, # will control how sparse the feature activations are\n", + " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=128, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", + " \n", " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", + " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", + " training_tokens=1_000_000 * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 25,\n", " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", + " \n", + " \n", + " # Resampling protocol\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", + " dead_feature_window=1000, # would effect resampling or ghost grads if we were using it.\n", + " dead_feature_threshold=1e-4, # would effect resampling or ghost grads if we were using it.\n", " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", + " log_to_wandb=True, # always use wandb unless you are just testing code.\n", + " wandb_project=\"sae_lens_tutorial\",\n", " wandb_log_frequency=10,\n", " # Misc\n", " device=device,\n", @@ -390,82 +297,145 @@ " dtype=torch.float32,\n", ")\n", "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" + "# look at the next cell to see some instruction for what to do while this is running.\n", + "sparse_autoencoder_dictionary = language_model_sae_runner(cfg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# GPT2 - Small" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Hook Z\n", - "\n" + "### Residual Stream" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Run name: 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07\n", - "n_tokens_per_buffer (millions): 0.524288\n", + "Run name: 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08\n", + "n_tokens_per_buffer (millions): 1.048576\n", "Lower bound: n_contexts_per_buffer (millions): 0.004096\n", - "Total training steps: 4882\n", + "Total training steps: 48828\n", "Total wandb updates: 488\n", - "n_tokens_per_feature_sampling_window (millions): 262.144\n", - "n_tokens_per_dead_feature_window (millions): 131.072\n", - "We will reset the sparsity calculation 9 times.\n", - "Number tokens in sparsity calculation window: 2.05e+06\n", - "Loaded pretrained model tiny-stories-1M into HookedTransformer\n", - "Moving model to device: mps\n" + "n_tokens_per_feature_sampling_window (millions): 2621.44\n", + "n_tokens_per_dead_feature_window (millions): 5242.88\n", + "We will reset the sparsity calculation 19 times.\n", + "Number tokens in sparsity calculation window: 1.02e+07\n", + "Loaded pretrained model gpt2-small into HookedTransformer\n", + "Moving model to device: cuda\n" ] }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fee8922d83f04003a2f1441eeb30200d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Resolving data files: 0%| | 0/73 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ea686292dff7449a9846fcfa29d6ff74", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.064 MB of 0.064 MB uploaded\\r'), FloatProgress(value=1.0, max=1.0)))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


details/current_learning_rate▁▁▂▂▃▄▄▅▅▆▆▇▇███████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▅▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇▇█████████
metrics/ce_loss_with_ablation▄▂▁▄▄▂▄▄▁▄▁▄▂█
metrics/ce_loss_with_sae█▄▂▂▂▂▁▂▁▁▁▁▁▁
metrics/ce_loss_without_sae▂▇▄█▅▃▄▇▆▅▄▆▁▁
metrics/explained_variance▁▃▄▆▆▇▇▇▇▇▇▇████████████████████████████
metrics/explained_variance_std█▄█▇▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▅▆▇▇▇█▇██████
metrics/l2_ratio▁▅▆▇▇▇▇▇██████
metrics/mean_log10_feature_sparsity█▃▂▁▁
sparsity/below_1e-5▁▁▅██
sparsity/below_1e-6▁▁▁▄█
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▅▆█
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▅▅▅▆▆▇▇█

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens59801600
losses/ghost_grad_loss0.0
losses/l1_loss160.66861
losses/mse_loss1.68098
losses/overall_loss2.96633
metrics/CE_loss_score0.96258
metrics/ce_loss_with_ablation11.49633
metrics/ce_loss_with_sae3.62324
metrics/ce_loss_without_sae3.3166
metrics/explained_variance0.78709
metrics/explained_variance_std0.05978
metrics/l050.03076
metrics/l2_norm102.32782
metrics/l2_ratio0.8864
metrics/mean_log10_feature_sparsity-5.31744
sparsity/below_1e-519194
sparsity/below_1e-611736
sparsity/dead_features60
sparsity/mean_passes_since_fired640.44727

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { "data": { "text/html": [ - "wandb version 0.16.5 is available! To upgrade, please run:\n", - " $ pip install wandb --upgrade" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/pq5q3x9s
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 0 artifact file(s) and 1 other file(s)" ], "text/plain": [ "" @@ -477,7 +447,7 @@ { "data": { "text/html": [ - "Tracking run with wandb version 0.16.3" + "Find logs at: ./wandb/run-20240416_155117-pq5q3x9s/logs" ], "text/plain": [ "" @@ -489,7 +459,7 @@ { "data": { "text/html": [ - "Run data is saved locally in /Users/josephbloom/GithubRepositories/mats_sae_training/scripts/wandb/run-20240326_191703-ec6k6v87" + "Successfully finished last run (ID:pq5q3x9s). Initializing new run:
" ], "text/plain": [ "" @@ -498,10 +468,24 @@ "metadata": {}, "output_type": "display_data" }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fd02fd0295cc4afda9bb0e1367c87f84", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='Waiting for wandb.init()...\\r'), FloatProgress(value=0.011112805799995032, max=1.0…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, { "data": { "text/html": [ - "Syncing run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 to Weights & Biases (docs)
" + "Tracking run with wandb version 0.16.6" ], "text/plain": [ "" @@ -513,7 +497,7 @@ { "data": { "text/html": [ - " View project at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests" + "Run data is saved locally in /home/paperspace/mats_sae_training/scripts/wandb/run-20240416_165827-vbwoyzi8" ], "text/plain": [ "" @@ -525,7 +509,7 @@ { "data": { "text/html": [ - " View run at https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87" + "Syncing run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 to Weights & Biases (docs)
" ], "text/plain": [ "" @@ -535,91 +519,56 @@ "output_type": "display_data" }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "Objective value: 116883.7422: 10%|█ | 10/100 [00:00<00:00, 128.72it/s]\n", - "/Users/josephbloom/GithubRepositories/mats_sae_training/sae_training/sparse_autoencoder.py:161: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", - " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", - "100%|██████████| 10/10 [00:02<00:00, 4.93it/s] 405504/20000000 [00:14<08:53, 36739.57it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 811008/20000000 [00:31<18:45, 17042.47it/s] \n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 1224704/20000000 [00:47<10:43, 29194.89it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s]| 1634304/20000000 [01:05<08:10, 37468.33it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.64it/s]| 2039808/20000000 [01:20<07:36, 39322.02it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.08it/s]| 2453504/20000000 [01:37<07:55, 36873.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s]| 2863104/20000000 [01:52<07:16, 39292.24it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]| 3272704/20000000 [02:09<06:52, 40537.06it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 3678208/20000000 [02:26<27:40, 9829.56it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s]| 4087808/20000000 [02:41<06:11, 42798.13it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.50it/s] | 4497408/20000000 [03:01<08:53, 29055.95it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 4911104/20000000 [03:16<06:55, 36330.89it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.57it/s] | 5316608/20000000 [03:34<06:31, 37461.30it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.87it/s] | 5726208/20000000 [03:50<05:45, 41309.20it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.51it/s] | 6139904/20000000 [04:07<06:03, 38122.10it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.90it/s] | 6549504/20000000 [04:24<05:43, 39198.19it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.91it/s] | 6955008/20000000 [04:43<05:01, 43328.38it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.84it/s] | 7368704/20000000 [05:00<12:14, 17200.22it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 7778304/20000000 [05:14<04:44, 43005.09it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.78it/s] | 8183808/20000000 [05:32<06:31, 30153.11it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.80it/s] | 8597504/20000000 [05:47<04:22, 43375.86it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 5.00it/s] | 9007104/20000000 [06:09<05:16, 34784.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.55it/s] | 9416704/20000000 [06:24<04:36, 38252.78it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.75it/s] | 9822208/20000000 [06:42<03:58, 42593.01it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.99it/s] | 10235904/20000000 [06:59<19:05, 8524.91it/s] \n", - "100%|██████████| 10/10 [00:02<00:00, 4.98it/s] | 10645504/20000000 [07:14<03:30, 44384.65it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 11055104/20000000 [07:31<05:24, 27562.66it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.83it/s] | 11464704/20000000 [07:45<03:26, 41316.56it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.81it/s] | 11870208/20000000 [08:02<03:44, 36217.25it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.89it/s] | 12279808/20000000 [08:16<02:52, 44715.52it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.85it/s] | 12693504/20000000 [08:34<03:02, 40061.41it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.02it/s] | 13103104/20000000 [08:48<02:38, 43563.35it/s]\n", - "100%|██████████| 10/10 [00:04<00:00, 2.17it/s] | 13508608/20000000 [09:05<02:34, 41937.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 13922304/20000000 [09:24<05:07, 19779.09it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 14327808/20000000 [09:38<02:05, 45367.15it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 14741504/20000000 [09:54<02:49, 30943.53it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.05it/s] | 15147008/20000000 [10:08<01:46, 45610.98it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s] | 15556608/20000000 [10:24<01:49, 40440.85it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.03it/s] | 15966208/20000000 [10:38<01:29, 45251.75it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16379904/20000000 [10:55<01:22, 43941.70it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.09it/s] | 16789504/20000000 [11:11<04:30, 11859.26it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.04it/s] | 17195008/20000000 [11:25<01:02, 44607.68it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s] | 17608704/20000000 [11:41<01:38, 24188.35it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.00it/s] | 18018304/20000000 [11:54<00:42, 46425.69it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.06it/s]▏| 18423808/20000000 [12:13<00:44, 35420.18it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.97it/s]▍| 18837504/20000000 [12:27<00:26, 43914.73it/s]\n", - "100%|██████████| 10/10 [00:01<00:00, 5.01it/s]▌| 19243008/20000000 [12:45<00:19, 38931.67it/s]\n", - "100%|██████████| 10/10 [00:02<00:00, 4.95it/s]▊| 19656704/20000000 [12:59<00:07, 43804.93it/s]\n", - "4883| MSE Loss 0.000 | L1 0.000: 100%|█████████▉| 19996672/20000000 [13:14<00:00, 37714.53it/s]" - ] + "data": { + "text/html": [ + " View project at https://wandb.ai/jbloom/gpt2_small_experiments_april" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "name": "stdout", + "data": { + "text/html": [ + " View run at https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", "output_type": "stream", "text": [ - "Saved model to checkpoints/sf7u2imk/final_sae_group_tiny-stories-1M_blocks.1.attn.hook_z_1024.pt\n" + "Objective value: 46608928.0000: 2%|▏ | 2/100 [00:00<00:01, 55.75it/s]\n", + "/home/paperspace/mats_sae_training/sae_lens/training/sparse_autoencoder.py:176: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", + " out = torch.tensor(origin, dtype=self.dtype, device=self.device)\n", + "120| MSE Loss 31.151 | L1 65.750: 0%| | 487424/300000000 [00:15<1:28:10, 56617.16it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (4elmsny3) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + " lambda data: self._console_raw_callback(\"stderr\", data),\n", + "2407| MSE Loss 0.070 | L1 0.027: 20%|█▉ | 9859072/50000000 [3:33:05<14:27:36, 771.10it/s]\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:02, 54947.70it/s] " ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "1dffd84a387d4cf48100fbe143287481", + "model_id": "6111ba99afb144ae82bab7723efb2c86", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "VBox(children=(Label(value='0.053 MB of 0.569 MB uploaded\\r'), FloatProgress(value=0.0935266880101429, max=1.0…" + "VBox(children=(Label(value='721.959 MB of 721.959 MB uploaded (0.005 MB deduped)\\r'), FloatProgress(value=1.0,…" ] }, "metadata": {}, "output_type": "display_data" }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "wandb: WARNING Source type is set to 'repo' but some required information is missing from the environment. A job will not be created from this run. See https://docs.wandb.ai/guides/launch/create-job\n" - ] - }, { "data": { "text/html": [ @@ -628,7 +577,7 @@ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n", " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n", " \n", - "

Run history:


details/current_learning_rate▁▃▅▆████████████████████████████████████
details/n_training_tokens▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss██▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▄▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▄▅▆▆▇▇▇▇▇▇▇▇███████████████████████████
metrics/ce_loss_with_ablation▂▃▂▅▃▆▅▃▄▆▇▆▅▇▅▄▇▅▁▆▄▅▆▄█▄▅▆▄▅▅▃▂▄▄▅▅█▆▆
metrics/ce_loss_with_sae█▅▄▃▃▃▂▂▂▂▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▁▂▁▂▁▂▂▂
metrics/ce_loss_without_sae▄▄▁▃▄▆▅▃▆▅█▆▅▆▅▄▅▆▁▇▆▅▆▃█▆▆▆▄▇▆▃▃▆▃▆▄█▇▅
metrics/explained_variance▁▅▇▇▇███████████████████████████████████
metrics/explained_variance_std██▆▄▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0██▇▆▅▅▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▆▆▇▇▇▆▆▆▇█▇▇▆▇▇▆▇▇▇▆▇████▇▇▇▇▇▇▇▇▇█▇▇▇
metrics/l2_ratio▁▃▁▂▄▃▂▄▆▅▅▅▅▆▅▆▇▆▆▇▇▆▆▆▇▆▆▇▆▇▆▇▇▇█▆▆▇▇▇
metrics/mean_log10_feature_sparsity█▇▅▄▃▃▂▁▁
sparsity/below_1e-5▁▁▁▁▁▁▁▁▁
sparsity/below_1e-6▁▁▁▁▁▁▁▁▁
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▁▂▂▂▄▇▄▃▅▆▄▇██

Run summary:


details/current_learning_rate0.0001
details/n_training_tokens19988480
losses/ghost_grad_loss0.0
losses/l1_loss1.41017
losses/mse_loss8e-05
losses/overall_loss0.00036
metrics/CE_loss_score0.98362
metrics/ce_loss_with_ablation5.49512
metrics/ce_loss_with_sae2.71813
metrics/ce_loss_without_sae2.67199
metrics/explained_variance0.98647
metrics/explained_variance_std0.00905
metrics/l0166.02246
metrics/l2_norm1.39317
metrics/l2_ratio0.99823
metrics/mean_log10_feature_sparsity-1.53525
sparsity/below_1e-50
sparsity/below_1e-60
sparsity/dead_features0
sparsity/mean_passes_since_fired0.02051

" + "

Run history:


details/current_learning_rate▁▄▆█████████████████████████████████████
details/n_training_tokens▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
losses/ghost_grad_loss▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/l1_loss█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score▁▆▇▇████████████████████████████████████
metrics/ce_loss_with_ablation▅▃▄▃▅▄▄▃▄▄▂▂▂▃▄▃▂▆▃▆▃▆▃▆▅▁▁▅▂▃▃▄▃▅▄▃█▄▃▆
metrics/ce_loss_with_sae█▄▂▂▂▂▂▁▂▁▁▂▂▁▁▁▂▁▁▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/ce_loss_without_sae▁▅▆▂▆▄▅▁▄▄▄▇▆▄▃▄▇▄▂▆▆█▄▅▅▄▄▄▅▄▅▅▃▅▄▅▂▃▂▄
metrics/explained_variance▁▆▇▇▇▇▇█████████████████████████████████
metrics/explained_variance_std█▄▃▂▂▂▂▁▂▁▂▁▁▂▂▁▂▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l2_norm▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▇▆▇▆▆▇▆▆▆▇▆█████████████
metrics/l2_ratio▁▄▅▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆█████████████
metrics/mean_log10_feature_sparsity█▅▄▄▄▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity/below_1e-5▁▁▅██████████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
sparsity/below_1e-6▁▁▁▃▅▆▇██████████████████████
sparsity/dead_features▁▁▁▁▁▁▁▁▁▁▂▂▃▄▅▆▆▇▇▇▇███████████████████
sparsity/mean_passes_since_fired▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███

Run summary:


details/current_learning_rate0.0004
details/n_training_tokens299827200
losses/ghost_grad_loss0.0
losses/l1_loss162.07342
losses/mse_loss1.42934
losses/overall_loss2.72593
metrics/CE_loss_score0.97257
metrics/ce_loss_with_ablation11.42603
metrics/ce_loss_with_sae3.61949
metrics/ce_loss_without_sae3.39944
metrics/explained_variance0.82112
metrics/explained_variance_std0.0526
metrics/l050.53198
metrics/l2_norm108.35806
metrics/l2_ratio0.94604
metrics/mean_log10_feature_sparsity-7.89094
sparsity/below_1e-518079
sparsity/below_1e-618075
sparsity/dead_features16912
sparsity/mean_passes_since_fired27024.85938

" ], "text/plain": [ "" @@ -640,7 +589,7 @@ { "data": { "text/html": [ - " View run 1024-L1-0.0002-LR-0.0001-Tokens-2.000e+07 at: https://wandb.ai/jbloom/mats_sae_training_language_benchmark_tests/runs/ec6k6v87
Synced 7 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)" + " View run 24576-L1-0.008-LR-0.0004-Tokens-2.000e+08 at: https://wandb.ai/jbloom/gpt2_small_experiments_april/runs/vbwoyzi8
View project at: https://wandb.ai/jbloom/gpt2_small_experiments_april
Synced 7 W&B file(s), 0 media file(s), 15 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" @@ -652,7 +601,7 @@ { "data": { "text/html": [ - "Find logs at: ./wandb/run-20240326_191703-ec6k6v87/logs" + "Find logs at: ./wandb/run-20240416_165827-vbwoyzi8/logs" ], "text/plain": [ "" @@ -665,223 +614,56 @@ "name": "stderr", "output_type": "stream", "text": [ - "4883| MSE Loss 0.000 | L1 0.000: : 20000768it [13:29, 37714.53it/s] /Users/josephbloom/miniforge3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2171: UserWarning: Run (ec6k6v87) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", + "73243| MSE Loss 1.416 | L1 1.255: : 300003328it [2:44:12, 54947.70it/s]/home/paperspace/miniconda3/envs/mats_sae_training/lib/python3.11/site-packages/wandb/sdk/wandb_run.py:2265: UserWarning: Run (vbwoyzi8) is finished. The call to `_console_raw_callback` will be ignored. Please make sure that you are using an active run.\n", " lambda data: self._console_raw_callback(\"stderr\", data),\n" ] } ], "source": [ - "import torch\n", - "import os\n", - "\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", - "\n", - "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", - "if device == \"cpu\" and torch.backends.mps.is_available():\n", - " device = \"mps\"\n", - "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "cfg = LanguageModelSAERunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"tiny-stories-1M\",\n", - " hook_point=\"blocks.1.attn.hook_z\",\n", - " hook_point_layer=1,\n", - " d_in=64,\n", - " # dataset_path=\"roneneldan/TinyStories\",\n", - " # is_dataset_tokenized=False,\n", - " # Dan at Apollo pretokenized this dataset for us which will speed up training.\n", - " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\",\n", - " is_dataset_tokenized=True,\n", - " # SAE Parameters\n", - " expansion_factor=16,\n", - " # Training Parameters\n", - " lr=1e-4,\n", - " lp_norm=1.0,\n", - " l1_coefficient=2e-4,\n", - " train_batch_size=4096,\n", - " context_size=128,\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=128,\n", - " total_training_tokens=1_000_000 * 20,\n", - " store_batch_size=32,\n", - " feature_sampling_window=500, # So we see the histograms.\n", - " dead_feature_window=250,\n", - " # WANDB\n", - " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_language_benchmark_tests\",\n", - " wandb_log_frequency=10,\n", - " # Misc\n", - " device=device,\n", - " seed=42,\n", - " n_checkpoints=0,\n", - " checkpoint_path=\"checkpoints\",\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "sparse_autoencoder = language_model_sae_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Toy Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from sae_lens.training.toy_model_runner import (\n", - " SAEToyModelRunnerConfig,\n", - " toy_model_sae_runner,\n", - ")\n", - "\n", - "\n", - "cfg = SAEToyModelRunnerConfig(\n", - " # Model Details\n", - " n_features=200,\n", - " n_hidden=5,\n", - " n_correlated_pairs=0,\n", - " n_anticorrelated_pairs=0,\n", - " feature_probability=0.025,\n", - " model_training_steps=10_000,\n", - " # SAE Parameters\n", - " d_sae=240,\n", - " l1_coefficient=0.001,\n", - " # SAE Train Config\n", - " train_batch_size=1028,\n", - " feature_sampling_window=3_000,\n", - " dead_feature_window=1_000,\n", - " feature_reinit_scale=0.5,\n", - " total_training_tokens=4096 * 300,\n", - " # Other parameters\n", - " log_to_wandb=True,\n", - " wandb_project=\"sae-training-test\",\n", - " wandb_log_frequency=5,\n", - " device=\"mps\",\n", - ")\n", - "\n", - "trained_sae = toy_model_sae_runner(cfg)\n", - "\n", - "assert trained_sae is not None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run caching of activations to disk" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", - "import sys\n", - "\n", - "sys.path.append(\"..\")\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "\n", - "from sae_lens.training.config import CacheActivationsRunnerConfig\n", - "from sae_lens.training.cache_activations_runner import cache_activations_runner\n", - "\n", - "cfg = CacheActivationsRunnerConfig(\n", - " # Data Generating Function (Model + Training Distibuion)\n", - " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.attn.hook_q\",\n", - " hook_point_layer=10,\n", - " hook_point_head_index=7,\n", - " d_in=64,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " cached_activations_path=\"../activations/\",\n", - " # Activation Store Parameters\n", - " n_batches_in_buffer=16,\n", - " total_training_tokens=500_000_000,\n", - " store_batch_size=32,\n", - " # Activation caching shuffle parameters\n", - " n_shuffles_final=16,\n", - " # Misc\n", - " device=\"mps\",\n", - " seed=42,\n", - " dtype=torch.float32,\n", - ")\n", - "\n", - "cache_activations_runner(cfg)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Train an SAE using the cached activations stored on disk\n", - "Pass `use_cached_activations=True` into the config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import os\n", "\n", - "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", - "os.environ[\"WANDB__SERVICE_WAIT\"] = \"300\"\n", - "from sae_lens.training.config import LanguageModelSAERunnerConfig\n", - "from sae_lens.training.lm_runner import language_model_sae_runner\n", "\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", " model_name=\"gpt2-small\",\n", - " hook_point=\"blocks.10.hook_resid_pre\",\n", - " hook_point_layer=11,\n", + " hook_point=\"blocks.8.hook_resid_pre\",\n", + " hook_point_layer=8,\n", " d_in=768,\n", - " dataset_path=\"Skylion007/openwebtext\",\n", - " is_dataset_tokenized=False,\n", - " use_cached_activations=True,\n", + " dataset_path=\"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", + " is_dataset_tokenized=True,\n", + " prepend_bos=True, # should experiment with turning this off.\n", " # SAE Parameters\n", - " expansion_factor=64, # determines the dimension of the SAE.\n", + " expansion_factor=32, # determines the dimension of the SAE.\n", + " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", + " apply_b_dec_to_input=False,\n", " # Training Parameters\n", - " lr=1e-5,\n", - " l1_coefficient=5e-4,\n", - " lr_scheduler_name=None,\n", + " adam_beta1=0,\n", + " adam_beta2=0.999,\n", + " lr=0.0004,\n", + " l1_coefficient=0.008,\n", + " lr_scheduler_name=\"constant\",\n", " train_batch_size=4096,\n", - " context_size=128,\n", + " context_size=256,\n", + " lr_warm_up_steps=5000,\n", " # Activation Store Parameters\n", - " n_batches_in_buffer=64,\n", - " total_training_tokens=200_000,\n", + " n_batches_in_buffer=128,\n", + " training_tokens=1_000_000 * 200, # 200M tokens seems doable overnight.\n", + " finetuning_method=\"decoder\",\n", + " finetuning_tokens=1_000_000 * 100,\n", " store_batch_size=32,\n", + " \n", " # Resampling protocol\n", - " feature_sampling_method=\"l2\",\n", - " feature_sampling_window=1000,\n", - " feature_reinit_scale=0.2,\n", + " use_ghost_grads=False,\n", + " feature_sampling_window=2500,\n", " dead_feature_window=5000,\n", - " dead_feature_threshold=1e-7,\n", + " dead_feature_threshold=1e-8,\n", + " \n", " # WANDB\n", " log_to_wandb=True,\n", - " wandb_project=\"mats_sae_training_gpt2_small\",\n", + " wandb_project=\"gpt2_small_experiments_april\",\n", " wandb_entity=None,\n", - " wandb_log_frequency=50,\n", + " wandb_log_frequency=100,\n", " # Misc\n", - " device=\"mps\",\n", + " device=device,\n", " seed=42,\n", " n_checkpoints=5,\n", " checkpoint_path=\"checkpoints\",\n", @@ -890,13 +672,6 @@ "\n", "sparse_autoencoder = language_model_sae_runner(cfg)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -915,7 +690,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.6" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/tests/benchmark/test_language_model_sae_runner.py b/tests/benchmark/test_language_model_sae_runner.py index e0030ce5..e1c878c0 100644 --- a/tests/benchmark/test_language_model_sae_runner.py +++ b/tests/benchmark/test_language_model_sae_runner.py @@ -30,7 +30,7 @@ def test_language_model_sae_runner(): context_size=128, # Activation Store Parameters n_batches_in_buffer=24, - total_training_tokens=1_000_000 * 10, + training_tokens=1_000_000 * 10, store_batch_size=32, # Resampling protocol use_ghost_grads=True, diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py index 4d18b6ab..0f172ab5 100644 --- a/tests/unit/helpers.py +++ b/tests/unit/helpers.py @@ -32,7 +32,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: feature_sampling_window=50, dead_feature_threshold=1e-7, n_batches_in_buffer=2, - total_training_tokens=1_000_000, + training_tokens=1_000_000, store_batch_size=4, log_to_wandb=False, wandb_project="test_project", diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..98898806 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -26,6 +26,7 @@ from tests.unit.helpers import build_sae_cfg +# TODO: Address why we have this code here rather than importing it. def build_train_ctx( sae: SparseAutoencoder, act_freq_scores: Tensor | None = None, @@ -310,11 +311,11 @@ def test_train_sae_group_on_language_model__runs( cfg = build_sae_cfg( checkpoint_path=checkpoint_dir, train_batch_size=32, - total_training_tokens=100, + training_tokens=100, context_size=8, ) # just a tiny datast which will run quickly - dataset = Dataset.from_list([{"text": "hello world"}] * 1000) + dataset = Dataset.from_list([{"text": "hello world"}] * 2000) activation_store = ActivationsStore.from_config(ts_model, cfg, dataset=dataset) sae_group = SparseAutoencoderDictionary(cfg) res = train_sae_group_on_language_model( diff --git a/tutorials/training_a_sparse_autoencoder.ipynb b/tutorials/training_a_sparse_autoencoder.ipynb index 73f4ccdd..17eed89c 100644 --- a/tutorials/training_a_sparse_autoencoder.ipynb +++ b/tutorials/training_a_sparse_autoencoder.ipynb @@ -335,7 +335,7 @@ " context_size=512, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", " # Activation Store Parameters\n", " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " total_training_tokens=1_000_000\n", + " training_tokens=1_000_000\n", " * 50, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", " store_batch_size=16,\n", " # Resampling protocol\n", From 822882cac9c05449b7237b7d42ce17297903da2f Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 18 Apr 2024 10:36:48 +0100 Subject: [PATCH 23/30] reformat run.ipynb --- scripts/run.ipynb | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/scripts/run.ipynb b/scripts/run.ipynb index 824b431c..f13aac59 100644 --- a/scripts/run.ipynb +++ b/scripts/run.ipynb @@ -256,13 +256,11 @@ " d_in=1024, # the width of the mlp output.\n", " dataset_path=\"apollo-research/roneneldan-TinyStories-tokenizer-gpt2\", # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.\n", " is_dataset_tokenized=True,\n", - " \n", " # SAE Parameters\n", " mse_loss_normalization=None, # We won't normalize the mse loss,\n", " expansion_factor=16, # the width of the SAE. Larger will result in better stats but slower training.\n", " b_dec_init_method=\"geometric_median\", # The geometric median can be used to initialize the decoder weights.\n", " apply_b_dec_to_input=False, # We won't apply the decoder to the input.\n", - " \n", " # Training Parameters\n", " lr=0.0008, # lower the better, we'll go fairly high to speed up the tutorial.\n", " lr_scheduler_name=\"constant\", # constant learning rate with warmup. Could be better schedules out there.\n", @@ -271,15 +269,13 @@ " lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n", " train_batch_size=4096,\n", " context_size=128, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n", - " \n", " # Activation Store Parameters\n", " n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n", - " training_tokens=1_000_000 * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", + " training_tokens=1_000_000\n", + " * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n", " finetuning_method=\"decoder\",\n", " finetuning_tokens=1_000_000 * 25,\n", " store_batch_size=32,\n", - " \n", - " \n", " # Resampling protocol\n", " use_ghost_grads=False,\n", " feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n", @@ -620,8 +616,6 @@ } ], "source": [ - "\n", - "\n", "cfg = LanguageModelSAERunnerConfig(\n", " # Data Generating Function (Model + Training Distibuion)\n", " model_name=\"gpt2-small\",\n", @@ -630,7 +624,7 @@ " d_in=768,\n", " dataset_path=\"apollo-research/Skylion007-openwebtext-tokenizer-gpt2\",\n", " is_dataset_tokenized=True,\n", - " prepend_bos=True, # should experiment with turning this off.\n", + " prepend_bos=True, # should experiment with turning this off.\n", " # SAE Parameters\n", " expansion_factor=32, # determines the dimension of the SAE.\n", " b_dec_init_method=\"geometric_median\", # geometric median is better but slower to get started\n", @@ -650,13 +644,11 @@ " finetuning_method=\"decoder\",\n", " finetuning_tokens=1_000_000 * 100,\n", " store_batch_size=32,\n", - " \n", " # Resampling protocol\n", " use_ghost_grads=False,\n", " feature_sampling_window=2500,\n", " dead_feature_window=5000,\n", " dead_feature_threshold=1e-8,\n", - " \n", " # WANDB\n", " log_to_wandb=True,\n", " wandb_project=\"gpt2_small_experiments_april\",\n", From bc766e4f7a8d472f647408b8a5cd3c6140d856b7 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 18 Apr 2024 10:43:45 +0100 Subject: [PATCH 24/30] minor changes --- docs/training_saes.md | 4 ++-- sae_lens/training/config.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/training_saes.md b/docs/training_saes.md index 632eefbd..3d6a3ca8 100644 --- a/docs/training_saes.md +++ b/docs/training_saes.md @@ -49,7 +49,7 @@ cfg = LanguageModelSAERunnerConfig( # Activation Store Parameters n_batches_in_buffer = 128, - total_training_tokens = 1_000_000 * 300, + training_tokens = 1_000_000 * 300, store_batch_size = 32, # Dead Neurons and Sparsity @@ -60,7 +60,7 @@ cfg = LanguageModelSAERunnerConfig( # WANDB log_to_wandb = True, - wandb_project= "mats_sae_training_gpt2", + wandb_project= "gpt2", wandb_entity = None, wandb_log_frequency=100, diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 22857e90..dbad703b 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -174,7 +174,9 @@ def __post_init__(self): f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 **6}" ) - total_training_steps = self.training_tokens // self.train_batch_size + total_training_steps = ( + self.training_tokens + self.finetuning_tokens + ) // self.train_batch_size print(f"Total training steps: {total_training_steps}") total_wandb_updates = total_training_steps // self.wandb_log_frequency From 2bb5975226807d352f2d3cf6b6dad7aefaf1b662 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Thu, 18 Apr 2024 11:12:01 +0100 Subject: [PATCH 25/30] par update --- tutorials/mamba_train_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/mamba_train_example.py b/tutorials/mamba_train_example.py index 12866bd2..19ff5074 100644 --- a/tutorials/mamba_train_example.py +++ b/tutorials/mamba_train_example.py @@ -32,7 +32,7 @@ lr_warm_up_steps=5000, # Activation Store Parameters n_batches_in_buffer=128, - total_training_tokens=1_000_000 * 300, + training_tokens=1_000_000 * 300, store_batch_size=32, # Dead Neurons and Sparsity use_ghost_grads=True, From 9c44731a9b7718c9f0913136ed9df42dac87c390 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 18 Apr 2024 19:08:39 +0100 Subject: [PATCH 26/30] chore: re-enabling isort in CI (#86) --- .github/workflows/build.yml | 4 ++-- pyproject.toml | 3 ++- sae_lens/analysis/neuronpedia_integration.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 4a473451..ba0a3d64 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -64,8 +64,8 @@ jobs: run: poetry run flake8 . - name: black code formatting run: poetry run black . --check - # - name: isort linting - # run: poetry run isort . --check-only --diff + - name: isort linting + run: poetry run isort . --check-only --diff - name: type checking run: poetry run pyright - name: Run Unit Tests diff --git a/pyproject.toml b/pyproject.toml index 38d2e911..0f08c82c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ pytest = "^8.0.2" pytest-cov = "^4.1.0" pre-commit = "^3.6.2" flake8 = "^7.0.0" -isort = "^5.13.2" +isort = "5.13.2" pyright = "^1.1.351" mamba-lens = "^0.0.4" @@ -48,6 +48,7 @@ mamba = ["mamba-lens"] [tool.isort] profile = "black" +src_paths = ["sae_lens", "tests"] [tool.pyright] typeCheckingMode = "strict" diff --git a/sae_lens/analysis/neuronpedia_integration.py b/sae_lens/analysis/neuronpedia_integration.py index 37d80e95..512bac32 100644 --- a/sae_lens/analysis/neuronpedia_integration.py +++ b/sae_lens/analysis/neuronpedia_integration.py @@ -1,6 +1,7 @@ import json import urllib.parse import webbrowser + import requests From 0ac218bf8068b8568310b40a1399f9eb3c8d992e Mon Sep 17 00:00:00 2001 From: Johnny Lin Date: Thu, 18 Apr 2024 13:57:08 -0700 Subject: [PATCH 27/30] v0.5.1 --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f08c82c..fb9ad251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sae-lens" -version = "0.5.0" +version = "0.5.1" description = "Training and Analyzing Sparse Autoencoders (SAEs)" authors = ["Joseph Bloom"] readme = "README.md" @@ -20,7 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" -sae-vis = { git = "https://github.com/callummcdougall/sae_vis.git", branch = "allow_disable_buffer" } +sae-vis = "^0.2.15" mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" mkdocs-autorefs = "^1.0.1" From 25cebf1e5e0630a377a5045c1b3571a5f181853f Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Fri, 19 Apr 2024 10:30:41 +0100 Subject: [PATCH 28/30] fix: typing issue, temporary --- sae_lens/analysis/neuronpedia_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 401941c4..a34e970a 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -244,7 +244,7 @@ def run(self): Column( SequencesConfig( stack_mode="stack-all", - buffer=None, + buffer=None, # type: ignore compute_buffer=True, n_quantiles=5, top_acts_group_size=20, From 00940219754ddb1be6708e54cdd0ac6ed5dc3948 Mon Sep 17 00:00:00 2001 From: jbloom-md Date: Fri, 19 Apr 2024 11:57:44 +0100 Subject: [PATCH 29/30] fix: pin pyzmq==26.0.1 temporarily --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index fb9ad251..c0aef664 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ mkdocstrings-python = "^1.9.0" safetensors = "^0.4.2" typer = "^0.12.3" mamba-lens = { version = "^0.0.4", optional = true } +pyzmq = "26.0.0" [tool.poetry.group.dev.dependencies] From 2d316724240e2cd6aceed118e3fb73811b48ecaa Mon Sep 17 00:00:00 2001 From: github-actions Date: Fri, 19 Apr 2024 11:04:29 +0000 Subject: [PATCH 30/30] 0.5.1 Automatically generated by python-semantic-release --- CHANGELOG.md | 87 ++++++++++++++++++++++++++++++++++++++++++++ sae_lens/__init__.py | 2 +- 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29501cd2..c1e6e8c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,55 @@ +## v0.5.1 (2024-04-19) + +### Chore + +* chore: re-enabling isort in CI (#86) ([`9c44731`](https://github.com/jbloomAus/SAELens/commit/9c44731a9b7718c9f0913136ed9df42dac87c390)) + +### Fix + +* fix: pin pyzmq==26.0.1 temporarily ([`0094021`](https://github.com/jbloomAus/SAELens/commit/00940219754ddb1be6708e54cdd0ac6ed5dc3948)) + +* fix: typing issue, temporary ([`25cebf1`](https://github.com/jbloomAus/SAELens/commit/25cebf1e5e0630a377a5045c1b3571a5f181853f)) + +### Unknown + +* v0.5.1 ([`0ac218b`](https://github.com/jbloomAus/SAELens/commit/0ac218bf8068b8568310b40a1399f9eb3c8d992e)) + +* Merge pull request #91 from jbloomAus/decoder-fine-tuning + +Decoder fine tuning ([`1fc652c`](https://github.com/jbloomAus/SAELens/commit/1fc652c19e2a34172c1fd520565e9620366f565c)) + +* par update ([`2bb5975`](https://github.com/jbloomAus/SAELens/commit/2bb5975226807d352f2d3cf6b6dad7aefaf1b662)) + +* Merge pull request #89 from jbloomAus/fix_np + +Enhance + Fix Neuronpedia generation / upload ([`38d507c`](https://github.com/jbloomAus/SAELens/commit/38d507c052875cbd78f8fd9dae45a658e47c2b9d)) + +* minor changes ([`bc766e4`](https://github.com/jbloomAus/SAELens/commit/bc766e4f7a8d472f647408b8a5cd3c6140d856b7)) + +* reformat run.ipynb ([`822882c`](https://github.com/jbloomAus/SAELens/commit/822882cac9c05449b7237b7d42ce17297903da2f)) + +* get decoder fine tuning working ([`11a71e1`](https://github.com/jbloomAus/SAELens/commit/11a71e1b95576ef6dc3dbec7eb1c76ce7ca44dfd)) + +* format ([`040676d`](https://github.com/jbloomAus/SAELens/commit/040676db6814c1f64171a32344e0bed40528c8f9)) + +* Merge pull request #88 from jbloomAus/get_feature_from_neuronpedia + +FEAT: Add API for getting Neuronpedia feature ([`1666a68`](https://github.com/jbloomAus/SAELens/commit/1666a68bb7d7ee4837e03d95b203ee371ca9ea9e)) + +* Fix resuming from batch ([`145a407`](https://github.com/jbloomAus/SAELens/commit/145a407f8d57755301bf56c87efd4e775c59b980)) + +* Use original repo for sae_vis ([`1a7d636`](https://github.com/jbloomAus/SAELens/commit/1a7d636c95ef508dde8bd100ab6d9f241b0be977)) + +* Use correct model name for np runner ([`138d5d4`](https://github.com/jbloomAus/SAELens/commit/138d5d445878c0830c6c96a5fbe6b10a1d9644b0)) + +* Merge main, remove eindex ([`6578436`](https://github.com/jbloomAus/SAELens/commit/6578436891e71a0ef60fb2ed6d6a6b6279d71cbc)) + +* Add API for getting Neuronpedia feature ([`e78207d`](https://github.com/jbloomAus/SAELens/commit/e78207d086cb3372dc805cbb4c87b694749cd905)) + + ## v0.5.0 (2024-04-17) ### Feature @@ -66,6 +115,14 @@ Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://gi * update readme ([`3694fd2`](https://github.com/jbloomAus/SAELens/commit/3694fd2c4cc7438121e4549636508c45835a5d38)) +* Fix upload skipped/dead features ([`932f380`](https://github.com/jbloomAus/SAELens/commit/932f380971ce3d431e6592c804d12f6df2b4ec78)) + +* Use python typer instead of shell script for neuronpedia jobs ([`b611e72`](https://github.com/jbloomAus/SAELens/commit/b611e721dd2620ab5a030cc0f6e37029c30711ca)) + +* Merge branch 'main' into fix_np ([`cc6cb6a`](https://github.com/jbloomAus/SAELens/commit/cc6cb6a96b793e41fb91f4ebbaf3bfa5e7c11b4e)) + +* convert sparsity to log sparsity if needed ([`8d7d404`](https://github.com/jbloomAus/SAELens/commit/8d7d4040033fb80c5b994cdc662b0f90b8fcc7aa)) + ## v0.4.0 (2024-04-16) @@ -87,8 +144,26 @@ Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://gi * default orthogonal init false ([`a8b0113`](https://github.com/jbloomAus/SAELens/commit/a8b0113140bd2f9b97befccc8f158dace02a4810)) +* Formatting ([`1e3d53e`](https://github.com/jbloomAus/SAELens/commit/1e3d53ec2b72897bfebb6065f3b530fe65d3a97c)) + +* Eindex required by sae_vis ([`f769e7a`](https://github.com/jbloomAus/SAELens/commit/f769e7a65ab84d4073852931a86ff3b5076eea3c)) + +* Upload dead feature stubs ([`9067380`](https://github.com/jbloomAus/SAELens/commit/9067380bf67b89d8b2d235944f696016286f683e)) + +* Make feature sparsity an argument ([`8230570`](https://github.com/jbloomAus/SAELens/commit/8230570297d68e35cb614a63abf442e4a01174d2)) + +* Fix buffer" ([`dde2481`](https://github.com/jbloomAus/SAELens/commit/dde248162b70ff4311d4182333b7cce43aed78df)) + +* Merge branch 'main' into fix_np ([`6658392`](https://github.com/jbloomAus/SAELens/commit/66583923cd625bfc1c1ef152bc5f5beaa764b2d6)) + * notebook update ([`feca408`](https://github.com/jbloomAus/SAELens/commit/feca408cf003737cd4eb529ca7fea2f77984f5c6)) +* Merge branch 'main' into fix_np ([`f8fb3ef`](https://github.com/jbloomAus/SAELens/commit/f8fb3efbde7fc79e6fafe2d9b3324c9f0b2a337d)) + +* Final fixes ([`e87788d`](https://github.com/jbloomAus/SAELens/commit/e87788d63a9b767e34e497c85a318337ab8aabb8)) + +* Don't use buffer, fix anomalies ([`2c9ca64`](https://github.com/jbloomAus/SAELens/commit/2c9ca642b334b7a444544a4640c483229dc04c62)) + ## v0.3.0 (2024-04-15) @@ -109,6 +184,10 @@ Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://gi * make dense_batch_mse_normalization optional ([`c41774e`](https://github.com/jbloomAus/SAELens/commit/c41774e5cfaeb195e3320e9e3fc93d60d921337d)) +* Runner is fixed, faster, cleaned up, and now gives whole sequences instead of buffer. ([`3837884`](https://github.com/jbloomAus/SAELens/commit/383788485917cee114fba24e8ded944aefcfb568)) + +* Merge branch 'main' into fix_np ([`3ed30cf`](https://github.com/jbloomAus/SAELens/commit/3ed30cf2b84a2444c8ed030641214f0dbb65898a)) + * add warning in run script ([`9a772ca`](https://github.com/jbloomAus/SAELens/commit/9a772ca6da155b5e97bc3109da74457f5addfbfd)) * update sae loading code ([`356a8ef`](https://github.com/jbloomAus/SAELens/commit/356a8efba06e4f453d2f15afe9171b71d780819a)) @@ -139,6 +218,10 @@ Co-authored-by: David Chanin <chanindav@gmail.com> ([`eea7db4`](https://gi ### Unknown +* Use legacy loader, add back histograms, logits. Fix anomaly characters. ([`ebbb622`](https://github.com/jbloomAus/SAELens/commit/ebbb622353bef21c953f844a108ea8d9fe31e9f9)) + +* Merge branch 'main' into fix_np ([`586e088`](https://github.com/jbloomAus/SAELens/commit/586e0881e08a9b013e2d4101878ef054c1f3dd8b)) + * Merge pull request #80 from wllgrnt/will-update-tutorial bugfix - minimum viable updates to tutorial notebook ([`e51016b`](https://github.com/jbloomAus/SAELens/commit/e51016b01f3b0f30c83365c54430908779671d87)) @@ -183,6 +266,8 @@ Fix artifact saving loading ([`8784c74`](https://github.com/jbloomAus/SAELens/co * add safetensors to project ([`0da48b0`](https://github.com/jbloomAus/SAELens/commit/0da48b044357eed17e5afffd3ce541e064185043)) +* Don't precompute background colors and tick values ([`271dbf0`](https://github.com/jbloomAus/SAELens/commit/271dbf05567b6e6ae4cfc1dab138132872038381)) + * Merge pull request #71 from weissercn/main Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/commit/84175055ba5876b335cbc0de38bf709d0b11cec1)) @@ -191,6 +276,8 @@ Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/com chore: updating README.md with pip install instructions and PyPI badge ([`4d7d1e7`](https://github.com/jbloomAus/SAELens/commit/4d7d1e7db5e952c7e9accf19c0ccce466cdcf6cf)) +* FIX: Add back correlated neurons, frac_nonzero ([`d532b82`](https://github.com/jbloomAus/SAELens/commit/d532b828bd77c18b73f495d6b42ca53b5148fd2f)) + * linting ([`1db0b5a`](https://github.com/jbloomAus/SAELens/commit/1db0b5ae7e091822c72bba0488d30fc16bc9a1c6)) * fixed graph name ([`ace4813`](https://github.com/jbloomAus/SAELens/commit/ace481322103737de2e80d688683d0c937ac5558)) diff --git a/sae_lens/__init__.py b/sae_lens/__init__.py index e13470b4..e5dff1f7 100644 --- a/sae_lens/__init__.py +++ b/sae_lens/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.5.0" +__version__ = "0.5.1" from .training.activations_store import ActivationsStore from .training.cache_activations_runner import cache_activations_runner