Skip to content

Commit

Permalink
fix: load the same config from_pretrained and get_sae_config (#361)
Browse files Browse the repository at this point in the history
* fix: load the same config from_pretrained and get_sae_config

* merge neuronpedia_id into get_sae_config

* fixing test
  • Loading branch information
chanind authored Nov 6, 2024
1 parent b8703fe commit 8e09458
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 14 deletions.
4 changes: 0 additions & 4 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,6 @@ def from_pretrained(
)
sae_info = sae_directory.get(release, None)
config_overrides = sae_info.config_overrides if sae_info is not None else None
neuronpedia_id = (
sae_info.neuronpedia_id[sae_id] if sae_info is not None else None
)

conversion_loader_name = get_conversion_loader_name(sae_info)
conversion_loader = NAMED_PRETRAINED_SAE_LOADERS[conversion_loader_name]
Expand All @@ -637,7 +634,6 @@ def from_pretrained(

sae = cls(SAEConfig.from_dict(cfg_dict))
sae.load_state_dict(state_dict)
sae.cfg.neuronpedia_id = neuronpedia_id

# Check if normalization is 'expected_average_only_in'
if cfg_dict.get("normalize_activations") == "expected_average_only_in":
Expand Down
32 changes: 23 additions & 9 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,9 @@ def sae_lens_loader(
options = SAEConfigLoadOptions(
device=device,
force_download=force_download,
cfg_overrides=cfg_overrides,
)
cfg_dict = get_sae_config(release, sae_id=sae_id, options=options)
# Apply overrides if provided
if cfg_overrides is not None:
cfg_dict.update(cfg_overrides)
cfg_dict["device"] = device
cfg_dict = handle_config_defaulting(cfg_dict)

repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id=sae_id)

weights_filename = f"{folder_name}/sae_weights.safetensors"
Expand Down Expand Up @@ -116,6 +111,9 @@ def get_sae_config_from_hf(
with open(cfg_path, "r") as f:
cfg_dict = json.load(f)

if options.device is not None:
cfg_dict["device"] = options.device

return cfg_dict


Expand All @@ -128,7 +126,7 @@ def handle_config_defaulting(cfg_dict: dict[str, Any]) -> dict[str, Any]:
cfg_dict.setdefault("sae_lens_training_version", None)
cfg_dict.setdefault("activation_fn_str", cfg_dict.get("activation_fn", "relu"))
cfg_dict.setdefault("architecture", "standard")
cfg_dict.setdefault("neuronpedia", None)
cfg_dict.setdefault("neuronpedia_id", None)

if "normalize_activations" in cfg_dict and isinstance(
cfg_dict["normalize_activations"], bool
Expand Down Expand Up @@ -310,7 +308,7 @@ def get_gemma_2_config(
else:
raise ValueError("Hook name not found in folder_name.")

return {
cfg = {
"architecture": "jumprelu",
"d_in": d_in,
"d_sae": d_sae,
Expand All @@ -329,6 +327,10 @@ def get_gemma_2_config(
"apply_b_dec_to_input": False,
"normalize_activations": None,
}
if options.device is not None:
cfg["device"] = options.device

return cfg


def gemma_2_sae_loader(
Expand Down Expand Up @@ -470,9 +472,21 @@ def get_sae_config(
saes_directory = get_pretrained_saes_directory()
sae_info = saes_directory.get(release, None)
repo_id, folder_name = get_repo_id_and_folder_name(release, sae_id=sae_id)
cfg_overrides = options.cfg_overrides or {}
if sae_info is not None:
# avoid modifying the original dict
sae_info_overrides: dict[str, Any] = {**(sae_info.config_overrides or {})}
if sae_info.neuronpedia_id is not None:
sae_info_overrides["neuronpedia_id"] = sae_info.neuronpedia_id.get(sae_id)
cfg_overrides = {**sae_info_overrides, **cfg_overrides}

conversion_loader_name = get_conversion_loader_name(sae_info)
config_getter = NAMED_PRETRAINED_SAE_CONFIG_GETTERS[conversion_loader_name]
return config_getter(repo_id, folder_name=folder_name, options=options)
cfg = {
**config_getter(repo_id, folder_name=folder_name, options=options),
**cfg_overrides,
}
return handle_config_defaulting(cfg)


def dictionary_learning_sae_loader_1(
Expand Down
33 changes: 32 additions & 1 deletion tests/unit/toolkit/test_pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sae_lens.sae import SAE
from sae_lens.toolkit.pretrained_sae_loaders import SAEConfigLoadOptions, get_sae_config


Expand All @@ -9,11 +10,15 @@ def test_get_sae_config_sae_lens():
)

expected_cfg_dict = {
"activation_fn_str": "relu",
"apply_b_dec_to_input": True,
"architecture": "standard",
"model_name": "gpt2-small",
"hook_point": "blocks.0.hook_resid_pre",
"hook_point_layer": 0,
"hook_point_head_index": None,
"dataset_path": "Skylion007/openwebtext",
"dataset_trust_remote_code": True,
"is_dataset_tokenized": False,
"context_size": 128,
"use_cached_activations": False,
Expand All @@ -32,9 +37,13 @@ def test_get_sae_config_sae_lens():
"lr": 0.0004,
"lr_scheduler_name": None,
"lr_warm_up_steps": 5000,
"model_from_pretrained_kwargs": {
"center_writing_weights": True,
},
"train_batch_size": 4096,
"use_ghost_grads": False,
"feature_sampling_window": 1000,
"finetuning_scaling_factor": False,
"feature_sampling_method": None,
"resample_batches": 1028,
"feature_reinit_scale": 0.2,
Expand All @@ -50,6 +59,10 @@ def test_get_sae_config_sae_lens():
"d_sae": 24576,
"tokens_per_buffer": 67108864,
"run_name": "24576-L1-8e-05-LR-0.0004-Tokens-3.000e+08",
"neuronpedia_id": "gpt2-small/0-res-jb",
"normalize_activations": "none",
"prepend_bos": True,
"sae_lens_training_version": None,
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -81,6 +94,7 @@ def test_get_sae_config_connor_rob_hook_z():
"context_size": 128,
"normalize_activations": "none",
"dataset_trust_remote_code": True,
"neuronpedia_id": "gpt2-small/0-att-kk",
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -111,6 +125,8 @@ def test_get_sae_config_gemma_2():
"dataset_trust_remote_code": True,
"apply_b_dec_to_input": False,
"normalize_activations": None,
"device": "cpu",
"neuronpedia_id": None,
}

assert cfg_dict == expected_cfg_dict
Expand Down Expand Up @@ -143,7 +159,22 @@ def test_get_sae_config_dictionary_learning_1():
"dataset_trust_remote_code": True,
"context_size": 128,
"normalize_activations": "none",
"neuronpedia_id": None,
"neuronpedia_id": "gemma-2-2b/3-sae_bench-standard-res-4k__trainer_1_step_29292",
}

assert cfg_dict == expected_cfg_dict


def test_get_sae_config_matches_from_pretrained():
from_pretrained_cfg_dict = SAE.from_pretrained(
"gpt2-small-res-jb",
sae_id="blocks.0.hook_resid_pre",
device="cpu",
)[1]
direct_sae_cfg = get_sae_config(
"gpt2-small-res-jb",
sae_id="blocks.0.hook_resid_pre",
options=SAEConfigLoadOptions(device="cpu"),
)

assert direct_sae_cfg == from_pretrained_cfg_dict

0 comments on commit 8e09458

Please sign in to comment.