From 92270a34ba36a4b663f6738e1bc46ec7a15bafbe Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 13:16:13 +0100 Subject: [PATCH 1/6] add bf16 autocast and gradient scaling --- sae_lens/training/config.py | 1 + sae_lens/training/lm_runner.py | 1 + .../training/train_sae_on_language_model.py | 62 ++++++++++++++----- 3 files changed, 47 insertions(+), 17 deletions(-) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 34636d0d..13739071 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -64,6 +64,7 @@ class LanguageModelSAERunnerConfig: seed: int = 42 dtype: str | torch.dtype = "float32" # type: ignore # prepend_bos: bool = True + autocast: bool = False # autocast to bf16 during training # Training Parameters diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index d2ec213a..5a7c0bd9 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -84,6 +84,7 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): use_wandb=cfg.log_to_wandb, wandb_log_frequency=cfg.wandb_log_frequency, eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs, + autocast=cfg.autocast, ).sae_group if cfg.log_to_wandb: diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 73c50a97..cd3ae073 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -1,3 +1,4 @@ +import contextlib import os import pickle import random @@ -186,6 +187,7 @@ def train_sae_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, + autocast: bool = False, ) -> SparseAutoencoderDictionary: """ @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. @@ -200,6 +202,7 @@ def train_sae_on_language_model( use_wandb=use_wandb, wandb_log_frequency=wandb_log_frequency, eval_every_n_wandb_logs=eval_every_n_wandb_logs, + autocast=autocast, ).sae_group @@ -219,6 +222,7 @@ def train_sae_group_on_language_model( use_wandb: bool = False, wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, + autocast: bool = False, ) -> TrainSAEGroupOutput: total_training_tokens = get_total_training_tokens(sae_group=sae_group) _update_sae_lens_training_version(sae_group) @@ -289,6 +293,7 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): all_layers=all_layers, batch_size=batch_size, wandb_suffix=wandb_suffix, + autocast=autocast, ) mse_losses.append(step_output.mse_loss) l1_losses.append(step_output.l1_loss) @@ -539,6 +544,7 @@ def _train_step( all_layers: list[int], batch_size: int, wandb_suffix: str, + autocast: bool = True, # TODO(tomMcGrath): propagate up to config ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) @@ -579,18 +585,31 @@ def _train_step( ctx.n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window ).bool() + # Setup autocast if necessary + if autocast: + scaler = torch.cuda.amp.GradScaler() + autocast_if_enabled = torch.autocast(device_type='cuda', dtype=torch.bfloat16) + + else: + autocast_if_enabled = contextlib.nullcontext() + scaler = None + # Forward and Backward Passes - ( - sae_out, - feature_acts, - loss, - mse_loss, - l1_loss, - ghost_grad_loss, - ) = sparse_autoencoder( - sae_in, - ghost_grad_neuron_mask, - ) + # for documentation on autocasting see: + # https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html + with autocast_if_enabled: + ( + sae_out, + feature_acts, + loss, + mse_loss, + l1_loss, + ghost_grad_loss, + ) = sparse_autoencoder( + sae_in, + ghost_grad_neuron_mask, + ) + did_fire = (feature_acts > 0).float().sum(-2) > 0 ctx.n_forward_passes_since_fired += 1 ctx.n_forward_passes_since_fired[did_fire] = 0 @@ -600,17 +619,26 @@ def _train_step( ctx.act_freq_scores += (feature_acts.abs() > 0).float().sum(0) ctx.n_frac_active_tokens += batch_size - ctx.optimizer.zero_grad() - loss.backward() + # Rescale gradients if we autocasted + if autocast and scaler is not None: # 2nd condition to keep pylance happy + scaler.scale(loss).backward() + scaler.unscale_(ctx.optimizer) # needed to clip correctly + # TODO: Work out if grad norm clipping should be in config / how to test it. + torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) + scaler.step(ctx.optimizer) + scaler.update() - # clip grad norm - # TODO: Work out if this should be in config / how to test it. - torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) + else: + loss.backward() + # TODO: Work out if grad norm clipping should be in config / how to test it. + torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) + ctx.optimizer.step() if sparse_autoencoder.normalize_sae_decoder: sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() - ctx.optimizer.step() + ctx.optimizer.zero_grad() + ctx.lr_scheduler.step() ctx.l1_scheduler.step() From d80af890d6193b76d9e51103f811aaad40c4ce54 Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 13:55:31 +0100 Subject: [PATCH 2/6] simplify autocast setup --- .../training/train_sae_on_language_model.py | 35 ++++++++----------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index cd3ae073..5681dc01 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -585,14 +585,16 @@ def _train_step( ctx.n_forward_passes_since_fired > sparse_autoencoder.cfg.dead_feature_window ).bool() - # Setup autocast if necessary + # Setup autocast if using + scaler = torch.cuda.amp.GradScaler(enabled=autocast) if autocast: - scaler = torch.cuda.amp.GradScaler() - autocast_if_enabled = torch.autocast(device_type='cuda', dtype=torch.bfloat16) - + autocast_if_enabled = torch.autocast( + device_type="cuda", + dtype=torch.bfloat16, + enabled=autocast, + ) else: autocast_if_enabled = contextlib.nullcontext() - scaler = None # Forward and Backward Passes # for documentation on autocasting see: @@ -609,7 +611,7 @@ def _train_step( sae_in, ghost_grad_neuron_mask, ) - + did_fire = (feature_acts > 0).float().sum(-2) > 0 ctx.n_forward_passes_since_fired += 1 ctx.n_forward_passes_since_fired[did_fire] = 0 @@ -619,20 +621,13 @@ def _train_step( ctx.act_freq_scores += (feature_acts.abs() > 0).float().sum(0) ctx.n_frac_active_tokens += batch_size - # Rescale gradients if we autocasted - if autocast and scaler is not None: # 2nd condition to keep pylance happy - scaler.scale(loss).backward() - scaler.unscale_(ctx.optimizer) # needed to clip correctly - # TODO: Work out if grad norm clipping should be in config / how to test it. - torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) - scaler.step(ctx.optimizer) - scaler.update() - - else: - loss.backward() - # TODO: Work out if grad norm clipping should be in config / how to test it. - torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) - ctx.optimizer.step() + # Scaler will rescale gradients if autocast is enabled + scaler.scale(loss).backward() # loss.backward() if not autocasting + scaler.unscale_(ctx.optimizer) # needed to clip correctly + # TODO: Work out if grad norm clipping should be in config / how to test it. + torch.nn.utils.clip_grad_norm_(sparse_autoencoder.parameters(), 1.0) + scaler.step(ctx.optimizer) # just ctx.optimizer.step() if not autocasting + scaler.update() if sparse_autoencoder.normalize_sae_decoder: sparse_autoencoder.remove_gradient_parallel_to_decoder_directions() From 862fdd4e17e5cea7c2231b081c173e6f3f2836f0 Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 13:59:17 +0100 Subject: [PATCH 3/6] remove completed TODO --- sae_lens/training/train_sae_on_language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 5681dc01..ac2c2075 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -544,7 +544,7 @@ def _train_step( all_layers: list[int], batch_size: int, wandb_suffix: str, - autocast: bool = True, # TODO(tomMcGrath): propagate up to config + autocast: bool = True, ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) From 9eb27aa00e98bd637139a4050507b678618fecb1 Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 14:34:36 +0100 Subject: [PATCH 4/6] add autocast dtype selection (generally keep bf16) --- sae_lens/training/config.py | 3 ++- sae_lens/training/lm_runner.py | 1 + sae_lens/training/train_sae_on_language_model.py | 8 +++++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 13739071..1a7b9821 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -64,7 +64,8 @@ class LanguageModelSAERunnerConfig: seed: int = 42 dtype: str | torch.dtype = "float32" # type: ignore # prepend_bos: bool = True - autocast: bool = False # autocast to bf16 during training + autocast: bool = False # autocast to autocast_dtype during training + autocast_dtype: torch.dtype = torch.bfloat16 # float16 is typically unstable # Training Parameters diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index 5a7c0bd9..d7bcd04f 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -85,6 +85,7 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): wandb_log_frequency=cfg.wandb_log_frequency, eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs, autocast=cfg.autocast, + autocast_dtype=cfg.autocast_dtype, ).sae_group if cfg.log_to_wandb: diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index ac2c2075..dfc1366a 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -188,6 +188,7 @@ def train_sae_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, + autocast_dtype: torch.dtype = torch.bfloat16, ) -> SparseAutoencoderDictionary: """ @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. @@ -203,6 +204,7 @@ def train_sae_on_language_model( wandb_log_frequency=wandb_log_frequency, eval_every_n_wandb_logs=eval_every_n_wandb_logs, autocast=autocast, + autocast_dtype=autocast_dtype, ).sae_group @@ -223,6 +225,7 @@ def train_sae_group_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, + autocast_dtype: torch.dtype = torch.bfloat16, ) -> TrainSAEGroupOutput: total_training_tokens = get_total_training_tokens(sae_group=sae_group) _update_sae_lens_training_version(sae_group) @@ -294,6 +297,7 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): batch_size=batch_size, wandb_suffix=wandb_suffix, autocast=autocast, + autocast_dtype=autocast_dtype, ) mse_losses.append(step_output.mse_loss) l1_losses.append(step_output.l1_loss) @@ -545,6 +549,7 @@ def _train_step( batch_size: int, wandb_suffix: str, autocast: bool = True, + autocast_dtype: torch.dtype = torch.bfloat16, ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) @@ -590,7 +595,8 @@ def _train_step( if autocast: autocast_if_enabled = torch.autocast( device_type="cuda", - dtype=torch.bfloat16, + # dtype=torch.bfloat16, + dtype=autocast_dtype, enabled=autocast, ) else: From 40e868d27b479c36d4b77f2c8455fbe6ec91a6c7 Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 14:35:32 +0100 Subject: [PATCH 5/6] formatting fix --- sae_lens/training/train_sae_on_language_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index dfc1366a..851c5885 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -598,7 +598,7 @@ def _train_step( # dtype=torch.bfloat16, dtype=autocast_dtype, enabled=autocast, - ) + ) else: autocast_if_enabled = contextlib.nullcontext() From ba0b8a7f7397cb458fcb4228bf0a26e5ceb9aff2 Mon Sep 17 00:00:00 2001 From: tomMcGrath Date: Tue, 7 May 2024 14:45:51 +0100 Subject: [PATCH 6/6] remove autocast dtype --- sae_lens/training/config.py | 1 - sae_lens/training/lm_runner.py | 1 - sae_lens/training/train_sae_on_language_model.py | 8 +------- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/sae_lens/training/config.py b/sae_lens/training/config.py index 1a7b9821..35e72f0d 100644 --- a/sae_lens/training/config.py +++ b/sae_lens/training/config.py @@ -65,7 +65,6 @@ class LanguageModelSAERunnerConfig: dtype: str | torch.dtype = "float32" # type: ignore # prepend_bos: bool = True autocast: bool = False # autocast to autocast_dtype during training - autocast_dtype: torch.dtype = torch.bfloat16 # float16 is typically unstable # Training Parameters diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index d7bcd04f..5a7c0bd9 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -85,7 +85,6 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): wandb_log_frequency=cfg.wandb_log_frequency, eval_every_n_wandb_logs=cfg.eval_every_n_wandb_logs, autocast=cfg.autocast, - autocast_dtype=cfg.autocast_dtype, ).sae_group if cfg.log_to_wandb: diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index 851c5885..639281a5 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -188,7 +188,6 @@ def train_sae_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, - autocast_dtype: torch.dtype = torch.bfloat16, ) -> SparseAutoencoderDictionary: """ @deprecated Use `train_sae_group_on_language_model` instead. This method is kept for backward compatibility. @@ -204,7 +203,6 @@ def train_sae_on_language_model( wandb_log_frequency=wandb_log_frequency, eval_every_n_wandb_logs=eval_every_n_wandb_logs, autocast=autocast, - autocast_dtype=autocast_dtype, ).sae_group @@ -225,7 +223,6 @@ def train_sae_group_on_language_model( wandb_log_frequency: int = 50, eval_every_n_wandb_logs: int = 100, autocast: bool = False, - autocast_dtype: torch.dtype = torch.bfloat16, ) -> TrainSAEGroupOutput: total_training_tokens = get_total_training_tokens(sae_group=sae_group) _update_sae_lens_training_version(sae_group) @@ -297,7 +294,6 @@ def interrupt_callback(sig_num: Any, stack_frame: Any): batch_size=batch_size, wandb_suffix=wandb_suffix, autocast=autocast, - autocast_dtype=autocast_dtype, ) mse_losses.append(step_output.mse_loss) l1_losses.append(step_output.l1_loss) @@ -549,7 +545,6 @@ def _train_step( batch_size: int, wandb_suffix: str, autocast: bool = True, - autocast_dtype: torch.dtype = torch.bfloat16, ) -> TrainStepOutput: assert sparse_autoencoder.cfg.d_sae is not None # keep pyright happy layer_id = all_layers.index(sparse_autoencoder.hook_point_layer) @@ -595,8 +590,7 @@ def _train_step( if autocast: autocast_if_enabled = torch.autocast( device_type="cuda", - # dtype=torch.bfloat16, - dtype=autocast_dtype, + dtype=torch.bfloat16, enabled=autocast, ) else: