Skip to content

Commit

Permalink
changes from CR
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed Nov 12, 2024
1 parent 49212f5 commit 46a6751
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 153 deletions.
22 changes: 1 addition & 21 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,27 +314,7 @@ def initialize_weights_jumprelu(self):
self.threshold = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)
self.initialize_weights_basic()

@overload
def to(
Expand Down
24 changes: 2 additions & 22 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,29 +266,9 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
def initialize_weights_jumprelu(self):
# same as the superclass, except we use a log_threshold parameter instead of threshold
self.log_threshold = nn.Parameter(
torch.ones(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
torch.empty(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)
self.initialize_weights_basic()

@property
def threshold(self) -> torch.Tensor:
Expand Down
65 changes: 65 additions & 0 deletions tests/unit/training/test_jumprelu_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import pytest
import torch

from sae_lens.training.training_sae import JumpReLU, TrainingSAE
from tests.unit.helpers import build_sae_cfg


def test_jumprelu_sae_encoding():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in
d_sae = sae.cfg.d_sae

x = torch.randn(batch_size, d_in)
feature_acts, hidden_pre = sae.encode_with_hidden_pre_jumprelu(x)

assert feature_acts.shape == (batch_size, d_sae)
assert hidden_pre.shape == (batch_size, d_sae)

# Check the JumpReLU thresholding
sae_in = sae.process_sae_in(x)
expected_hidden_pre = sae_in @ sae.W_enc + sae.b_enc
threshold = torch.exp(sae.log_threshold)
expected_feature_acts = JumpReLU.apply(
expected_hidden_pre, threshold, sae.bandwidth
)

assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) # type: ignore


def test_jumprelu_sae_training_forward_pass():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in

x = torch.randn(batch_size, d_in)
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"] + train_step_output.losses["l0_loss"]
).item() # type: ignore
)

expected_mse_loss = (
(torch.pow((train_step_output.sae_out - x.float()), 2))
.sum(dim=-1)
.mean()
.detach()
.float()
)

assert (
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore
)
27 changes: 0 additions & 27 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae import SAE, _disable_hooks
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import build_sae_cfg


Expand Down Expand Up @@ -203,32 +202,6 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None:
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_jumprelu(tmp_path: Path) -> None:
cfg = build_sae_cfg(architecture="gated")
model_path = str(tmp_path)
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
sae_state_dict = sae.state_dict()
sae.save_model(model_path)

assert os.path.exists(model_path)

sae_loaded = SAE.load_from_pretrained(model_path, device="cpu")

sae_loaded_state_dict = sae_loaded.state_dict()

# check state_dict matches the original
for key in sae.state_dict().keys():
assert torch.allclose(
sae_state_dict[key],
sae_loaded_state_dict[key],
)

sae_in = torch.randn(10, cfg.d_in, device=cfg.device)
sae_out_1 = sae(sae_in)
sae_out_2 = sae_loaded(sae_in)
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
cfg = build_sae_cfg(activation_fn_kwargs={"k": 30})
model_path = str(tmp_path)
Expand Down
83 changes: 0 additions & 83 deletions tests/unit/training/test_training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,86 +64,3 @@ def test_TrainingSAE_initializes_only_with_log_threshold_if_jumprelu():
sae.threshold,
torch.ones_like(sae.log_threshold.data) * cfg.jumprelu_init_threshold,
)


def test_TrainingSAE_jumprelu_sae_encoding():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in
d_sae = sae.cfg.d_sae

x = torch.randn(batch_size, d_in)
feature_acts, hidden_pre = sae.encode_with_hidden_pre_jumprelu(x)

assert feature_acts.shape == (batch_size, d_sae)
assert hidden_pre.shape == (batch_size, d_sae)

# Check the JumpReLU thresholding
sae_in = sae.process_sae_in(x)
expected_hidden_pre = sae_in @ sae.W_enc + sae.b_enc
threshold = torch.exp(sae.log_threshold)
expected_feature_acts = JumpReLU.apply(
expected_hidden_pre, threshold, sae.bandwidth
)

assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) # type: ignore


def test_TrainingSAE_jumprelu_sae_training_forward_pass():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in

x = torch.randn(batch_size, d_in)
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"] + train_step_output.losses["l0_loss"]
).item() # type: ignore
)

expected_mse_loss = (
(torch.pow((train_step_output.sae_out - x.float()), 2))
.sum(dim=-1)
.mean()
.detach()
.float()
)

assert (
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore
)


def test_TrainingSAE_jumprelu_save_and_load(tmp_path: Path):
cfg = build_sae_cfg(architecture="jumprelu")
training_sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

training_sae.save_model(str(tmp_path))

loaded_training_sae = TrainingSAE.load_from_pretrained(str(tmp_path))
loaded_sae = SAE.load_from_pretrained(str(tmp_path))

assert training_sae.cfg.to_dict() == loaded_training_sae.cfg.to_dict()
for param_name, param in training_sae.named_parameters():
assert torch.allclose(param, loaded_training_sae.state_dict()[param_name])

test_input = torch.randn(32, cfg.d_in)
training_sae_out = training_sae.encode_with_hidden_pre_fn(test_input)[0]
loaded_training_sae_out = loaded_training_sae.encode_with_hidden_pre_fn(test_input)[
0
]
loaded_sae_out = loaded_sae.encode(test_input)
assert torch.allclose(training_sae_out, loaded_training_sae_out)
assert torch.allclose(training_sae_out, loaded_sae_out)

0 comments on commit 46a6751

Please sign in to comment.