diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index da0c735..a7e63dc 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -54,7 +54,7 @@ jobs: - name: Set up Python run: uv python install 3.12 - name: Type check - run: uv run --group torch pyright + run: uv run --group torch basedpyright ruff: runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index a8e91e0..b0dbfa0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ cuda12 = [ flax = [ "flax>=0.10.1", ] +metal = [ + "jax-metal>=0.1.1", +] [tool.uv.workspace] members = ["xlens"] @@ -138,4 +141,4 @@ reportUntypedFunctionDecorator = false reportUnknownArgumentType = false reportUnknownVariableType = false reportMissingTypeStubs = false -reportConstantRedefinition = false \ No newline at end of file +reportConstantRedefinition = false diff --git a/src/xlens/pretrained/convert.py b/src/xlens/pretrained/convert.py index b2b7185..df4ab83 100644 --- a/src/xlens/pretrained/convert.py +++ b/src/xlens/pretrained/convert.py @@ -28,9 +28,36 @@ ) -def get_pretrained_model_config(model_name: str) -> HookedTransformerConfig: - return converter.get_pretrained_model_config(model_name) +def get_pretrained_model_config(model_name_or_path: str) -> HookedTransformerConfig: + """Get the configuration for a pretrained model from Hugging Face. + Args: + model_name (str): The name of the model on Hugging Face Hub (e.g., 'gpt2', 'facebook/opt-125m') -def get_pretrained_weights(cfg: HookedTransformerConfig, model_name: str, hf_model: Any = None) -> dict[str, jax.Array]: - return converter.get_pretrained_weights(cfg, model_name, hf_model=hf_model) + Returns: + HookedTransformerConfig: Configuration object containing the model architecture details + + Raises: + ValueError: If the model architecture is not supported + """ + return converter.get_pretrained_model_config(model_name_or_path) + + +def get_pretrained_weights( + cfg: HookedTransformerConfig, model_name_or_path: str, hf_model: Any = None +) -> dict[str, jax.Array]: + """Load pretrained weights from a Hugging Face model and convert them to JAX arrays. + + Args: + cfg (HookedTransformerConfig): Configuration object for the target model + model_name (str): The name of the model on Hugging Face Hub (e.g., 'gpt2', 'facebook/opt-125m') + hf_model (Any, optional): Pre-loaded Hugging Face model. If None, the model will be loaded + from the Hub. Defaults to None. + + Returns: + dict[str, jax.Array]: Dictionary mapping parameter names to their values as JAX arrays + + Raises: + ValueError: If the model architecture is not supported or weights cannot be converted + """ + return converter.get_pretrained_weights(cfg, model_name_or_path, hf_model=hf_model) diff --git a/src/xlens/pretrained/converters/qwen2.py b/src/xlens/pretrained/converters/qwen2.py index ee0c83c..dcb281a 100644 --- a/src/xlens/pretrained/converters/qwen2.py +++ b/src/xlens/pretrained/converters/qwen2.py @@ -92,9 +92,30 @@ def convert_hf_weights( state_dict[f"blocks.{l}.attn.W_K"] = W_K state_dict[f"blocks.{l}.attn.W_V"] = W_V - state_dict[f"blocks.{l}.attn.b_Q"] = jnp.zeros((cfg.n_heads, cfg.d_head)) - state_dict[f"blocks.{l}.attn.b_K"] = jnp.zeros((cfg.n_key_value_heads, cfg.d_head)) - state_dict[f"blocks.{l}.attn.b_V"] = jnp.zeros((cfg.n_key_value_heads, cfg.d_head)) + b_Q = hf_weights[f"model.layers.{l}.self_attn.q_proj.bias"] + b_Q = einops.rearrange( + b_Q, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_heads, + ) + + b_K = hf_weights[f"model.layers.{l}.self_attn.k_proj.bias"] + b_K = einops.rearrange( + b_K, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_key_value_heads, + ) + + b_V = hf_weights[f"model.layers.{l}.self_attn.v_proj.bias"] + b_V = einops.rearrange( + b_V, + "(n_head d_head) -> n_head d_head", + n_head=cfg.n_key_value_heads, + ) + + state_dict[f"blocks.{l}.attn.b_Q"] = b_Q + state_dict[f"blocks.{l}.attn.b_K"] = b_K + state_dict[f"blocks.{l}.attn.b_V"] = b_V W_O = hf_weights[f"model.layers.{l}.self_attn.o_proj.weight"] W_O = einops.rearrange(W_O, "m (n h)->n h m", n=cfg.n_heads) diff --git a/uv.lock b/uv.lock index 7e4b040..308b344 100644 --- a/uv.lock +++ b/uv.lock @@ -569,6 +569,21 @@ with-cuda = [ { name = "nvidia-nvjitlink-cu12" }, ] +[[package]] +name = "jax-metal" +version = "0.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jax" }, + { name = "jaxlib" }, + { name = "six" }, + { name = "wheel" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl", hash = "sha256:d918a78443cb808c9491a24a5c2a94cc4eabfd0461d5bcda29a8f332dfbe9b7e", size = 54662235 }, + { url = "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl", hash = "sha256:f1dbfecb298cdd3ba6da3ad6dc9a2adb63d71741f8b8ece28c296b32d608b6c8", size = 41179678 }, +] + [[package]] name = "jaxlib" version = "0.4.34" @@ -1694,7 +1709,7 @@ name = "triton" version = "3.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock" }, + { name = "filelock", marker = "python_full_version < '3.13'" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/78/eb/65f5ba83c2a123f6498a3097746607e5b2f16add29e36765305e4ac7fdd8/triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc", size = 209551444 }, @@ -1780,6 +1795,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/92/52/a8563300b7a0474f68081acf5427331c8c2d8233137e49a29c1bcdf41817/wandb-0.18.5-py3-none-win_amd64.whl", hash = "sha256:83b619167eb2ffdd1188cba3805ccad158f6fd7fc06bef43daf6d2729a787fa0", size = 15417794 }, ] +[[package]] +name = "wheel" +version = "0.44.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b7/a0/95e9e962c5fd9da11c1e28aa4c0d8210ab277b1ada951d2aee336b505813/wheel-0.44.0.tar.gz", hash = "sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49", size = 100733 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1b/d1/9babe2ccaecff775992753d8686970b1e2755d21c8a63be73aba7a4e7d77/wheel-0.44.0-py3-none-any.whl", hash = "sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f", size = 67059 }, +] + [[package]] name = "xlens" version = "0.1.0" @@ -1807,6 +1831,9 @@ dev = [ flax = [ { name = "flax" }, ] +metal = [ + { name = "jax-metal" }, +] test = [ { name = "pytest" }, ] @@ -1835,6 +1862,7 @@ dev = [ { name = "ruff", specifier = ">=0.7.1" }, ] flax = [{ name = "flax", specifier = ">=0.10.1" }] +metal = [{ name = "jax-metal", specifier = ">=0.1.1" }] test = [{ name = "pytest", specifier = ">=8.3.3" }] torch = [ { name = "torch", specifier = ">=2.5.1" },