Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support seqpos slicing #294

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class LanguageModelSAERunnerConfig:
store_batch_size_prompts (int): The batch size for storing activations. This controls how many prompts are in the batch of the language model when generating actiations.
train_batch_size_tokens (int): The batch size for training. This controls the batch size of the SAE Training loop.
normalize_activations (str): Activation Normalization Strategy. Either none, expected_average_only_in (estimate the average activation norm and divide activations by it -> this can be folded post training and set to None), or constant_norm_rescale (at runtime set activation norm to sqrt(d_in) and then scale up the SAE output).
seqpos_slice (tuple): Determines slicing of (batch, seq, d_in) activations when constructing batches, during training. Example: for Othello we sometimes use (5, -5).
device (str): The device to use. Usually cuda.
act_store_device (str): The device to use for the activation store. CPU is advised in order to save vram.
seed (int): The seed to use.
Expand Down Expand Up @@ -151,6 +152,7 @@ class LanguageModelSAERunnerConfig:
normalize_activations: str = (
"none" # none, expected_average_only_in (Anthropic April Update), constant_norm_rescale (Anthropic Feb Update)
)
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down Expand Up @@ -378,6 +380,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"activation_fn_kwargs": self.activation_fn_kwargs,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}

def get_training_sae_cfg_dict(self) -> dict[str, Any]:
Expand Down Expand Up @@ -419,6 +422,15 @@ def to_json(self, path: str) -> None:
def from_json(cls, path: str) -> "LanguageModelSAERunnerConfig":
with open(path + "cfg.json", "r") as f:
cfg = json.load(f)

# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in cfg:
if isinstance(cfg["seqpos_slice"], list):
cfg["seqpos_slice"] = tuple(cfg["seqpos_slice"])
elif not isinstance(cfg["seqpos_slice"], tuple):
cfg["seqpos_slice"] = (cfg["seqpos_slice"],)

return cls(**cfg)


Expand Down Expand Up @@ -453,6 +465,7 @@ class CacheActivationsRunnerConfig:
store_batch_size_prompts: int = 32
train_batch_size_tokens: int = 4096
normalize_activations: str = "none" # should always be none for activation caching
seqpos_slice: tuple[int | None, ...] = (None,)

# Misc
device: str = "cpu"
Expand Down
6 changes: 6 additions & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class SAEConfig:
activation_fn_kwargs: dict[str, Any] = field(default_factory=dict)
neuronpedia_id: Optional[str] = None
model_from_pretrained_kwargs: dict[str, Any] = field(default_factory=dict)
seqpos_slice: tuple[int | None, ...] = (None,)

@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
Expand All @@ -81,6 +82,10 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig":
for k, v in config_dict.items()
if k in cls.__dataclass_fields__ # pylint: disable=no-member
}

if "seqpos_slice" in config_dict:
config_dict["seqpos_slice"] = tuple(config_dict["seqpos_slice"])

return cls(**config_dict)

# def __post_init__(self):
Expand Down Expand Up @@ -108,6 +113,7 @@ def to_dict(self) -> dict[str, Any]:
"normalize_activations": self.normalize_activations,
"neuronpedia_id": self.neuronpedia_id,
"model_from_pretrained_kwargs": self.model_from_pretrained_kwargs,
"seqpos_slice": self.seqpos_slice,
}


Expand Down
36 changes: 21 additions & 15 deletions sae_lens/training/activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def from_config(
model_kwargs=cfg.model_kwargs,
autocast_lm=cfg.autocast_lm,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
seqpos_slice=cfg.seqpos_slice,
)

@classmethod
Expand Down Expand Up @@ -122,6 +123,7 @@ def from_sae(
dataset_trust_remote_code=sae.cfg.dataset_trust_remote_code,
dtype=sae.cfg.dtype,
device=torch.device(device),
seqpos_slice=sae.cfg.seqpos_slice,
)

def __init__(
Expand All @@ -146,6 +148,7 @@ def __init__(
model_kwargs: dict[str, Any] | None = None,
autocast_lm: bool = False,
dataset_trust_remote_code: bool | None = None,
seqpos_slice: tuple[int | None, ...] = (None,),
):
self.model = model
if model_kwargs is None:
Expand Down Expand Up @@ -187,6 +190,7 @@ def __init__(
self.dtype = DTYPE_MAP[dtype]
self.cached_activations_path = cached_activations_path
self.autocast_lm = autocast_lm
self.seqpos_slice = seqpos_slice

self.n_dataset_processed = 0

Expand Down Expand Up @@ -428,37 +432,38 @@ def get_activations(self, batch_tokens: torch.Tensor):
autocast_if_enabled = contextlib.nullcontext()

with autocast_if_enabled:
layerwise_activations = self.model.run_with_cache(
layerwise_activations_cache = self.model.run_with_cache(
batch_tokens,
names_filter=[self.hook_name],
stop_at_layer=self.hook_layer + 1,
prepend_bos=False,
**self.model_kwargs,
)[1]

n_batches, n_context = batch_tokens.shape
layerwise_activations = layerwise_activations_cache[self.hook_name][
:, slice(*self.seqpos_slice)
]
n_batches, n_context = layerwise_activations.shape[:2]

stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in))

if self.hook_head_index is not None:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name][
stacked_activations[:, :, 0] = layerwise_activations[
:, :, self.hook_head_index
]
elif (
layerwise_activations[self.hook_name].ndim > 3
): # if we have a head dimension
elif layerwise_activations.ndim > 3: # if we have a head dimension
try:
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].view(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.view(
n_batches, n_context, -1
)
except RuntimeError as e:
print(f"Error during view operation: {e}")
print("Attempting to use reshape instead...")
stacked_activations[:, :, 0] = layerwise_activations[
self.hook_name
].reshape(n_batches, n_context, -1)
stacked_activations[:, :, 0] = layerwise_activations.reshape(
n_batches, n_context, -1
)
else:
stacked_activations[:, :, 0] = layerwise_activations[self.hook_name]
stacked_activations[:, :, 0] = layerwise_activations

return stacked_activations

Expand All @@ -474,14 +479,15 @@ def get_buffer(
If raise_on_epoch_end is True, when the dataset it exhausted it will automatically refill the dataset and then raise a StopIteration so that the caller has a chance to react.
"""
context_size = self.context_size
training_context_size = len(range(context_size)[slice(*self.seqpos_slice)])
batch_size = self.store_batch_size_prompts
d_in = self.d_in
total_size = batch_size * n_batches_in_buffer
num_layers = 1

if self.cached_activations_path is not None:
# Load the activations from disk
buffer_size = total_size * context_size
buffer_size = total_size * training_context_size
# Initialize an empty tensor with an additional dimension for layers
new_buffer = torch.zeros(
(buffer_size, num_layers, d_in),
Expand Down Expand Up @@ -535,7 +541,7 @@ def get_buffer(
refill_iterator = range(0, batch_size * n_batches_in_buffer, batch_size)
# Initialize empty tensor buffer of the maximum required size with an additional dimension for layers
new_buffer = torch.zeros(
(total_size, context_size, num_layers, d_in),
(total_size, training_context_size, num_layers, d_in),
dtype=self.dtype, # type: ignore
device=self.device,
)
Expand Down
13 changes: 13 additions & 0 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def from_sae_runner_config(
context_size=cfg.context_size,
dataset_path=cfg.dataset_path,
prepend_bos=cfg.prepend_bos,
seqpos_slice=cfg.seqpos_slice,
# Training cfg
l1_coefficient=cfg.l1_coefficient,
lp_norm=cfg.lp_norm,
Expand All @@ -99,6 +100,18 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig":
valid_config_dict = {
key: val for key, val in config_dict.items() if key in valid_field_names
}

# ensure seqpos slice is tuple
# ensure that seqpos slices is a tuple
# Ensure seqpos_slice is a tuple
if "seqpos_slice" in valid_config_dict:
if isinstance(valid_config_dict["seqpos_slice"], list):
valid_config_dict["seqpos_slice"] = tuple(
valid_config_dict["seqpos_slice"]
)
elif not isinstance(valid_config_dict["seqpos_slice"], tuple):
valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],)

return TrainingSAEConfig(**valid_config_dict)

def to_dict(self) -> dict[str, Any]:
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/training/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,26 @@ def test_validate_pretokenized_dataset_tokenizer_does_nothing_if_the_dataset_pat
model_tokenizer = ts_model.tokenizer
assert model_tokenizer is not None
validate_pretokenized_dataset_tokenizer(ds_path, model_tokenizer)


def test_activations_store_respects_seqpos_slice(ts_model: HookedTransformer):
cfg = build_sae_cfg(
context_size=10,
seqpos_slice=(2, 8), # Only consider positions 2 to 7 (inclusive)
)
dataset = Dataset.from_list(
[
{"text": "This is a test sentence for slicing."},
]
* 100
)

activation_store = ActivationsStore.from_config(
ts_model, cfg, override_dataset=dataset
)

batch = activation_store.get_batch_tokens(1)
activations = activation_store.get_activations(batch)

assert batch.shape == (1, 10) # Full context size
assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! Really great test 🥇

1 change: 1 addition & 0 deletions tests/unit/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_sae_training_runner_config_get_sae_base_parameters():
"model_from_pretrained_kwargs": {
"center_writing_weights": False,
},
"seqpos_slice": (None,),
}
assert expected_config == cfg.get_base_sae_cfg_dict()

Expand Down
17 changes: 17 additions & 0 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,23 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_seqpos(tmp_path: Path) -> None:
cfg = build_sae_cfg(
seqpos_slice=(1, 3),
device="cpu",
)
model_path = str(tmp_path)
sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())

assert sae.cfg.seqpos_slice == (1, 3)

sae.save_model(model_path)

sae_loaded = SAE.load_from_pretrained(model_path, device="cpu")

assert sae_loaded.cfg.seqpos_slice == (1, 3)


# TODO: Handle scaling factor in saeBase
# def test_sae_save_and_load_from_pretrained_lacks_scaling_factor(
# tmp_path: Path,
Expand Down
Loading