From cc9899c7c7ecfa734bdb9260b80878092a0b3d4f Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 2 Mar 2024 23:41:27 +0000 Subject: [PATCH] chore: adding more tests to ActivationsStore + light refactoring --- pyproject.toml | 1 + sae_training/activations_store.py | 99 ++++++++------- tests/unit/conftest.py | 9 ++ tests/unit/helpers.py | 50 ++++++++ tests/unit/test_activations_store.py | 174 ++++++++++++--------------- 5 files changed, 190 insertions(+), 143 deletions(-) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/helpers.py diff --git a/pyproject.toml b/pyproject.toml index e7388ad8..abbd631b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ reportUnnecessaryIsInstance = "none" reportUnnecessaryComparison = "none" reportConstantRedefinition = "none" reportUnknownLambdaType = "none" +reportPrivateUsage = "none" [build-system] requires = ["poetry-core"] diff --git a/sae_training/activations_store.py b/sae_training/activations_store.py index 0af9cfd6..82f18e93 100644 --- a/sae_training/activations_store.py +++ b/sae_training/activations_store.py @@ -2,10 +2,18 @@ from typing import Any, Iterator, cast import torch -from datasets import load_dataset +from datasets import ( + Dataset, + DatasetDict, + IterableDataset, + IterableDatasetDict, + load_dataset, +) from torch.utils.data import DataLoader from transformer_lens import HookedTransformer +HfDataset = DatasetDict | Dataset | IterableDatasetDict | IterableDataset + class ActivationsStore: """ @@ -17,11 +25,14 @@ def __init__( self, cfg: Any, model: HookedTransformer, + dataset: HfDataset | None = None, create_dataloader: bool = True, ): self.cfg = cfg self.model = model - self.dataset = load_dataset(cfg.dataset_path, split="train", streaming=True) + self.dataset = dataset or load_dataset( + cfg.dataset_path, split="train", streaming=True + ) self.iterable_dataset = iter(self.dataset) # Check if dataset is tokenized @@ -78,23 +89,7 @@ def get_batch_tokens(self): # pbar = tqdm(total=batch_size, desc="Filling batches") while batch_tokens.shape[0] < batch_size: - if not self.cfg.is_dataset_tokenized: - s = next(self.iterable_dataset)["text"] - tokens = self.model.to_tokens( - s, - truncate=True, - move_to_device=True, - ).squeeze(0) - assert ( - len(tokens.shape) == 1 - ), f"tokens.shape should be 1D but was {tokens.shape}" - else: - tokens = torch.tensor( - next(self.iterable_dataset)["tokens"], - dtype=torch.long, - device=device, - requires_grad=False, - ) + tokens = self._get_next_dataset_tokens() token_len = tokens.shape[0] # TODO: Fix this so that we are limiting how many tokens we get from the same context. @@ -120,16 +115,18 @@ def get_batch_tokens(self): # Remove used part, add BOS tokens = tokens[space_left:] - tokens = torch.cat( - ( - bos_token_id_tensor, - tokens, - ), - dim=0, - ) - token_len -= space_left - token_len += 1 + + # only add BOS if it's not already the first token + if tokens[0] != bos_token_id_tensor: + tokens = torch.cat( + ( + bos_token_id_tensor, + tokens, + ), + dim=0, + ) + token_len += 1 current_length = context_size # If a batch is full, concatenate and move to next batch @@ -145,7 +142,7 @@ def get_batch_tokens(self): # pbar.refresh() return batch_tokens[:batch_size] - def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False): + def get_activations(self, batch_tokens: torch.Tensor): """ Returns activations of shape (batches, context, num_layers, d_in) """ @@ -156,24 +153,15 @@ def get_activations(self, batch_tokens: torch.Tensor, get_loss: bool = False): ) act_names = [self.cfg.hook_point.format(layer=layer) for layer in layers] hook_point_max_layer = max(layers) + layerwise_activations = self.model.run_with_cache( + batch_tokens, + names_filter=act_names, + stop_at_layer=hook_point_max_layer + 1, + )[1] + activations_list = [layerwise_activations[act_name] for act_name in act_names] if self.cfg.hook_point_head_index is not None: - layerwise_activations = self.model.run_with_cache( - batch_tokens, - names_filter=act_names, - stop_at_layer=hook_point_max_layer + 1, - )[1] - activations_list = [ - layerwise_activations[act_name][:, :, self.cfg.hook_point_head_index] - for act_name in act_names - ] - else: - layerwise_activations = self.model.run_with_cache( - batch_tokens, - names_filter=act_names, - stop_at_layer=hook_point_max_layer + 1, - )[1] activations_list = [ - layerwise_activations[act_name] for act_name in act_names + act[:, :, self.cfg.hook_point_head_index] for act in activations_list ] # Stack along a new dimension to keep separate layers distinct @@ -315,3 +303,24 @@ def next_batch(self): # If the DataLoader is exhausted, create a new one self.dataloader = self.get_data_loader() return next(self.dataloader) + + def _get_next_dataset_tokens(self) -> torch.Tensor: + device = self.cfg.device + if not self.cfg.is_dataset_tokenized: + s = next(self.iterable_dataset)["text"] + tokens = self.model.to_tokens( + s, + truncate=True, + move_to_device=True, + ).squeeze(0) + assert ( + len(tokens.shape) == 1 + ), f"tokens.shape should be 1D but was {tokens.shape}" + else: + tokens = torch.tensor( + next(self.iterable_dataset)["tokens"], + dtype=torch.long, + device=device, + requires_grad=False, + ) + return tokens diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..b614bbd5 --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,9 @@ +import pytest +from transformer_lens import HookedTransformer + +from tests.unit.helpers import TEST_MODEL + + +@pytest.fixture +def model(): + return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu") diff --git a/tests/unit/helpers.py b/tests/unit/helpers.py new file mode 100644 index 00000000..8c70d77d --- /dev/null +++ b/tests/unit/helpers.py @@ -0,0 +1,50 @@ +from typing import Any + +import torch + +from sae_training.config import LanguageModelSAERunnerConfig + +TEST_MODEL = "tiny-stories-1M" +TEST_DATASET = "roneneldan/TinyStories" + + +def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: + """ + Helper to create a mock instance of LanguageModelSAERunnerConfig. + """ + # Create a mock object with the necessary attributes + mock_config = LanguageModelSAERunnerConfig( + model_name=TEST_MODEL, + hook_point="blocks.0.hook_mlp_out", + hook_point_layer=0, + hook_point_head_index=None, + dataset_path=TEST_DATASET, + is_dataset_tokenized=False, + use_cached_activations=False, + d_in=64, + expansion_factor=2, + d_sae=64 * 2, + l1_coefficient=2e-3, + lp_norm=1, + lr=2e-4, + train_batch_size=2048, + context_size=64, + feature_sampling_window=50, + dead_feature_threshold=1e-7, + n_batches_in_buffer=10, + total_training_tokens=1_000_000, + store_batch_size=32, + log_to_wandb=False, + wandb_project="test_project", + wandb_entity="test_entity", + wandb_log_frequency=10, + device=torch.device("cpu"), + seed=24, + checkpoint_path="test/checkpoints", + dtype=torch.float32, + ) + + for key, val in kwargs.items(): + setattr(mock_config, key, val) + + return mock_config diff --git a/tests/unit/test_activations_store.py b/tests/unit/test_activations_store.py index ebed81b1..7ee32c90 100644 --- a/tests/unit/test_activations_store.py +++ b/tests/unit/test_activations_store.py @@ -1,109 +1,28 @@ from collections.abc import Iterable -from types import SimpleNamespace +from math import ceil from typing import Any import pytest import torch -from datasets import IterableDataset +from datasets import Dataset, IterableDataset from transformer_lens import HookedTransformer from sae_training.activations_store import ActivationsStore +from tests.unit.helpers import build_sae_cfg -TEST_MODEL = "tiny-stories-1M" -TEST_DATASET = "roneneldan/TinyStories" - -@pytest.fixture -def cfg(): - """ - Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. - """ - # Create a mock object with the necessary attributes - mock_config = SimpleNamespace() - mock_config.model_name = TEST_MODEL - mock_config.hook_point = "blocks.0.hook_mlp_out" - mock_config.hook_point_layer = 1 - mock_config.dataset_path = TEST_DATASET - mock_config.is_dataset_tokenized = False - mock_config.d_in = 64 - mock_config.expansion_factor = 2 - mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor - mock_config.l1_coefficient = 2e-3 - mock_config.lr = 2e-4 - mock_config.train_batch_size = 32 - mock_config.context_size = 16 - mock_config.use_cached_activations = False - mock_config.hook_point_head_index = None - mock_config.lp_norm = 1 - - mock_config.feature_sampling_method = None - mock_config.feature_sampling_window = 50 - mock_config.feature_reinit_scale = 0.1 - mock_config.dead_feature_threshold = 1e-7 - - mock_config.n_batches_in_buffer = 4 - mock_config.total_training_tokens = 1_000_000 - mock_config.store_batch_size = 32 - - mock_config.log_to_wandb = False - mock_config.wandb_project = "test_project" - mock_config.wandb_entity = "test_entity" - mock_config.wandb_log_frequency = 10 - mock_config.device = torch.device("cpu") - mock_config.seed = 24 - mock_config.checkpoint_path = "test/checkpoints" - mock_config.dtype = torch.float32 - - return mock_config +def tokenize_with_bos(model: HookedTransformer, text: str) -> list[int]: + assert model.tokenizer is not None + assert model.tokenizer.bos_token_id is not None + return [model.tokenizer.bos_token_id] + model.tokenizer.encode(text) @pytest.fixture -def cfg_head_hook(): +def cfg(): """ Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. """ - # Create a mock object with the necessary attributes - mock_config = SimpleNamespace() - mock_config.model_name = TEST_MODEL - mock_config.hook_point = "blocks.0.attn.hook_q" - mock_config.hook_point_layer = 1 - mock_config.hook_point_head_index = 2 - mock_config.dataset_path = TEST_DATASET - mock_config.is_dataset_tokenized = False - mock_config.d_in = 4 - mock_config.expansion_factor = 2 - mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor - mock_config.l1_coefficient = 2e-3 - mock_config.lr = 2e-4 - mock_config.train_batch_size = 32 - mock_config.context_size = 128 - mock_config.use_cached_activations = False - mock_config.hook_point_head_index = 0 - - mock_config.feature_sampling_method = None - mock_config.feature_sampling_window = 50 - mock_config.feature_reinit_scale = 0.1 - mock_config.dead_feature_threshold = 1e-7 - - mock_config.n_batches_in_buffer = 4 - mock_config.total_training_tokens = 1_000_000 - mock_config.store_batch_size = 32 - - mock_config.log_to_wandb = False - mock_config.wandb_project = "test_project" - mock_config.wandb_entity = "test_entity" - mock_config.wandb_log_frequency = 10 - mock_config.device = torch.device("cpu") - mock_config.seed = 24 - mock_config.checkpoint_path = "test/checkpoints" - mock_config.dtype = torch.float32 - - return mock_config - - -@pytest.fixture -def model(): - return HookedTransformer.from_pretrained(TEST_MODEL, device="cpu") + return build_sae_cfg() @pytest.fixture @@ -111,11 +30,6 @@ def activation_store(cfg: Any, model: HookedTransformer): return ActivationsStore(cfg, model) -@pytest.fixture -def activation_store_head_hook(cfg_head_hook: Any, model: HookedTransformer): - return ActivationsStore(cfg_head_hook, model) - - def test_activations_store__init__(cfg: Any, model: HookedTransformer): store = ActivationsStore(cfg, model) @@ -159,9 +73,14 @@ def test_activations_store__get_activations(activation_store: ActivationsStore): assert activations.device == cfg.device -def test_activations_store__get_activations_head_hook( - activation_store_head_hook: ActivationsStore, -): +def test_activations_store__get_activations_head_hook(model: HookedTransformer): + cfg = build_sae_cfg( + hook_point="blocks.0.attn.hook_q", + hook_point_head_index=2, + hook_point_layer=1, + d_in=4, + ) + activation_store_head_hook = ActivationsStore(cfg, model) batch = activation_store_head_hook.get_batch_tokens() activations = activation_store_head_hook.get_activations(batch) @@ -181,3 +100,62 @@ def test_activations_store__get_buffer(activation_store: ActivationsStore): assert buffer.shape == (buffer_size_expected, 1, cfg.d_in) assert buffer.device == cfg.device + + +# 12 is divisible by the length of "hello world", 11 and 13 are not +@pytest.mark.parametrize("context_size", [11, 12, 13]) +def test_activations_store__get_batch_tokens__fills_the_context_separated_by_bos( + model: HookedTransformer, context_size: int +): + assert model.tokenizer is not None + dataset = Dataset.from_list( + [ + {"text": "hello world"}, + ] + * 100 + ) + cfg = build_sae_cfg( + store_batch_size=2, + context_size=context_size, + ) + + activation_store = ActivationsStore( + cfg, model, dataset=dataset, create_dataloader=False + ) + encoded_text = tokenize_with_bos(model, "hello world") + tokens = activation_store.get_batch_tokens() + assert tokens.shape == (2, context_size) # batch_size x context_size + all_expected_tokens = (encoded_text * ceil(2 * context_size / len(encoded_text)))[ + : 2 * context_size + ] + expected_tokens1 = all_expected_tokens[:context_size] + expected_tokens2 = all_expected_tokens[context_size:] + if expected_tokens2[0] != model.tokenizer.bos_token_id: + expected_tokens2 = [model.tokenizer.bos_token_id] + expected_tokens2[:-1] + assert tokens[0].tolist() == expected_tokens1 + assert tokens[1].tolist() == expected_tokens2 + + +def test_activations_store__get_next_dataset_tokens__tokenizes_each_example_in_order( + cfg: Any, model: HookedTransformer +): + dataset = Dataset.from_list( + [ + {"text": "hello world1"}, + {"text": "hello world2"}, + {"text": "hello world3"}, + ] + ) + activation_store = ActivationsStore( + cfg, model, dataset=dataset, create_dataloader=False + ) + + assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos( + model, "hello world1" + ) + assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos( + model, "hello world2" + ) + assert activation_store._get_next_dataset_tokens().tolist() == tokenize_with_bos( + model, "hello world3" + )