diff --git a/sae_training/config.py b/sae_training/config.py index 9341d57d..0c578696 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -75,6 +75,7 @@ class LanguageModelSAERunnerConfig(RunnerConfig): # WANDB log_to_wandb: bool = True wandb_project: str = "mats_sae_training_language_model" + run_name: Optional[str] = None wandb_entity: str = None wandb_log_frequency: int = 10 @@ -89,7 +90,8 @@ def __post_init__(self): self.train_batch_size * self.context_size * self.n_batches_in_buffer ) - self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" + if self.run_name is None: + self.run_name = f"{self.d_sae}-L1-{self.l1_coefficient}-LR-{self.lr}-Tokens-{self.total_training_tokens:3.3e}" if self.b_dec_init_method not in ["geometric_median", "mean", "zeros"]: raise ValueError(