Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] Error Term flag prevents Feature Ablation in Gemma 2 2b #326

Closed
gboxo opened this issue Oct 9, 2024 · 2 comments · Fixed by #328
Closed

[Bug Report] Error Term flag prevents Feature Ablation in Gemma 2 2b #326

gboxo opened this issue Oct 9, 2024 · 2 comments · Fixed by #328

Comments

@gboxo
Copy link

gboxo commented Oct 9, 2024

The use_error_term flag, prevents Feature Ablation
Feature Ablation, while use_error_term = True, is possible with GPT2 but not with Gemma-2.

Code example

from sae_lens import HookedSAETransformer, SAE, SAEConfig
from gemma_utils import get_gemma_2_config, gemma_2_sae_loader
import torch
torch.set_grad_enabled(False)
from sae_lens import SAE
device = "cpu"

model = HookedSAETransformer.from_pretrained("google/gemma-2-2b-it", device=device)

full_strings = {
        10:"layer_10/width_16k/average_l0_77",
                }
attn_repo_id = "google/gemma-scope-2b-pt-att"
layers = [10]

with torch.no_grad():
    repo_id = "google/gemma-scope-2b-pt-res"
    folder_name = "layer_10/width_16k/average_l0_77"
    config = get_gemma_2_config(repo_id, folder_name)
    cfg, state_dict, log_spar = gemma_2_sae_loader(repo_id, folder_name)
    sae_cfg = SAEConfig.from_dict(cfg)
    sae = SAE(sae_cfg)
    sae.load_state_dict(state_dict)
    sae.d_head = 256
    sae.use_error_term = True


string = "The quick brown fox jumps over the lazy dog."
tokens = model.to_tokens(string)

sae_filter = lambda x: "hook_sae_output" in x
# ============= Original Logits =========

original_logits,original_cache = model.run_with_cache(tokens, names_filter = sae_filter)
model.reset_hooks(including_permanent=True)


# ============= Add SAEs with error term ===========
sae.use_error_term = True
model.add_sae(sae)
logits_with_sae_we, cache_with_sae_we = model.run_with_cache(tokens,names_filter = sae_filter)


# ============ Add SAEs w/o error term =========

model.reset_saes() # Reset the model SAEs
from copy import deepcopy
sae_ne =  deepcopy(sae)
sae_ne.use_error_term = False
model.add_sae(sae_ne)
logits_with_sae_ne, cache_with_sae_ne = model.run_with_cache(tokens,names_filter = sae_filter)

# =========== Add SAEs with error term and ablate some feature =========

model.reset_hooks()# Correct order
model.reset_saes()
sae.use_error_term = True
model.add_sae(sae)
def ablate_hooked_sae(acts,hook):
    acts[:,:,:] = 20 # This is absurd
    return acts
with model.hooks(fwd_hooks = [("blocks.10.hook_resid_post.hook_sae_acts_post",ablate_hooked_sae)]):
    logits_with_ablated_sae,cache_with_ablated_sae = model.run_with_cache(tokens, names_filter = sae_filter)




# ===== Comparison of the logits ==========
print("Original Logits & Logits with SAEs with error term") # Should be true
print(torch.allclose(logits_with_sae_we, original_logits, atol=1)) 

print("Original Logits & Logits with SAEs with error term") # Should be false
print(torch.allclose(logits_with_sae_ne, original_logits, atol=1)) 


print("Original Logits & Logits with SAEs with error term") # Should be false
print(torch.allclose(logits_with_ablated_sae, original_logits, atol=1)) 

# ===== Comparison of the SAE output ==========
cache_with_sae_we = cache_with_sae_we["blocks.10.hook_resid_post.hook_sae_output"]
cache_with_sae_ne = cache_with_sae_ne["blocks.10.hook_resid_post.hook_sae_output"]
cache_with_ablated_sae = cache_with_ablated_sae["blocks.10.hook_resid_post.hook_sae_output"]

print("Cache with SAEs with error term & Cache with SAEs without error term") # Should be False
print(torch.allclose(cache_with_sae_we, cache_with_sae_ne, atol=1)) 


print("Cache with SAEs with error term & Cache with SAEs with error term and ablation") # Should be false
print(torch.allclose(cache_with_ablated_sae, cache_with_sae_we, atol=1)) 

print("Cache with SAEs with no error term & Cache with SAEs with error term and ablation") # Should be false
print(torch.allclose(cache_with_ablated_sae, cache_with_sae_ne, atol=1)) 

Output:

Logits:
Original Logits & Logits with SAEs with error term
True
Original Logits & Logits with SAEs with error term
False
Original Logits & Logits with SAEs with error term
True
------------------
SAE Output:
Cache with SAEs with error term & Cache with SAEs without error term (should be False)
False
Cache with SAEs with error term & Cache with SAEs with error term and ablation (should be False)
True
Cache with SAEs with no error term & Cache with SAEs with error term and ablation (should be False)
False


System Info
Describe the characteristic of your environment:

  • TransformerLens installed trough pip

  • Linux

  • Python 3.10.13

  • [x ] I have checked that there is no similar issue in the repo (required)

@chanind
Copy link
Collaborator

chanind commented Oct 17, 2024

Thank you for finding this! It looks like this is caused by a bad copy/paste duplicating the jumprelu forward in https://github.com/jbloomAus/SAELens/blob/main/sae_lens/sae.py#L479, but forgetting to remove all hooks, and is fixed in #328. We should have tests on this exact scenario to avoid this arising in the future. I'll add those tests to #328.

@chanind
Copy link
Collaborator

chanind commented Oct 17, 2024

Added test: 5314f2a. This test fails on main

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants