Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add logic to train JumpReLU SAEs #352

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sae_lens/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class LanguageModelSAERunnerConfig:
seed (int): The seed to use.
dtype (str): The data type to use.
prepend_bos (bool): Whether to prepend the beginning of sequence token. You should use whatever the model was trained with.
threshold (float): Threshold for training JumpReLU SAEs.
autocast (bool): Whether to use autocast during training. Saves vram.
autocast_lm (bool): Whether to use autocast during activation fetching.
compile_llm (bool): Whether to compile the LLM.
Expand Down Expand Up @@ -162,6 +163,7 @@ class LanguageModelSAERunnerConfig:
seed: int = 42
dtype: str = "float32" # type: ignore #
prepend_bos: bool = True
threshold: float = 0.001
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think this would be clearer as a jumprelu_init_threshold or something to make it clear this is only used for initialization

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also should have jumprelu_bandwidth as a param as well, currently it seems hardcoded`

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thanks for catching that bandwidth wasn't configurable. I've changed the first name's field and added the second field.


# Performance - see compilation section of lm_runner.py for info
autocast: bool = False # autocast to autocast_dtype during training
Expand Down Expand Up @@ -410,6 +412,7 @@ def get_training_sae_cfg_dict(self) -> dict[str, Any]:
"decoder_heuristic_init": self.decoder_heuristic_init,
"init_encoder_as_decoder_transpose": self.init_encoder_as_decoder_transpose,
"normalize_activations": self.normalize_activations,
"threshold": self.threshold,
}

def to_dict(self) -> dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions sae_lens/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ def load_from_pretrained(
sae_cfg = SAEConfig.from_dict(cfg_dict)

sae = cls(sae_cfg)

sae.load_state_dict(state_dict)

return sae
Expand Down
162 changes: 141 additions & 21 deletions sae_lens/training/training_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from typing import Any, Optional

import einops
import numpy as np
import torch
from jaxtyping import Float
from safetensors.torch import save_file
from torch import nn

from sae_lens.config import LanguageModelSAERunnerConfig
Expand All @@ -24,6 +26,68 @@
SAE_CFG_PATH = "cfg.json"


def rectangle(x: torch.Tensor) -> torch.Tensor:
return ((x > -0.5) & (x < 0.5)).to(x)


class Step(torch.autograd.Function):
@staticmethod
def forward(
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float = 0.001
) -> torch.Tensor:
return (x > threshold).to(x)

@staticmethod
def setup_context(
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
) -> None:
x, threshold, bandwidth = inputs
del output
ctx.save_for_backward(x, threshold)
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
x_grad = 0.0 * grad_output # We don't apply STE to x input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to just return None for the x_grad rather than multiplying by 0. I know this is just from the example code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh cool, I've made the change.

threshold_grad = torch.sum(
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output,
dim=0,
)
return x_grad, threshold_grad, None


class JumpReLU(torch.autograd.Function):
@staticmethod
def forward(
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float = 0.001
) -> torch.Tensor:
return (x * (x > threshold)).to(x)

@staticmethod
def setup_context(
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor
) -> None:
x, threshold, bandwidth = inputs
del output
ctx.save_for_backward(x, threshold)
ctx.bandwidth = bandwidth

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override]
x, threshold = ctx.saved_tensors
bandwidth = ctx.bandwidth
x_grad = (x > threshold) * grad_output # We don't apply STE to x input
threshold_grad = torch.sum(
-(threshold / bandwidth)
* rectangle((x - threshold) / bandwidth)
* grad_output,
dim=0,
)
return x_grad, threshold_grad, None


@dataclass
class TrainStepOutput:
sae_in: torch.Tensor
Expand All @@ -50,6 +114,7 @@ class TrainingSAEConfig(SAEConfig):
decoder_heuristic_init: bool = False
init_encoder_as_decoder_transpose: bool = False
scale_sparsity_penalty_by_decoder_norm: bool = False
threshold: float = 0.001

@classmethod
def from_sae_runner_config(
Expand Down Expand Up @@ -90,6 +155,7 @@ def from_sae_runner_config(
normalize_activations=cfg.normalize_activations,
dataset_trust_remote_code=cfg.dataset_trust_remote_code,
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs,
threshold=cfg.threshold,
)

@classmethod
Expand Down Expand Up @@ -173,11 +239,19 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False):
super().__init__(base_sae_cfg)
self.cfg = cfg # type: ignore

self.encode_with_hidden_pre_fn = (
self.encode_with_hidden_pre
if cfg.architecture != "gated"
else self.encode_with_hidden_pre_gated
)
if cfg.architecture == "standard":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre
elif cfg.architecture == "gated":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated
elif cfg.architecture == "jumprelu":
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu
threshold = cfg.threshold
self.log_threshold = nn.Parameter(
torch.ones(cfg.d_sae, dtype=self.dtype, device=self.device)
* np.log(threshold)
)
else:
raise ValueError(f"Unknown architecture: {cfg.architecture}")

self.check_cfg_compatibility()

Expand Down Expand Up @@ -211,6 +285,24 @@ def encode_standard(
feature_acts, _ = self.encode_with_hidden_pre_fn(x)
return feature_acts

def encode_with_hidden_pre_jumprelu(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
sae_in = self.process_sae_in(x)

hidden_pre = sae_in @ self.W_enc + self.b_enc

if self.training:
hidden_pre = (
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale
)

threshold = torch.exp(self.log_threshold)

feature_acts = JumpReLU.apply(hidden_pre, threshold)

return feature_acts, hidden_pre # type: ignore

def encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
Expand Down Expand Up @@ -271,27 +363,16 @@ def training_forward_pass(

# do a forward pass to get SAE out, but we also need the
# hidden pre.
feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in)
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
sae_out = self.decode(feature_acts)

# MSE LOSS
per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in)
mse_loss = per_item_mse_loss.sum(dim=-1).mean()

# GHOST GRADS
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None:

# first half of second forward pass
_, hidden_pre = self.encode_with_hidden_pre_fn(sae_in)
ghost_grad_loss = self.calculate_ghost_grad_loss(
x=sae_in,
sae_out=sae_out,
per_item_mse_loss=per_item_mse_loss,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
else:
ghost_grad_loss = 0.0
l1_loss = torch.tensor(0.0, device=sae_in.device)
aux_reconstruction_loss = torch.tensor(0.0, device=sae_in.device)
ghost_grad_loss = torch.tensor(0.0, device=sae_in.device)

if self.cfg.architecture == "gated":
# Gated SAE Loss Calculation
Expand All @@ -316,6 +397,11 @@ def training_forward_pass(
).mean()

loss = mse_loss + l1_loss + aux_reconstruction_loss
elif self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold)
l0 = torch.sum(Step.apply(hidden_pre, threshold), dim=-1) # type: ignore
l1_loss = (current_l1_coefficient * l0).mean()
loss = mse_loss + l1_loss
else:
# default SAE sparsity loss
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1)
Expand All @@ -326,7 +412,19 @@ def training_forward_pass(
l1_loss = (current_l1_coefficient * sparsity).mean()
loss = mse_loss + l1_loss + ghost_grad_loss

aux_reconstruction_loss = torch.tensor(0.0)
if (
self.cfg.use_ghost_grads
and self.training
and dead_neuron_mask is not None
):
ghost_grad_loss = self.calculate_ghost_grad_loss(
x=sae_in,
sae_out=sae_out,
per_item_mse_loss=per_item_mse_loss,
hidden_pre=hidden_pre,
dead_neuron_mask=dead_neuron_mask,
)
loss = loss + ghost_grad_loss

return TrainStepOutput(
sae_in=sae_in,
Expand Down Expand Up @@ -403,6 +501,28 @@ def batch_norm_mse_loss_fn(
else:
return standard_mse_loss_fn

def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None):
if not os.path.exists(path):
os.mkdir(path)

state_dict = self.state_dict().copy()

if self.cfg.architecture == "jumprelu":
threshold = torch.exp(self.log_threshold).detach()
del state_dict["log_threshold"]
state_dict["threshold"] = threshold
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


save_file(state_dict, f"{path}/{SAE_WEIGHTS_PATH}")

# Save the config
config = self.cfg.to_dict()
with open(f"{path}/{SAE_CFG_PATH}", "w") as f:
json.dump(config, f)

if sparsity is not None:
sparsity_in_dict = {"sparsity": sparsity}
save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}")

@classmethod
def load_from_pretrained(
cls,
Expand Down
58 changes: 58 additions & 0 deletions tests/unit/test_jumprelu_sae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import pytest
import torch

from sae_lens.training.training_sae import JumpReLU, TrainingSAE
from tests.unit.helpers import build_sae_cfg


def test_jumprelu_sae_encoding():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in
d_sae = sae.cfg.d_sae

x = torch.randn(batch_size, d_in)
feature_acts, hidden_pre = sae.encode_with_hidden_pre_jumprelu(x)

assert feature_acts.shape == (batch_size, d_sae)
assert hidden_pre.shape == (batch_size, d_sae)

# Check the JumpReLU thresholding
sae_in = sae.process_sae_in(x)
expected_hidden_pre = sae_in @ sae.W_enc + sae.b_enc
threshold = torch.exp(sae.log_threshold)
expected_feature_acts = JumpReLU.apply(expected_hidden_pre, threshold)

assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) # type: ignore


def test_jumprelu_sae_training_forward_pass():
cfg = build_sae_cfg(architecture="jumprelu")
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())

batch_size = 32
d_in = sae.cfg.d_in

x = torch.randn(batch_size, d_in)
train_step_output = sae.training_forward_pass(
sae_in=x,
current_l1_coefficient=sae.cfg.l1_coefficient,
)

assert train_step_output.sae_out.shape == (batch_size, d_in)
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae)
assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == (
train_step_output.mse_loss + train_step_output.l1_loss
)

expected_mse_loss = (
(torch.pow((train_step_output.sae_out - x.float()), 2))
.sum(dim=-1)
.mean()
.detach()
.float()
)

assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss
27 changes: 27 additions & 0 deletions tests/unit/training/test_sae_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.sae import SAE, _disable_hooks
from sae_lens.training.training_sae import TrainingSAE
from tests.unit.helpers import build_sae_cfg


Expand Down Expand Up @@ -198,6 +199,32 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None:
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_jumprelu(tmp_path: Path) -> None:
cfg = build_sae_cfg(architecture="gated")
model_path = str(tmp_path)
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
sae_state_dict = sae.state_dict()
sae.save_model(model_path)

assert os.path.exists(model_path)

sae_loaded = SAE.load_from_pretrained(model_path, device="cpu")

sae_loaded_state_dict = sae_loaded.state_dict()

# check state_dict matches the original
for key in sae.state_dict().keys():
assert torch.allclose(
sae_state_dict[key],
sae_loaded_state_dict[key],
)

sae_in = torch.randn(10, cfg.d_in, device=cfg.device)
sae_out_1 = sae(sae_in)
sae_out_2 = sae_loaded(sae_in)
assert torch.allclose(sae_out_1, sae_out_2)


def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None:
cfg = build_sae_cfg(activation_fn_kwargs={"k": 30})
model_path = str(tmp_path)
Expand Down
Loading
Loading