-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
breaks up SAE.forward() into encode() and decode() #107
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #107 +/- ##
==========================================
+ Coverage 58.15% 58.44% +0.29%
==========================================
Files 17 17
Lines 1429 1439 +10
Branches 237 237
==========================================
+ Hits 831 841 +10
Misses 547 547
Partials 51 51 ☔ View full report in Codecov by Sentry. |
@@ -118,7 +119,13 @@ def __init__( | |||
|
|||
self.setup() # Required for `HookedRootModule`s | |||
|
|||
def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None): | |||
def encode( | |||
self, x: Float[torch.Tensor, "... d_in"], return_hidden_pre: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a great change, but could this be broken into 2 methods for the 2 return types here? Having different return types based on a boolean flag breaks type checking, and means that anytime this method is called it will need to be followed by an assert
or a cast
to override the type checker to tell it the correct type of the result.
I'd propose something like the following 2 methods:
def encode(self, x: Float[torch.Tensor, "... d_in"]) -> Float[torch.Tensor, "... d_sae"]:
return self._encode_with_hidden_pre(x)[0]
def _encode_with_hidden_pre(self, x: Float[torch.Tensor, "... d_in"]) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
...
In the above, _encode_with_hidden_pre()
is prefixed with a _
to indicate it's more of a private method for the SAE forward, but this is optional if we think users might want to call this method directly themselves. Probably most users would just want to call sae.encode()
and get the features back.
Curious for what @jbloomAus thinks on this as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Seems weird to also return hidden pre. @evanhanders what's the thinking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was struggling with typing in the tests and this seems like a great solution! I'll implement later today. (I think another alternative is to add a hook for hidden_pre?)
@jbloomAus -- we need hidden_pre later in the forward pass for the ghost grads calculation, so need to get it back somewhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just implemented suggestion by @chanind & pushed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work with this! Love the tests too
* breaks up SAE.forward() into encode() and decode() * cleans up return typing of encode by splitting into a hidden and public function
Description
This PR breaks up the forward() pass of the sparse autoencoder into separate encode() and decode() functions, so that users can e.g., just run the encoder or just run the decoder on input or SAE activations.
I'm not 100% sure I have the 'right' unit tests for these functions, so happy to iterate on those.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)