Skip to content

Commit

Permalink
feat: Add torch compile (#129)
Browse files Browse the repository at this point in the history
* Surface # of eval batches and # of eval sequences

* fix formatting

* config changes

* add compilation to lm_runner.py

* remove accidental print statement

* formatting fix
  • Loading branch information
tomMcGrath authored May 8, 2024
1 parent 758a50b commit 5c41336
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 1 deletion.
6 changes: 6 additions & 0 deletions sae_lens/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,13 @@ class LanguageModelSAERunnerConfig:
seed: int = 42
dtype: str | torch.dtype = "float32" # type: ignore #
prepend_bos: bool = True

# Performance - see compilation section of lm_runner.py for info
autocast: bool = False # autocast to autocast_dtype during training
compile_llm: bool = False # use torch.compile on the LLM
llm_compilation_mode: str | None = None # which torch.compile mode to use
compile_sae: bool = False # use torch.compile on the SAE
sae_compilation_mode: str | None = None

# Training Parameters

Expand Down
30 changes: 29 additions & 1 deletion sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import traceback
from typing import Any, cast

import torch
import wandb

from sae_lens.training.config import LanguageModelSAERunnerConfig
Expand Down Expand Up @@ -71,9 +72,36 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
id=cfg.wandb_id,
)

# Compile model and SAE
# torch.compile can provide significant speedups (10-20% in testing)
# using max-autotune gives the best speedups but:
# (a) increases VRAM usage,
# (b) can't be used on both SAE and LM (some issue with cudagraphs), and
# (c) takes some time to compile
# optimal settings seem to be:
# use max-autotune on SAE and max-autotune-no-cudagraphs on LM
# (also pylance seems to really hate this)
if cfg.compile_llm:
model = torch.compile(
model, # pyright: ignore [reportPossiblyUnboundVariable]
mode=cfg.llm_compilation_mode,
)

if cfg.compile_sae:
for (
k
) in (
sparse_autoencoder.autoencoders.keys() # pyright: ignore [reportPossiblyUnboundVariable]
):
sae = sparse_autoencoder.autoencoders[ # pyright: ignore [reportPossiblyUnboundVariable]
k
]
sae = torch.compile(sae, mode=cfg.sae_compilation_mode)
sparse_autoencoder.autoencoders[k] = sae # type: ignore # pyright: ignore [reportPossiblyUnboundVariable]

# train SAE
sparse_autoencoder = train_sae_group_on_language_model(
model=model, # pyright: ignore [reportPossiblyUnboundVariable]
model=model, # pyright: ignore [reportPossiblyUnboundVariable] # type: ignore
sae_group=sparse_autoencoder, # pyright: ignore [reportPossiblyUnboundVariable]
activation_store=activations_loader, # pyright: ignore [reportPossiblyUnboundVariable]
train_contexts=train_contexts,
Expand Down

0 comments on commit 5c41336

Please sign in to comment.