diff --git a/sae_lens/training/lm_runner.py b/sae_lens/training/lm_runner.py index a3c2d673..eb20f598 100644 --- a/sae_lens/training/lm_runner.py +++ b/sae_lens/training/lm_runner.py @@ -14,6 +14,9 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): """ """ + training_run_state = None + train_contexts = None + if cfg.resume: try: checkpoint_path = cfg.get_resume_checkpoint_path() @@ -41,8 +44,6 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): cfg.resume = False if not cfg.resume: - training_run_state = None - train_contexts = None if cfg.from_pretrained_path is not None: ( model, @@ -62,11 +63,10 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig): wandb.init(project=cfg.wandb_project, config=cast(Any, cfg), name=cfg.run_name) # train SAE - sparse_autoencoder = train_sae_group_on_language_model( - model=model, - sae_group=sparse_autoencoder, - activation_store=activations_loader, + model=model, # pyright: ignore [reportPossiblyUnboundVariable] + sae_group=sparse_autoencoder, # pyright: ignore [reportPossiblyUnboundVariable] + activation_store=activations_loader, # pyright: ignore [reportPossiblyUnboundVariable] train_contexts=train_contexts, training_run_state=training_run_state, batch_size=cfg.train_batch_size, diff --git a/sae_lens/training/train_sae_on_language_model.py b/sae_lens/training/train_sae_on_language_model.py index a42f2f13..ff31cec9 100644 --- a/sae_lens/training/train_sae_on_language_model.py +++ b/sae_lens/training/train_sae_on_language_model.py @@ -104,8 +104,8 @@ def load(cls, path: str, sae: SparseAutoencoder, total_training_steps: int): state_dict[attr.name] = value # non-tensor values (like int or bool) elif not type(value) is torch.Tensor: - state_dict[attr.name] = state_dict[attr.name].item() - ctx = cls(**state_dict) + state_dict[attr.name] = state_dict[attr.name].item() # pyright: ignore [reportArgumentType] + ctx = cls(**state_dict) # pyright: ignore [reportArgumentType] # if fine tuning, we need to set sae requires grad properly if ctx.finetuning: ctx.begin_finetuning(sae=sae) @@ -129,9 +129,9 @@ class SAETrainingRunState: n_training_tokens: int = 0 started_fine_tuning: bool = False checkpoint_paths: list[str] = field(default_factory=list) - torch_state: Optional[torch.ByteTensor] = None - torch_cuda_state: Optional[torch.ByteTensor] = None - numpy_state: Optional[tuple[str, np.ndarray[np.uint32], int, int, float]] = None + torch_state: Optional[torch.Tensor] = None + torch_cuda_state: Optional[list[torch.Tensor]] = None + numpy_state: Optional[dict[str, Any] | tuple[str, np.ndarray[Any, np.dtype[np.uint32]], int, int, float]] = None random_state: Optional[Any] = None def __post_init__(self): @@ -145,9 +145,13 @@ def __post_init__(self): self.random_state = random.getstate() def set_random_state(self): + assert self.torch_state is not None torch.random.set_rng_state(self.torch_state) + assert self.torch_cuda_state is not None torch.cuda.set_rng_state_all(self.torch_cuda_state) + assert self.numpy_state is not None np.random.set_state(self.numpy_state) + assert self.random_state is not None random.setstate(self.random_state) @classmethod