Skip to content

Commit

Permalink
Merge branch 'main' into dev_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Apr 20, 2024
2 parents 7b1d6b3 + 2d31672 commit 0e26c4a
Show file tree
Hide file tree
Showing 34 changed files with 1,470 additions and 950 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ jobs:
run: poetry run flake8 .
- name: black code formatting
run: poetry run black . --check
# - name: isort linting
# run: poetry run isort . --check-only --diff
- name: isort linting
run: poetry run isort . --check-only --diff
- name: type checking
run: poetry run pyright
- name: Run Unit Tests
Expand Down
152 changes: 152 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,128 @@



## v0.5.1 (2024-04-19)

### Chore

* chore: re-enabling isort in CI (#86) ([`9c44731`](https://github.com/jbloomAus/SAELens/commit/9c44731a9b7718c9f0913136ed9df42dac87c390))

### Fix

* fix: pin pyzmq==26.0.1 temporarily ([`0094021`](https://github.com/jbloomAus/SAELens/commit/00940219754ddb1be6708e54cdd0ac6ed5dc3948))

* fix: typing issue, temporary ([`25cebf1`](https://github.com/jbloomAus/SAELens/commit/25cebf1e5e0630a377a5045c1b3571a5f181853f))

### Unknown

* v0.5.1 ([`0ac218b`](https://github.com/jbloomAus/SAELens/commit/0ac218bf8068b8568310b40a1399f9eb3c8d992e))

* Merge pull request #91 from jbloomAus/decoder-fine-tuning

Decoder fine tuning ([`1fc652c`](https://github.com/jbloomAus/SAELens/commit/1fc652c19e2a34172c1fd520565e9620366f565c))

* par update ([`2bb5975`](https://github.com/jbloomAus/SAELens/commit/2bb5975226807d352f2d3cf6b6dad7aefaf1b662))

* Merge pull request #89 from jbloomAus/fix_np

Enhance + Fix Neuronpedia generation / upload ([`38d507c`](https://github.com/jbloomAus/SAELens/commit/38d507c052875cbd78f8fd9dae45a658e47c2b9d))

* minor changes ([`bc766e4`](https://github.com/jbloomAus/SAELens/commit/bc766e4f7a8d472f647408b8a5cd3c6140d856b7))

* reformat run.ipynb ([`822882c`](https://github.com/jbloomAus/SAELens/commit/822882cac9c05449b7237b7d42ce17297903da2f))

* get decoder fine tuning working ([`11a71e1`](https://github.com/jbloomAus/SAELens/commit/11a71e1b95576ef6dc3dbec7eb1c76ce7ca44dfd))

* format ([`040676d`](https://github.com/jbloomAus/SAELens/commit/040676db6814c1f64171a32344e0bed40528c8f9))

* Merge pull request #88 from jbloomAus/get_feature_from_neuronpedia

FEAT: Add API for getting Neuronpedia feature ([`1666a68`](https://github.com/jbloomAus/SAELens/commit/1666a68bb7d7ee4837e03d95b203ee371ca9ea9e))

* Fix resuming from batch ([`145a407`](https://github.com/jbloomAus/SAELens/commit/145a407f8d57755301bf56c87efd4e775c59b980))

* Use original repo for sae_vis ([`1a7d636`](https://github.com/jbloomAus/SAELens/commit/1a7d636c95ef508dde8bd100ab6d9f241b0be977))

* Use correct model name for np runner ([`138d5d4`](https://github.com/jbloomAus/SAELens/commit/138d5d445878c0830c6c96a5fbe6b10a1d9644b0))

* Merge main, remove eindex ([`6578436`](https://github.com/jbloomAus/SAELens/commit/6578436891e71a0ef60fb2ed6d6a6b6279d71cbc))

* Add API for getting Neuronpedia feature ([`e78207d`](https://github.com/jbloomAus/SAELens/commit/e78207d086cb3372dc805cbb4c87b694749cd905))


## v0.5.0 (2024-04-17)

### Feature

* feat: Mamba support vs mamba-lens (#79)

* mamba support

* added init

* added optional model kwargs

* Support transformers and mamba

* forgot one model kwargs

* failed opts

* tokens input

* hack to fix tokens, will look into fixing mambalens

* fixed checkpoint

* added sae group

* removed some comments and fixed merge error

* removed unneeded params since that issue is fixed in mambalens now

* Unneded input param

* removed debug checkpoing and eval

* added refs to hookedrootmodule

* feed linter

* added example and fixed loading

* made layer for eval change

* fix linter issues

* adding mamba-lens as optional dep, and fixing typing/linting

* adding a test for loading mamba model

* adding mamba-lens to dev for CI

* updating min mamba-lens version

* updating mamba-lens version

---------

Co-authored-by: David Chanin <[email protected]> ([`eea7db4`](https://github.com/jbloomAus/SAELens/commit/eea7db4b99098c33cd862e7e2280a32b630826bd))

### Unknown

* update readme ([`440df7b`](https://github.com/jbloomAus/SAELens/commit/440df7b6c0ef55ba3d116054f81e1ee4a58f9089))

* update readme ([`3694fd2`](https://github.com/jbloomAus/SAELens/commit/3694fd2c4cc7438121e4549636508c45835a5d38))

* Fix upload skipped/dead features ([`932f380`](https://github.com/jbloomAus/SAELens/commit/932f380971ce3d431e6592c804d12f6df2b4ec78))

* Use python typer instead of shell script for neuronpedia jobs ([`b611e72`](https://github.com/jbloomAus/SAELens/commit/b611e721dd2620ab5a030cc0f6e37029c30711ca))

* Merge branch 'main' into fix_np ([`cc6cb6a`](https://github.com/jbloomAus/SAELens/commit/cc6cb6a96b793e41fb91f4ebbaf3bfa5e7c11b4e))

* convert sparsity to log sparsity if needed ([`8d7d404`](https://github.com/jbloomAus/SAELens/commit/8d7d4040033fb80c5b994cdc662b0f90b8fcc7aa))


## v0.4.0 (2024-04-16)

### Feature
Expand All @@ -22,8 +144,26 @@

* default orthogonal init false ([`a8b0113`](https://github.com/jbloomAus/SAELens/commit/a8b0113140bd2f9b97befccc8f158dace02a4810))

* Formatting ([`1e3d53e`](https://github.com/jbloomAus/SAELens/commit/1e3d53ec2b72897bfebb6065f3b530fe65d3a97c))

* Eindex required by sae_vis ([`f769e7a`](https://github.com/jbloomAus/SAELens/commit/f769e7a65ab84d4073852931a86ff3b5076eea3c))

* Upload dead feature stubs ([`9067380`](https://github.com/jbloomAus/SAELens/commit/9067380bf67b89d8b2d235944f696016286f683e))

* Make feature sparsity an argument ([`8230570`](https://github.com/jbloomAus/SAELens/commit/8230570297d68e35cb614a63abf442e4a01174d2))

* Fix buffer" ([`dde2481`](https://github.com/jbloomAus/SAELens/commit/dde248162b70ff4311d4182333b7cce43aed78df))

* Merge branch 'main' into fix_np ([`6658392`](https://github.com/jbloomAus/SAELens/commit/66583923cd625bfc1c1ef152bc5f5beaa764b2d6))

* notebook update ([`feca408`](https://github.com/jbloomAus/SAELens/commit/feca408cf003737cd4eb529ca7fea2f77984f5c6))

* Merge branch 'main' into fix_np ([`f8fb3ef`](https://github.com/jbloomAus/SAELens/commit/f8fb3efbde7fc79e6fafe2d9b3324c9f0b2a337d))

* Final fixes ([`e87788d`](https://github.com/jbloomAus/SAELens/commit/e87788d63a9b767e34e497c85a318337ab8aabb8))

* Don't use buffer, fix anomalies ([`2c9ca64`](https://github.com/jbloomAus/SAELens/commit/2c9ca642b334b7a444544a4640c483229dc04c62))


## v0.3.0 (2024-04-15)

Expand All @@ -44,6 +184,10 @@

* make dense_batch_mse_normalization optional ([`c41774e`](https://github.com/jbloomAus/SAELens/commit/c41774e5cfaeb195e3320e9e3fc93d60d921337d))

* Runner is fixed, faster, cleaned up, and now gives whole sequences instead of buffer. ([`3837884`](https://github.com/jbloomAus/SAELens/commit/383788485917cee114fba24e8ded944aefcfb568))

* Merge branch 'main' into fix_np ([`3ed30cf`](https://github.com/jbloomAus/SAELens/commit/3ed30cf2b84a2444c8ed030641214f0dbb65898a))

* add warning in run script ([`9a772ca`](https://github.com/jbloomAus/SAELens/commit/9a772ca6da155b5e97bc3109da74457f5addfbfd))

* update sae loading code ([`356a8ef`](https://github.com/jbloomAus/SAELens/commit/356a8efba06e4f453d2f15afe9171b71d780819a))
Expand Down Expand Up @@ -74,6 +218,10 @@

### Unknown

* Use legacy loader, add back histograms, logits. Fix anomaly characters. ([`ebbb622`](https://github.com/jbloomAus/SAELens/commit/ebbb622353bef21c953f844a108ea8d9fe31e9f9))

* Merge branch 'main' into fix_np ([`586e088`](https://github.com/jbloomAus/SAELens/commit/586e0881e08a9b013e2d4101878ef054c1f3dd8b))

* Merge pull request #80 from wllgrnt/will-update-tutorial

bugfix - minimum viable updates to tutorial notebook ([`e51016b`](https://github.com/jbloomAus/SAELens/commit/e51016b01f3b0f30c83365c54430908779671d87))
Expand Down Expand Up @@ -118,6 +266,8 @@ Fix artifact saving loading ([`8784c74`](https://github.com/jbloomAus/SAELens/co

* add safetensors to project ([`0da48b0`](https://github.com/jbloomAus/SAELens/commit/0da48b044357eed17e5afffd3ce541e064185043))

* Don't precompute background colors and tick values ([`271dbf0`](https://github.com/jbloomAus/SAELens/commit/271dbf05567b6e6ae4cfc1dab138132872038381))

* Merge pull request #71 from weissercn/main

Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/commit/84175055ba5876b335cbc0de38bf709d0b11cec1))
Expand All @@ -126,6 +276,8 @@ Addressing notebook issues ([`8417505`](https://github.com/jbloomAus/SAELens/com

chore: updating README.md with pip install instructions and PyPI badge ([`4d7d1e7`](https://github.com/jbloomAus/SAELens/commit/4d7d1e7db5e952c7e9accf19c0ccce466cdcf6cf))

* FIX: Add back correlated neurons, frac_nonzero ([`d532b82`](https://github.com/jbloomAus/SAELens/commit/d532b828bd77c18b73f495d6b42ca53b5148fd2f))

* linting ([`1db0b5a`](https://github.com/jbloomAus/SAELens/commit/1db0b5ae7e091822c72bba0488d30fc16bc9a1c6))

* fixed graph name ([`ace4813`](https://github.com/jbloomAus/SAELens/commit/ace481322103737de2e80d688683d0c937ac5558))
Expand Down
Empty file added __init__.py
Empty file.
4 changes: 2 additions & 2 deletions docs/training_saes.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cfg = LanguageModelSAERunnerConfig(

# Activation Store Parameters
n_batches_in_buffer = 128,
total_training_tokens = 1_000_000 * 300,
training_tokens = 1_000_000 * 300,
store_batch_size = 32,

# Dead Neurons and Sparsity
Expand All @@ -60,7 +60,7 @@ cfg = LanguageModelSAERunnerConfig(

# WANDB
log_to_wandb = True,
wandb_project= "mats_sae_training_gpt2",
wandb_project= "gpt2",
wandb_entity = None,
wandb_log_frequency=100,

Expand Down
14 changes: 11 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sae-lens"
version = "0.4.0"
version = "0.5.1"
description = "Training and Analyzing Sparse Autoencoders (SAEs)"
authors = ["Joseph Bloom"]
readme = "README.md"
Expand All @@ -20,14 +20,17 @@ matplotlib-inline = "^0.1.6"
datasets = "^2.17.1"
babe = "^0.0.7"
nltk = "^3.8.1"
sae-vis = "0.2.6"
sae-vis = "^0.2.15"
mkdocs = "^1.5.3"
mkdocs-material = "^9.5.15"
mkdocs-autorefs = "^1.0.1"
mkdocs-section-index = "^0.3.8"
mkdocstrings = "^0.24.1"
mkdocstrings-python = "^1.9.0"
safetensors = "^0.4.2"
typer = "^0.12.3"
mamba-lens = { version = "^0.0.4", optional = true }
pyzmq = "26.0.0"


[tool.poetry.group.dev.dependencies]
Expand All @@ -36,12 +39,17 @@ pytest = "^8.0.2"
pytest-cov = "^4.1.0"
pre-commit = "^3.6.2"
flake8 = "^7.0.0"
isort = "^5.13.2"
isort = "5.13.2"
pyright = "^1.1.351"
mamba-lens = "^0.0.4"

[tool.poetry.extras]
mamba = ["mamba-lens"]


[tool.isort]
profile = "black"
src_paths = ["sae_lens", "tests"]

[tool.pyright]
typeCheckingMode = "strict"
Expand Down
2 changes: 1 addition & 1 deletion sae_lens/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.4.0"
__version__ = "0.5.1"

from .training.activations_store import ActivationsStore
from .training.cache_activations_runner import cache_activations_runner
Expand Down
12 changes: 10 additions & 2 deletions sae_lens/analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import plotly
import plotly.express as px
import torch
import wandb
from sae_vis.data_config_classes import (
ActsHistogramConfig,
Column,
Expand All @@ -22,13 +23,15 @@
from sae_vis.data_fetching_fns import get_feature_data
from torch.nn.functional import cosine_similarity
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_lens.training.session_loader import LMSparseAutoencoderSessionloader


class DashboardRunner:

model: HookedTransformer | None = None

def __init__(
self,
sae_path: Optional[str] = None,
Expand Down Expand Up @@ -131,10 +134,14 @@ def get_dashboard_folder_name(self):

def init_sae_session(self):
(
self.model,
model,
sae_group,
self.activation_store,
) = LMSparseAutoencoderSessionloader.load_pretrained_sae(self.sae_path)
assert isinstance(
model, HookedTransformer
) # only HookedTransformer is allowed to be used in the dashboard
self.model = model
# TODO: handle multiple autoencoders
self.sparse_autoencoder = next(iter(sae_group))[1]

Expand Down Expand Up @@ -316,6 +323,7 @@ def run(self):
if self.use_wandb:
wandb.log({"time/time_to_get_tokens": end - start})

assert self.model is not None
vocab_dict = cast(Any, self.model.tokenizer).vocab
vocab_dict = {
v: k.replace("Ġ", " ").replace("\n", "\\n") for k, v in vocab_dict.items()
Expand Down
23 changes: 22 additions & 1 deletion sae_lens/analysis/neuronpedia_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import urllib.parse
import webbrowser

import requests


def get_neuronpedia_quick_list(
features: list[int],
Expand All @@ -14,10 +16,29 @@ def get_neuronpedia_quick_list(
name = urllib.parse.quote(name)
url = url + "?name=" + name
list_feature = [
{"modelId": model, "layer": f"{layer}-{dataset}", "index": str(feature)}
{
"modelId": model,
"layer": f"{layer}-{dataset}",
"index": str(feature),
}
for feature in features
]
url = url + "&features=" + urllib.parse.quote(json.dumps(list_feature))
webbrowser.open(url)

return url


def get_neuronpedia_feature(
feature: int,
layer: int,
model: str = "gpt2-small",
dataset: str = "res-jb",
):
url = "https://neuronpedia.org/api/feature/"
url = url + f"{model}/{layer}-{dataset}/{feature}"

result = requests.get(url).json()
result["index"] = int(result["index"])

return result
Loading

0 comments on commit 0e26c4a

Please sign in to comment.