Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaks up SAE.forward() into encode() and decode() #107

Merged
merged 2 commits into from
Apr 29, 2024

Conversation

evanhanders
Copy link
Contributor

Description

This PR breaks up the forward() pass of the sparse autoencoder into separate encode() and decode() functions, so that users can e.g., just run the encoder or just run the decoder on input or SAE activations.

I'm not 100% sure I have the 'right' unit tests for these functions, so happy to iterate on those.

Checklist:

  • [ x ] I have commented my code, particularly in hard-to-understand areas
  • [ x ] I have made corresponding changes to the documentation
  • [ x ] My changes generate no new warnings
  • [ x ] I have added tests that prove my fix is effective or that my feature works
  • [ x ] New and existing unit tests pass locally with my changes
  • [ x ] I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and unit tests (acceptance tests not currently in use)

  • [ x ] I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Copy link

codecov bot commented Apr 26, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 58.44%. Comparing base (0184671) to head (ddf8244).
Report is 6 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #107      +/-   ##
==========================================
+ Coverage   58.15%   58.44%   +0.29%     
==========================================
  Files          17       17              
  Lines        1429     1439      +10     
  Branches      237      237              
==========================================
+ Hits          831      841      +10     
  Misses        547      547              
  Partials       51       51              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@@ -118,7 +119,13 @@ def __init__(

self.setup() # Required for `HookedRootModule`s

def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None):
def encode(
self, x: Float[torch.Tensor, "... d_in"], return_hidden_pre: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a great change, but could this be broken into 2 methods for the 2 return types here? Having different return types based on a boolean flag breaks type checking, and means that anytime this method is called it will need to be followed by an assert or a cast to override the type checker to tell it the correct type of the result.

I'd propose something like the following 2 methods:

def encode(self, x: Float[torch.Tensor, "... d_in"]) -> Float[torch.Tensor, "... d_sae"]:
    return self._encode_with_hidden_pre(x)[0]

def _encode_with_hidden_pre(self, x: Float[torch.Tensor, "... d_in"]) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
    ...

In the above, _encode_with_hidden_pre() is prefixed with a _ to indicate it's more of a private method for the SAE forward, but this is optional if we think users might want to call this method directly themselves. Probably most users would just want to call sae.encode() and get the features back.

Curious for what @jbloomAus thinks on this as well

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Seems weird to also return hidden pre. @evanhanders what's the thinking here?

Copy link
Contributor Author

@evanhanders evanhanders Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was struggling with typing in the tests and this seems like a great solution! I'll implement later today. (I think another alternative is to add a hook for hidden_pre?)

@jbloomAus -- we need hidden_pre later in the forward pass for the ghost grads calculation, so need to get it back somewhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just implemented suggestion by @chanind & pushed!

Copy link
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work with this! Love the tests too

@chanind chanind merged commit 7b4311b into jbloomAus:main Apr 29, 2024
7 checks passed
tom-pollak pushed a commit to tom-pollak/SAELens that referenced this pull request Oct 22, 2024
* breaks up SAE.forward() into encode() and decode()

* cleans up return typing of encode by splitting into a hidden and public function
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants