diff --git a/pyproject.toml b/pyproject.toml index 1475de8f..38d2e911 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ matplotlib-inline = "^0.1.6" datasets = "^2.17.1" babe = "^0.0.7" nltk = "^3.8.1" -sae-vis = "0.2.6" +sae-vis = { git = "https://github.com/callummcdougall/sae_vis.git", branch = "allow_disable_buffer" } mkdocs = "^1.5.3" mkdocs-material = "^9.5.15" mkdocs-autorefs = "^1.0.1" @@ -28,6 +28,7 @@ mkdocs-section-index = "^0.3.8" mkdocstrings = "^0.24.1" mkdocstrings-python = "^1.9.0" safetensors = "^0.4.2" +typer = "^0.12.3" mamba-lens = { version = "^0.0.4", optional = true } diff --git a/sae_lens/analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py index d77be35d..150b0a52 100644 --- a/sae_lens/analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -11,6 +11,7 @@ import plotly import plotly.express as px import torch +import wandb from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, @@ -24,7 +25,6 @@ from tqdm import tqdm from transformer_lens import HookedTransformer -import wandb from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 14c39556..401941c4 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -4,7 +4,6 @@ # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json -import time import numpy as np import torch @@ -13,15 +12,19 @@ ActsHistogramConfig, Column, FeatureTablesConfig, + LogitsHistogramConfig, + LogitsTableConfig, SaeVisConfig, SaeVisLayoutConfig, SequencesConfig, ) -from sae_vis.data_fetching_fns import get_feature_data +from sae_vis.data_storing_fns import SaeVisData from tqdm import tqdm from transformer_lens import HookedTransformer +from sae_lens.toolkit.pretrained_saes import load_sparsity from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader +from sae_lens.training.sparse_autoencoder import SparseAutoencoder OUT_OF_RANGE_TOKEN = "<|outofrange|>" @@ -29,6 +32,21 @@ "bg_color_map", ["white", "darkorange"] ) +DEFAULT_SPARSITY_THRESHOLD = -5 + +HTML_ANOMALIES = { + "âĢĶ": "—", + "âĢĵ": "–", + "âĢľ": "“", + "âĢĿ": "”", + "âĢĺ": "‘", + "âĢĻ": "’", + "âĢĭ": " ", # todo: this is actually zero width space + "Ġ": " ", + "Ċ": "\n", + "ĉ": "\t", +} + class NpEncoder(json.JSONEncoder): def default(self, o: Any): @@ -43,69 +61,51 @@ def default(self, o: Any): class NeuronpediaRunner: - model: HookedTransformer | None = None - def __init__( self, + sae_id: str, sae_path: str, - feature_sparsity_path: Optional[str] = None, - neuronpedia_parent_folder: str = "./neuronpedia_outputs", - init_session: bool = True, + outputs_dir: str, + sparsity_threshold: int = DEFAULT_SPARSITY_THRESHOLD, # token pars n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6, - # sampling pars - n_features_at_a_time: int = 1024, - buffer_tokens_left: int = 8, - buffer_tokens_right: int = 8, - # start and end batch + # batching + n_features_at_a_time: int = 128, start_batch_inclusive: int = 0, end_batch_inclusive: Optional[int] = None, ): - self.sae_path = sae_path - if init_session: - self.init_sae_session() - self.feature_sparsity_path = feature_sparsity_path + self.device = "cpu" + if torch.backends.mps.is_available(): + self.device = "mps" + elif torch.cuda.is_available(): + self.device = "cuda" + + self.sae_path = sae_path + self.sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + self.sae_path, device=self.device + ) + loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) + self.model, _, self.activation_store = loader.load_sae_training_group_session() + self.model_id = self.sparse_autoencoder.cfg.model_name + self.layer = self.sparse_autoencoder.cfg.hook_point_layer + self.sae_id = sae_id + self.sparsity_threshold = sparsity_threshold self.n_features_at_a_time = n_features_at_a_time - self.buffer_tokens_left = buffer_tokens_left - self.buffer_tokens_right = buffer_tokens_right self.n_batches_to_sample_from = n_batches_to_sample_from self.n_prompts_to_select = n_prompts_to_select self.start_batch = start_batch_inclusive self.end_batch = end_batch_inclusive - # Deal with file structure - if not os.path.exists(neuronpedia_parent_folder): - os.makedirs(neuronpedia_parent_folder) - self.neuronpedia_folder = ( - f"{neuronpedia_parent_folder}/{self.get_folder_name()}" - ) - if not os.path.exists(self.neuronpedia_folder): - os.makedirs(self.neuronpedia_folder) - - def get_folder_name(self): - model = self.sparse_autoencoder.cfg.model_name - hook_point = self.sparse_autoencoder.cfg.hook_point - d_sae = self.sparse_autoencoder.cfg.d_sae - dashboard_folder_name = f"{model}_{hook_point}_{d_sae}" - - return dashboard_folder_name - - def init_sae_session(self): - ( - model, - sae_group, - self.activation_store, - ) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path) - # only HookedTransformer works with this runner - assert isinstance(model, HookedTransformer) - self.model = model - # TODO: handle multiple autoencoders - self.sparse_autoencoder = next(iter(sae_group))[1] + if not os.path.exists(outputs_dir): + os.makedirs(outputs_dir) + self.outputs_dir = outputs_dir def get_tokens( - self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6 + self, + n_batches_to_sample_from: int = 2**12, + n_prompts_to_select: int = 4096 * 6, ): all_tokens_list = [] pbar = tqdm(range(n_batches_to_sample_from)) @@ -124,7 +124,9 @@ def round_list(self, to_round: list[float]): return list(np.round(to_round, 3)) def to_str_tokens_safe( - self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor] + self, + vocab_dict: Dict[int, str], + tokens: Union[int, List[int], torch.Tensor], ): """ does to_str_tokens, except handles out of range @@ -151,28 +153,20 @@ def to_str_tokens_safe( return np.reshape(str_tokens, tokens.shape).tolist() def run(self): - """ - Generate the Neuronpedia outputs. - """ - - if self.model is None: - self.init_sae_session() - self.n_features = self.sparse_autoencoder.cfg.d_sae assert self.n_features is not None # if we have feature sparsity, then use it to only generate outputs for non-dead features self.target_feature_indexes: list[int] = [] - if self.feature_sparsity_path: - loaded = torch.load( - self.feature_sparsity_path, map_location=self.sparse_autoencoder.device - ) - self.target_feature_indexes = ( - (loaded > -5).nonzero(as_tuple=True)[0].tolist() - ) - else: - self.target_feature_indexes = list(range(self.n_features)) - print("No feat sparsity path specified - doing all indexes.") + sparsity = load_sparsity(self.sae_path) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) + sparsity = sparsity.to(self.device) + self.target_feature_indexes = ( + (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() + ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) @@ -180,9 +174,9 @@ def run(self): feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] - print(f"==== Starting at batch: {self.start_batch}") - if self.end_batch is not None: - print(f"==== Ending at batch: {self.end_batch}") + # print(f"==== Starting Batch: {self.start_batch}") + # if self.end_batch is not None and self.end_batch != self.start_batch: + # print(f"==== Ending at Batch: {self.end_batch}") if self.start_batch > len(feature_idx) + 1: print( @@ -192,32 +186,41 @@ def run(self): # write dead into file so we can create them as dead in Neuronpedia skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) - skipped_indexes_json = json.dumps({"skipped_indexes": list(skipped_indexes)}) - with open(f"{self.neuronpedia_folder}/skipped_indexes.json", "w") as f: + skipped_indexes_json = json.dumps( + { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "skipped_indexes": list(skipped_indexes), + } + ) + with open(f"{self.outputs_dir}/skipped_indexes.json", "w") as f: f.write(skipped_indexes_json) - print(f"Total features to run: {len(self.target_feature_indexes)}") - print(f"Total skipped: {len(skipped_indexes)}") - print(f"Total batches: {len(feature_idx)}") - - print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}") - print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}") - print(f"Writing files to: {self.neuronpedia_folder}") + tokens_file = f"{self.outputs_dir}/tokens_{self.n_batches_to_sample_from}_{self.n_prompts_to_select}.pt" + if os.path.isfile(tokens_file): + print("Tokens exist, loading them.") + tokens = torch.load(tokens_file) + else: + print("Tokens don't exist, making them.") + tokens = self.get_tokens( + self.n_batches_to_sample_from, self.n_prompts_to_select + ) + torch.save( + tokens, + tokens_file, + ) - # get tokens: - start = time.time() - tokens = self.get_tokens( - self.n_batches_to_sample_from, self.n_prompts_to_select - ) - end = time.time() - print(f"Time to get tokens: {end - start}") + vocab_dict = self.model.tokenizer.vocab + new_vocab_dict = {} + # Replace substrings in the keys of vocab_dict using HTML_ANOMALIES + for k, v in vocab_dict.items(): + modified_key = k + for anomaly in HTML_ANOMALIES: + modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) + new_vocab_dict[v] = modified_key + vocab_dict = new_vocab_dict - assert self.model is not None - vocab_dict = cast(Any, self.model.tokenizer).vocab - vocab_dict = { - v: k.replace("Ġ", " ").replace("\n", "\\n").replace("Ċ", "\n") - for k, v in vocab_dict.items() - } # pad with blank tokens to the actual vocab size for i in range(len(vocab_dict), self.model.cfg.d_vocab): vocab_dict[i] = OUT_OF_RANGE_TOKEN @@ -234,44 +237,37 @@ def run(self): # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue - print(f"Doing batch: {feature_batch_count}") + print(f"========== Running Batch #{feature_batch_count} ==========") layout = SaeVisLayoutConfig( columns=[ Column( SequencesConfig( stack_mode="stack-all", - buffer=( - self.buffer_tokens_left, - self.buffer_tokens_right, - ), - compute_buffer=False, - n_quantiles=10, + buffer=None, + compute_buffer=True, + n_quantiles=5, top_acts_group_size=20, quantile_group_size=5, ), - width=650, - ), - Column( ActsHistogramConfig(), - FeatureTablesConfig(n_rows=5), - width=500, - ), - ], - height=1000, + LogitsHistogramConfig(), + LogitsTableConfig(), + FeatureTablesConfig(n_rows=3), + ) + ] ) feature_vis_params = SaeVisConfig( hook_point=self.sparse_autoencoder.cfg.hook_point, - minibatch_size_features=256, + minibatch_size_features=128, minibatch_size_tokens=64, features=features_to_process, - verbose=False, + verbose=True, feature_centric_layout=layout, ) - - feature_data = get_feature_data( + feature_data = SaeVisData.create( encoder=self.sparse_autoencoder, # type: ignore - model=self.model, + model=cast(HookedTransformer, self.model), tokens=tokens, cfg=feature_vis_params, ) @@ -288,20 +284,6 @@ def run(self): feature.logits_table_data.bottom_logits ) - # TODO: don't precompute/store these. should do it on the frontend - max_value = max( - np.absolute(bottom10_logits).max(), - np.absolute(top10_logits).max(), - ) - neg_bg_values = self.round_list( - np.absolute(bottom10_logits) / max_value - ) - pos_bg_values = self.round_list( - np.absolute(top10_logits) / max_value - ) - feature_output["neg_bg_values"] = neg_bg_values - feature_output["pos_bg_values"] = pos_bg_values - if feature.feature_tables_data: feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices @@ -315,23 +297,21 @@ def run(self): feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - # TODO: this value doesn't exist in the new output type, commenting out for now - # there is a cossim value though - is that what's needed? - # feature_output["correlated_neurons_l1"] = self.round_list( - # feature.feature_tables_data.correlated_neurons_l1 - # ) + feature_output["correlated_neurons_l1"] = self.round_list( + feature.feature_tables_data.correlated_neurons_cossim + ) feature_output["correlated_neurons_pearson"] = self.round_list( feature.feature_tables_data.correlated_neurons_pearson ) - # feature_output["correlated_features_indices"] = ( - # feature.feature_tables_data.correlated_features_indices - # ) - # feature_output["correlated_features_l1"] = self.round_list( - # feature.feature_tables_data.correlated_features_l1 - # ) - # feature_output["correlated_features_pearson"] = self.round_list( - # feature.feature_tables_data.correlated_features_pearson - # ) + feature_output["correlated_features_indices"] = ( + feature.feature_tables_data.correlated_features_indices + ) + feature_output["correlated_features_l1"] = self.round_list( + feature.feature_tables_data.correlated_features_cossim + ) + feature_output["correlated_features_pearson"] = self.round_list( + feature.feature_tables_data.correlated_features_pearson + ) feature_output["neg_str"] = self.to_str_tokens_safe( vocab_dict, feature.logits_table_data.bottom_token_ids @@ -342,30 +322,23 @@ def run(self): ) feature_output["pos_values"] = top10_logits - # TODO: don't know what this should be in the new version - # feature_output["frac_nonzero"] = ( - # feature.middle_plots_data.frac_nonzero - # ) + feature_output["frac_nonzero"] = ( + float( + feature.acts_histogram_data.title.split(" = ")[1].split( + "%" + )[0] + ) + / 100 + if feature.acts_histogram_data.title is not None + else 0 + ) freq_hist_data = feature.acts_histogram_data freq_bar_values = self.round_list(freq_hist_data.bar_values) feature_output["freq_hist_data_bar_values"] = freq_bar_values - feature_output["freq_hist_data_tick_vals"] = self.round_list( - freq_hist_data.tick_vals - ) - - # TODO: don't precompute/store these. should do it on the frontend - freq_bar_values_clipped = [ - (0.4 * max(freq_bar_values) + 0.6 * v) / max(freq_bar_values) - for v in freq_bar_values - ] - freq_bar_colors = [ - colors.rgb2hex(BG_COLOR_MAP(v)) for v in freq_bar_values_clipped - ] feature_output["freq_hist_data_bar_heights"] = self.round_list( freq_hist_data.bar_heights ) - feature_output["freq_bar_colors"] = freq_bar_colors logits_hist_data = feature.logits_histogram_data feature_output["logits_hist_data_bar_heights"] = self.round_list( @@ -374,11 +347,7 @@ def run(self): feature_output["logits_hist_data_bar_values"] = self.round_list( logits_hist_data.bar_values ) - feature_output["logits_hist_data_tick_vals"] = self.round_list( - logits_hist_data.tick_vals - ) - # TODO: check this feature_output["num_tokens_for_dashboard"] = ( self.n_prompts_to_select ) @@ -441,10 +410,19 @@ def run(self): features_outputs.append(feature_output) - json_object = json.dumps(features_outputs, cls=NpEncoder) + to_write = { + "model_id": self.model_id, + "layer": str(self.layer), + "sae_id": self.sae_id, + "features": features_outputs, + "n_batches_to_sample_from": self.n_batches_to_sample_from, + "n_prompts_to_select": self.n_prompts_to_select, + } + json_object = json.dumps(to_write, cls=NpEncoder) with open( - f"{self.neuronpedia_folder}/batch-{feature_batch_count}.json", "w" + f"{self.outputs_dir}/batch-{feature_batch_count}.json", + "w", ) as f: f.write(json_object) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 2a91b86f..c28de151 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -2,7 +2,6 @@ from typing import Any, Optional, cast import torch - import wandb DTYPE_MAP = { diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 80aa91dd..9742512d 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -3,9 +3,9 @@ import pandas as pd import torch +import wandb from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index c6a4ce79..4d7ea440 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -1,6 +1,7 @@ from typing import Any, cast import wandb + from sae_lens.training.config import LanguageModelSAERunnerConfig from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/training/toy_model_runner.py b/sae_lens/training/toy_model_runner.py index 29b436a6..83044c3c 100644 --- a/sae_lens/training/toy_model_runner.py +++ b/sae_lens/training/toy_model_runner.py @@ -3,8 +3,8 @@ import einops import torch - import wandb + from sae_lens.training.sparse_autoencoder import SparseAutoencoder from sae_lens.training.toy_models import Config as ToyConfig from sae_lens.training.toy_models import Model as ToyModel diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index efcd3473..9cf9e9a6 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -3,13 +3,13 @@ from typing import Any, cast import torch +import wandb from safetensors.torch import save_file from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.evals import run_evals from sae_lens.training.geometric_median import compute_geometric_median diff --git a/sae_lens/training/train_sae_on_toy_model.py b/sae_lens/training/train_sae_on_toy_model.py index e5227da9..61236b7a 100644 --- a/sae_lens/training/train_sae_on_toy_model.py +++ b/sae_lens/training/train_sae_on_toy_model.py @@ -1,10 +1,10 @@ from typing import Any, cast import torch +import wandb from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..e0c8b9e6 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -5,11 +5,11 @@ import pytest import torch +import wandb from datasets import Dataset from torch import Tensor from transformer_lens import HookedTransformer -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import get_scheduler from sae_lens.training.sae_group import SparseAutoencoderDictionary diff --git a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb index 79474377..a208f77c 100644 --- a/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb +++ b/tutorials/neuronpedia/generating_neuronpedia_outputs.ipynb @@ -22,27 +22,17 @@ "metadata": {}, "outputs": [], "source": [ - "# from huggingface_hub import hf_hub_download\n", - "\n", - "# MODEL = \"gpt2-small\"\n", - "# LAYER = 0\n", - "# SOURCE = \"res-jb\"\n", - "# REPO_ID = \"jbloom/GPT2-Small-SAEs\"\n", - "# FILENAME = f\"final_sparse_autoencoder_gpt2-small_blocks.{LAYER}.hook_resid_pre_24576.pt\"\n", - "# SAE_PATH = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)\n", - "\n", - "# Change these\n", - "MODEL = \"pythia-70m-deduped\"\n", - "LAYER = 0\n", - "TYPE = \"resid\"\n", - "SOURCE_AUTHOR_SUFFIX = \"sm\"\n", - "SOURCE = \"res-sm\"\n", - "\n", - "# Change these depending on how your files are named\n", - "SAE_PATH = f\"../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt\"\n", - "FEATURE_SPARSITY_PATH = (\n", - " f\"../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt\"\n", - ")" + "from sae_lens.toolkit.pretrained_saes import download_sae_from_hf\n", + "import os\n", + "\n", + "MODEL_ID = \"gpt2-small\"\n", + "SAE_ID = \"res-jb\"\n", + "\n", + "(_, SAE_WEIGHTS_PATH, _) = download_sae_from_hf(\n", + " \"jbloom/GPT2-Small-SAEs-Reformatted\", \"blocks.0.hook_resid_pre\"\n", + ")\n", + "\n", + "SAE_PATH = os.path.dirname(SAE_WEIGHTS_PATH)" ] }, { @@ -60,21 +50,21 @@ "source": [ "from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner\n", "\n", - "NP_OUTPUT_FOLDER = \"../neuronpedia_outputs\"\n", + "print(SAE_PATH)\n", + "NP_OUTPUT_FOLDER = \"../../neuronpedia_outputs/my_outputs\"\n", "\n", "runner = NeuronpediaRunner(\n", + " sae_id=SAE_ID,\n", " sae_path=SAE_PATH,\n", - " feature_sparsity_path=FEATURE_SPARSITY_PATH,\n", - " neuronpedia_parent_folder=NP_OUTPUT_FOLDER,\n", - " init_session=True,\n", + " outputs_dir=NP_OUTPUT_FOLDER,\n", + " sparsity_threshold=-5,\n", " n_batches_to_sample_from=2**12,\n", - " n_prompts_to_select=4096 * 6,\n", - " n_features_at_a_time=512,\n", - " buffer_tokens_left=64,\n", - " buffer_tokens_right=63,\n", - " start_batch_inclusive=22,\n", - " end_batch_inclusive=23,\n", + " n_prompts_to_select=4096*6,\n", + " n_features_at_a_time=24,\n", + " start_batch_inclusive=1,\n", + " end_batch_inclusive=1,\n", ")\n", + "\n", "runner.run()" ] }, @@ -100,8 +90,7 @@ "import os\n", "import requests\n", "\n", - "folder_path = runner.neuronpedia_folder\n", - "\n", + "FEATURE_OUTPUTS_FOLDER = runner.outputs_dir\n", "\n", "def nanToNeg999(obj: Any) -> Any:\n", " if isinstance(obj, dict):\n", @@ -120,13 +109,12 @@ "\n", "# Server info\n", "host = \"http://localhost:3000\"\n", - "sourceName = str(LAYER) + \"-\" + SOURCE\n", "\n", "# Upload alive features\n", - "for file_name in os.listdir(folder_path):\n", + "for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):\n", " if file_name.startswith(\"batch-\") and file_name.endswith(\".json\"):\n", " print(\"Uploading file: \" + file_name)\n", - " file_path = os.path.join(folder_path, file_name)\n", + " file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)\n", " f = open(file_path, \"r\")\n", " data = json.load(f)\n", "\n", @@ -134,31 +122,21 @@ " data_fixed = json.dumps(data, cls=NanConverter)\n", " data = json.loads(data_fixed)\n", "\n", - " url = host + \"/api/internal/upload-features\"\n", + " url = host + \"/api/local/upload-features\"\n", " resp = requests.post(\n", " url,\n", - " json={\n", - " \"modelId\": MODEL,\n", - " \"layer\": sourceName,\n", - " \"features\": data,\n", - " },\n", + " json=data,\n", " )\n", "\n", - "# Upload dead features (just makes blanks features)\n", - "# We want this for completeness\n", - "# skipped_path = os.path.join(folder_path, \"skipped_indexes.json\")\n", - "# f = open(skipped_path, \"r\")\n", - "# data = json.load(f)\n", - "# skipped_indexes = data[\"skipped_indexes\"]\n", - "# url = host + \"/api/internal/upload-dead-features\"\n", - "# resp = requests.post(\n", - "# url,\n", - "# json={\n", - "# \"modelId\": MODEL,\n", - "# \"layer\": sourceName,\n", - "# \"deadIndexes\": skipped_indexes,\n", - "# },\n", - "# )" + "# Upload dead feature stubs\n", + "skipped_path = os.path.join(FEATURE_OUTPUTS_FOLDER, \"skipped_indexes.json\")\n", + "f = open(skipped_path, \"r\")\n", + "data = json.load(f)\n", + "url = host + \"/api/local/upload-dead-features\"\n", + "resp = requests.post(\n", + " url,\n", + " json=data,\n", + ")" ] }, { @@ -185,7 +163,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.11.8" } }, "nbformat": 4, diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py new file mode 100644 index 00000000..945a2939 --- /dev/null +++ b/tutorials/neuronpedia/make_batch.py @@ -0,0 +1,29 @@ +import sys + +from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner + +# we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +SAE_ID = sys.argv[1] +SAE_PATH = sys.argv[2] +OUTPUTS_DIR = sys.argv[3] +SPARSITY_THRESHOLD = int(sys.argv[4]) +N_BATCHES_SAMPLE = int(sys.argv[5]) +N_PROMPTS_SELECT = int(sys.argv[6]) +FEATURES_AT_A_TIME = int(sys.argv[7]) +START_BATCH_INCLUSIVE = int(sys.argv[8]) +END_BATCH_INCLUSIVE = int(sys.argv[9]) + +runner = NeuronpediaRunner( + sae_id=SAE_ID, + sae_path=SAE_PATH, + outputs_dir=OUTPUTS_DIR, + sparsity_threshold=SPARSITY_THRESHOLD, + n_batches_to_sample_from=N_BATCHES_SAMPLE, + n_prompts_to_select=N_PROMPTS_SELECT, + n_features_at_a_time=FEATURES_AT_A_TIME, + start_batch_inclusive=START_BATCH_INCLUSIVE, + end_batch_inclusive=END_BATCH_INCLUSIVE, +) +runner.run() diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py new file mode 100755 index 00000000..aa36a6b6 --- /dev/null +++ b/tutorials/neuronpedia/neuronpedia.py @@ -0,0 +1,404 @@ +# we use a script that launches separate python processes to work around OOM issues - this ensures every batch gets the whole system available memory +# better fix is to investigate and fix the memory issues + +import json +import math +import os +import subprocess +from decimal import Decimal +from pathlib import Path +from typing import Any + +import requests +import torch +import typer +from rich import print +from rich.align import Align +from rich.panel import Panel +from typing_extensions import Annotated + +from sae_lens.toolkit.pretrained_saes import load_sparsity +from sae_lens.training.sparse_autoencoder import SparseAutoencoder + +OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") + +app = typer.Typer( + add_completion=False, + no_args_is_help=True, + help="Tool that generates features (generate) and uploads features (upload) to Neuronpedia.", +) + + +@app.command() +def generate( + sae_id: Annotated[ + str, + typer.Option( + help="SAE ID to generate features for (must exactly match the one used on Neuronpedia). Example: res-jb", + prompt=""" +What is the SAE ID you want to generate features for? +This was set when you did 'Add SAEs' on Neuronpedia. This must exactly match that ID (including casing). +It's in the format [abbrev hook name]-[abbrev author name], like res-jb. +Enter SAE ID""", + ), + ], + sae_path: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + help="Absolute local path to the SAE directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors).", + prompt=""" +What is the absolute local path to your SAE's directory (with cfg.json, sae_weights.safetensors, sparsity.safetensors)? +Enter path""", + ), + ], + log_sparsity: Annotated[ + int, + typer.Option( + min=-10, + max=0, + help="Desired feature log sparsity threshold. Range -10 to 0.", + prompt=""" +What is your desired feature log sparsity threshold? +Enter value from -10 to 0""", + ), + ] = -5, + feat_per_batch: Annotated[ + int, + typer.Option( + min=1, + max=2048, + help="Features to generate per batch. More requires more memory.", + prompt=""" +How many features do you want to generate per batch? More requires more memory. +Enter value""", + ), + ] = 128, + resume_from_batch: Annotated[ + int, + typer.Option( + min=1, + help="Batch number to resume from.", + prompt=""" +Do you want to resume from a specific batch number? +Enter 1 to start from the beginning""", + ), + ] = 1, + n_batches_to_sample: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of batches to sample from.", + prompt=""" +[Activation Text Generation] How many batches do you want to sample from? +Enter value""", + ), + ] = 2 + ** 12, + n_prompts_to_select: Annotated[ + int, + typer.Option( + min=1, + help="[Activation Text Generation] Number of prompts to select from.", + prompt=""" +[Activation Text Generation] How many prompts do you want to select from? +Enter value""", + ), + ] = 4096 + * 6, +): + """ + This will start a batch job that generates features for Neuronpedia for a specific SAE. To upload those features, use the 'upload' command afterwards. + """ + + # Check arguments + if sae_path.is_dir() is not True: + print("Error: SAE path must be a directory.") + raise typer.Abort() + if sae_path.joinpath("cfg.json").is_file() is not True: + print("Error: cfg.json file not found in SAE directory.") + raise typer.Abort() + if sae_path.joinpath("sae_weights.safetensors").is_file() is not True: + print("Error: sae_weights.safetensors file not found in SAE directory.") + raise typer.Abort() + if sae_path.joinpath("sparsity.safetensors").is_file() is not True: + print("Error: sparsity.safetensors file not found in SAE directory.") + raise typer.Abort() + + sae_path_string = sae_path.as_posix() + + # Load SAE + device = "cpu" + if torch.backends.mps.is_available(): + device = "mps" + elif torch.cuda.is_available(): + device = "cuda" + sparse_autoencoder = SparseAutoencoder.load_from_pretrained( + sae_path_string, device=device + ) + model_id = sparse_autoencoder.cfg.model_name + + outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" + outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) + if outputs_dir.exists() and outputs_dir.is_file(): + print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.") + raise typer.Abort() + outputs_dir.mkdir(parents=True, exist_ok=True) + # Check if output_dir has any files starting with "batch_" + batch_files = list(outputs_dir.glob("batch-*.json")) + if len(batch_files) > 0 and resume_from_batch == 1: + print( + f"Error: Output directory {outputs_dir.as_posix()} has existing batch files. This is only allowed if you are resuming from a batch. Please delete or move the existing batch-*.json files." + ) + raise typer.Abort() + + sparsity = load_sparsity(sae_path_string) + # convert sparsity to logged sparsity if it's not + # TODO: standardize the sparsity file format + if len(sparsity) > 0 and sparsity[0] >= 0: + sparsity = torch.log10(sparsity + 1e-10) + sparsity = sparsity.to(device) + alive_indexes = (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() + num_alive = len(alive_indexes) + num_dead = sparse_autoencoder.d_sae - num_alive + + print("\n") + print( + Align.center( + Panel.fit( + f""" +[white]SAE Path: [green]{sae_path.as_posix()} +[white]Model ID: [green]{model_id} +[white]Hook Point: [green]{sparse_autoencoder.cfg.hook_point} +[white]Using Device: [green]{device} +""", + title="SAE Info", + ) + ) + ) + num_batches = math.ceil(num_alive / feat_per_batch) + print( + Align.center( + Panel.fit( + f""" +[white]Total Features: [green]{sparse_autoencoder.d_sae} +[white]Log Sparsity Threshold: [green]{log_sparsity} +[white]Alive Features: [green]{num_alive} +[white]Dead Features: [red]{num_dead} +[white]Features per Batch: [green]{feat_per_batch} +[white]Number of Batches: [green]{num_batches} +{resume_from_batch != 1 and f"[white]Resuming from Batch: [green]{resume_from_batch}" or ""} +""", + title="Number of Features", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[white]Dataset: [green]{sparse_autoencoder.cfg.dataset_path} +[white]Batches to Sample From: [green]{n_batches_to_sample} +[white]Prompts to Select From: [green]{n_prompts_to_select} +""", + title="Activation Text Settings", + ) + ) + ) + print( + Align.center( + Panel.fit( + f""" +[green]{outputs_dir.absolute().as_posix()} +""", + title="Output Directory", + ) + ) + ) + + print( + Align.center( + "\n========== [yellow]Starting batch feature generations...[/yellow] ==========" + ) + ) + + # iterate from 1 to num_batches + for i in range(resume_from_batch, num_batches + 1): + command = [ + "python", + "make_batch.py", + sae_id, + sae_path.absolute().as_posix(), + outputs_dir.absolute().as_posix(), + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + print("\n") + print( + Align.center( + Panel.fit( + f""" +[yellow]{" ".join(command)} +""", + title="Running Command for Batch #" + str(i), + ) + ) + ) + # make a subprocess call to python make_batch.py + subprocess.run( + [ + "python", + "make_batch.py", + sae_id, + sae_path, + outputs_dir, + str(log_sparsity), + str(n_batches_to_sample), + str(n_prompts_to_select), + str(feat_per_batch), + str(i), + str(i), + ] + ) + + print( + Align.center( + Panel( + f""" +Your Features Are In: [green]{outputs_dir.absolute().as_posix()} +Use [yellow]'neuronpedia.py upload'[/yellow] to upload your features to Neuronpedia. +""", + title="Generation Complete", + ) + ) + ) + + +@app.command() +def upload( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will upload features that were generated to Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + files_to_upload = list(outputs_dir.glob("batch-*.json")) + + # sort files by batch number + files_to_upload.sort(key=lambda x: int(x.stem.split("-")[1])) + + print("\n") + # Upload alive features + for file_path in files_to_upload: + print("===== Uploading file: " + os.path.basename(file_path)) + f = open(file_path, "r") + data = json.load(f) + + # Replace NaNs + data_fixed = json.dumps(data, cls=NanConverter) + data = json.loads(data_fixed) + + url = host + "/api/local/upload-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + f""" +{len(files_to_upload)} batch files uploaded to Neuronpedia. +""", + title="Uploads Complete", + ) + ) + ) + + +@app.command() +def upload_dead_stubs( + outputs_dir: Annotated[ + Path, + typer.Option( + exists=True, + dir_okay=True, + readable=True, + resolve_path=True, + prompt="What is the absolute, full local file path to the feature outputs directory?", + ), + ], + host: Annotated[ + str, + typer.Option( + prompt="""Host to upload to? (Default: http://localhost:3000)""", + ), + ] = "http://localhost:3000", +): + """ + This will create "There are no activations for this feature" stubs for dead features on Neuronpedia. It currently only works if you have admin access to a Neuronpedia instance via localhost:3000. + """ + + skipped_path = os.path.join(outputs_dir, "skipped_indexes.json") + f = open(skipped_path, "r") + data = json.load(f) + url = host + "/api/local/upload-skipped-features" + requests.post( + url, + json=data, + ) + + print( + Align.center( + Panel( + """ +Dead feature stubs created. +""", + title="Complete", + ) + ) + ) + + +# Helper utilities that help fix weird NaNs in the feature outputs + + +def nanToNeg999(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: nanToNeg999(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [nanToNeg999(v) for v in obj] + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): + return -999 + return obj + + +class NanConverter(json.JSONEncoder): + def encode(self, o: Any, *args: Any, **kwargs: Any): + return super().encode(nanToNeg999(o), *args, **kwargs) + + +if __name__ == "__main__": + app() diff --git a/tutorials/neuronpedia/np_runner.sh b/tutorials/neuronpedia/np_runner.sh deleted file mode 100755 index c005d48c..00000000 --- a/tutorials/neuronpedia/np_runner.sh +++ /dev/null @@ -1,15 +0,0 @@ -# script for working around memory issues - this ensures every batch gets the whole system available memory -# better fix is to investigate and fix the memory issues - -#!/bin/bash -LAYER=$1 -TYPE=$2 -SOURCE_AUTHOR_SUFFIX=$3 -FEATURES_AT_A_TIME=$4 -START_BATCH_INCLUSIVE=$5 -END_BATCH_INCLUSIVE=$6 -for j in $(seq $5 $6) - do - echo "Iteration: $j" - python np_runner_batch.py $1 $2 $3 $4 $j $j -done \ No newline at end of file diff --git a/tutorials/neuronpedia/np_runner_batch.py b/tutorials/neuronpedia/np_runner_batch.py deleted file mode 100644 index 23a72ec3..00000000 --- a/tutorials/neuronpedia/np_runner_batch.py +++ /dev/null @@ -1,35 +0,0 @@ -import sys - -LAYER = int(sys.argv[1]) # 0 -TYPE = sys.argv[2] # "resid" -SOURCE_AUTHOR_SUFFIX = sys.argv[3] # "sm" -FEATURES_AT_A_TIME = int( - sys.argv[4] -) # this must stay the same or your batching will be off -START_BATCH_INCLUSIVE = int(sys.argv[5]) -END_BATCH_INCLUSIVE = int(sys.argv[6]) if len(sys.argv) > 6 else None - -# Change these depending on how your files are named -SAE_PATH = f"../../data/{SOURCE_AUTHOR_SUFFIX}/sae_{LAYER}_{TYPE}.pt" -FEATURE_SPARSITY_PATH = ( - f"../../data/{SOURCE_AUTHOR_SUFFIX}/feature_sparsity_{LAYER}_{TYPE}.pt" -) - -from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner - -NP_OUTPUT_FOLDER = "../../neuronpedia_outputs" - -runner = NeuronpediaRunner( - sae_path=SAE_PATH, - feature_sparsity_path=FEATURE_SPARSITY_PATH, - neuronpedia_parent_folder=NP_OUTPUT_FOLDER, - init_session=True, - n_batches_to_sample_from=2**12, - n_prompts_to_select=4096 * 6, - n_features_at_a_time=FEATURES_AT_A_TIME, - buffer_tokens_left=64, - buffer_tokens_right=63, - start_batch_inclusive=START_BATCH_INCLUSIVE, - end_batch_inclusive=END_BATCH_INCLUSIVE, -) -runner.run()