Skip to content

Commit

Permalink
Merge pull request #65 from chanind/fix-forgotten-scheduler-opts
Browse files Browse the repository at this point in the history
passing accidentally overlooked scheduler opts
  • Loading branch information
jbloomAus authored Apr 3, 2024
2 parents c960d99 + ad089b7 commit 773bc02
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 13 deletions.
5 changes: 5 additions & 0 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_end: float | None = None # only used for cosine annealing, default is lr / 10
lr_scheduler_name: str = (
"constant" # constant, cosineannealing, cosineannealingwarmrestarts
)
lr_warm_up_steps: int = 500
lr_decay_steps: int = 0
train_batch_size: int = 4096
n_restart_cycles: int = 1 # only used for cosineannealingwarmrestarts

# Resampling protocol args
use_ghost_grads: bool = False # want to change this to true on some timeline.
Expand Down Expand Up @@ -112,6 +114,9 @@ def __post_init__(self):

self.device = torch.device(self.device)

if self.lr_end is None:
self.lr_end = self.lr / 10

unique_id = cast(
Any, wandb
).util.generate_id() # not sure why this type is erroring
Expand Down
12 changes: 6 additions & 6 deletions sae_training/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@
import torch.optim.lr_scheduler as lr_scheduler


# None
# Linear Warmup and decay
# Constant
# Cosine Annealing with Warmup
# Cosine Annealing with Warmup / Restarts
# No default values specified so the type-checker can verify we don't forget any arguments.
def get_scheduler(
scheduler_name: str,
optimizer: optim.Optimizer,
training_steps: int,
lr: float,
warm_up_steps: int = 0,
decay_steps: int = 0,
num_cycles: int = 1,
lr_end: float = 0.0,
warm_up_steps: int,
decay_steps: int,
lr_end: float,
num_cycles: int,
) -> lr_scheduler.LRScheduler:
"""
Loosely based on this, seemed simpler write this than import
Expand Down
5 changes: 4 additions & 1 deletion sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,16 @@ def _build_train_context(
n_frac_active_tokens = 0

optimizer = Adam(sae.parameters(), lr=sae.cfg.lr)
assert sae.cfg.lr_end is not None # this is set in config post-init
scheduler = get_scheduler(
sae.cfg.lr_scheduler_name,
lr=sae.cfg.lr,
optimizer=optimizer,
warm_up_steps=sae.cfg.lr_warm_up_steps,
decay_steps=sae.cfg.lr_decay_steps,
training_steps=total_training_steps,
lr_end=sae.cfg.lr / 10, # heuristic for now.
lr_end=sae.cfg.lr_end,
num_cycles=sae.cfg.n_restart_cycles,
)

return SAETrainContext(
Expand Down
54 changes: 49 additions & 5 deletions tests/unit/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,44 @@ def step(optimizer: Adam, scheduler: LRScheduler):

def test_get_scheduler_errors_on_uknown_scheduler(optimizer: Adam):
with pytest.raises(ValueError, match="Unsupported scheduler: unknown"):
get_scheduler("unknown", optimizer, lr=LR, training_steps=10)
get_scheduler(
"unknown",
optimizer,
lr=LR,
training_steps=10,
warm_up_steps=0,
decay_steps=0,
lr_end=0.0,
num_cycles=1,
)


def test_get_scheduler_constant(optimizer: Adam):
scheduler = get_scheduler("constant", optimizer, lr=LR, training_steps=4)
scheduler = get_scheduler(
"constant",
optimizer,
lr=LR,
training_steps=4,
warm_up_steps=0,
decay_steps=0,
lr_end=0.0,
num_cycles=1,
)
assert scheduler.get_last_lr() == [0.1]
step_times(3, optimizer, scheduler)
assert scheduler.get_last_lr() == [0.1]


def test_get_scheduler_constantwithwarmup(optimizer: Adam):
scheduler = get_scheduler(
"constant", optimizer, lr=LR, warm_up_steps=2, training_steps=4
"constant",
optimizer,
lr=LR,
warm_up_steps=2,
training_steps=4,
decay_steps=0,
lr_end=0.0,
num_cycles=1,
)
assert scheduler.get_last_lr() == [pytest.approx(0.05)]
step(optimizer, scheduler)
Expand All @@ -54,7 +79,14 @@ def test_get_scheduler_constantwithwarmup(optimizer: Adam):

def test_get_scheduler_linearwarmupdecay(optimizer: Adam):
scheduler = get_scheduler(
"constant", optimizer, lr=LR, warm_up_steps=2, decay_steps=4, training_steps=6
"constant",
optimizer,
lr=LR,
warm_up_steps=2,
decay_steps=4,
training_steps=6,
lr_end=0.0,
num_cycles=1,
)
# first, ramp up for 2 steps
assert scheduler.get_last_lr() == [0.05]
Expand All @@ -80,14 +112,23 @@ def test_get_scheduler_errors_if_lr_end_is_0_and_decay_is_set(optimizer: Adam):
optimizer,
lr=LR,
lr_end=0.0,
warm_up_steps=0,
decay_steps=2,
training_steps=6,
num_cycles=1,
)


def test_get_scheduler_cosineannealing(optimizer: Adam):
scheduler: Any = get_scheduler(
"cosineannealing", optimizer, lr=LR, training_steps=4, lr_end=0.05
"cosineannealing",
optimizer,
lr=LR,
training_steps=4,
lr_end=0.05,
warm_up_steps=0,
decay_steps=0,
num_cycles=1,
)
assert len(scheduler._schedulers) == 1
main_scheduler = scheduler._schedulers[0]
Expand All @@ -107,6 +148,7 @@ def test_get_scheduler_cosineannealing_with_warmup_and_decay():
training_steps=8,
decay_steps=2,
lr_end=lr_end,
num_cycles=1,
)
# first, ramp up for 2 steps
assert scheduler.get_last_lr() == [0.05]
Expand Down Expand Up @@ -147,6 +189,8 @@ def test_get_scheduler_cosineannealingwarmrestarts(optimizer: Adam):
training_steps=8,
lr_end=0.05,
num_cycles=2,
warm_up_steps=0,
decay_steps=0,
)
assert len(scheduler._schedulers) == 1
main_scheduler = scheduler._schedulers[0]
Expand Down
9 changes: 8 additions & 1 deletion tests/unit/test_train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,14 @@ def build_train_ctx(
n_frac_active_tokens=n_frac_active_tokens,
optimizer=optimizer,
scheduler=get_scheduler(
"constant", lr=sae.cfg.lr, optimizer=optimizer, training_steps=1000
"constant",
lr=sae.cfg.lr,
optimizer=optimizer,
training_steps=1000,
lr_end=0,
warm_up_steps=0,
decay_steps=0,
num_cycles=1,
),
)

Expand Down

0 comments on commit 773bc02

Please sign in to comment.