Skip to content

Commit

Permalink
fixing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 11, 2024
1 parent 1742bf7 commit 49212f5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
9 changes: 6 additions & 3 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ class TrainingSAE(SAE):
device: torch.device

def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):

super().__init__(cfg)
base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict())
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore

if cfg.architecture == "standard":
Expand All @@ -244,6 +244,10 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
elif cfg.architecture == "jumprelu":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
self.bandwidth = cfg.jumprelu_bandwidth
self.log_threshold.data = torch.ones(
self.cfg.d_sae, dtype=self.dtype, device=self.device
) * np.log(cfg.jumprelu_init_threshold)

else:
raise ValueError(f"Unknown architecture: {cfg.architecture}")

Expand All @@ -263,7 +267,6 @@ def initialize_weights_jumprelu(self):
# same as the superclass, except we use a log_threshold parameter instead of threshold
self.log_threshold = nn.Parameter(
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
* np.log(self.cfg.jumprelu_init_threshold)
)
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
Expand Down
6 changes: 5 additions & 1 deletion tests/unit/training/test_training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,15 @@ def test_TrainingSAE_encode_returns_same_value_as_encode_with_hidden_pre(


def test_TrainingSAE_initializes_only_with_log_threshold_if_jumprelu():
cfg = build_sae_cfg(architecture="jumprelu")
cfg = build_sae_cfg(architecture="jumprelu", jumprelu_init_threshold=0.01)
sae = TrainingSAE(TrainingSAEConfig.from_sae_runner_config(cfg))
param_names = dict(sae.named_parameters()).keys()
assert "log_threshold" in param_names
assert "threshold" not in param_names
assert torch.allclose(
sae.threshold,
torch.ones_like(sae.log_threshold.data) * cfg.jumprelu_init_threshold,
)


def test_TrainingSAE_jumprelu_sae_encoding():
Expand Down

0 comments on commit 49212f5

Please sign in to comment.