Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make load_state_dict use strict=False #95

Merged
merged 3 commits into from
Apr 21, 2024
Merged

Conversation

neelnanda-io
Copy link
Collaborator

@neelnanda-io neelnanda-io commented Apr 20, 2024

Your library is too strict, which creates errors if loading state dicts that don't eg contain a scaling factor. I added a quick fix that won't raise errors when loading a state dict that only contains some parameters. This runs the risk that someone may accidentally load a state dict with too few parameters, which is bad? But seems pretty unlikely to me

Copy link

codecov bot commented Apr 20, 2024

Codecov Report

Attention: Patch coverage is 25.00000% with 3 lines in your changes are missing coverage. Please review.

Project coverage is 57.62%. Comparing base (6a056b7) to head (c22fbbd).

Files Patch % Lines
sae_lens/training/sparse_autoencoder.py 25.00% 2 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #95      +/-   ##
==========================================
- Coverage   57.74%   57.62%   -0.13%     
==========================================
  Files          16       16              
  Lines        1394     1397       +3     
  Branches      227      228       +1     
==========================================
  Hits          805      805              
- Misses        543      545       +2     
- Partials       46       47       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@hijohnnylin
Copy link
Collaborator

Perhaps checking that the state dict’s keys is a superset of a set of required keys?

I recently made a very similar change to sae_vis here: https://github.com/callummcdougall/sae_vis/pull/39/files

@jbloomAus
Copy link
Owner

Hmm. So this is resulting from the decoder fine-tuning changes which added a new state_dict parameter to SAEs. This parameter doesn't effect the output if unused.

I think we should:

  • explicitly check whether the extra weight is the scale parameter.
  • deliberately allow strict = false if that's the only difference.

@jbloomAus jbloomAus merged commit 4a9e274 into main Apr 21, 2024
5 of 7 checks passed
@jbloomAus jbloomAus deleted the load-state-dict-not-strict branch May 20, 2024 13:36
tom-pollak pushed a commit to tom-pollak/SAELens that referenced this pull request Oct 22, 2024
…rict

Make load_state_dict use strict=False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants