diff --git a/sae_lens/training/sparse_autoencoder.py b/sae_lens/training/sparse_autoencoder.py index 4694a7de..73363527 100644 --- a/sae_lens/training/sparse_autoencoder.py +++ b/sae_lens/training/sparse_autoencoder.py @@ -10,6 +10,7 @@ import einops import torch +from jaxtyping import Float from safetensors import safe_open from safetensors.torch import save_file from torch import nn @@ -118,7 +119,16 @@ def __init__( self.setup() # Required for `HookedRootModule`s - def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None): + def encode( + self, x: Float[torch.Tensor, "... d_in"] + ) -> Float[torch.Tensor, "... d_sae"]: + feature_acts, _ = self._encode_with_hidden_pre(x) + return feature_acts + + def _encode_with_hidden_pre( + self, x: Float[torch.Tensor, "... d_in"] + ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: + """Encodes input activation tensor x into an SAE feature activation tensor.""" # move x to correct dtype x = x.to(self.dtype) sae_in = self.hook_sae_in( @@ -139,6 +149,12 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) noisy_hidden_pre = hidden_pre + noise feature_acts = self.hook_hidden_post(self.activation_fn(noisy_hidden_pre)) + return feature_acts, hidden_pre + + def decode( + self, feature_acts: Float[torch.Tensor, "... d_sae"] + ) -> Float[torch.Tensor, "... d_in"]: + """Decodes SAE feature activation tensor into a reconstructed input activation tensor.""" sae_out = self.hook_sae_out( einops.einsum( feature_acts @@ -148,6 +164,14 @@ def forward(self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None) ) + self.b_dec ) + return sae_out + + def forward( + self, x: torch.Tensor, dead_neuron_mask: torch.Tensor | None = None + ) -> ForwardOutput: + + feature_acts, hidden_pre = self._encode_with_hidden_pre(x) + sae_out = self.decode(feature_acts) # add config for whether l2 is normalized: per_item_mse_loss = _per_item_mse_loss_with_target_norm( diff --git a/tests/unit/training/test_sparse_autoencoder.py b/tests/unit/training/test_sparse_autoencoder.py index 828e3bc2..a9a59d23 100644 --- a/tests/unit/training/test_sparse_autoencoder.py +++ b/tests/unit/training/test_sparse_autoencoder.py @@ -149,6 +149,54 @@ def test_SparseAutoencoder_save_and_load_from_pretrained_lacks_scaling_factor( ) +def test_sparse_autoencoder_encode(sparse_autoencoder: SparseAutoencoder): + batch_size = 32 + d_in = sparse_autoencoder.d_in + d_sae = sparse_autoencoder.d_sae + + x = torch.randn(batch_size, d_in) + feature_acts1 = sparse_autoencoder.encode(x) + ( + _, + feature_acts2, + _, + _, + _, + _, + ) = sparse_autoencoder.forward( + x, + ) + + # Check shape + assert feature_acts1.shape == (batch_size, d_sae) + + # Check values + assert torch.allclose(feature_acts1, feature_acts2) + + +def test_sparse_autoencoder_decode(sparse_autoencoder: SparseAutoencoder): + batch_size = 32 + d_in = sparse_autoencoder.d_in + + x = torch.randn(batch_size, d_in) + feature_acts = sparse_autoencoder.encode(x) + sae_out1 = sparse_autoencoder.decode(feature_acts) + + ( + sae_out2, + _, + _, + _, + _, + _, + ) = sparse_autoencoder.forward( + x, + ) + + assert sae_out1.shape == x.shape + assert torch.allclose(sae_out1, sae_out2) + + def test_sparse_autoencoder_forward(sparse_autoencoder: SparseAutoencoder): batch_size = 32 d_in = sparse_autoencoder.d_in