diff --git a/sae_training/config.py b/sae_training/config.py index a6da855d..87aaa523 100644 --- a/sae_training/config.py +++ b/sae_training/config.py @@ -63,12 +63,14 @@ class LanguageModelSAERunnerConfig(RunnerConfig): l1_coefficient: float = 1e-3 lp_norm: float = 1 lr: float = 3e-4 + lr_end: float | None = None # only used for cosine annealing, default is lr / 10 lr_scheduler_name: str = ( "constant" # constant, cosineannealing, cosineannealingwarmrestarts ) lr_warm_up_steps: int = 500 lr_decay_steps: int = 0 train_batch_size: int = 4096 + n_restart_cycles: int = 1 # only used for cosineannealingwarmrestarts # Resampling protocol args use_ghost_grads: bool = False # want to change this to true on some timeline. @@ -112,6 +114,9 @@ def __post_init__(self): self.device = torch.device(self.device) + if self.lr_end is None: + self.lr_end = self.lr / 10 + unique_id = cast( Any, wandb ).util.generate_id() # not sure why this type is erroring diff --git a/sae_training/optim.py b/sae_training/optim.py index 4f3e9514..0b1917b8 100644 --- a/sae_training/optim.py +++ b/sae_training/optim.py @@ -6,19 +6,19 @@ import torch.optim.lr_scheduler as lr_scheduler -# None -# Linear Warmup and decay +# Constant # Cosine Annealing with Warmup # Cosine Annealing with Warmup / Restarts +# No default values specified so the type-checker can verify we don't forget any arguments. def get_scheduler( scheduler_name: str, optimizer: optim.Optimizer, training_steps: int, lr: float, - warm_up_steps: int = 0, - decay_steps: int = 0, - num_cycles: int = 1, - lr_end: float = 0.0, + warm_up_steps: int, + decay_steps: int, + lr_end: float, + num_cycles: int, ) -> lr_scheduler.LRScheduler: """ Loosely based on this, seemed simpler write this than import diff --git a/sae_training/train_sae_on_language_model.py b/sae_training/train_sae_on_language_model.py index bb2766d5..00e5113f 100644 --- a/sae_training/train_sae_on_language_model.py +++ b/sae_training/train_sae_on_language_model.py @@ -216,13 +216,16 @@ def _build_train_context( n_frac_active_tokens = 0 optimizer = Adam(sae.parameters(), lr=sae.cfg.lr) + assert sae.cfg.lr_end is not None # this is set in config post-init scheduler = get_scheduler( sae.cfg.lr_scheduler_name, lr=sae.cfg.lr, optimizer=optimizer, warm_up_steps=sae.cfg.lr_warm_up_steps, + decay_steps=sae.cfg.lr_decay_steps, training_steps=total_training_steps, - lr_end=sae.cfg.lr / 10, # heuristic for now. + lr_end=sae.cfg.lr_end, + num_cycles=sae.cfg.n_restart_cycles, ) return SAETrainContext( diff --git a/tests/unit/test_optim.py b/tests/unit/test_optim.py index d7a44eb4..2c89a96a 100644 --- a/tests/unit/test_optim.py +++ b/tests/unit/test_optim.py @@ -31,11 +31,29 @@ def step(optimizer: Adam, scheduler: LRScheduler): def test_get_scheduler_errors_on_uknown_scheduler(optimizer: Adam): with pytest.raises(ValueError, match="Unsupported scheduler: unknown"): - get_scheduler("unknown", optimizer, lr=LR, training_steps=10) + get_scheduler( + "unknown", + optimizer, + lr=LR, + training_steps=10, + warm_up_steps=0, + decay_steps=0, + lr_end=0.0, + num_cycles=1, + ) def test_get_scheduler_constant(optimizer: Adam): - scheduler = get_scheduler("constant", optimizer, lr=LR, training_steps=4) + scheduler = get_scheduler( + "constant", + optimizer, + lr=LR, + training_steps=4, + warm_up_steps=0, + decay_steps=0, + lr_end=0.0, + num_cycles=1, + ) assert scheduler.get_last_lr() == [0.1] step_times(3, optimizer, scheduler) assert scheduler.get_last_lr() == [0.1] @@ -43,7 +61,14 @@ def test_get_scheduler_constant(optimizer: Adam): def test_get_scheduler_constantwithwarmup(optimizer: Adam): scheduler = get_scheduler( - "constant", optimizer, lr=LR, warm_up_steps=2, training_steps=4 + "constant", + optimizer, + lr=LR, + warm_up_steps=2, + training_steps=4, + decay_steps=0, + lr_end=0.0, + num_cycles=1, ) assert scheduler.get_last_lr() == [pytest.approx(0.05)] step(optimizer, scheduler) @@ -54,7 +79,14 @@ def test_get_scheduler_constantwithwarmup(optimizer: Adam): def test_get_scheduler_linearwarmupdecay(optimizer: Adam): scheduler = get_scheduler( - "constant", optimizer, lr=LR, warm_up_steps=2, decay_steps=4, training_steps=6 + "constant", + optimizer, + lr=LR, + warm_up_steps=2, + decay_steps=4, + training_steps=6, + lr_end=0.0, + num_cycles=1, ) # first, ramp up for 2 steps assert scheduler.get_last_lr() == [0.05] @@ -80,14 +112,23 @@ def test_get_scheduler_errors_if_lr_end_is_0_and_decay_is_set(optimizer: Adam): optimizer, lr=LR, lr_end=0.0, + warm_up_steps=0, decay_steps=2, training_steps=6, + num_cycles=1, ) def test_get_scheduler_cosineannealing(optimizer: Adam): scheduler: Any = get_scheduler( - "cosineannealing", optimizer, lr=LR, training_steps=4, lr_end=0.05 + "cosineannealing", + optimizer, + lr=LR, + training_steps=4, + lr_end=0.05, + warm_up_steps=0, + decay_steps=0, + num_cycles=1, ) assert len(scheduler._schedulers) == 1 main_scheduler = scheduler._schedulers[0] @@ -107,6 +148,7 @@ def test_get_scheduler_cosineannealing_with_warmup_and_decay(): training_steps=8, decay_steps=2, lr_end=lr_end, + num_cycles=1, ) # first, ramp up for 2 steps assert scheduler.get_last_lr() == [0.05] @@ -147,6 +189,8 @@ def test_get_scheduler_cosineannealingwarmrestarts(optimizer: Adam): training_steps=8, lr_end=0.05, num_cycles=2, + warm_up_steps=0, + decay_steps=0, ) assert len(scheduler._schedulers) == 1 main_scheduler = scheduler._schedulers[0] diff --git a/tests/unit/test_train_sae_on_language_model.py b/tests/unit/test_train_sae_on_language_model.py index 5fe681f9..3f5b4e58 100644 --- a/tests/unit/test_train_sae_on_language_model.py +++ b/tests/unit/test_train_sae_on_language_model.py @@ -47,7 +47,14 @@ def build_train_ctx( n_frac_active_tokens=n_frac_active_tokens, optimizer=optimizer, scheduler=get_scheduler( - "constant", lr=sae.cfg.lr, optimizer=optimizer, training_steps=1000 + "constant", + lr=sae.cfg.lr, + optimizer=optimizer, + training_steps=1000, + lr_end=0, + warm_up_steps=0, + decay_steps=0, + num_cycles=1, ), )