-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] feat: add mlp transcoders #183
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #183 +/- ##
==========================================
- Coverage 59.59% 58.48% -1.11%
==========================================
Files 25 25
Lines 2636 2780 +144
Branches 445 466 +21
==========================================
+ Hits 1571 1626 +55
- Misses 987 1075 +88
- Partials 78 79 +1 ☔ View full report in Codecov by Sentry. |
Initial Wandb run here: https://wandb.ai/dtch1997/benchmark/workspace Benchmarked using the following command:
Green = MLP-out SAE, red = MLP transcoder
|
Current status of the PR:
The next priority might be to support pre-trained MLP transcoders. |
Some notes on architecture.
the right balance depends on a few things:
Feedback on the last two points above would be very useful. Edit:
|
# NOTE: Transcoders have an additional b_dec_out parameter. | ||
# Reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/7b44d870a5a301ef29eddfd77cb1f4dca854760a/sae_training/sparse_autoencoder.py#L93C1-L97C14 | ||
self.b_dec_out = nn.Parameter( | ||
torch.zeros(self.cfg.d_out, dtype=self.dtype, device=self.device) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why the extra bias is needed. I'm probably just confused and missing something, but it would make the implementation simpler if you don't need it.
I understand that in normal SAEs people sometimes subtract b_dec from the input. This isn't really necessary but has a nice interpretation of choosing a new "0 point" which you can consider as the origin in the feature basis.
For transcoders this makes less sense. Since you aren't reconstructing the same activations you probably don't want to tie the pre-encoder bias with the post-decoder bias.
Thus, in the current implementation we do:
and
This isn't any more expressive, you can always fold the first two biases (
Overall I'd recommend dropping the complexity here, which maybe means you can just eliminate the Transcoder class entirely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this makes sense! i'll try dropping the extra b_dec
term when training. I was initially concerned about supporting the previously-trained checkpoints, but as you say weight folding should solve that.
Description
Add support for training, loading, and running inference on MLP transcoders.
Transcoder
subclass ofSAE
TrainingTranscoder
subclass ofTrainingSAE
TranscoderTrainer
subclass ofSAETrainer
LanguageModelTranscoderTrainingRunner
subclass ofLanguageModelSAETrainingRunner
Fixes #182
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)Performance Check.
If you have implemented a training change, please indicate precisely how performance changes with respect to the following metrics:
Please links to wandb dashboards with a control and test group.