-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: fixing jumprelu encode and save/load #373
Conversation
Informally tagging @anthonyduong9 as a reviewer |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #373 +/- ##
==========================================
+ Coverage 72.46% 72.74% +0.28%
==========================================
Files 22 22
Lines 3258 3266 +8
Branches 431 431
==========================================
+ Hits 2361 2376 +15
+ Misses 767 762 -5
+ Partials 130 128 -2 ☔ View full report in Codecov by Sentry. |
sae_lens/training/training_sae.py
Outdated
self.b_enc = nn.Parameter( | ||
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) | ||
) | ||
|
||
self.W_dec = nn.Parameter( | ||
torch.nn.init.kaiming_uniform_( | ||
torch.empty( | ||
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device | ||
) | ||
) | ||
) | ||
self.W_enc = nn.Parameter( | ||
torch.nn.init.kaiming_uniform_( | ||
torch.empty( | ||
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device | ||
) | ||
) | ||
) | ||
self.b_dec = nn.Parameter( | ||
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block is also in SAE.initialize_weights_basic()
and SAE.initialize_weights_jumprelu()
(where we have a comment about the parameters being the same as in the former). Extracting it to a function (e.g. a method on SAE
) or potentially refactoring so we can just call SAE.initialize_weights_basic()
here and in SAE.initialize_weights_jumprelu()
would reduce a lot of code and make it self-documenting that the three initalize_
methods are almost the same. I'm not sure if we should do it in this PR or a separate one though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aah good call - I'll make that change
def test_TrainingSAE_jumprelu_save_and_load(tmp_path: Path): | ||
cfg = build_sae_cfg(architecture="jumprelu") | ||
training_sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) | ||
|
||
training_sae.save_model(str(tmp_path)) | ||
|
||
loaded_training_sae = TrainingSAE.load_from_pretrained(str(tmp_path)) | ||
loaded_sae = SAE.load_from_pretrained(str(tmp_path)) | ||
|
||
assert training_sae.cfg.to_dict() == loaded_training_sae.cfg.to_dict() | ||
for param_name, param in training_sae.named_parameters(): | ||
assert torch.allclose(param, loaded_training_sae.state_dict()[param_name]) | ||
|
||
test_input = torch.randn(32, cfg.d_in) | ||
training_sae_out = training_sae.encode_with_hidden_pre_fn(test_input)[0] | ||
loaded_training_sae_out = loaded_training_sae.encode_with_hidden_pre_fn(test_input)[ | ||
0 | ||
] | ||
loaded_sae_out = loaded_sae.encode(test_input) | ||
assert torch.allclose(training_sae_out, loaded_training_sae_out) | ||
assert torch.allclose(training_sae_out, loaded_sae_out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test looks like a test that I wrote, except I just realized I accidentally passed "gated"
instead of "jumprelu"
.
SAELens/tests/unit/training/test_sae_basic.py
Lines 206 to 229 in 156ddc9
def test_sae_save_and_load_from_pretrained_jumprelu(tmp_path: Path) -> None: | |
cfg = build_sae_cfg(architecture="gated") | |
model_path = str(tmp_path) | |
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) | |
sae_state_dict = sae.state_dict() | |
sae.save_model(model_path) | |
assert os.path.exists(model_path) | |
sae_loaded = SAE.load_from_pretrained(model_path, device="cpu") | |
sae_loaded_state_dict = sae_loaded.state_dict() | |
# check state_dict matches the original | |
for key in sae.state_dict().keys(): | |
assert torch.allclose( | |
sae_state_dict[key], | |
sae_loaded_state_dict[key], | |
) | |
sae_in = torch.randn(10, cfg.d_in, device=cfg.device) | |
sae_out_1 = sae(sae_in) | |
sae_out_2 = sae_loaded(sae_in) | |
assert torch.allclose(sae_out_1, sae_out_2) |
So we should delete the test that I wrote too.
cfg = build_sae_cfg(architecture="jumprelu") | ||
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this block in three of the tests, so it might help to extract it to a fixture.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a huge fan of pulling out stuff into fixtures unless it saves a lot of lines of code. Making this into a fixture means you can't just look at the test and understand what it's doing, you also need to scroll to look at the fixture. Since this is just 2 lines of code, I'm rather leave it as is IMO
def test_TrainingSAE_jumprelu_sae_training_forward_pass(): | ||
cfg = build_sae_cfg(architecture="jumprelu") | ||
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) | ||
|
||
batch_size = 32 | ||
d_in = sae.cfg.d_in | ||
|
||
x = torch.randn(batch_size, d_in) | ||
train_step_output = sae.training_forward_pass( | ||
sae_in=x, | ||
current_l1_coefficient=sae.cfg.l1_coefficient, | ||
) | ||
|
||
assert train_step_output.sae_out.shape == (batch_size, d_in) | ||
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae) | ||
assert ( | ||
pytest.approx(train_step_output.loss.detach(), rel=1e-3) | ||
== ( | ||
train_step_output.losses["mse_loss"] + train_step_output.losses["l0_loss"] | ||
).item() # type: ignore | ||
) | ||
|
||
expected_mse_loss = ( | ||
(torch.pow((train_step_output.sae_out - x.float()), 2)) | ||
.sum(dim=-1) | ||
.mean() | ||
.detach() | ||
.float() | ||
) | ||
|
||
assert ( | ||
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reasoning for moving this out of test_jumprelu_sae.py
?
I created that file as I was pattern matching with test_gated_sae.py
and probably other files but am wondering if we should move tests from test_gated_sae.py
into here as well (though not in this PR), and generally, how we want to organize tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh my bad I didn't see test_gated_sae.py
- I would argue that should be in here too. I usually aim to have a 1-to-1 mapping of test file to library file, so if there's sae_lens/training/training_sae.py
then I would expect a corresponding tests/training/test_training_sae.py
. I'm not a fan of having tests not map onto source files since it becomes hard to know where tests should go, and where to look to see where something is tested. We should probably write up something about how we expect code / tests to be structured since I don't think this was discussed officially. Regardless, I think we should have some deterministic way to structure tests so it's not arbitrary where things are tested.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like there's a lot of tests files related to the SAE class that don't follow this paradigm, so I reverted the changes to test_jumprelu_sae.py
pending a larger discussion
46a6751
to
8052ed1
Compare
8052ed1
to
4461ccb
Compare
Description
This PR fixes some minor issues related to jumprelu SAEs:
threshold
param and alog_threshold
param onTrainingSAE
, and callingencode()
incorrectly used thethreshold
whileencode_with_hidden_pre
correctly used thelog_threshold
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)