Skip to content

Commit

Permalink
feat: Add bf16 autocast (jbloomAus#126)
Browse files Browse the repository at this point in the history
* add bf16 autocast and gradient scaling

* simplify autocast setup

* remove completed TODO

* add autocast dtype selection (generally keep bf16)

* formatting fix

* remove autocast dtype
  • Loading branch information
tomMcGrath authored May 7, 2024
1 parent 3265c06 commit a553408
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
1 change: 1 addition & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LanguageModelSAERunnerConfig:
seed: int = 42
dtype: str | torch.dtype = "float32" # type: ignore #
prepend_bos: bool = True
autocast: bool = False # autocast to autocast_dtype during training

# Training Parameters

Expand Down
1 change: 1 addition & 0 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
use_wandb=cfg.log_to_wandb,
wandb_log_frequency=cfg.wandb_log_frequency,
eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs,
autocast=cfg.autocast,
).sae_group

if cfg.log_to_wandb:
Expand Down
57 changes: 40 additions & 17 deletions sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os
import pickle
import random
Expand Down Expand Up @@ -186,6 +187,7 @@ def train_sae_on_language_model(
use_wandb: bool = False,
wandb_log_frequency: int = 50,
eval_every_n_wandb_logs: int = 100,
autocast: bool = False,
) -> SparseAutoencoderDictionary:
"""
@deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility.
Expand All @@ -200,6 +202,7 @@ def train_sae_on_language_model(
use_wandb=use_wandb,
wandb_log_frequency=wandb_log_frequency,
eval_every_n_wandb_logs=eval_every_n_wandb_logs,
autocast=autocast,
).sae_group


Expand All @@ -219,6 +222,7 @@ def train_sae_group_on_language_model(
use_wandb: bool = False,
wandb_log_frequency: int = 50,
eval_every_n_wandb_logs: int = 100,
autocast: bool = False,
) -> TrainSAEGroupOutput:
total_training_tokens = get_total_training_tokens(sae_group=sae_group)
_update_sae_lens_training_version(sae_group)
Expand Down Expand Up @@ -289,6 +293,7 @@ def interrupt_callback(sig_num: Any, stack_frame: Any):
all_layers=all_layers,
batch_size=batch_size,
wandb_suffix=wandb_suffix,
autocast=autocast,
)
mse_losses.append(step_output.mse_loss)
l1_losses.append(step_output.l1_loss)
Expand Down Expand Up @@ -539,6 +544,7 @@ def _train_step(
all_layers: list[int],
batch_size: int,
wandb_suffix: str,
autocast: bool = True,
) -> TrainStepOutput:
assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy
layer_id = all_layers.index(sparse_autoencoder.hook_point_layer)
Expand Down Expand Up @@ -579,18 +585,33 @@ def _train_step(
ctx.n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window
).bool()

# Setup autocast if using
scaler = torch.cuda.amp.GradScaler(enabled=autocast)
if autocast:
autocast_if_enabled = torch.autocast(
device_type="cuda",
dtype=torch.bfloat16,
enabled=autocast,
)
else:
autocast_if_enabled = contextlib.nullcontext()

# Forward and Backward Passes
(
sae_out,
feature_acts,
loss,
mse_loss,
l1_loss,
ghost_grad_loss,
) = sparse_autoencoder(
sae_in,
ghost_grad_neuron_mask,
)
# for documentation on autocasting see:
# https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html
with autocast_if_enabled:
(
sae_out,
feature_acts,
loss,
mse_loss,
l1_loss,
ghost_grad_loss,
) = sparse_autoencoder(
sae_in,
ghost_grad_neuron_mask,
)

did_fire = (feature_acts > 0).float().sum(-2) > 0
ctx.n_forward_passes_since_fired += 1
ctx.n_forward_passes_since_fired[did_fire] = 0
Expand All @@ -600,17 +621,19 @@ def _train_step(
ctx.act_freq_scores += (feature_acts.abs() > 0).float().sum(0)
ctx.n_frac_active_tokens += batch_size

ctx.optimizer.zero_grad()
loss.backward()

# clip grad norm
# TODO: Work out if this should be in config / how to test it.
# Scaler will rescale gradients if autocast is enabled
scaler.scale(loss).backward() # loss.backward() if not autocasting
scaler.unscale_(ctx.optimizer) # needed to clip correctly
# TODO: Work out if grad norm clipping should be in config / how to test it.
torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0)
scaler.step(ctx.optimizer) # just ctx.optimizer.step() if not autocasting
scaler.update()

if sparse_autoencoder.normalize_sae_decoder:
sparse_autoencoder.remove_gradient_parallel_to_decoder_directions()

ctx.optimizer.step()
ctx.optimizer.zero_grad()

ctx.lr_scheduler.step()
ctx.l1_scheduler.step()

Expand Down

0 comments on commit a553408

Please sign in to comment.