Skip to content

Commit

Permalink
fix: Qwen2 converter
Browse files Browse the repository at this point in the history
  • Loading branch information
dest1n1s committed Nov 5, 2024
1 parent 3b4cd8a commit abeb3e5
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
- name: Set up Python
run: uv python install 3.12
- name: Type check
run: uv run --group torch pyright
run: uv run --group torch basedpyright

ruff:
runs-on: ubuntu-latest
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ cuda12 = [
flax = [
"flax>=0.10.1",
]
metal = [
"jax-metal>=0.1.1",
]

[tool.uv.workspace]
members = ["xlens"]
Expand Down Expand Up @@ -138,4 +141,4 @@ reportUntypedFunctionDecorator = false
reportUnknownArgumentType = false
reportUnknownVariableType = false
reportMissingTypeStubs = false
reportConstantRedefinition = false
reportConstantRedefinition = false
35 changes: 31 additions & 4 deletions src/xlens/pretrained/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,36 @@
)


def get_pretrained_model_config(model_name: str) -> HookedTransformerConfig:
return converter.get_pretrained_model_config(model_name)
def get_pretrained_model_config(model_name_or_path: str) -> HookedTransformerConfig:
"""Get the configuration for a pretrained model from Hugging Face.
Args:
model_name (str): The name of the model on Hugging Face Hub (e.g., 'gpt2', 'facebook/opt-125m')
def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model: Any = None) -> dict[str, jax.Array]:
return converter.get_pretrained_weights(cfg, model_name, hf_model=hf_model)
Returns:
HookedTransformerConfig: Configuration object containing the model architecture details
Raises:
ValueError: If the model architecture is not supported
"""
return converter.get_pretrained_model_config(model_name_or_path)


def get_pretrained_weights(
cfg: HookedTransformerConfig, model_name_or_path: str, hf_model: Any = None
) -> dict[str, jax.Array]:
"""Load pretrained weights from a Hugging Face model and convert them to JAX arrays.
Args:
cfg (HookedTransformerConfig): Configuration object for the target model
model_name (str): The name of the model on Hugging Face Hub (e.g., 'gpt2', 'facebook/opt-125m')
hf_model (Any, optional): Pre-loaded Hugging Face model. If None, the model will be loaded
from the Hub. Defaults to None.
Returns:
dict[str, jax.Array]: Dictionary mapping parameter names to their values as JAX arrays
Raises:
ValueError: If the model architecture is not supported or weights cannot be converted
"""
return converter.get_pretrained_weights(cfg, model_name_or_path, hf_model=hf_model)
27 changes: 24 additions & 3 deletions src/xlens/pretrained/converters/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,30 @@ def convert_hf_weights(
state_dict[f"blocks.{l}.attn.W_K"] = W_K
state_dict[f"blocks.{l}.attn.W_V"] = W_V

state_dict[f"blocks.{l}.attn.b_Q"] = jnp.zeros((cfg.n_heads, cfg.d_head))
state_dict[f"blocks.{l}.attn.b_K"] = jnp.zeros((cfg.n_key_value_heads, cfg.d_head))
state_dict[f"blocks.{l}.attn.b_V"] = jnp.zeros((cfg.n_key_value_heads, cfg.d_head))
b_Q = hf_weights[f"model.layers.{l}.self_attn.q_proj.bias"]
b_Q = einops.rearrange(
b_Q,
"(n_head d_head) -> n_head d_head",
n_head=cfg.n_heads,
)

b_K = hf_weights[f"model.layers.{l}.self_attn.k_proj.bias"]
b_K = einops.rearrange(
b_K,
"(n_head d_head) -> n_head d_head",
n_head=cfg.n_key_value_heads,
)

b_V = hf_weights[f"model.layers.{l}.self_attn.v_proj.bias"]
b_V = einops.rearrange(
b_V,
"(n_head d_head) -> n_head d_head",
n_head=cfg.n_key_value_heads,
)

state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
state_dict[f"blocks.{l}.attn.b_K"] = b_K
state_dict[f"blocks.{l}.attn.b_V"] = b_V

W_O = hf_weights[f"model.layers.{l}.self_attn.o_proj.weight"]
W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads)
Expand Down
30 changes: 29 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit abeb3e5

Please sign in to comment.