Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 18, 2024
1 parent b63a8f3 commit f1d6350
Show file tree
Hide file tree
Showing 11 changed files with 60 additions and 102 deletions.
2 changes: 1 addition & 1 deletion sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -24,7 +25,6 @@
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


Expand Down
115 changes: 39 additions & 76 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,27 @@
# set TOKENIZERS_PARALLELISM to false to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import json

import numpy as np
import torch
from matplotlib import colors
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
FeatureTablesConfig,
LogitsHistogramConfig,
LogitsTableConfig,
FeatureTablesConfig,
SaeVisConfig,
SaeVisLayoutConfig,
SequencesConfig,
)
from tqdm import tqdm
from sae_vis.data_storing_fns import SaeVisData
from tqdm import tqdm
from transformer_lens import HookedTransformer

from sae_lens.toolkit.pretrained_saes import load_sparsity
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.toolkit.pretrained_saes import load_sparsity

OUT_OF_RANGE_TOKEN = "<|outofrange|>"

Expand Down Expand Up @@ -85,9 +87,7 @@ def __init__(
self.sae_path, device=self.device
)
loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg)
self.model, _, self.activation_store = (
loader.load_sae_training_group_session()
)
self.model, _, self.activation_store = loader.load_sae_training_group_session()
self.model_id = self.sparse_autoencoder.cfg.model_name
self.layer = self.sparse_autoencoder.cfg.hook_point_layer
self.sae_id = sae_id
Expand Down Expand Up @@ -165,16 +165,12 @@ def run(self):
sparsity = torch.log10(sparsity + 1e-10)
sparsity = sparsity.to(self.device)
self.target_feature_indexes = (
(sparsity > self.sparsity_threshold)
.nonzero(as_tuple=True)[0]
.tolist()
(sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist()
)

# divide into batches
feature_idx = torch.tensor(self.target_feature_indexes)
n_subarrays = np.ceil(
len(feature_idx) / self.n_features_at_a_time
).astype(int)
n_subarrays = np.ceil(len(feature_idx) / self.n_features_at_a_time).astype(int)
feature_idx = np.array_split(feature_idx, n_subarrays)
feature_idx = [x.tolist() for x in feature_idx]

Expand All @@ -189,9 +185,7 @@ def run(self):
exit()

# write dead into file so we can create them as dead in Neuronpedia
skipped_indexes = set(range(self.n_features)) - set(
self.target_feature_indexes
)
skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes)
skipped_indexes_json = json.dumps(
{
"model_id": self.model_id,
Expand Down Expand Up @@ -223,9 +217,7 @@ def run(self):
for k, v in vocab_dict.items():
modified_key = k
for anomaly in HTML_ANOMALIES:
modified_key = modified_key.replace(
anomaly, HTML_ANOMALIES[anomaly]
)
modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly])
new_vocab_dict[v] = modified_key
vocab_dict = new_vocab_dict

Expand All @@ -241,16 +233,11 @@ def run(self):
if feature_batch_count < self.start_batch:
# print(f"Skipping batch - it's after start_batch: {feature_batch_count}")
continue
if (
self.end_batch is not None
and feature_batch_count > self.end_batch
):
if self.end_batch is not None and feature_batch_count > self.end_batch:
# print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
continue

print(
f"========== Running Batch #{feature_batch_count} =========="
)
print(f"========== Running Batch #{feature_batch_count} ==========")

