Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixing jumprelu encode and save/load #373

Merged
merged 3 commits into from
Nov 12, 2024
Merged

fix: fixing jumprelu encode and save/load #373

merged 3 commits into from
Nov 12, 2024

Conversation

chanind
Copy link
Collaborator

@chanind chanind commented Nov 11, 2024

Description

This PR fixes some minor issues related to jumprelu SAEs:

  • fixes a bug where we were accidentally creating a threshold param and a log_threshold param on TrainingSAE, and calling encode() incorrectly used the threshold while encode_with_hidden_pre correctly used the log_threshold
  • enables loading saved jumprelu SAEs for further training
  • fixes a bug where jumprelu bandwidth and init threshold weren't being saved along with the rest of the config

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • This change requires a documentation update

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

@chanind
Copy link
Collaborator Author

chanind commented Nov 11, 2024

Informally tagging @anthonyduong9 as a reviewer

Copy link

codecov bot commented Nov 11, 2024

Codecov Report

Attention: Patch coverage is 93.33333% with 2 lines in your changes missing coverage. Please review.

Project coverage is 72.74%. Comparing base (aa98caf) to head (4461ccb).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
sae_lens/training/training_sae.py 90.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #373      +/-   ##
==========================================
+ Coverage   72.46%   72.74%   +0.28%     
==========================================
  Files          22       22              
  Lines        3258     3266       +8     
  Branches      431      431              
==========================================
+ Hits         2361     2376      +15     
+ Misses        767      762       -5     
+ Partials      130      128       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Comment on lines 271 to 291
self.b_enc = nn.Parameter(
torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device)
)

self.W_dec = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device
)
)
)
self.W_enc = nn.Parameter(
torch.nn.init.kaiming_uniform_(
torch.empty(
self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device
)
)
)
self.b_dec = nn.Parameter(
torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block is also in SAE.initialize_weights_basic() and SAE.initialize_weights_jumprelu() (where we have a comment about the parameters being the same as in the former). Extracting it to a function (e.g. a method on SAE) or potentially refactoring so we can just call SAE.initialize_weights_basic() here and in SAE.initialize_weights_jumprelu() would reduce a lot of code and make it self-documenting that the three initalize_ methods are almost the same. I'm not sure if we should do it in this PR or a separate one though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aah good call - I'll make that change

Comment on lines +129 to +149
def test_TrainingSAE_jumprelu_save_and_load(tmp_path: Path):
cfg = build_sae_cfg(architecture="jumprelu")
training_sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

training_sae.save_model(str(tmp_path))

loaded_training_sae = TrainingSAE.load_from_pretrained(str(tmp_path))
loaded_sae = SAE.load_from_pretrained(str(tmp_path))

assert training_sae.cfg.to_dict() == loaded_training_sae.cfg.to_dict()
for param_name, param in training_sae.named_parameters():
assert torch.allclose(param, loaded_training_sae.state_dict()[param_name])

test_input = torch.randn(32, cfg.d_in)
training_sae_out = training_sae.encode_with_hidden_pre_fn(test_input)[0]
loaded_training_sae_out = loaded_training_sae.encode_with_hidden_pre_fn(test_input)[
0
]
loaded_sae_out = loaded_sae.encode(test_input)
assert torch.allclose(training_sae_out, loaded_training_sae_out)
assert torch.allclose(training_sae_out, loaded_sae_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test looks like a test that I wrote, except I just realized I accidentally passed "gated" instead of "jumprelu".

def test_sae_save_and_load_from_pretrained_jumprelu(tmp_path: Path) -> None:
cfg = build_sae_cfg(architecture="gated")
model_path = str(tmp_path)
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
sae_state_dict = sae.state_dict()
sae.save_model(model_path)
assert os.path.exists(model_path)
sae_loaded = SAE.load_from_pretrained(model_path, device="cpu")
sae_loaded_state_dict = sae_loaded.state_dict()
# check state_dict matches the original
for key in sae.state_dict().keys():
assert torch.allclose(
sae_state_dict[key],
sae_loaded_state_dict[key],
)
sae_in = torch.randn(10, cfg.d_in, device=cfg.device)
sae_out_1 = sae(sae_in)
sae_out_2 = sae_loaded(sae_in)
assert torch.allclose(sae_out_1, sae_out_2)

So we should delete the test that I wrote too.

Comment on lines 70 to 71
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this block in three of the tests, so it might help to extract it to a fixture.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a huge fan of pulling out stuff into fixtures unless it saves a lot of lines of code. Making this into a fixture means you can't just look at the test and understand what it's doing, you also need to scroll to look at the fixture. Since this is just 2 lines of code, I'm rather leave it as is IMO

Comment on lines 94 to 126
def test_TrainingSAE_jumprelu_sae_training_forward_pass():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in

x = torch.randn(batch_size, d_in)
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert (
pytest.approx(train_step_output.loss.detach(), rel=1e-3)
== (
train_step_output.losses["mse_loss"] + train_step_output.losses["l0_loss"]
).item() # type: ignore
)

expected_mse_loss = (
(torch.pow((train_step_output.sae_out - x.float()), 2))
.sum(dim=-1)
.mean()
.detach()
.float()
)

assert (
pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning for moving this out of test_jumprelu_sae.py?

I created that file as I was pattern matching with test_gated_sae.py and probably other files but am wondering if we should move tests from test_gated_sae.py into here as well (though not in this PR), and generally, how we want to organize tests.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh my bad I didn't see test_gated_sae.py - I would argue that should be in here too. I usually aim to have a 1-to-1 mapping of test file to library file, so if there's sae_lens/training/training_sae.py then I would expect a corresponding tests/training/test_training_sae.py. I'm not a fan of having tests not map onto source files since it becomes hard to know where tests should go, and where to look to see where something is tested. We should probably write up something about how we expect code / tests to be structured since I don't think this was discussed officially. Regardless, I think we should have some deterministic way to structure tests so it's not arbitrary where things are tested.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there's a lot of tests files related to the SAE class that don't follow this paradigm, so I reverted the changes to test_jumprelu_sae.py pending a larger discussion

@chanind chanind merged commit 17506ac into main Nov 12, 2024
7 checks passed
@chanind chanind deleted the jumprelu-fixes branch November 12, 2024 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants