-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support seqpos slicing #294
base: main
Are you sure you want to change the base?
Support seqpos slicing #294
Conversation
Hey @callummcdougall I've pushed:
|
Got it, sorry for causing undue work - yes in the future will make sure to add tests! I wasn't sure about putting it in the sae config cause it's about the SAE's training data (or what inputs make sense for it) but not about e.g. the SAE's actual architecture. I was basing this on the fact that |
@callummcdougall I think the idea is that if you couldn't evaluate the SAE without knowing about this property, then it needs to be in the SAE config. Speaking of which I don't see any changes to the evals.py but presumably we should ensure that evals are only run on seqpos positions? Are you able to do this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code-wise this looks good to me, and looks like a reasonable addition to the library! Will defer to @jbloomAus if this is OK to merge. I guess there's a question fo whether the expectation is that this would require different evals, or if this is something that only effects training.
activations = activation_store.get_activations(batch) | ||
|
||
assert batch.shape == (1, 10) # Full context size | ||
assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice! Really great test 🥇
Think it does seem valuable to also have the logged metrics during training only apply to the right sequence positions - is that what you meant @jbloomAus , or did you mean evals that are applied in a non-training context? Either way I can likely get to that later this week |
This allows seqpos slicing during training. Basically we add a
seqpos_slice
arg to theLanguageModelSAERunnerConfig
(in the form of a tuple, which gets converted to a slice viaslice(*seqpos_slice)
- this is because slice objects aren't serializable when we're saving the config).Apart from this config, the only other file getting changed is
activations_store.py
. It now has aseqpos_slice
attribute, and it uses this to slice the activations which are fetched fromget_activations
(and which are used inget_buffer
).Note that the default behaviour is
seqpos_slice = (None,)
, which slices over all sequence positions. Also note thatseqpos_slice
can be used in conjunction withcontext_size
(i.e. one doesn't make the other redundant).