-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: Test JumpReLU/Gated SAE and fix sae forward with error term #328
Conversation
x = self.run_time_activation_norm_fn_in(x) | ||
|
||
# apply b_dec_to_input if using that method. | ||
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we were calling hook_sae_input()
here after subtracting b_dec
, but in other encode()
variants we call hook_sae_input()
after reshape_fn_in()
instead. I assumed this difference was a bug and not intentional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep.
@ self.W_dec | ||
+ self.b_dec, | ||
d_head=self.d_head, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assumed the code in these codepaths is trying to exactly recreate all the encode_x()
variants followed by decode()
, but is duplicating code just to avoid triggering hooks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. Disabling hooks is a good solution, better than duplicating code.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #328 +/- ##
==========================================
+ Coverage 63.97% 64.92% +0.94%
==========================================
Files 25 25
Lines 3223 3190 -33
Branches 408 407 -1
==========================================
+ Hits 2062 2071 +9
+ Misses 1052 1013 -39
+ Partials 109 106 -3 ☔ View full report in Codecov by Sentry. |
ff865b3
to
5314f2a
Compare
|
||
# "... d_in, d_in d_sae -> ... d_sae", | ||
hidden_pre = sae_in @ self.W_enc + self.b_enc | ||
feature_acts = self.hook_sae_acts_post( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling the hook here is a bug and is responsible for #326
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct, thanks for catching this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great PR. Sorry for all the BS. Really like the context for disabling hooks.
@ self.W_dec | ||
+ self.b_dec, | ||
d_head=self.d_head, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes. Disabling hooks is a good solution, better than duplicating code.
|
||
# "... d_in, d_in d_sae -> ... d_sae", | ||
hidden_pre = sae_in @ self.W_enc + self.b_enc | ||
feature_acts = self.hook_sae_acts_post( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct, thanks for catching this!
x = self.run_time_activation_norm_fn_in(x) | ||
|
||
# apply b_dec_to_input if using that method. | ||
sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep.
Description
This PR adds test coverage to the JumpReLU and GatedSAE encode methods.
In doing this, I realized there's a lot of duplication between all the encode variants and cleaned that up as well. I think there were some potential minor bugs in this duplication, for instance in
forward()
, when adding error term, we calledrun_time_activation_norm_fn_out()
afterreshape_fn_out()
, but we do the opposite indecode()
.It looks like most of the duplication between the
error
section offorward()
and the normalencode() / decode()
was there just to avoid triggering hooks, so I added a contextmanager_disable_hooks()
which is used to disable hooks in this branch of the code while reusing our existingencode() / decode()
methods. This should mean we don't need to worry about these duplicated codepaths diverging and causing bugs.If the refactor is out of scope, I can revert the changes to
sae.py
and just leave the test coverage.Fixes #323
Fixes #326
Note: after investigating #326, it looks like this is caused by a bad copy/paste in the duplicated code for jumprelu forward, where we're accidentally including the hooks: https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L479.
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)