Skip to content

Commit

Permalink
fix: ensure that load_from_pretrained config typing is correct
Browse files Browse the repository at this point in the history
  • Loading branch information
chanind committed May 7, 2024
1 parent 67660ec commit b0f4454
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ automated-interpretability = "^0.0.3"
python-dotenv = "^1.0.1"
pyyaml = "^6.0.1"
pytest-profiling = "^1.7.0"
typeguard = "^2.13.3"


[tool.poetry.group.dev.dependencies]
Expand Down
22 changes: 18 additions & 4 deletions sae_lens/toolkit/pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import json
from typing import Protocol
from dataclasses import fields
from typing import Any, Protocol

import torch
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from typeguard import TypeCheckError, check_type

from sae_lens import __version__
from sae_lens.training.config import LanguageModelSAERunnerConfig
Expand Down Expand Up @@ -44,9 +46,21 @@ def load_pretrained_sae_lens_sae_components(
) -> tuple[LanguageModelSAERunnerConfig, dict[str, torch.Tensor]]:
with open(cfg_path, "r") as f:
config = json.load(f)
var_names = LanguageModelSAERunnerConfig.__init__.__code__.co_varnames
# filter config for varnames
config = {k: v for k, v in config.items() if k in var_names}

config_fields = fields(LanguageModelSAERunnerConfig)
fields_by_name = {f.name: f for f in config_fields}

def is_valid_config_field(var_name: str, val: Any) -> bool:
if var_name not in fields_by_name:
return False
try:
check_type(val, fields_by_name[var_name].type)
return True
except TypeCheckError:
return False

# filter out any invalid config fields
config = {k: v for k, v in config.items() if is_valid_config_field(k, v)}
config["verbose"] = False
config["device"] = device

Expand Down
3 changes: 2 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import json
import os
import pickle
from typing import Callable, NamedTuple, Optional
from dataclasses import fields
from typing import Any, Callable, NamedTuple, Optional

import einops
import torch
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/training/test_session_loader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import tempfile
from dataclasses import fields

import pytest
import torch
from huggingface_hub import hf_hub_download
from transformer_lens import HookedTransformer
from typeguard import TypeCheckError, check_type

from sae_lens.training.activations_store import ActivationsStore
from sae_lens.training.config import LanguageModelSAERunnerConfig
Expand Down Expand Up @@ -105,3 +107,13 @@ def test_load_pretrained_sae_from_huggingface():
assert isinstance(activation_store, ActivationsStore)
assert sae.cfg.hook_point_layer == layer
assert sae.cfg.model_name == "gpt2-small"

for field in fields(LanguageModelSAERunnerConfig):
val = getattr(sae.cfg, field.name)
try:
check_type(val, field.type)
except TypeCheckError as e:
# reraise to get nicer error message so we know what field is messed up
raise ValueError(
f"Field {field.name} with value {val} is not of type {field.type} in the config.\n{e}"
)

0 comments on commit b0f4454

Please sign in to comment.