Skip to content

Commit

Permalink
feat: rename batch sizes to give informative units (#133)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: renamed batch sizing config params

* renaming batch sizes to give units

* changes in notebooks

* missed one!

---------

Co-authored-by: David Chanin <[email protected]>
  • Loading branch information
tomMcGrath and chanind authored May 10, 2024
1 parent 007141e commit cc78e27
Show file tree
Hide file tree
Showing 21 changed files with 81 additions and 74 deletions.
2 changes: 1 addition & 1 deletion sae_lens/toolkit/pretrained_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def convert_connor_rob_sae_to_our_saelens_format(
expansion_factor=expansion_factor,
context_size=config["seq_len"], # type: ignore
device=device,
store_batch_size=32,
store_batch_size_prompts=32,
n_batches_in_buffer=10,
prepend_bos=False,
verbose=False,
Expand Down
18 changes: 9 additions & 9 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def from_config(
d_in=cfg.d_in,
n_batches_in_buffer=cfg.n_batches_in_buffer,
total_training_tokens=cfg.training_tokens,
store_batch_size=cfg.store_batch_size,
train_batch_size=cfg.train_batch_size,
store_batch_size_prompts=cfg.store_batch_size_prompts,
train_batch_size_tokens=cfg.train_batch_size_tokens,
prepend_bos=cfg.prepend_bos,
normalize_activations=cfg.normalize_activations,
device=cfg.device,
Expand All @@ -86,8 +86,8 @@ def __init__(
d_in: int,
n_batches_in_buffer: int,
total_training_tokens: int,
store_batch_size: int,
train_batch_size: int,
store_batch_size_prompts: int,
train_batch_size_tokens: int,
prepend_bos: bool,
normalize_activations: bool,
device: str | torch.device,
Expand All @@ -111,8 +111,8 @@ def __init__(
self.d_in = d_in
self.n_batches_in_buffer = n_batches_in_buffer
self.total_training_tokens = total_training_tokens
self.store_batch_size = store_batch_size
self.train_batch_size = train_batch_size
self.store_batch_size_prompts = store_batch_size_prompts
self.train_batch_size_tokens = train_batch_size_tokens
self.prepend_bos = prepend_bos
self.normalize_activations = normalize_activations
self.device = device
Expand Down Expand Up @@ -196,7 +196,7 @@ def get_batch_tokens(self, batch_size: int | None = None):
Streams a batch of tokens from a dataset.
"""
if not batch_size:
batch_size = self.store_batch_size
batch_size = self.store_batch_size_prompts
context_size = self.context_size
device = self.device

Expand Down Expand Up @@ -306,7 +306,7 @@ def get_activations(self, batch_tokens: torch.Tensor):

def get_buffer(self, n_batches_in_buffer: int) -> torch.Tensor:
context_size = self.context_size
batch_size = self.store_batch_size
batch_size = self.store_batch_size_prompts
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = len(self.hook_point_layers) # Number of hook points or layers
Expand Down Expand Up @@ -426,7 +426,7 @@ def get_data_loader(
"""

batch_size = self.train_batch_size
batch_size = self.train_batch_size_tokens

# 1. # create new buffer by mixing stored and new buffer
mixing_buffer = torch.cat(
Expand Down
4 changes: 2 additions & 2 deletions sae_lens/training/cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __str__(self):
)
tokens_in_buffer = (
self.cfg.n_batches_in_buffer
* self.cfg.store_batch_size
* self.cfg.store_batch_size_prompts
* self.cfg.context_size
)
total_training_tokens = self.cfg.training_tokens
Expand Down Expand Up @@ -75,7 +75,7 @@ def run(self):

print(f"Started caching {self.cfg.training_tokens} activations")
tokens_per_buffer = (
self.cfg.store_batch_size
self.cfg.store_batch_size_prompts
* self.cfg.context_size
* self.cfg.n_batches_in_buffer
)
Expand Down
32 changes: 18 additions & 14 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class LanguageModelSAERunnerConfig:
n_batches_in_buffer: int = 20
training_tokens: int = 2_000_000
finetuning_tokens: int = 0
store_batch_size: int = 32
train_batch_size: int = 4096
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: bool = False

# Misc
Expand All @@ -75,7 +75,7 @@ class LanguageModelSAERunnerConfig:
# Training Parameters

## Batch size
train_batch_size: int = 4096
train_batch_size_tokens: int = 4096

## Adam
adam_beta1: float | list[float] = 0
Expand Down Expand Up @@ -114,7 +114,7 @@ class LanguageModelSAERunnerConfig:

# Evals
n_eval_batches: int = 10
n_eval_seqs: int | None = None # useful if evals cause OOM
eval_batch_size_prompts: int | None = None # useful if evals cause OOM

# WANDB
log_to_wandb: bool = True
Expand Down Expand Up @@ -149,7 +149,7 @@ def __post_init__(self):
if not isinstance(self.expansion_factor, list):
self.d_sae = self.d_in * self.expansion_factor
self.tokens_per_buffer = (
self.train_batch_size * self.context_size * self.n_batches_in_buffer
self.train_batch_size_tokens * self.context_size * self.n_batches_in_buffer
)

if self.run_name is None:
Expand Down Expand Up @@ -203,17 +203,21 @@ def __post_init__(self):
)
# Print out some useful info:
n_tokens_per_buffer = (
self.store_batch_size * self.context_size * self.n_batches_in_buffer
self.store_batch_size_prompts
* self.context_size
* self.n_batches_in_buffer
)
print(f"n_tokens_per_buffer (millions): {n_tokens_per_buffer / 10 ** 6}")
n_contexts_per_buffer = self.store_batch_size * self.n_batches_in_buffer
n_contexts_per_buffer = (
self.store_batch_size_prompts * self.n_batches_in_buffer
)
print(
f"Lower bound: n_contexts_per_buffer (millions): {n_contexts_per_buffer / 10 ** 6}"
)

total_training_steps = (
self.training_tokens + self.finetuning_tokens
) // self.train_batch_size
) // self.train_batch_size_tokens
print(f"Total training steps: {total_training_steps}")

total_wandb_updates = total_training_steps // self.wandb_log_frequency
Expand All @@ -225,17 +229,17 @@ def __post_init__(self):
total_training_steps // self.feature_sampling_window
)
print(
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size) / 10 ** 6}"
f"n_tokens_per_feature_sampling_window (millions): {(self.feature_sampling_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
)
print(
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size) / 10 ** 6}"
f"n_tokens_per_dead_feature_window (millions): {(self.dead_feature_window * self.context_size * self.train_batch_size_tokens) / 10 ** 6}"
)
print(
f"We will reset the sparsity calculation {n_feature_window_samples} times."
)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size)
# print("Number tokens in dead feature calculation window: ", self.dead_feature_window * self.train_batch_size_tokens)
print(
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size:.2e}"
f"Number tokens in sparsity calculation window: {self.feature_sampling_window * self.train_batch_size_tokens:.2e}"
)

if not isinstance(self.use_ghost_grads, list) and self.use_ghost_grads:
Expand Down Expand Up @@ -313,8 +317,8 @@ class CacheActivationsRunnerConfig:
# Activation Store Parameters
n_batches_in_buffer: int = 20
training_tokens: int = 2_000_000
store_batch_size: int = 32
train_batch_size: int = 4096
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: bool = False

# Misc
Expand Down
10 changes: 5 additions & 5 deletions sae_lens/training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def run_evals(
n_training_steps: int,
suffix: str = "",
n_eval_batches: int = 10,
n_eval_seqs: int | None = None,
eval_batch_size_prompts: int | None = None,
) -> Mapping[str, Any]:
hook_point = sparse_autoencoder.cfg.hook_point
hook_point_layer = sparse_autoencoder.hook_point_layer
Expand All @@ -27,15 +27,15 @@ def run_evals(
layer=hook_point_layer
)
### Evals
eval_tokens = activation_store.get_batch_tokens(n_eval_seqs)
eval_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)

# Get Reconstruction Score
losses_df = recons_loss_batched(
sparse_autoencoder,
model,
activation_store,
n_batches=n_eval_batches,
n_eval_seqs=n_eval_seqs,
eval_batch_size_prompts=eval_batch_size_prompts,
)

recons_score = losses_df["score"].mean()
Expand Down Expand Up @@ -103,11 +103,11 @@ def recons_loss_batched(
model: HookedRootModule,
activation_store: ActivationsStore,
n_batches: int = 100,
n_eval_seqs: int | None = None,
eval_batch_size_prompts: int | None = None,
):
losses = []
for _ in range(n_batches):
batch_tokens = activation_store.get_batch_tokens(n_eval_seqs)
batch_tokens = activation_store.get_batch_tokens(eval_batch_size_prompts)
score, loss, recons_loss, zero_abl_loss = get_recons_loss(
sparse_autoencoder, model, batch_tokens
)
Expand Down
6 changes: 3 additions & 3 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
checkpoint_path=checkpoint_path,
cfg=cfg,
model=model,
batch_size=cfg.train_batch_size,
batch_size=cfg.train_batch_size_tokens,
)
# no checkpoints found, don't resume
except FileNotFoundError:
Expand Down Expand Up @@ -106,15 +106,15 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
activation_store=activations_loader, # pyright: ignore [reportPossiblyUnboundVariable]
train_contexts=train_contexts,
training_run_state=training_run_state,
batch_size=cfg.train_batch_size,
batch_size=cfg.train_batch_size_tokens,
n_checkpoints=cfg.n_checkpoints,
feature_sampling_window=cfg.feature_sampling_window,
use_wandb=cfg.log_to_wandb,
wandb_log_frequency=cfg.wandb_log_frequency,
eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs,
autocast=cfg.autocast,
n_eval_batches=cfg.n_eval_batches,
n_eval_seqs=cfg.n_eval_seqs,
eval_batch_size_prompts=cfg.eval_batch_size_prompts,
).sae_group

if cfg.log_to_wandb:
Expand Down
8 changes: 4 additions & 4 deletions sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def train_sae_on_language_model(
eval_every_n_wandb_logs: int = 100,
autocast: bool = False,
n_eval_batches: int = 10,
n_eval_seqs: int | None = None,
eval_batch_size_prompts: int | None = None,
) -> SparseAutoencoderDictionary:
"""
@deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility.
Expand All @@ -206,7 +206,7 @@ def train_sae_on_language_model(
eval_every_n_wandb_logs=eval_every_n_wandb_logs,
autocast=autocast,
n_eval_batches=n_eval_batches,
n_eval_seqs=n_eval_seqs,
eval_batch_size_prompts=eval_batch_size_prompts,
).sae_group


Expand All @@ -228,7 +228,7 @@ def train_sae_group_on_language_model(
eval_every_n_wandb_logs: int = 100,
autocast: bool = False,
n_eval_batches: int = 10,
n_eval_seqs: int | None = None,
eval_batch_size_prompts: int | None = None,
) -> TrainSAEGroupOutput:
total_training_tokens = get_total_training_tokens(sae_group=sae_group)
_update_sae_lens_training_version(sae_group)
Expand Down Expand Up @@ -332,7 +332,7 @@ def interrupt_callback(sig_num: Any, stack_frame: Any):
training_run_state.n_training_steps,
suffix=wandb_suffix,
n_eval_batches=n_eval_batches,
n_eval_seqs=n_eval_seqs,
eval_batch_size_prompts=eval_batch_size_prompts,
)
sparse_autoencoder.train()

Expand Down
4 changes: 2 additions & 2 deletions scripts/caching_replication_how_train_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@
is_dataset_tokenized=True,
prepend_bos=True,
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size=4096,
train_batch_size_tokens=4096,
# buffer details
n_batches_in_buffer=4,
store_batch_size=128,
store_batch_size_prompts=128,
normalize_activations=False,
#
shuffle_every_n_buffers=8,
Expand Down
2 changes: 1 addition & 1 deletion scripts/replication_how_train_SAEs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@
"final_l1_value = sparse_autoencoder.cfg.l1_coefficient\n",
"\n",
"l1_scheduler = L1Scheduler(\n",
" total_steps=sparse_autoencoder.cfg.training_tokens // sparse_autoencoder.cfg.train_batch_size,\n",
" total_steps=sparse_autoencoder.cfg.training_tokens // sparse_autoencoder.cfg.train_batch_size_tokens,\n",
" l1_warm_up_steps=l1_warmup_steps,\n",
" sparse_autoencoder=sparse_autoencoder\n",
")\n",
Expand Down
4 changes: 2 additions & 2 deletions scripts/replication_how_train_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
use_cached_activations=True,
cached_activations_path="/home/paperspace/shared_volumes/activations_volume_1/gelu-1l",
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size=4096,
train_batch_size_tokens=4096,
# Loss Function
## Reconstruction Coefficient.
mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
Expand Down Expand Up @@ -92,7 +92,7 @@
adam_beta2=0.999,
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer=64,
store_batch_size=16,
store_batch_size_prompts=16,
normalize_activations=False,
# Feature Store
feature_sampling_window=1000,
Expand Down
4 changes: 2 additions & 2 deletions scripts/replication_how_train_saes_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
use_cached_activations=True,
cached_activations_path="/home/paperspace/shared_volumes/activations_volume_1/gelu-1l",
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size=4096,
train_batch_size_tokens=4096,
# Loss Function
## Reconstruction Coefficient.
mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch.
Expand Down Expand Up @@ -92,7 +92,7 @@
adam_beta2=0.999,
# Buffer details won't matter in we cache / shuffle our activations ahead of time.
n_batches_in_buffer=64,
store_batch_size=16,
store_batch_size_prompts=16,
normalize_activations=False,
# Feature Store
feature_sampling_window=1000,
Expand Down
8 changes: 4 additions & 4 deletions scripts/run.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,15 @@
" lr_warm_up_steps=10000, # this can help avoid too many dead features initially.\n",
" l1_coefficient=0.0015, # will control how sparse the feature activations are\n",
" lp_norm=1.0, # the L1 penalty (and not a Lp for p < 1)\n",
" train_batch_size=4096,\n",
" train_batch_size_tokens=4096,\n",
" context_size=128, # will control the lenght of the prompts we feed to the model. Larger is better but slower.\n",
" # Activation Store Parameters\n",
" n_batches_in_buffer=64, # controls how many activations we store / shuffle.\n",
" training_tokens=1_000_000\n",
" * 25, # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.\n",
" finetuning_method=\"decoder\",\n",
" finetuning_tokens=1_000_000 * 25,\n",
" store_batch_size=32,\n",
" store_batch_size_prompts=32,\n",
" # Resampling protocol\n",
" use_ghost_grads=False,\n",
" feature_sampling_window=1000, # this controls our reporting of feature sparsity stats\n",
Expand Down Expand Up @@ -635,15 +635,15 @@
" lr=0.0004,\n",
" l1_coefficient=0.008,\n",
" lr_scheduler_name=\"constant\",\n",
" train_batch_size=4096,\n",
" train_batch_size_tokens=4096,\n",
" context_size=256,\n",
" lr_warm_up_steps=5000,\n",
" # Activation Store Parameters\n",
" n_batches_in_buffer=128,\n",
" training_tokens=1_000_000 * 200, # 200M tokens seems doable overnight.\n",
" finetuning_method=\"decoder\",\n",
" finetuning_tokens=1_000_000 * 100,\n",
" store_batch_size=32,\n",
" store_batch_size_prompts=32,\n",
" # Resampling protocol\n",
" use_ghost_grads=False,\n",
" feature_sampling_window=2500,\n",
Expand Down
4 changes: 2 additions & 2 deletions tests/benchmark/test_cache_activations_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ def test_cache_activations_runner():
is_dataset_tokenized=True,
prepend_bos=True,
training_tokens=total_training_tokens, # For initial testing I think this is a good number.
train_batch_size=4096,
train_batch_size_tokens=4096,
# buffer details
n_batches_in_buffer=32,
store_batch_size=16,
store_batch_size_prompts=16,
normalize_activations=False,
#
shuffle_every_n_buffers=8,
Expand Down
Loading

0 comments on commit cc78e27

Please sign in to comment.