diff --git a/sae_lens/analysis/dashboard_runner.py b/sae_lens/analysis/dashboard_runner.py index d77be35d..150b0a52 100644 --- a/sae_lens/analysis/dashboard_runner.py +++ b/sae_lens/analysis/dashboard_runner.py @@ -11,6 +11,7 @@ import plotly import plotly.express as px import torch +import wandb from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, @@ -24,7 +25,6 @@ from tqdm import tqdm from transformer_lens import HookedTransformer -import wandb from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/analysis/neuronpedia_runner.py b/sae_lens/analysis/neuronpedia_runner.py index 99228c4a..401941c4 100644 --- a/sae_lens/analysis/neuronpedia_runner.py +++ b/sae_lens/analysis/neuronpedia_runner.py @@ -4,25 +4,27 @@ # set TOKENIZERS_PARALLELISM to false to avoid warnings os.environ["TOKENIZERS_PARALLELISM"] = "false" import json + import numpy as np import torch from matplotlib import colors from sae_vis.data_config_classes import ( ActsHistogramConfig, Column, + FeatureTablesConfig, LogitsHistogramConfig, LogitsTableConfig, - FeatureTablesConfig, SaeVisConfig, SaeVisLayoutConfig, SequencesConfig, ) -from tqdm import tqdm from sae_vis.data_storing_fns import SaeVisData +from tqdm import tqdm from transformer_lens import HookedTransformer + +from sae_lens.toolkit.pretrained_saes import load_sparsity from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader from sae_lens.training.sparse_autoencoder import SparseAutoencoder -from sae_lens.toolkit.pretrained_saes import load_sparsity OUT_OF_RANGE_TOKEN = "<|outofrange|>" @@ -85,9 +87,7 @@ def __init__( self.sae_path, device=self.device ) loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg) - self.model, _, self.activation_store = ( - loader.load_sae_training_group_session() - ) + self.model, _, self.activation_store = loader.load_sae_training_group_session() self.model_id = self.sparse_autoencoder.cfg.model_name self.layer = self.sparse_autoencoder.cfg.hook_point_layer self.sae_id = sae_id @@ -165,16 +165,12 @@ def run(self): sparsity = torch.log10(sparsity + 1e-10) sparsity = sparsity.to(self.device) self.target_feature_indexes = ( - (sparsity > self.sparsity_threshold) - .nonzero(as_tuple=True)[0] - .tolist() + (sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist() ) # divide into batches feature_idx = torch.tensor(self.target_feature_indexes) - n_subarrays = np.ceil( - len(feature_idx) / self.n_features_at_a_time - ).astype(int) + n_subarrays = np.ceil(len(feature_idx) / self.n_features_at_a_time).astype(int) feature_idx = np.array_split(feature_idx, n_subarrays) feature_idx = [x.tolist() for x in feature_idx] @@ -189,9 +185,7 @@ def run(self): exit() # write dead into file so we can create them as dead in Neuronpedia - skipped_indexes = set(range(self.n_features)) - set( - self.target_feature_indexes - ) + skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes) skipped_indexes_json = json.dumps( { "model_id": self.model_id, @@ -223,9 +217,7 @@ def run(self): for k, v in vocab_dict.items(): modified_key = k for anomaly in HTML_ANOMALIES: - modified_key = modified_key.replace( - anomaly, HTML_ANOMALIES[anomaly] - ) + modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly]) new_vocab_dict[v] = modified_key vocab_dict = new_vocab_dict @@ -241,16 +233,11 @@ def run(self): if feature_batch_count < self.start_batch: # print(f"Skipping batch - it's after start_batch: {feature_batch_count}") continue - if ( - self.end_batch is not None - and feature_batch_count > self.end_batch - ): + if self.end_batch is not None and feature_batch_count > self.end_batch: # print(f"Skipping batch - it's after end_batch: {feature_batch_count}") continue - print( - f"========== Running Batch #{feature_batch_count} ==========" - ) + print(f"========== Running Batch #{feature_batch_count} ==========") layout = SaeVisLayoutConfig( columns=[ @@ -286,17 +273,13 @@ def run(self): ) features_outputs = [] - for _, feat_index in enumerate( - feature_data.feature_data_dict.keys() - ): + for _, feat_index in enumerate(feature_data.feature_data_dict.keys()): feature = feature_data.feature_data_dict[feat_index] feature_output = {} feature_output["featureIndex"] = feat_index - top10_logits = self.round_list( - feature.logits_table_data.top_logits - ) + top10_logits = self.round_list(feature.logits_table_data.top_logits) bottom10_logits = self.round_list( feature.logits_table_data.bottom_logits ) @@ -305,41 +288,29 @@ def run(self): feature_output["neuron_alignment_indices"] = ( feature.feature_tables_data.neuron_alignment_indices ) - feature_output["neuron_alignment_values"] = ( - self.round_list( - feature.feature_tables_data.neuron_alignment_values - ) + feature_output["neuron_alignment_values"] = self.round_list( + feature.feature_tables_data.neuron_alignment_values ) - feature_output["neuron_alignment_l1"] = ( - self.round_list( - feature.feature_tables_data.neuron_alignment_l1 - ) + feature_output["neuron_alignment_l1"] = self.round_list( + feature.feature_tables_data.neuron_alignment_l1 ) feature_output["correlated_neurons_indices"] = ( feature.feature_tables_data.correlated_neurons_indices ) - feature_output["correlated_neurons_l1"] = ( - self.round_list( - feature.feature_tables_data.correlated_neurons_cossim - ) + feature_output["correlated_neurons_l1"] = self.round_list( + feature.feature_tables_data.correlated_neurons_cossim ) - feature_output["correlated_neurons_pearson"] = ( - self.round_list( - feature.feature_tables_data.correlated_neurons_pearson - ) + feature_output["correlated_neurons_pearson"] = self.round_list( + feature.feature_tables_data.correlated_neurons_pearson ) feature_output["correlated_features_indices"] = ( feature.feature_tables_data.correlated_features_indices ) - feature_output["correlated_features_l1"] = ( - self.round_list( - feature.feature_tables_data.correlated_features_cossim - ) + feature_output["correlated_features_l1"] = self.round_list( + feature.feature_tables_data.correlated_features_cossim ) - feature_output["correlated_features_pearson"] = ( - self.round_list( - feature.feature_tables_data.correlated_features_pearson - ) + feature_output["correlated_features_pearson"] = self.round_list( + feature.feature_tables_data.correlated_features_pearson ) feature_output["neg_str"] = self.to_str_tokens_safe( @@ -353,9 +324,9 @@ def run(self): feature_output["frac_nonzero"] = ( float( - feature.acts_histogram_data.title.split(" = ")[ - 1 - ].split("%")[0] + feature.acts_histogram_data.title.split(" = ")[1].split( + "%" + )[0] ) / 100 if feature.acts_histogram_data.title is not None @@ -363,22 +334,18 @@ def run(self): ) freq_hist_data = feature.acts_histogram_data - freq_bar_values = self.round_list( - freq_hist_data.bar_values - ) - feature_output["freq_hist_data_bar_values"] = ( - freq_bar_values - ) - feature_output["freq_hist_data_bar_heights"] = ( - self.round_list(freq_hist_data.bar_heights) + freq_bar_values = self.round_list(freq_hist_data.bar_values) + feature_output["freq_hist_data_bar_values"] = freq_bar_values + feature_output["freq_hist_data_bar_heights"] = self.round_list( + freq_hist_data.bar_heights ) logits_hist_data = feature.logits_histogram_data - feature_output["logits_hist_data_bar_heights"] = ( - self.round_list(logits_hist_data.bar_heights) + feature_output["logits_hist_data_bar_heights"] = self.round_list( + logits_hist_data.bar_heights ) - feature_output["logits_hist_data_bar_values"] = ( - self.round_list(logits_hist_data.bar_values) + feature_output["logits_hist_data_bar_values"] = self.round_list( + logits_hist_data.bar_values ) feature_output["num_tokens_for_dashboard"] = ( @@ -432,12 +399,8 @@ def run(self): {"pos": posContribs, "neg": negContribs} ) activation["tokens"] = strs - activation["values"] = self.round_list( - sd.feat_acts - ) - activation["maxValue"] = max( - activation["values"] - ) + activation["values"] = self.round_list(sd.feat_acts) + activation["maxValue"] = max(activation["values"]) activation["lossValues"] = self.round_list( sd.loss_contribution ) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 2a91b86f..c28de151 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -2,7 +2,6 @@ from typing import Any, Optional, cast import torch - import wandb DTYPE_MAP = { diff --git a/sae_lens/training/evals.py b/sae_lens/training/evals.py index 80aa91dd..9742512d 100644 --- a/sae_lens/training/evals.py +++ b/sae_lens/training/evals.py @@ -3,9 +3,9 @@ import pandas as pd import torch +import wandb from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index c6a4ce79..4d7ea440 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -1,6 +1,7 @@ from typing import Any, cast import wandb + from sae_lens.training.config import LanguageModelSAERunnerConfig from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader diff --git a/sae_lens/training/toy_model_runner.py b/sae_lens/training/toy_model_runner.py index 29b436a6..83044c3c 100644 --- a/sae_lens/training/toy_model_runner.py +++ b/sae_lens/training/toy_model_runner.py @@ -3,8 +3,8 @@ import einops import torch - import wandb + from sae_lens.training.sparse_autoencoder import SparseAutoencoder from sae_lens.training.toy_models import Config as ToyConfig from sae_lens.training.toy_models import Model as ToyModel diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index efcd3473..9cf9e9a6 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -3,13 +3,13 @@ from typing import Any, cast import torch +import wandb from safetensors.torch import save_file from torch.optim import Adam, Optimizer from torch.optim.lr_scheduler import LRScheduler from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.evals import run_evals from sae_lens.training.geometric_median import compute_geometric_median diff --git a/sae_lens/training/train_sae_on_toy_model.py b/sae_lens/training/train_sae_on_toy_model.py index e5227da9..61236b7a 100644 --- a/sae_lens/training/train_sae_on_toy_model.py +++ b/sae_lens/training/train_sae_on_toy_model.py @@ -1,10 +1,10 @@ from typing import Any, cast import torch +import wandb from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from sae_lens.training.sparse_autoencoder import SparseAutoencoder diff --git a/tests/unit/training/test_train_sae_on_language_model.py b/tests/unit/training/test_train_sae_on_language_model.py index eae5b35d..e0c8b9e6 100644 --- a/tests/unit/training/test_train_sae_on_language_model.py +++ b/tests/unit/training/test_train_sae_on_language_model.py @@ -5,11 +5,11 @@ import pytest import torch +import wandb from datasets import Dataset from torch import Tensor from transformer_lens import HookedTransformer -import wandb from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.optim import get_scheduler from sae_lens.training.sae_group import SparseAutoencoderDictionary diff --git a/tutorials/neuronpedia/make_batch.py b/tutorials/neuronpedia/make_batch.py index b5f28e82..945a2939 100644 --- a/tutorials/neuronpedia/make_batch.py +++ b/tutorials/neuronpedia/make_batch.py @@ -1,4 +1,5 @@ import sys + from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner # we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory diff --git a/tutorials/neuronpedia/neuronpedia.py b/tutorials/neuronpedia/neuronpedia.py index 6e3e9af9..aa36a6b6 100755 --- a/tutorials/neuronpedia/neuronpedia.py +++ b/tutorials/neuronpedia/neuronpedia.py @@ -2,21 +2,23 @@ # better fix is to investigate and fix the memory issues import json -import os -import requests -import typer -import torch import math +import os import subprocess -from typing import Any from decimal import Decimal from pathlib import Path -from typing_extensions import Annotated +from typing import Any + +import requests +import torch +import typer from rich import print from rich.align import Align from rich.panel import Panel -from sae_lens.training.sparse_autoencoder import SparseAutoencoder +from typing_extensions import Annotated + from sae_lens.toolkit.pretrained_saes import load_sparsity +from sae_lens.training.sparse_autoencoder import SparseAutoencoder OUTPUT_DIR_BASE = Path("../../neuronpedia_outputs") @@ -120,9 +122,7 @@ def generate( print("Error: cfg.json file not found in SAE directory.") raise typer.Abort() if sae_path.joinpath("sae_weights.safetensors").is_file() is not True: - print( - "Error: sae_weights.safetensors file not found in SAE directory." - ) + print("Error: sae_weights.safetensors file not found in SAE directory.") raise typer.Abort() if sae_path.joinpath("sparsity.safetensors").is_file() is not True: print("Error: sparsity.safetensors file not found in SAE directory.") @@ -144,9 +144,7 @@ def generate( outputs_subdir = f"{model_id}_{sae_id}_{sparse_autoencoder.cfg.hook_point}" outputs_dir = OUTPUT_DIR_BASE.joinpath(outputs_subdir) if outputs_dir.exists() and outputs_dir.is_file(): - print( - f"Error: Output directory {outputs_dir.as_posix()} exists and is a file." - ) + print(f"Error: Output directory {outputs_dir.as_posix()} exists and is a file.") raise typer.Abort() outputs_dir.mkdir(parents=True, exist_ok=True) # Check if output_dir has any files starting with "batch_" @@ -163,9 +161,7 @@ def generate( if len(sparsity) > 0 and sparsity[0] >= 0: sparsity = torch.log10(sparsity + 1e-10) sparsity = sparsity.to(device) - alive_indexes = ( - (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() - ) + alive_indexes = (sparsity > log_sparsity).nonzero(as_tuple=True)[0].tolist() num_alive = len(alive_indexes) num_dead = sparse_autoencoder.d_sae - num_alive @@ -394,9 +390,7 @@ def nanToNeg999(obj: Any) -> Any: return {k: nanToNeg999(v) for k, v in obj.items()} elif isinstance(obj, list): return [nanToNeg999(v) for v in obj] - elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan( - obj - ): + elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj): return -999 return obj