From 465e00333b9fae79ce73c277e7e1018e863002ca Mon Sep 17 00:00:00 2001 From: David Chanin Date: Mon, 26 Feb 2024 13:40:16 +0000 Subject: [PATCH] chore: using poetry for dependency management --- .github/workflows/tests.yml | 16 +++++---- .gitignore | 2 +- README.md | 8 ++--- makefile | 9 +++-- pyproject.toml | 36 +++++++++++++++++++ requirements.txt | 3 +- sae_analysis/dashboard_runner.py | 6 ++-- sae_analysis/visualizer/data_fns.py | 5 +-- sae_analysis/visualizer/html_fns.py | 4 +-- sae_analysis/visualizer/model_fns.py | 8 ++--- sae_training/activations_store.py | 12 +++---- sae_training/cache_activations_runner.py | 2 +- sae_training/config.py | 11 +++--- sae_training/evals.py | 7 ++-- .../geom_median/src/geom_median/numpy/main.py | 2 +- .../src/geom_median/numpy/utils.py | 1 + .../src/geom_median/numpy/weiszfeld_array.py | 11 +++--- .../numpy/weiszfeld_list_of_array.py | 11 +++--- .../geom_median/src/geom_median/torch/main.py | 2 +- .../src/geom_median/torch/utils.py | 1 + .../src/geom_median/torch/weiszfeld_array.py | 8 +++-- .../torch/weiszfeld_list_of_array.py | 11 +++--- sae_training/lm_runner.py | 1 - sae_training/optim.py | 1 + sae_training/toy_model_runner.py | 2 +- sae_training/toy_models.py | 1 + sae_training/train_sae_on_language_model.py | 2 +- sae_training/train_sae_on_toy_model.py | 2 +- scripts/generate_dashboards.py | 6 ++-- 29 files changed, 122 insertions(+), 69 deletions(-) create mode 100644 pyproject.toml diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 21f5fd89..748aaef8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -42,18 +42,20 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - cache: 'pip' + - name: Install Poetry + uses: snok/install-poetry@v1 - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 pytest - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + run: poetry install --no-interaction - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + poetry run flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + poetry run flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: black code formatting + run: poetry run black . --check + - name: isort linting + run: poetry run isort . --check-only --diff - name: Run Unit Tests run: | make unit-test diff --git a/.gitignore b/.gitignore index 754fd271..c6c0809f 100644 --- a/.gitignore +++ b/.gitignore @@ -99,7 +99,7 @@ ipython_config.py # This is especially recommended for binary packages to ensure reproducibility, and is more # commonly ignored for libraries. # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock +poetry.lock # pdm # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. diff --git a/README.md b/README.md index 13bcc59a..a0c15e63 100644 --- a/README.md +++ b/README.md @@ -5,12 +5,10 @@ This codebase contains training scripts and analysis code for Sparse AutoEncoder ## Set Up -``` - -conda create --name mats_sae_training python=3.11 -y -conda activate mats_sae_training -pip install -r requirements.txt +This project uses [Poetry](https://python-poetry.org/) for dependency management. Ensure Poetry is installed, then to install the dependencies, run: +``` +poetry install ``` ## Background diff --git a/makefile b/makefile index ca0e9cc0..40924573 100644 --- a/makefile +++ b/makefile @@ -1,6 +1,11 @@ format: + poetry run black . + poetry run isort . check-format: + poetry run flake8 . + poetry run black --check . + poetry run isort --check-only --diff . test: @@ -8,7 +13,7 @@ test: make acceptance-test unit-test: - pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit + poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit acceptance-test: - pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance + poetry run pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..f02784f8 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[tool.poetry] +name = "mats_sae_training" +version = "0.1.0" +description = "Training Sparse Autoencoders (SAEs)" +authors = ["Joseph Bloom"] +readme = "README.md" +packages = [{include = "sae_analysis"}, {include = "sae_training"}] + +[tool.poetry.dependencies] +python = "^3.10" +transformer-lens = "^1.14.0" +transformers = "^4.38.1" +jupyter = "^1.0.0" +plotly = "^5.19.0" +plotly-express = "^0.4.1" +nbformat = "^5.9.2" +ipykernel = "^6.29.2" +matplotlib = "^3.8.3" +matplotlib-inline = "^0.1.6" +eindex = {git = "https://github.com/callummcdougall/eindex.git"} + + +[tool.poetry.group.dev.dependencies] +black = "^24.2.0" +pytest = "^8.0.2" +pytest-cov = "^4.1.0" +pre-commit = "^3.6.2" +flake8 = "^7.0.0" +isort = "^5.13.2" + +[tool.isort] +profile = "black" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 6c962a30..c33375c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,8 @@ nbformat==5.9.2 ipykernel==6.27.1 matplotlib==3.8.2 matplotlib-inline==0.1.6 -pylint==3.0.2 +flake8==7.0.0 +isort==5.13.2 black==23.11.0 pytest==7.4.3 pytest-cov==4.1.0 diff --git a/sae_analysis/dashboard_runner.py b/sae_analysis/dashboard_runner.py index 3ff2bc7b..815c6b14 100644 --- a/sae_analysis/dashboard_runner.py +++ b/sae_analysis/dashboard_runner.py @@ -14,10 +14,10 @@ import plotly import plotly.express as px import torch +import wandb from torch.nn.functional import cosine_similarity from tqdm import tqdm -import wandb from sae_analysis.visualizer.data_fns import get_feature_data from sae_training.utils import LMSparseAutoencoderSessionloader @@ -148,9 +148,7 @@ def init_sae_session(self): self.activation_store, ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path) - def get_tokens( - self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6 - ): + def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6): """ Get the tokens needed for dashboard generation. """ diff --git a/sae_analysis/visualizer/data_fns.py b/sae_analysis/visualizer/data_fns.py index c6b1c486..effe6f3f 100644 --- a/sae_analysis/visualizer/data_fns.py +++ b/sae_analysis/visualizer/data_fns.py @@ -1,13 +1,10 @@ import gzip -import json import os import pickle import time -from collections import defaultdict from dataclasses import dataclass -from functools import partial from pathlib import Path -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Union import einops import numpy as np diff --git a/sae_analysis/visualizer/html_fns.py b/sae_analysis/visualizer/html_fns.py index 3ece35ac..9b0e546d 100644 --- a/sae_analysis/visualizer/html_fns.py +++ b/sae_analysis/visualizer/html_fns.py @@ -230,8 +230,8 @@ def generate_tables_html( ], [None, "+.2f", ".1%", None, "+.2f", "+.2f", None, "+.2f", "+.2f"], ): - fn = ( - lambda m: str(mylist[int(m.group(1))]) + fn = lambda m: ( + str(mylist[int(m.group(1))]) if myformat is None else format(mylist[int(m.group(1))], myformat) ) diff --git a/sae_analysis/visualizer/model_fns.py b/sae_analysis/visualizer/model_fns.py index 068d94d8..3a4d6879 100644 --- a/sae_analysis/visualizer/model_fns.py +++ b/sae_analysis/visualizer/model_fns.py @@ -1,11 +1,11 @@ -from transformer_lens import utils -import torch import pprint +from dataclasses import dataclass + +import torch import torch.nn as nn import torch.nn.functional as F import tqdm.notebook as tqdm -from dataclasses import dataclass - +from transformer_lens import utils DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 19a0653d..9bae6824 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -206,9 +206,9 @@ def get_buffer(self, n_batches_in_buffer): taking_subset_of_file = True # Add it to the buffer - new_buffer[ - n_tokens_filled : n_tokens_filled + activations.shape[0] - ] = activations + new_buffer[n_tokens_filled : n_tokens_filled + activations.shape[0]] = ( + activations + ) # Update counters n_tokens_filled += activations.shape[0] @@ -235,9 +235,9 @@ def get_buffer(self, n_batches_in_buffer): for refill_batch_idx_start in refill_iterator: refill_batch_tokens = self.get_batch_tokens() refill_activations = self.get_activations(refill_batch_tokens) - new_buffer[ - refill_batch_idx_start : refill_batch_idx_start + batch_size - ] = refill_activations + new_buffer[refill_batch_idx_start : refill_batch_idx_start + batch_size] = ( + refill_activations + ) # pbar.update(1) diff --git a/sae_training/cache_activations_runner.py b/sae_training/cache_activations_runner.py index cb4daae6..b01797b3 100644 --- a/sae_training/cache_activations_runner.py +++ b/sae_training/cache_activations_runner.py @@ -2,8 +2,8 @@ import os import torch -from transformer_lens import HookedTransformer from tqdm import tqdm +from transformer_lens import HookedTransformer from sae_training.activations_store import ActivationsStore from sae_training.config import CacheActivationsRunnerConfig diff --git a/sae_training/config.py b/sae_training/config.py index 9341d57d..78c41205 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -3,7 +3,6 @@ from typing import Optional import torch - import wandb @@ -22,9 +21,9 @@ class RunnerConfig(ABC): is_dataset_tokenized: bool = True context_size: int = 128 use_cached_activations: bool = False - cached_activations_path: Optional[ - str - ] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" + cached_activations_path: Optional[str] = ( + None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}" + ) # SAE Parameters d_in: int = 512 @@ -61,7 +60,9 @@ class LanguageModelSAERunnerConfig(RunnerConfig): # Training Parameters l1_coefficient: float = 1e-3 lr: float = 3e-4 - lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + lr_scheduler_name: str = ( + "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup + ) lr_warm_up_steps: int = 500 train_batch_size: int = 4096 diff --git a/sae_training/evals.py b/sae_training/evals.py index c06017fa..39b4c1f5 100644 --- a/sae_training/evals.py +++ b/sae_training/evals.py @@ -2,11 +2,11 @@ import pandas as pd import torch +import wandb from tqdm import tqdm from transformer_lens import HookedTransformer from transformer_lens.utils import get_act_name -import wandb from sae_training.activations_store import ActivationsStore from sae_training.sparse_autoencoder import SparseAutoencoder @@ -27,7 +27,10 @@ def run_evals( # Get Reconstruction Score losses_df = recons_loss_batched( - sparse_autoencoder, model, activation_store, n_batches = 10, + sparse_autoencoder, + model, + activation_store, + n_batches=10, ) recons_score = losses_df["score"].mean() diff --git a/sae_training/geom_median/src/geom_median/numpy/main.py b/sae_training/geom_median/src/geom_median/numpy/main.py index a52aef68..ddf82412 100644 --- a/sae_training/geom_median/src/geom_median/numpy/main.py +++ b/sae_training/geom_median/src/geom_median/numpy/main.py @@ -1,8 +1,8 @@ import numpy as np +from . import utils from .weiszfeld_array import geometric_median_array, geometric_median_per_component from .weiszfeld_list_of_array import geometric_median_list_of_array -from . import utils def compute_geometric_median( diff --git a/sae_training/geom_median/src/geom_median/numpy/utils.py b/sae_training/geom_median/src/geom_median/numpy/utils.py index 29382e4a..bda8dac7 100644 --- a/sae_training/geom_median/src/geom_median/numpy/utils.py +++ b/sae_training/geom_median/src/geom_median/numpy/utils.py @@ -1,4 +1,5 @@ from itertools import zip_longest + import numpy as np diff --git a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py index 5b352a7c..90b58eb7 100644 --- a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py +++ b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_array.py @@ -1,6 +1,7 @@ -import numpy as np from types import SimpleNamespace +import numpy as np + def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20): """ @@ -36,9 +37,11 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20): return SimpleNamespace( median=median, - termination="function value converged within tolerance" - if early_termination - else "maximum iterations reached", + termination=( + "function value converged within tolerance" + if early_termination + else "maximum iterations reached" + ), logs=logs, ) diff --git a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py index 172894f0..6244a044 100644 --- a/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py +++ b/sae_training/geom_median/src/geom_median/numpy/weiszfeld_list_of_array.py @@ -1,6 +1,7 @@ -import numpy as np from types import SimpleNamespace +import numpy as np + def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20): """ @@ -38,9 +39,11 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol= return SimpleNamespace( median=median, - termination="function value converged within tolerance" - if early_termination - else "maximum iterations reached", + termination=( + "function value converged within tolerance" + if early_termination + else "maximum iterations reached" + ), logs=logs, ) diff --git a/sae_training/geom_median/src/geom_median/torch/main.py b/sae_training/geom_median/src/geom_median/torch/main.py index 86eb4866..fbeacc29 100644 --- a/sae_training/geom_median/src/geom_median/torch/main.py +++ b/sae_training/geom_median/src/geom_median/torch/main.py @@ -1,8 +1,8 @@ import torch +from . import utils from .weiszfeld_array import geometric_median_array, geometric_median_per_component from .weiszfeld_list_of_array import geometric_median_list_of_array -from . import utils def compute_geometric_median( diff --git a/sae_training/geom_median/src/geom_median/torch/utils.py b/sae_training/geom_median/src/geom_median/torch/utils.py index 02b69741..4b5e0950 100644 --- a/sae_training/geom_median/src/geom_median/torch/utils.py +++ b/sae_training/geom_median/src/geom_median/torch/utils.py @@ -1,4 +1,5 @@ from itertools import zip_longest + import torch diff --git a/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py b/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py index ae337a02..b0310866 100644 --- a/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py +++ b/sae_training/geom_median/src/geom_median/torch/weiszfeld_array.py @@ -48,9 +48,11 @@ def geometric_median_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20): return SimpleNamespace( median=median, new_weights=new_weights, - termination="function value converged within tolerance" - if early_termination - else "maximum iterations reached", + termination=( + "function value converged within tolerance" + if early_termination + else "maximum iterations reached" + ), logs=logs, ) diff --git a/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py b/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py index 2920a480..5148b0e4 100644 --- a/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py +++ b/sae_training/geom_median/src/geom_median/torch/weiszfeld_list_of_array.py @@ -1,6 +1,7 @@ +from types import SimpleNamespace + import numpy as np import torch -from types import SimpleNamespace def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol=1e-20): @@ -43,9 +44,11 @@ def geometric_median_list_of_array(points, weights, eps=1e-6, maxiter=100, ftol= return SimpleNamespace( median=median, new_weights=new_weights, - termination="function value converged within tolerance" - if early_termination - else "maximum iterations reached", + termination=( + "function value converged within tolerance" + if early_termination + else "maximum iterations reached" + ), logs=logs, ) diff --git a/sae_training/lm_runner.py b/sae_training/lm_runner.py index e6c382b2..ec038de8 100644 --- a/sae_training/lm_runner.py +++ b/sae_training/lm_runner.py @@ -1,7 +1,6 @@ import os import torch - import wandb # from sae_training.activation_store import ActivationStore diff --git a/sae_training/optim.py b/sae_training/optim.py index 27ea8084..d808f562 100644 --- a/sae_training/optim.py +++ b/sae_training/optim.py @@ -1,6 +1,7 @@ """ Took the LR scheduler from my previous work: https://github.com/jbloomAus/DecisionTransformerInterpretability/blob/ee55df35cdb92e81d689c72fb9dd5a7252893363/src/decision_transformer/utils.py#L425 """ + import math from typing import Optional diff --git a/sae_training/toy_model_runner.py b/sae_training/toy_model_runner.py index 2ac4a36c..7726a24f 100644 --- a/sae_training/toy_model_runner.py +++ b/sae_training/toy_model_runner.py @@ -2,9 +2,9 @@ import einops import torch +import wandb from transformer_lens import HookedTransformer -import wandb from sae_training.sparse_autoencoder import SparseAutoencoder from sae_training.toy_models import Config as ToyConfig from sae_training.toy_models import Model as ToyModel diff --git a/sae_training/toy_models.py b/sae_training/toy_models.py index 98f38439..32671de4 100644 --- a/sae_training/toy_models.py +++ b/sae_training/toy_models.py @@ -4,6 +4,7 @@ https://github.com/callummcdougall/sae-exercises-mats?fbclid=IwAR3qYAELbyD_x5IAYN4yCDFQzxXHeuH6CwMi_E7g4Qg6G1QXRNAYabQ4xGs """ + from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index 4f53af31..b6bc3848 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -1,9 +1,9 @@ import torch +import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens import HookedTransformer -import wandb from sae_training.activations_store import ActivationsStore from sae_training.evals import run_evals from sae_training.optim import get_scheduler diff --git a/sae_training/train_sae_on_toy_model.py b/sae_training/train_sae_on_toy_model.py index 522b151d..5554d4db 100644 --- a/sae_training/train_sae_on_toy_model.py +++ b/sae_training/train_sae_on_toy_model.py @@ -1,9 +1,9 @@ import einops import torch +import wandb from torch.utils.data import DataLoader from tqdm import tqdm -import wandb from sae_training.sparse_autoencoder import SparseAutoencoder from sae_training.toy_models import Model as ToyModel diff --git a/scripts/generate_dashboards.py b/scripts/generate_dashboards.py index 4e8e450e..7da5dc02 100644 --- a/scripts/generate_dashboards.py +++ b/scripts/generate_dashboards.py @@ -14,10 +14,10 @@ import plotly import plotly.express as px import torch +import wandb from torch.nn.functional import cosine_similarity from tqdm import tqdm -import wandb from sae_analysis.visualizer.data_fns import get_feature_data from sae_training.utils import LMSparseAutoencoderSessionloader @@ -128,9 +128,7 @@ def init_sae_session(self): self.activation_store, ) = LMSparseAutoencoderSessionloader.load_session_from_pretrained(self.sae_path) - def get_tokens( - self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6 - ): + def get_tokens(self, n_batches_to_sample_from=2**12, n_prompts_to_select=4096 * 6): """ Get the tokens needed for dashboard generation. """