Skip to content

Commit

Permalink
feat: Support training JumpReLU SAEs (#352)
Browse files Browse the repository at this point in the history
* adds JumpReLU logic to TrainingSAE

* adds unit tests for JumpReLU

* changes classes to match tutorial

* replaces bandwidth constant with param

* re-add logic to JumpReLU logic to TrainingSAE

* adds TrainingSAE.save_model()

* changes threshold to match paper

* add tests for TrainingSAE when archicture is jumprelu

* adds test for SAE.load_from_pretrained() for JumpReLU

* removes code causing test to fail

* renames initial_threshold to threshold

* removes setattr()

* adds test for TrainingSAE.save_model()

* renames threshold to jumprelu_init_threshold

* adds jumprelu_bandwidth

* removes default value for jumprelu_init_threshold downstream

* replaces zero tensor with None in Step.backward()

* adds jumprelu to architecture type
  • Loading branch information
anthonyduong9 authored Nov 3, 2024
1 parent 9cb99e1 commit 0b56d03
Show file tree
Hide file tree
Showing 7 changed files with 309 additions and 23 deletions.
8 changes: 7 additions & 1 deletion sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ class LanguageModelSAERunnerConfig:
seed (int): The seed to use.
dtype (str): The data type to use.
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
jumprelu_init_threshold (float): The threshold to initialize for training JumpReLU SAEs.
jumprelu_bandwidth (float): Bandwidth for training JumpReLU SAEs.
autocast (bool): Whether to use autocast during training. Saves vram.
autocast_lm (bool): Whether to use autocast during activation fetching.
compile_llm (bool): Whether to compile the LLM.
Expand Down Expand Up @@ -128,7 +130,7 @@ class LanguageModelSAERunnerConfig:
)

# SAE Parameters
architecture: Literal["standard", "gated"] = "standard"
architecture: Literal["standard", "gated", "jumprelu"] = "standard"
d_in: int = 512
d_sae: Optional[int] = None
b_dec_init_method: str = "geometric_median"
Expand Down Expand Up @@ -162,6 +164,8 @@ class LanguageModelSAERunnerConfig:
seed: int = 42
dtype: str = "float32" # type: ignore #
prepend_bos: bool = True
jumprelu_init_threshold: float = 0.001
jumprelu_bandwidth: float = 0.001

# Performance - see compilation section of lm_runner.py for info
autocast: bool = False # autocast to autocast_dtype during training
Expand Down Expand Up @@ -410,6 +414,8 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
"decoder_heuristic_init": self.decoder_heuristic_init,
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
"normalize_activations": self.normalize_activations,
"jumprelu_init_threshold": self.jumprelu_init_threshold,
"jumprelu_bandwidth": self.jumprelu_bandwidth,
}

