diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 34636d0d..35e72f0d 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -64,6 +64,7 @@ class LanguageModelSAERunnerConfig: seed: int = 42 dtype: str | torch.dtype = "float32" # type: ignore # prepend_bos: bool = True + autocast: bool = False # autocast to autocast_dtype during training # Training Parameters diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index d2ec213a..5a7c0bd9 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -84,6 +84,7 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): use_wandb=cfg.log_to_wandb, wandb_log_frequency=cfg.wandb_log_frequency, eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs, + autocast=cfg.autocast, ).sae_group if cfg.log_to_wandb: diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 73c50a97..639281a5 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -1,3 +1,4 @@ +import contextlib import os import pickle import random @@ -186,6 +187,7 @@ def train_sae_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, + autocast: bool = False, ) -> SparseAutoencoderDictionary: """ @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. @@ -200,6 +202,7 @@ def train_sae_on_language_model( use_wandb=use_wandb, wandb_log_frequency=wandb_log_frequency, eval_every_n_wandb_logs=eval_every_n_wandb_logs, + autocast=autocast, ).sae_group @@ -219,6 +222,7 @@ def train_sae_group_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, + autocast: bool = False, ) -> TrainSAEGroupOutput: total_training_tokens = get_total_training_tokens(sae_group=sae_group) _update_sae_lens_training_version(sae_group) @@ -289,6 +293,7 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): all_layers=all_layers, batch_size=batch_size, wandb_suffix=wandb_suffix, + autocast=autocast, ) mse_losses.append(step_output.mse_loss) l1_losses.append(step_output.l1_loss) @@ -539,6 +544,7 @@ def _train_step( all_layers: list[int], batch_size: int, wandb_suffix: str, + autocast: bool = True, ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) @@ -579,18 +585,33 @@ def _train_step( ctx.n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window ).bool() + # Setup autocast if using + scaler = torch.cuda.amp.GradScaler(enabled=autocast) + if autocast: + autocast_if_enabled = torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=autocast, + ) + else: + autocast_if_enabled = contextlib.nullcontext() + # Forward and Backward Passes - ( - sae_out, - feature_acts, - loss, - mse_loss, - l1_loss, - ghost_grad_loss, - ) = sparse_autoencoder( - sae_in, - ghost_grad_neuron_mask, - ) + # for documentation on autocasting see: + # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html + with autocast_if_enabled: + ( + sae_out, + feature_acts, + loss, + mse_loss, + l1_loss, + ghost_grad_loss, + ) = sparse_autoencoder( + sae_in, + ghost_grad_neuron_mask, + ) + did_fire = (feature_acts > 0).float().sum(-2) > 0 ctx.n_forward_passes_since_fired += 1 ctx.n_forward_passes_since_fired[did_fire] = 0 @@ -600,17 +621,19 @@ def _train_step( ctx.act_freq_scores += (feature_acts.abs() > 0).float().sum(0) ctx.n_frac_active_tokens += batch_size - ctx.optimizer.zero_grad() - loss.backward() - - # clip grad norm - # TODO: Work out if this should be in config / how to test it. + # Scaler will rescale gradients if autocast is enabled + scaler.scale(loss).backward() # loss.backward() if not autocasting + scaler.unscale_(ctx.optimizer) # needed to clip correctly + # TODO: Work out if grad norm clipping should be in config / how to test it. torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) + scaler.step(ctx.optimizer) # just ctx.optimizer.step() if not autocasting + scaler.update() if sparse_autoencoder.normalize_sae_decoder: sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() - ctx.optimizer.step() + ctx.optimizer.zero_grad() + ctx.lr_scheduler.step() ctx.l1_scheduler.step()