-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: adding load_model helper for huggingface causal LM models #226
Conversation
33de7f1
to
dc627db
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #226 +/- ##
==========================================
+ Coverage 66.80% 67.08% +0.27%
==========================================
Files 25 25
Lines 3389 3463 +74
Branches 434 451 +17
==========================================
+ Hits 2264 2323 +59
- Misses 1005 1012 +7
- Partials 120 128 +8 ☔ View full report in Codecov by Sentry. |
tagging @anthonyduong9 as a reviewer informally, since Github won't let me use the "reviewers" list to do this. |
@@ -40,5 +47,130 @@ def load_model( | |||
model_name, device=cast(Any, device), **model_from_pretrained_kwargs | |||
), | |||
) | |||
elif model_class_name == "AutoModelForCausalLM": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see many elif
s and else
s after return
s. I find it to easier to read when the elif
s are if
s, and there's no else
, but I've known others who are against this, so am curious to hear your thoughts.
There's a Pylint message for this that seems to be enabled in .pylintrc
, but it seems we don't use .pylintrc
. Perhaps I should open an issue on Pylint for discussion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this, if we can have a linting rule set up to enforce a style here I'm happy with that. the .pylintrc is likely legacy, so I'll delete that. I'd support moving to Ruff, since that supports everything from pylint, flake8, pyflakes, isort, black, etc, is a lot faster, and seems to be what the industry is moving towards. I'll open an issue for this.
sae_lens/load_model.py
Outdated
**kwargs: Any, | ||
) -> Output | Loss: | ||
# This is just what's needed for evals, not everything that HookedTransformer has | ||
assert return_type in ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see a lot of assert
in the codebase, outside of tests, but a lot of people say to not have assert
in production code and to have exceptions instead. Should we have assert
outside of tests?
Perhaps I could open an issue for discussion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can see the argument for that. I would at least agree that a user of the library should never see an assert fail. We use a lot of asserts throughout the codebase to narrow types for pyright which seems fine IMO, but should change to throw exceptions if it's something we expect a user to see. Here, this is something that should never be triggered and I just wanted to make sure that if we did try to pass another return type into this wrapper we'd at least get an error. I'll change this to a NotImplementedError
- that should accomplish the same goal.
tests/unit/test_evals.py
Outdated
sae = SAE.from_pretrained( | ||
release="gpt2-small-res-jb", | ||
sae_id="blocks.4.hook_resid_pre", | ||
device="cpu", | ||
)[0] | ||
hf_model = load_model( | ||
model_class_name="AutoModelForCausalLM", | ||
model_name="gpt2", | ||
device="cpu", | ||
) | ||
tlens_model = HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu") | ||
|
||
example_ds = Dataset.from_list( | ||
[ | ||
{"text": "hello world1"}, | ||
{"text": "hello world2"}, | ||
{"text": "hello world3"}, | ||
] | ||
* 20 | ||
) | ||
cfg = build_sae_cfg(hook_name="transformer.h.3") | ||
sae.cfg.hook_name = "transformer.h.3" | ||
hf_store = ActivationsStore.from_config(hf_model, cfg, override_dataset=example_ds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block is repeated in
test_get_sparsity_and_variance_metrics_with_hf_model_gives_same_results_as_tlens_model()
.
Should we extract the repeated code to pytest
fixtures?
model_name="gpt2", | ||
device="cpu", | ||
) | ||
assert model is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we remove? If model
is None
, the next assert
would fail.
model = load_model( | ||
model_class_name="AutoModelForCausalLM", | ||
model_name="gpt2", | ||
device="cpu", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this in four tests in this file, perhaps we should extract to a fixture.
b613da6
to
0c3a5b8
Compare
Description
This PR adds support for loading huggingface
AutoModelForCausalLM
models by internally wrapping them with aHookedRootModule
subclass.To load a model from Huggingface, you can specify
model_class_name = 'AutoModelForCausalLM'
in the SAE runner config. Thehook_name
will need to match the named_parameters of the huggingface model, so the usualblocks.0.hooks_resid_pre
won't work. Otherwise everything should work the same as when working with TransformerLens models.Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and unit tests (acceptance tests not currently in use)
make check-ci
to check format and linting. (you can runmake format
to format code if needed.)