Skip to content

Commit

Permalink
feat: breaks up SAE.forward() into encode() and decode() (jbloomAus#107)
Browse files Browse the repository at this point in the history
* breaks up SAE.forward() into encode() and decode()

* cleans up return typing of encode by splitting into a hidden and public function
  • Loading branch information
evanhanders authored Apr 29, 2024
1 parent 2cd6895 commit e620bed
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
26 changes: 25 additions & 1 deletion sae_lens/training/sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import einops
import torch
from jaxtyping import Float
from safetensors import safe_open
from safetensors.torch import save_file
from torch import nn
Expand Down Expand Up @@ -118,7 +119,16 @@ def __init__(

self.setup() # Required for `HookedRootModule`s

def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None):
def encode(
self, x: Float[torch.Tensor, "... d_in"]
) -> Float[torch.Tensor, "... d_sae"]:
feature_acts, _ = self._encode_with_hidden_pre(x)
return feature_acts

def _encode_with_hidden_pre(
self, x: Float[torch.Tensor, "... d_in"]
) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]:
"""Encodes input activation tensor x into an SAE feature activation tensor."""
# move x to correct dtype
x = x.to(self.dtype)
sae_in = self.hook_sae_in(
Expand All @@ -139,6 +149,12 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
noisy_hidden_pre = hidden_pre + noise
feature_acts = self.hook_hidden_post(self.activation_fn(noisy_hidden_pre))

return feature_acts, hidden_pre

def decode(
self, feature_acts: Float[torch.Tensor, "... d_sae"]
) -> Float[torch.Tensor, "... d_in"]:
"""Decodes SAE feature activation tensor into a reconstructed input activation tensor."""
sae_out = self.hook_sae_out(
einops.einsum(
feature_acts
Expand All @@ -148,6 +164,14 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None)
)
+ self.b_dec
)
return sae_out

def forward(
self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None
) -> ForwardOutput:

feature_acts, hidden_pre = self._encode_with_hidden_pre(x)
sae_out = self.decode(feature_acts)

# add config for whether l2 is normalized:
per_item_mse_loss = _per_item_mse_loss_with_target_norm(
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/training/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,54 @@ def test_SparseAutoencoder_save_and_load_from_pretrained_lacks_scaling_factor(
)


def test_sparse_autoencoder_encode(sparse_autoencoder: SparseAutoencoder):
batch_size = 32
d_in = sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae

x = torch.randn(batch_size, d_in)
feature_acts1 = sparse_autoencoder.encode(x)
(
_,
feature_acts2,
_,
_,
_,
_,
) = sparse_autoencoder.forward(
x,
)

# Check shape
assert feature_acts1.shape == (batch_size, d_sae)

# Check values
assert torch.allclose(feature_acts1, feature_acts2)


def test_sparse_autoencoder_decode(sparse_autoencoder: SparseAutoencoder):
batch_size = 32
d_in = sparse_autoencoder.d_in

x = torch.randn(batch_size, d_in)
feature_acts = sparse_autoencoder.encode(x)
sae_out1 = sparse_autoencoder.decode(feature_acts)

(
sae_out2,
_,
_,
_,
_,
_,
) = sparse_autoencoder.forward(
x,
)

assert sae_out1.shape == x.shape
assert torch.allclose(sae_out1, sae_out2)


def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder):
batch_size = 32
d_in = sparse_autoencoder.d_in
Expand Down

0 comments on commit e620bed

Please sign in to comment.