Skip to content

Commit

Permalink
chore: adding more tests to ActivationsStore + light refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Mar 6, 2024
1 parent b2478c1 commit cc9899c
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 143 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ reportUnnecessaryIsInstance = "none"
reportUnnecessaryComparison = "none"
reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"

[build-system]
requires = ["poetry-core"]
Expand Down
99 changes: 54 additions & 45 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
from typing import Any, Iterator, cast

import torch
from datasets import load_dataset
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer

HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset


class ActivationsStore:
"""
Expand All @@ -17,11 +25,14 @@ def __init__(
self,
cfg: Any,
model: HookedTransformer,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.dataset = dataset or load_dataset(
cfg.dataset_path, split="train", streaming=True
)
self.iterable_dataset = iter(self.dataset)

# Check if dataset is tokenized
Expand Down Expand Up @@ -78,23 +89,7 @@ def get_batch_tokens(self):

# pbar = tqdm(total=batch_size, desc="Filling batches")
while batch_tokens.shape[0] < batch_size:
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
).squeeze(0)
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
dtype=torch.long,
device=device,
requires_grad=False,
)
tokens = self._get_next_dataset_tokens()
token_len = tokens.shape[0]

# TODO: Fix this so that we are limiting how many tokens we get from the same context.
Expand All @@ -120,16 +115,18 @@ def get_batch_tokens(self):

# Remove used part, add BOS
tokens = tokens[space_left:]
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)

token_len -= space_left
token_len += 1

# only add BOS if it's not already the first token
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)
token_len += 1
current_length = context_size

# If a batch is full, concatenate and move to next batch
Expand All @@ -145,7 +142,7 @@ def get_batch_tokens(self):
# pbar.refresh()
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
def get_activations(self, batch_tokens: torch.Tensor):
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
Expand All @@ -156,24 +153,15 @@ def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name] for act_name in act_names
act[:, :, self.cfg.hook_point_head_index] for act in activations_list
]

# Stack along a new dimension to keep separate layers distinct
Expand Down Expand Up @@ -315,3 +303,24 @@ def next_batch(self):
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
device = self.cfg.device
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)["text"]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
).squeeze(0)
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)["tokens"],
dtype=torch.long,
device=device,
requires_grad=False,
)
return tokens
9 changes: 9 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from transformer_lens import HookedTransformer

from tests.unit.helpers import TEST_MODEL


@pytest.fixture
def model():
return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu")
50 changes: 50 additions & 0 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import torch

from sae_training.config import LanguageModelSAERunnerConfig

TEST_MODEL = "tiny-stories-1M"
TEST_DATASET = "roneneldan/TinyStories"


def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
"""
Helper to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
mock_config = LanguageModelSAERunnerConfig(
model_name=TEST_MODEL,
hook_point="blocks.0.hook_mlp_out",
hook_point_layer=0,
hook_point_head_index=None,
dataset_path=TEST_DATASET,
is_dataset_tokenized=False,
use_cached_activations=False,
d_in=64,
expansion_factor=2,
d_sae=64 * 2,
l1_coefficient=2e-3,
lp_norm=1,
lr=2e-4,
train_batch_size=2048,
context_size=64,
feature_sampling_window=50,
dead_feature_threshold=1e-7,
n_batches_in_buffer=10,
total_training_tokens=1_000_000,
store_batch_size=32,
log_to_wandb=False,
wandb_project="test_project",
wandb_entity="test_entity",
wandb_log_frequency=10,
device=torch.device("cpu"),
seed=24,
checkpoint_path="test/checkpoints",
dtype=torch.float32,
)

for key, val in kwargs.items():
setattr(mock_config, key, val)

return mock_config
Loading

0 comments on commit cc9899c

Please sign in to comment.