def to_dict(self) -> dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def load_from_pretrained(
sae_cfg = SAEConfig.from_dict(cfg_dict)

sae = cls(sae_cfg)

sae.load_state_dict(state_dict)

return sae
Expand Down
163 changes: 142 additions & 21 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from typing import Any, Optional

import einops
import numpy as np
import torch
from jaxtyping import Float
from safetensors.torch import save_file
from torch import nn

from sae_lens.config import LanguageModelSAERunnerConfig
Expand All @@ -24,6 +26,67 @@
SAE_CFG_PATH = "cfg.json"


def rectangle(x: torch.Tensor) -> torch.Tensor:
return ((x > -0.5) & (x < 0.5)).to(x)


class Step(torch.autograd.Function):
@staticmethod
def forward(
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float
) -> torch.Tensor:
return (x > threshold).to(x)

@staticmethod
def setup_context(
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
) -> None:
x, threshold, bandwidth = inputs
del output
ctx.save_for_backward(x, threshold)
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[None, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
threshold_grad = torch.sum(
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
dim=0,
)
return None, threshold_grad, None


class JumpReLU(torch.autograd.Function):
@staticmethod
def forward(
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float
) -> torch.Tensor:
return (x * (x > threshold)).to(x)

@staticmethod
def setup_context(
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
) -> None:
x, threshold, bandwidth = inputs
del output
ctx.save_for_backward(x, threshold)
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
threshold_grad = torch.sum(
-(threshold / bandwidth)
* rectangle((x - threshold) / bandwidth)
* grad_output,
dim=0,
)
return x_grad, threshold_grad, None


@dataclass
class TrainStepOutput:
sae_in: torch.Tensor
Expand All @@ -50,6 +113,8 @@ class TrainingSAEConfig(SAEConfig):
decoder_heuristic_init: bool = False
init_encoder_as_decoder_transpose: bool = False
scale_sparsity_penalty_by_decoder_norm: bool = False
jumprelu_init_threshold: float
jumprelu_bandwidth: float

@classmethod
def from_sae_runner_config(
Expand Down Expand Up @@ -90,6 +155,8 @@ def from_sae_runner_config(
normalize_activations=cfg.normalize_activations,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
jumprelu_init_threshold=cfg.jumprelu_init_threshold,
jumprelu_bandwidth=cfg.jumprelu_bandwidth,
)

@classmethod
Expand Down Expand Up @@ -173,11 +240,19 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore

self.encode_with_hidden_pre_fn = (
self.encode_with_hidden_pre
if cfg.architecture != "gated"
else self.encode_with_hidden_pre_gated
)
if cfg.architecture == "standard":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
elif cfg.architecture == "gated":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
elif cfg.architecture == "jumprelu":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
self.log_threshold = nn.Parameter(
torch.ones(cfg.d_sae, dtype=self.dtype, device=self.device)
* np.log(cfg.jumprelu_init_threshold)
)
self.bandwidth = cfg.jumprelu_bandwidth
else:
raise ValueError(f"Unknown architecture: {cfg.architecture}")

self.check_cfg_compatibility()

Expand Down Expand Up @@ -211,6 +286,24 @@ def encode_standard(
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
return feature_acts

def encode_with_hidden_pre_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
sae_in = self.process_sae_in(x)

hidden_pre = sae_in @ self.W_enc + self.b_enc

if self.training:
hidden_pre = (
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
)

threshold = torch.exp(self.log_threshold)

feature_acts = JumpReLU.apply(hidden_pre, threshold, self.bandwidth)

return feature_acts, hidden_pre # type: ignore

def encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
Expand Down Expand Up @@ -271,27 +364,16 @@ def training_forward_pass(

# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in)
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
sae_out = self.decode(feature_acts)

# MSE LOSS
per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
mse_loss = per_item_mse_loss.sum(dim=-1).mean()

# GHOST GRADS
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None:

# first half of second forward pass
_, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
ghost_grad_loss = self.calculate_ghost_grad_loss(
x=sae_in,
sae_out=sae_out,
per_item_mse_loss=per_item_mse_loss,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
else:
ghost_grad_loss = 0.0
l1_loss = torch.tensor(0.0, device=sae_in.device)
aux_reconstruction_loss = torch.tensor(0.0, device=sae_in.device)
ghost_grad_loss = torch.tensor(0.0, device=sae_in.device)

if self.cfg.architecture == "gated":
# Gated SAE Loss Calculation
Expand All @@ -316,6 +398,11 @@ def training_forward_pass(
).mean()

loss = mse_loss + l1_loss + aux_reconstruction_loss
elif self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold)
l0 = torch.sum(Step.apply(hidden_pre, threshold, self.bandwidth), dim=-1) # type: ignore
l1_loss = (current_l1_coefficient * l0).mean()
loss = mse_loss + l1_loss
else:
# default SAE sparsity loss
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
Expand All @@ -326,7 +413,19 @@ def training_forward_pass(
l1_loss = (current_l1_coefficient * sparsity).mean()
loss = mse_loss + l1_loss + ghost_grad_loss

aux_reconstruction_loss = torch.tensor(0.0)
if (
self.cfg.use_ghost_grads
and self.training
and dead_neuron_mask is not None
):
ghost_grad_loss = self.calculate_ghost_grad_loss(
x=sae_in,
sae_out=sae_out,
per_item_mse_loss=per_item_mse_loss,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
loss = loss + ghost_grad_loss

return TrainStepOutput(
sae_in=sae_in,
Expand Down Expand Up @@ -403,6 +502,28 @@ def batch_norm_mse_loss_fn(
else:
return standard_mse_loss_fn

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
if not os.path.exists(path):
os.mkdir(path)

state_dict = self.state_dict().copy()

if self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold).detach()
del state_dict["log_threshold"]
state_dict["threshold"] = threshold

save_file(state_dict, f"{path}/{SAE_WEIGHTS_PATH}")

# Save the config
config = self.cfg.to_dict()
with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
json.dump(config, f)

if sparsity is not None:
sparsity_in_dict = {"sparsity": sparsity}
save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}")

@classmethod
def load_from_pretrained(
cls,
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/test_jumprelu_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import pytest
import torch

from sae_lens.training.training_sae import JumpReLU, TrainingSAE
from tests.unit.helpers import build_sae_cfg


def test_jumprelu_sae_encoding():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in
d_sae = sae.cfg.d_sae

x = torch.randn(batch_size, d_in)
feature_acts, hidden_pre = sae.encode_with_hidden_pre_jumprelu(x)

assert feature_acts.shape == (batch_size, d_sae)
assert hidden_pre.shape == (batch_size, d_sae)

# Check the JumpReLU thresholding
sae_in = sae.process_sae_in(x)
expected_hidden_pre = sae_in @ sae.W_enc + sae.b_enc
threshold = torch.exp(sae.log_threshold)
expected_feature_acts = JumpReLU.apply(
expected_hidden_pre, threshold, sae.bandwidth
)

assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) # type: ignore


def test_jumprelu_sae_training_forward_pass():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in

x = torch.randn(batch_size, d_in)
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == (
train_step_output.mse_loss + train_step_output.l1_loss
)

expected_mse_loss = (
(torch.pow((train_step_output.sae_out - x.float()), 2))
.sum(dim=-1)
.mean()
.detach()
.float()
)

assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss
27 changes: 27 additions & 0 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae import SAE, _disable_hooks
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import build_sae_cfg


Expand Down Expand Up @@ -202,6 +203,32 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None:
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_jumprelu(tmp_path: Path) -> None:
cfg = build_sae_cfg(architecture="gated")
model_path = str(tmp_path)
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
sae_state_dict = sae.state_dict()
sae.save_model(model_path)

assert os.path.exists(model_path)

sae_loaded = SAE.load_from_pretrained(model_path, device="cpu")

sae_loaded_state_dict = sae_loaded.state_dict()

# check state_dict matches the original
for key in sae.state_dict().keys():
assert torch.allclose(
sae_state_dict[key],
sae_loaded_state_dict[key],
)

sae_in = torch.randn(10, cfg.d_in, device=cfg.device)
sae_out_1 = sae(sae_in)
sae_out_2 = sae_loaded(sae_in)
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
cfg = build_sae_cfg(activation_fn_kwargs={"k": 30})
model_path = str(tmp_path)
Expand Down
Loading

0 comments on commit 0b56d03

Please sign in to comment.