Skip to content

Commit

Permalink
pyright typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Phylliida committed Apr 25, 2024
1 parent 9137d42 commit 920c1de
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 11 deletions.
12 changes: 6 additions & 6 deletions sae_lens/training/lm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
""" """
training_run_state = None
train_contexts = None

if cfg.resume:
try:
checkpoint_path = cfg.get_resume_checkpoint_path()
Expand Down Expand Up @@ -41,8 +44,6 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
cfg.resume = False

if not cfg.resume:
training_run_state = None
train_contexts = None
if cfg.from_pretrained_path is not None:
(
model,
Expand All @@ -62,11 +63,10 @@ def language_model_sae_runner(cfg: LanguageModelSAERunnerConfig):
wandb.init(project=cfg.wandb_project, config=cast(Any, cfg), name=cfg.run_name)

# train SAE

sparse_autoencoder = train_sae_group_on_language_model(
model=model,
sae_group=sparse_autoencoder,
activation_store=activations_loader,
model=model, # pyright: ignore [reportPossiblyUnboundVariable]
sae_group=sparse_autoencoder, # pyright: ignore [reportPossiblyUnboundVariable]
activation_store=activations_loader, # pyright: ignore [reportPossiblyUnboundVariable]
train_contexts=train_contexts,
training_run_state=training_run_state,
batch_size=cfg.train_batch_size,
Expand Down
14 changes: 9 additions & 5 deletions sae_lens/training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ def load(cls, path: str, sae: SparseAutoencoder, total_training_steps: int):
state_dict[attr.name] = value
# non-tensor values (like int or bool)
elif not type(value) is torch.Tensor:
state_dict[attr.name] = state_dict[attr.name].item()
ctx = cls(**state_dict)
state_dict[attr.name] = state_dict[attr.name].item() # pyright: ignore [reportArgumentType]
ctx = cls(**state_dict) # pyright: ignore [reportArgumentType]
# if fine tuning, we need to set sae requires grad properly
if ctx.finetuning:
ctx.begin_finetuning(sae=sae)
Expand All @@ -129,9 +129,9 @@ class SAETrainingRunState:
n_training_tokens: int = 0
started_fine_tuning: bool = False
checkpoint_paths: list[str] = field(default_factory=list)
torch_state: Optional[torch.ByteTensor] = None
torch_cuda_state: Optional[torch.ByteTensor] = None
numpy_state: Optional[tuple[str, np.ndarray[np.uint32], int, int, float]] = None
torch_state: Optional[torch.Tensor] = None
torch_cuda_state: Optional[list[torch.Tensor]] = None
numpy_state: Optional[dict[str, Any] | tuple[str, np.ndarray[Any, np.dtype[np.uint32]], int, int, float]] = None
random_state: Optional[Any] = None

def __post_init__(self):
Expand All @@ -145,9 +145,13 @@ def __post_init__(self):
self.random_state = random.getstate()

def set_random_state(self):
assert self.torch_state is not None
torch.random.set_rng_state(self.torch_state)
assert self.torch_cuda_state is not None
torch.cuda.set_rng_state_all(self.torch_cuda_state)
assert self.numpy_state is not None
np.random.set_state(self.numpy_state)
assert self.random_state is not None
random.setstate(self.random_state)

@classmethod
Expand Down

0 comments on commit 920c1de

Please sign in to comment.