Skip to content

Commit

Permalink
Merge pull request #20 from chanind/activations_store_tests
Browse files Browse the repository at this point in the history
chore: adding more tests to ActivationsStore + light refactoring
  • Loading branch information
jbloomAus authored Mar 21, 2024
2 parents e814054 + 4896d0a commit 69dcf8e
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 46 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ reportUnnecessaryIsInstance = "none"
reportUnnecessaryComparison = "none"
reportConstantRedefinition = "none"
reportUnknownLambdaType = "none"
reportPrivateUsage = "none"

[build-system]
requires = ["poetry-core"]
Expand Down
99 changes: 54 additions & 45 deletions sae_training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,18 @@
from typing import Any, Iterator, cast

import torch
from datasets import load_dataset
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from torch.utils.data import DataLoader
from transformer_lens import HookedTransformer

HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset


class ActivationsStore:
"""
Expand All @@ -17,11 +25,14 @@ def __init__(
self,
cfg: Any,
model: HookedTransformer,
dataset: HfDataset | None = None,
create_dataloader: bool = True,
):
self.cfg = cfg
self.model = model
self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True)
self.dataset = dataset or load_dataset(
cfg.dataset_path, split="train", streaming=True
)
self.iterable_dataset = iter(self.dataset)

# Check if dataset is tokenized
Expand Down Expand Up @@ -89,23 +100,7 @@ def get_batch_tokens(self):

# pbar = tqdm(total=batch_size, desc="Filling batches")
while batch_tokens.shape[0] < batch_size:
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)[self.tokens_column]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
).squeeze(0)
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)[self.tokens_column],
dtype=torch.long,
device=device,
requires_grad=False,
)
tokens = self._get_next_dataset_tokens()
token_len = tokens.shape[0]

# TODO: Fix this so that we are limiting how many tokens we get from the same context.
Expand All @@ -131,16 +126,18 @@ def get_batch_tokens(self):

# Remove used part, add BOS
tokens = tokens[space_left:]
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)

token_len -= space_left
token_len += 1

# only add BOS if it's not already the first token
if tokens[0] != bos_token_id_tensor:
tokens = torch.cat(
(
bos_token_id_tensor,
tokens,
),
dim=0,
)
token_len += 1
current_length = context_size

# If a batch is full, concatenate and move to next batch
Expand All @@ -156,7 +153,7 @@ def get_batch_tokens(self):
# pbar.refresh()
return batch_tokens[:batch_size]

def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
def get_activations(self, batch_tokens: torch.Tensor):
"""
Returns activations of shape (batches, context, num_layers, d_in)
"""
Expand All @@ -167,24 +164,15 @@ def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False):
)
act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers]
hook_point_max_layer = max(layers)
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [layerwise_activations[act_name] for act_name in act_names]
if self.cfg.hook_point_head_index is not None:
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index]
for act_name in act_names
]
else:
layerwise_activations = self.model.run_with_cache(
batch_tokens,
names_filter=act_names,
stop_at_layer=hook_point_max_layer + 1,
)[1]
activations_list = [
layerwise_activations[act_name] for act_name in act_names
act[:, :, self.cfg.hook_point_head_index] for act in activations_list
]

# Stack along a new dimension to keep separate layers distinct
Expand Down Expand Up @@ -326,3 +314,24 @@ def next_batch(self):
# If the DataLoader is exhausted, create a new one
self.dataloader = self.get_data_loader()
return next(self.dataloader)

def _get_next_dataset_tokens(self) -> torch.Tensor:
device = self.cfg.device
if not self.cfg.is_dataset_tokenized:
s = next(self.iterable_dataset)[self.tokens_column]
tokens = self.model.to_tokens(
s,
truncate=True,
move_to_device=True,
).squeeze(0)
assert (
len(tokens.shape) == 1
), f"tokens.shape should be 1D but was {tokens.shape}"
else:
tokens = torch.tensor(
next(self.iterable_dataset)[self.tokens_column],
dtype=torch.long,
device=device,
requires_grad=False,
)
return tokens
9 changes: 9 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from transformer_lens import HookedTransformer

from tests.unit.helpers import TINYSTORIES_MODEL


@pytest.fixture
def ts_model():
return HookedTransformer.from_pretrained(TINYSTORIES_MODEL, device="cpu")
50 changes: 50 additions & 0 deletions tests/unit/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import Any

import torch

from sae_training.config import LanguageModelSAERunnerConfig

TINYSTORIES_MODEL = "tiny-stories-1M"
TINYSTORIES_DATASET = "roneneldan/TinyStories"


def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig:
"""
Helper to create a mock instance of LanguageModelSAERunnerConfig.
"""
# Create a mock object with the necessary attributes
mock_config = LanguageModelSAERunnerConfig(
model_name=TINYSTORIES_MODEL,
hook_point="blocks.0.hook_mlp_out",
hook_point_layer=0,
hook_point_head_index=None,
dataset_path=TINYSTORIES_DATASET,
is_dataset_tokenized=False,
use_cached_activations=False,
d_in=64,
expansion_factor=2,
d_sae=64 * 2,
l1_coefficient=2e-3,
lp_norm=1,
lr=2e-4,
train_batch_size=2048,
context_size=64,
feature_sampling_window=50,
dead_feature_threshold=1e-7,
n_batches_in_buffer=10,
total_training_tokens=1_000_000,
store_batch_size=32,
log_to_wandb=False,
wandb_project="test_project",
wandb_entity="test_entity",
wandb_log_frequency=10,
device=torch.device("cpu"),
seed=24,
checkpoint_path="test/checkpoints",
dtype=torch.float32,
)

for key, val in kwargs.items():
setattr(mock_config, key, val)

return mock_config
87 changes: 86 additions & 1 deletion tests/unit/test_activations_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from collections.abc import Iterable
from math import ceil
from types import SimpleNamespace

import pytest
import torch
from datasets import IterableDataset
from datasets import Dataset, IterableDataset
from transformer_lens import HookedTransformer

from sae_training.activations_store import ActivationsStore
from tests.unit.helpers import build_sae_cfg


def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]:
assert model.tokenizer is not None
assert model.tokenizer.bos_token_id is not None
return [model.tokenizer.bos_token_id] + model.tokenizer.encode(text)


# Define a new fixture for different configurations
Expand Down Expand Up @@ -163,6 +171,23 @@ def test_activations_store__get_activations(activation_store: ActivationsStore):
assert activations.device == cfg.device


def test_activations_store__get_activations_head_hook(ts_model: HookedTransformer):
cfg = build_sae_cfg(
hook_point="blocks.0.attn.hook_q",
hook_point_head_index=2,
hook_point_layer=1,
d_in=4,
)
activation_store_head_hook = ActivationsStore(cfg, ts_model)
batch = activation_store_head_hook.get_batch_tokens()
activations = activation_store_head_hook.get_activations(batch)

cfg = activation_store_head_hook.cfg
assert isinstance(activations, torch.Tensor)
assert activations.shape == (cfg.store_batch_size, cfg.context_size, 1, cfg.d_in)
assert activations.device == cfg.device


def test_activations_store__get_buffer(activation_store: ActivationsStore):
n_batches_in_buffer = 3
buffer = activation_store.get_buffer(n_batches_in_buffer)
Expand All @@ -173,3 +198,63 @@ def test_activations_store__get_buffer(activation_store: ActivationsStore):

assert buffer.shape == (buffer_size_expected, 1, cfg.d_in)
assert buffer.device == cfg.device


# 12 is divisible by the length of "hello world", 11 and 13 are not
@pytest.mark.parametrize("context_size", [11, 12, 13])
def test_activations_store__get_batch_tokens__fills_the_context_separated_by_bos(
ts_model: HookedTransformer, context_size: int
):
assert ts_model.tokenizer is not None
dataset = Dataset.from_list(
[
{"text": "hello world"},
]
* 100
)
cfg = build_sae_cfg(
store_batch_size=2,
context_size=context_size,
)

activation_store = ActivationsStore(
cfg, ts_model, dataset=dataset, create_dataloader=False
)
encoded_text = tokenize_with_bos(ts_model, "hello world")
tokens = activation_store.get_batch_tokens()
assert tokens.shape == (2, context_size) # batch_size x context_size
all_expected_tokens = (encoded_text * ceil(2 * context_size / len(encoded_text)))[
: 2 * context_size
]
expected_tokens1 = all_expected_tokens[:context_size]
expected_tokens2 = all_expected_tokens[context_size:]
if expected_tokens2[0] != ts_model.tokenizer.bos_token_id:
expected_tokens2 = [ts_model.tokenizer.bos_token_id] + expected_tokens2[:-1]
assert tokens[0].tolist() == expected_tokens1
assert tokens[1].tolist() == expected_tokens2


def test_activations_store__get_next_dataset_tokens__tokenizes_each_example_in_order(
ts_model: HookedTransformer,
):
cfg = build_sae_cfg()
dataset = Dataset.from_list(
[
{"text": "hello world1"},
{"text": "hello world2"},
{"text": "hello world3"},
]
)
activation_store = ActivationsStore(
cfg, ts_model, dataset=dataset, create_dataloader=False
)

assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos(
ts_model, "hello world1"
)
assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos(
ts_model, "hello world2"
)
assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos(
ts_model, "hello world3"
)

0 comments on commit 69dcf8e

Please sign in to comment.