Skip to content

Commit

Permalink
enable setting adam pars in config
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Apr 16, 2024
1 parent c558849 commit 1e53ede
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
2 changes: 2 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class LanguageModelSAERunnerConfig:
prepend_bos: bool = True

# Training Parameters
adam_beta1: float | list[float] = 0
adam_beta2: float | list[float] = 0.999
mse_loss_normalization: Optional[str] = None
l1_coefficient: float | list[float] = 1e-3
lp_norm: float | list[float] = 1
Expand Down
9 changes: 8 additions & 1 deletion sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,14 @@ def _build_train_context(
)
n_frac_active_tokens = 0

optimizer = Adam(sae.parameters(), lr=sae.cfg.lr)
optimizer = Adam(
sae.parameters(),
lr=sae.cfg.lr,
betas=(
sae.cfg.adam_beta1, # type: ignore
sae.cfg.adam_beta2, # type: ignore
),
)
assert sae.cfg.lr_end is not None # this is set in config post-init
scheduler = get_scheduler(
sae.cfg.lr_scheduler_name,
Expand Down

0 comments on commit 1e53ede

Please sign in to comment.