Skip to content

Commit

Permalink
fix: Mistral weight conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 4, 2024
1 parent 8f6a48b commit 60332d1
Show file tree
Hide file tree
Showing 4 changed files with 703 additions and 15 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ test = [
]
torch = [
"torch>=2.5.1",
"transformer-lens>=2.8.1",
]
cuda12 = [
"jax[cuda12]>=0.4.35",
Expand Down
4 changes: 2 additions & 2 deletions src/xlens/pretrained/convert_weight/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def convert_mistral_weights(params: dict[str, jax.Array], cfg: HookedTransformer
W_K = einops.rearrange(W_K, "(n h) m->n m h", n=cfg.n_key_value_heads)
W_V = einops.rearrange(W_V, "(n h) m->n m h", n=cfg.n_key_value_heads)
state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
state_dict[f"blocks.{l}.attn._W_K"] = W_K
state_dict[f"blocks.{l}.attn._W_V"] = W_V
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = jnp.zeros((cfg.n_heads, cfg.d_head))
state_dict[f"blocks.{l}.attn.b_K"] = jnp.zeros((cfg.n_key_value_heads, cfg.d_head))
Expand Down
45 changes: 43 additions & 2 deletions tests/acceptance/computation/test_mistral_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,23 @@ def test_mistral_computation():
hf_model.eval()

hf_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
hf_output = hf_model(hf_input)
hf_output = hf_model(hf_input, output_hidden_states=True)
hf_logits = hf_output.logits
hf_hidden_states = hf_output.hidden_states

del hf_model
torch.cuda.empty_cache()

model = HookedTransformer.from_pretrained("mistralai/Mistral-7B-v0.1")

input = jnp.array(hf_input)
logits = model(input)
logits, cache = model.run_with_cache(input, hook_names=[f"blocks.{i}.hook_resid_pre" for i in range(12)])

for i in range(12):
print(
f"Block {i} Residual Pre Difference: ",
jnp.linalg.norm(jnp.array(hf_hidden_states[i]) - cache[f"blocks.{i}.hook_resid_pre"]),
)

print("Logits Difference: ", jnp.linalg.norm(logits - jnp.array(hf_logits)))

Expand All @@ -40,5 +47,39 @@ def test_mistral_computation():
assert jnp.allclose(probs, jnp.array(hf_probs), atol=1e-3)


# @torch.no_grad()
# def test_mistral_computation_tl():
# import transformer_lens as tl

# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# tl_model = tl.HookedTransformer.from_pretrained("mistralai/Mistral-7B-v0.1", tokenizer=tokenizer, fold_ln=False)
# tl_model.eval()

# tl_input = tokenizer("Hello, my dog is cute!", return_tensors="pt")["input_ids"]
# tl_logits, tl_cache = tl_model.run_with_cache(
# tl_input, names_filter=[f"blocks.{i}.hook_resid_pre" for i in range(12)]
# )

# model = HookedTransformer.from_pretrained("mistralai/Mistral-7B-v0.1")

# input = jnp.array(tl_input)
# logits, cache = model.run_with_cache(input, hook_names=[f"blocks.{i}.hook_resid_pre" for i in range(12)])

# for i in range(12):
# print(
# f"Block {i} Residual Pre Difference: ",
# jnp.linalg.norm(jnp.array(tl_cache[f"blocks.{i}.hook_resid_pre"]) - cache[f"blocks.{i}.hook_resid_pre"]),
# )

# print("Logits Difference: ", jnp.linalg.norm(logits - jnp.array(tl_logits)))

# tl_probs = torch.nn.functional.softmax(tl_logits, dim=-1)
# probs = jax.nn.softmax(logits, axis=-1)

# print("Probs Difference: ", jnp.linalg.norm(probs - jnp.array(tl_probs)))

# assert jnp.allclose(probs, jnp.array(tl_probs), atol=1e-3)


if __name__ == "__main__":
test_mistral_computation()
Loading

0 comments on commit 60332d1

Please sign in to comment.