Skip to content

Commit

Permalink
testing that sae.forward() with error term works with hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Oct 17, 2024
1 parent 9b2f417 commit ff865b3
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from torch import nn
from transformer_lens.hook_points import HookPoint

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae import SAE, _disable_hooks
Expand Down Expand Up @@ -146,7 +147,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig):


def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None:
cfg = build_sae_cfg(device="cpu")
cfg = build_sae_cfg()
model_path = str(tmp_path)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae_state_dict = sae.state_dict()
Expand All @@ -172,7 +173,7 @@ def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None:


def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None:
cfg = build_sae_cfg(architecture="gated", device="cpu")
cfg = build_sae_cfg(architecture="gated")
model_path = str(tmp_path)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae_state_dict = sae.state_dict()
Expand Down Expand Up @@ -230,7 +231,7 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
# def test_sae_save_and_load_from_pretrained_lacks_scaling_factor(
# tmp_path: Path,
# ) -> None:
# cfg = build_sae_cfg(device="cpu")
# cfg = build_sae_cfg()
# model_path = str(tmp_path)
# sparse_autoencoder = saeBase(**cfg.get_sae_base_parameters())
# sparse_autoencoder_state_dict = sparse_autoencoder.state_dict()
Expand Down Expand Up @@ -269,7 +270,7 @@ def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None:


def test_sae_move_between_devices() -> None:
cfg = build_sae_cfg(device="cpu")
cfg = build_sae_cfg()
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())

sae.to("meta")
Expand Down Expand Up @@ -422,3 +423,25 @@ def test_disable_hooks_temporarily_stops_hooks_from_running():
assert disabled_cache.keys() == set()
for key in orig_cache.keys():
assert torch.allclose(orig_cache[key], subseq_cache[key])


@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"])
def test_sae_forward_pass_works_with_error_term_and_hooks(architecture: str):
cfg = build_sae_cfg(architecture=architecture, d_in=32, d_sae=64)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
sae.use_error_term = True
sae_in = torch.randn(10, cfg.d_in)
original_out, original_cache = sae.run_with_cache(sae_in)

def ablate_hooked_sae(acts: torch.Tensor, hook: HookPoint):
acts[:, :] = 20 # This is absurd
return acts

with sae.hooks(fwd_hooks=[("hook_sae_acts_post", ablate_hooked_sae)]):
ablated_out, ablated_cache = sae.run_with_cache(sae_in)

assert not torch.allclose(original_out, ablated_out)
assert torch.all(ablated_cache["hook_sae_acts_post"] == 20)
assert torch.allclose(
original_cache["hook_sae_error"], ablated_cache["hook_sae_error"]
)

0 comments on commit ff865b3

Please sign in to comment.