layout = SaeVisLayoutConfig(
columns=[
Expand Down Expand Up @@ -286,17 +273,13 @@ def run(self):
)

features_outputs = []
for _, feat_index in enumerate(
feature_data.feature_data_dict.keys()
):
for _, feat_index in enumerate(feature_data.feature_data_dict.keys()):
feature = feature_data.feature_data_dict[feat_index]

feature_output = {}
feature_output["featureIndex"] = feat_index

top10_logits = self.round_list(
feature.logits_table_data.top_logits
)
top10_logits = self.round_list(feature.logits_table_data.top_logits)
bottom10_logits = self.round_list(
feature.logits_table_data.bottom_logits
)
Expand All @@ -305,41 +288,29 @@ def run(self):
feature_output["neuron_alignment_indices"] = (
feature.feature_tables_data.neuron_alignment_indices
)
feature_output["neuron_alignment_values"] = (
self.round_list(
feature.feature_tables_data.neuron_alignment_values
)
feature_output["neuron_alignment_values"] = self.round_list(
feature.feature_tables_data.neuron_alignment_values
)
feature_output["neuron_alignment_l1"] = (
self.round_list(
feature.feature_tables_data.neuron_alignment_l1
)
feature_output["neuron_alignment_l1"] = self.round_list(
feature.feature_tables_data.neuron_alignment_l1
)
feature_output["correlated_neurons_indices"] = (
feature.feature_tables_data.correlated_neurons_indices
)
feature_output["correlated_neurons_l1"] = (
self.round_list(
feature.feature_tables_data.correlated_neurons_cossim
)
feature_output["correlated_neurons_l1"] = self.round_list(
feature.feature_tables_data.correlated_neurons_cossim
)
feature_output["correlated_neurons_pearson"] = (
self.round_list(
feature.feature_tables_data.correlated_neurons_pearson
)
feature_output["correlated_neurons_pearson"] = self.round_list(
feature.feature_tables_data.correlated_neurons_pearson
)
feature_output["correlated_features_indices"] = (
feature.feature_tables_data.correlated_features_indices
)
feature_output["correlated_features_l1"] = (
self.round_list(
feature.feature_tables_data.correlated_features_cossim
)
feature_output["correlated_features_l1"] = self.round_list(
feature.feature_tables_data.correlated_features_cossim
)
feature_output["correlated_features_pearson"] = (
self.round_list(
feature.feature_tables_data.correlated_features_pearson
)
feature_output["correlated_features_pearson"] = self.round_list(
feature.feature_tables_data.correlated_features_pearson
)

feature_output["neg_str"] = self.to_str_tokens_safe(
Expand All @@ -353,32 +324,28 @@ def run(self):

feature_output["frac_nonzero"] = (
float(
feature.acts_histogram_data.title.split(" = ")[
1
].split("%")[0]
feature.acts_histogram_data.title.split(" = ")[1].split(
"%"
)[0]
)
/ 100
if feature.acts_histogram_data.title is not None
else 0
)

freq_hist_data = feature.acts_histogram_data
freq_bar_values = self.round_list(
freq_hist_data.bar_values
)
feature_output["freq_hist_data_bar_values"] = (
freq_bar_values
)
feature_output["freq_hist_data_bar_heights"] = (
self.round_list(freq_hist_data.bar_heights)
freq_bar_values = self.round_list(freq_hist_data.bar_values)
feature_output["freq_hist_data_bar_values"] = freq_bar_values
feature_output["freq_hist_data_bar_heights"] = self.round_list(
freq_hist_data.bar_heights
)

logits_hist_data = feature.logits_histogram_data
feature_output["logits_hist_data_bar_heights"] = (
self.round_list(logits_hist_data.bar_heights)
feature_output["logits_hist_data_bar_heights"] = self.round_list(
logits_hist_data.bar_heights
)
feature_output["logits_hist_data_bar_values"] = (
self.round_list(logits_hist_data.bar_values)
feature_output["logits_hist_data_bar_values"] = self.round_list(
logits_hist_data.bar_values
)

feature_output["num_tokens_for_dashboard"] = (
Expand Down Expand Up @@ -432,12 +399,8 @@ def run(self):
{"pos": posContribs, "neg": negContribs}
)
activation["tokens"] = strs
activation["values"] = self.round_list(
sd.feat_acts
)
activation["maxValue"] = max(
activation["values"]
)
activation["values"] = self.round_list(sd.feat_acts)
activation["maxValue"] = max(activation["values"])
activation["lossValues"] = self.round_list(
sd.loss_contribution
)
Expand Down
1 change: 0 additions & 1 deletion sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Optional, cast

import torch

import wandb

DTYPE_MAP = {
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

import pandas as pd
import torch
import wandb
from transformer_lens.hook_points import HookedRootModule

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.sparse_autoencoder import SparseAutoencoder

Expand Down
1 change: 1 addition & 0 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, cast

import wandb

from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader

Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import einops
import torch

import wandb

from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.toy_models import Config as ToyConfig
from sae_lens.training.toy_models import Model as ToyModel
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from typing import Any, cast

import torch
import wandb
from safetensors.torch import save_file
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm
from transformer_lens.hook_points import HookedRootModule

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.evals import run_evals
from sae_lens.training.geometric_median import compute_geometric_median
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Any, cast

import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_lens.training.sparse_autoencoder import SparseAutoencoder


Expand Down
2 changes: 1 addition & 1 deletion tests/unit/training/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import pytest
import torch
import wandb
from datasets import Dataset
from torch import Tensor
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.optim import get_scheduler
from sae_lens.training.sae_group import SparseAutoencoderDictionary
Expand Down
1 change: 1 addition & 0 deletions tutorials/neuronpedia/make_batch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys

from sae_lens.analysis.neuronpedia_runner import NeuronpediaRunner

# we use another python script to launch this using subprocess to work around OOM issues - this ensures every batch gets the whole system available memory
Expand Down
Loading

0 comments on commit f1d6350

Please sign in to comment.