-
Notifications
You must be signed in to change notification settings - Fork 123
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add logic to train JumpReLU SAEs #352
Changes from 13 commits
cb3622c
51036e8
432529e
84b69af
1dce8a3
af070a5
499125f
dc4c6f9
bf0fca4
69a07c1
3e421cb
876b403
2553f1c
36e5348
5f77c18
bea8f50
080877c
5fdca8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,10 @@ | |
from typing import Any, Optional | ||
|
||
import einops | ||
import numpy as np | ||
import torch | ||
from jaxtyping import Float | ||
from safetensors.torch import save_file | ||
from torch import nn | ||
|
||
from sae_lens.config import LanguageModelSAERunnerConfig | ||
|
@@ -24,6 +26,68 @@ | |
SAE_CFG_PATH = "cfg.json" | ||
|
||
|
||
def rectangle(x: torch.Tensor) -> torch.Tensor: | ||
return ((x > -0.5) & (x < 0.5)).to(x) | ||
|
||
|
||
class Step(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float = 0.001 | ||
) -> torch.Tensor: | ||
return (x > threshold).to(x) | ||
|
||
@staticmethod | ||
def setup_context( | ||
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor | ||
) -> None: | ||
x, threshold, bandwidth = inputs | ||
del output | ||
ctx.save_for_backward(x, threshold) | ||
ctx.bandwidth = bandwidth | ||
|
||
@staticmethod | ||
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override] | ||
x, threshold = ctx.saved_tensors | ||
bandwidth = ctx.bandwidth | ||
x_grad = 0.0 * grad_output # We don't apply STE to x input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's fine to just return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh cool, I've made the change. |
||
threshold_grad = torch.sum( | ||
-(1.0 / bandwidth) * rectangle((x - threshold) / bandwidth) * grad_output, | ||
dim=0, | ||
) | ||
return x_grad, threshold_grad, None | ||
|
||
|
||
class JumpReLU(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
x: torch.Tensor, threshold: torch.Tensor, bandwidth: float = 0.001 | ||
) -> torch.Tensor: | ||
return (x * (x > threshold)).to(x) | ||
|
||
@staticmethod | ||
def setup_context( | ||
ctx: Any, inputs: tuple[torch.Tensor, torch.Tensor, float], output: torch.Tensor | ||
) -> None: | ||
x, threshold, bandwidth = inputs | ||
del output | ||
ctx.save_for_backward(x, threshold) | ||
ctx.bandwidth = bandwidth | ||
|
||
@staticmethod | ||
def backward(ctx: Any, grad_output: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: # type: ignore[override] | ||
x, threshold = ctx.saved_tensors | ||
bandwidth = ctx.bandwidth | ||
x_grad = (x > threshold) * grad_output # We don't apply STE to x input | ||
threshold_grad = torch.sum( | ||
-(threshold / bandwidth) | ||
* rectangle((x - threshold) / bandwidth) | ||
* grad_output, | ||
dim=0, | ||
) | ||
return x_grad, threshold_grad, None | ||
|
||
|
||
@dataclass | ||
class TrainStepOutput: | ||
sae_in: torch.Tensor | ||
|
@@ -50,6 +114,7 @@ class TrainingSAEConfig(SAEConfig): | |
decoder_heuristic_init: bool = False | ||
init_encoder_as_decoder_transpose: bool = False | ||
scale_sparsity_penalty_by_decoder_norm: bool = False | ||
threshold: float = 0.001 | ||
|
||
@classmethod | ||
def from_sae_runner_config( | ||
|
@@ -90,6 +155,7 @@ def from_sae_runner_config( | |
normalize_activations=cfg.normalize_activations, | ||
dataset_trust_remote_code=cfg.dataset_trust_remote_code, | ||
model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs, | ||
threshold=cfg.threshold, | ||
) | ||
|
||
@classmethod | ||
|
@@ -173,11 +239,19 @@ def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): | |
super().__init__(base_sae_cfg) | ||
self.cfg = cfg # type: ignore | ||
|
||
self.encode_with_hidden_pre_fn = ( | ||
self.encode_with_hidden_pre | ||
if cfg.architecture != "gated" | ||
else self.encode_with_hidden_pre_gated | ||
) | ||
if cfg.architecture == "standard": | ||
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre | ||
elif cfg.architecture == "gated": | ||
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_gated | ||
elif cfg.architecture == "jumprelu": | ||
self.encode_with_hidden_pre_fn = self.encode_with_hidden_pre_jumprelu | ||
threshold = cfg.threshold | ||
self.log_threshold = nn.Parameter( | ||
torch.ones(cfg.d_sae, dtype=self.dtype, device=self.device) | ||
* np.log(threshold) | ||
) | ||
else: | ||
raise ValueError(f"Unknown architecture: {cfg.architecture}") | ||
|
||
self.check_cfg_compatibility() | ||
|
||
|
@@ -211,6 +285,24 @@ def encode_standard( | |
feature_acts, _ = self.encode_with_hidden_pre_fn(x) | ||
return feature_acts | ||
|
||
def encode_with_hidden_pre_jumprelu( | ||
self, x: Float[torch.Tensor, "... d_in"] | ||
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: | ||
sae_in = self.process_sae_in(x) | ||
|
||
hidden_pre = sae_in @ self.W_enc + self.b_enc | ||
|
||
if self.training: | ||
hidden_pre = ( | ||
hidden_pre + torch.randn_like(hidden_pre) * self.cfg.noise_scale | ||
) | ||
|
||
threshold = torch.exp(self.log_threshold) | ||
|
||
feature_acts = JumpReLU.apply(hidden_pre, threshold) | ||
|
||
return feature_acts, hidden_pre # type: ignore | ||
|
||
def encode_with_hidden_pre( | ||
self, x: Float[torch.Tensor, "... d_in"] | ||
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: | ||
|
@@ -271,27 +363,16 @@ def training_forward_pass( | |
|
||
# do a forward pass to get SAE out, but we also need the | ||
# hidden pre. | ||
feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in) | ||
feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) | ||
sae_out = self.decode(feature_acts) | ||
|
||
# MSE LOSS | ||
per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in) | ||
mse_loss = per_item_mse_loss.sum(dim=-1).mean() | ||
|
||
# GHOST GRADS | ||
if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: | ||
|
||
# first half of second forward pass | ||
_, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) | ||
ghost_grad_loss = self.calculate_ghost_grad_loss( | ||
x=sae_in, | ||
sae_out=sae_out, | ||
per_item_mse_loss=per_item_mse_loss, | ||
hidden_pre=hidden_pre, | ||
dead_neuron_mask=dead_neuron_mask, | ||
) | ||
else: | ||
ghost_grad_loss = 0.0 | ||
l1_loss = torch.tensor(0.0, device=sae_in.device) | ||
aux_reconstruction_loss = torch.tensor(0.0, device=sae_in.device) | ||
ghost_grad_loss = torch.tensor(0.0, device=sae_in.device) | ||
|
||
if self.cfg.architecture == "gated": | ||
# Gated SAE Loss Calculation | ||
|
@@ -316,6 +397,11 @@ def training_forward_pass( | |
).mean() | ||
|
||
loss = mse_loss + l1_loss + aux_reconstruction_loss | ||
elif self.cfg.architecture == "jumprelu": | ||
threshold = torch.exp(self.log_threshold) | ||
l0 = torch.sum(Step.apply(hidden_pre, threshold), dim=-1) # type: ignore | ||
l1_loss = (current_l1_coefficient * l0).mean() | ||
loss = mse_loss + l1_loss | ||
else: | ||
# default SAE sparsity loss | ||
weighted_feature_acts = feature_acts * self.W_dec.norm(dim=1) | ||
|
@@ -326,7 +412,19 @@ def training_forward_pass( | |
l1_loss = (current_l1_coefficient * sparsity).mean() | ||
loss = mse_loss + l1_loss + ghost_grad_loss | ||
|
||
aux_reconstruction_loss = torch.tensor(0.0) | ||
if ( | ||
self.cfg.use_ghost_grads | ||
and self.training | ||
and dead_neuron_mask is not None | ||
): | ||
ghost_grad_loss = self.calculate_ghost_grad_loss( | ||
x=sae_in, | ||
sae_out=sae_out, | ||
per_item_mse_loss=per_item_mse_loss, | ||
hidden_pre=hidden_pre, | ||
dead_neuron_mask=dead_neuron_mask, | ||
) | ||
loss = loss + ghost_grad_loss | ||
|
||
return TrainStepOutput( | ||
sae_in=sae_in, | ||
|
@@ -403,6 +501,28 @@ def batch_norm_mse_loss_fn( | |
else: | ||
return standard_mse_loss_fn | ||
|
||
def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): | ||
if not os.path.exists(path): | ||
os.mkdir(path) | ||
|
||
state_dict = self.state_dict().copy() | ||
|
||
if self.cfg.architecture == "jumprelu": | ||
threshold = torch.exp(self.log_threshold).detach() | ||
del state_dict["log_threshold"] | ||
state_dict["threshold"] = threshold | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
save_file(state_dict, f"{path}/{SAE_WEIGHTS_PATH}") | ||
|
||
# Save the config | ||
config = self.cfg.to_dict() | ||
with open(f"{path}/{SAE_CFG_PATH}", "w") as f: | ||
json.dump(config, f) | ||
|
||
if sparsity is not None: | ||
sparsity_in_dict = {"sparsity": sparsity} | ||
save_file(sparsity_in_dict, f"{path}/{SPARSITY_PATH}") | ||
|
||
@classmethod | ||
def load_from_pretrained( | ||
cls, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import pytest | ||
import torch | ||
|
||
from sae_lens.training.training_sae import JumpReLU, TrainingSAE | ||
from tests.unit.helpers import build_sae_cfg | ||
|
||
|
||
def test_jumprelu_sae_encoding(): | ||
cfg = build_sae_cfg(architecture="jumprelu") | ||
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) | ||
|
||
batch_size = 32 | ||
d_in = sae.cfg.d_in | ||
d_sae = sae.cfg.d_sae | ||
|
||
x = torch.randn(batch_size, d_in) | ||
feature_acts, hidden_pre = sae.encode_with_hidden_pre_jumprelu(x) | ||
|
||
assert feature_acts.shape == (batch_size, d_sae) | ||
assert hidden_pre.shape == (batch_size, d_sae) | ||
|
||
# Check the JumpReLU thresholding | ||
sae_in = sae.process_sae_in(x) | ||
expected_hidden_pre = sae_in @ sae.W_enc + sae.b_enc | ||
threshold = torch.exp(sae.log_threshold) | ||
expected_feature_acts = JumpReLU.apply(expected_hidden_pre, threshold) | ||
|
||
assert torch.allclose(feature_acts, expected_feature_acts, atol=1e-6) # type: ignore | ||
|
||
|
||
def test_jumprelu_sae_training_forward_pass(): | ||
cfg = build_sae_cfg(architecture="jumprelu") | ||
sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) | ||
|
||
batch_size = 32 | ||
d_in = sae.cfg.d_in | ||
|
||
x = torch.randn(batch_size, d_in) | ||
train_step_output = sae.training_forward_pass( | ||
sae_in=x, | ||
current_l1_coefficient=sae.cfg.l1_coefficient, | ||
) | ||
|
||
assert train_step_output.sae_out.shape == (batch_size, d_in) | ||
assert train_step_output.feature_acts.shape == (batch_size, sae.cfg.d_sae) | ||
assert pytest.approx(train_step_output.loss.detach(), rel=1e-3) == ( | ||
train_step_output.mse_loss + train_step_output.l1_loss | ||
) | ||
|
||
expected_mse_loss = ( | ||
(torch.pow((train_step_output.sae_out - x.float()), 2)) | ||
.sum(dim=-1) | ||
.mean() | ||
.detach() | ||
.float() | ||
) | ||
|
||
assert pytest.approx(train_step_output.mse_loss) == expected_mse_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this would be clearer as a
jumprelu_init_threshold
or something to make it clear this is only used for initializationThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this also should have
jumprelu_bandwidth
as a param as well, currently it seems hardcoded`There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, thanks for catching that bandwidth wasn't configurable. I've changed the first name's field and added the second field.