diff --git a/builder.py b/build/builder.py similarity index 99% rename from builder.py rename to build/builder.py index 1b44f2f42..abbb5b777 100644 --- a/builder.py +++ b/build/builder.py @@ -21,7 +21,7 @@ from typing import Union, Optional from sentencepiece import SentencePieceProcessor -from model import Transformer +from build.model import Transformer @dataclass diff --git a/gguf_loader.py b/build/gguf_loader.py similarity index 100% rename from gguf_loader.py rename to build/gguf_loader.py diff --git a/gguf_util.py b/build/gguf_util.py similarity index 100% rename from gguf_util.py rename to build/gguf_util.py diff --git a/build/model.py b/build/model.py new file mode 100644 index 000000000..9c3bb6b95 --- /dev/null +++ b/build/model.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F + +from quantize import get_precision + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + block_size: int = 2048 + vocab_size: int = 32000 + n_layer: int = 32 + # n_head in gpt-fast + n_heads: int = 32 + dim: int = 4096 + # hidden dim is intermediate_size in gpt-fast + hidden_dim: int = None + n_local_heads: int = -1 + head_dim: int = 64 + rope_base: float = 10000 + norm_eps: float = 1e-5 + multiple_of = 256 + ffn_dim_multiplier = None + + def __post_init__(self): + if self.n_local_heads == -1: + self.n_local_heads = self.n_heads + if self.hidden_dim is None: + # If hidden_dim is not explicitly set in the ModelArgs, + # then calculate implicitly based on dim and + # also multiple of `args.multiple_of` + multiple_of = self.multiple_of + hidden_dim = 4 * self.dim + hidden_dim = int(2 * hidden_dim / 3) + if self.ffn_dim_multiplier is not None: + hidden_dim = int(self.ffn_dim_multiplier * hidden_dim) + self.hidden_dim = find_multiple(hidden_dim, multiple_of) + self.head_dim = self.dim // self.n_heads + + @classmethod + def from_params(cls, params_path): + with open(params_path, "r") as f: + params = json.loads(f.read()) + return cls(**params) + + @classmethod + def from_table(cls, name: str): + print(f"name {name}") + if name in transformer_configs: + return cls(**transformer_configs[name]) + else: + raise RuntimeError(f"unknown table index {name} for transformer_configs") + + @classmethod + def from_name(cls, name: str): + print(f"name {name}") + if name in transformer_configs: + return cls(**transformer_configs[name]) + # fuzzy search + config = [ + config + for config in transformer_configs + if config in str(name).upper() or config in str(name) + ] + + # We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match, + # take longer name (as it have more symbols matched) + if len(config) > 1: + config.sort(key=len, reverse=True) + assert len(config[0]) != len( + config[1] + ), name # make sure only one 'best' match + elif len(config) == 0: + raise ValueError(f"Unknown model directory name {name}. Must be one of {list(transformer_configs.keys())}.") + + return cls(**transformer_configs[config[0]]) + + +transformer_configs = { + "CodeLlama-7b-Python-hf": dict( + block_size=16384, vocab_size=32000, n_layer=32, dim=4096, rope_base=1000000 + ), + "7B": dict(n_layer=32, n_heads=32, dim=4096), + "13B": dict(n_layer=40, n_heads=40, dim=5120), + "30B": dict(n_layer=60, n_heads=52, dim=6656), + "34B": dict( + n_layer=48, + n_heads=64, + dim=8192, + vocab_size=32000, + n_local_heads=8, + hidden_dim=22016, + rope_base=1000000, + ), # CodeLlama-34B-Python-hf + "70B": dict( + n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672 + ), + "Mistral-7B": dict( + n_layer=32, + n_heads=32, + n_local_heads=8, + dim=4096, + hidden_dim=14336, + vocab_size=32000, + ), + "Mistral-7B-Instruct-v0.1": dict( + n_layer=32, + n_heads=32, + n_local_heads=8, + dim=4096, + hidden_dim=14336, + vocab_size=32000, + ), + "Mistral-7B-Instruct-v0.2": dict( + n_layer=32, + n_heads=32, + n_local_heads=8, + dim=4096, + hidden_dim=14336, + vocab_size=32000, + ), + "stories15M": dict(n_layer=6, n_heads=6, dim=288), + "stories110M": dict(n_layer=12, n_heads=12, dim=768), +} + + +class KVCache(nn.Module): + def __init__( + self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None): + # torch.float): # bfloat16 ): + super().__init__() + if not dtype: + dtype=get_precision() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) + self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + k_out = self.k_cache + v_out = self.v_cache + k_out[:, :, input_pos] = k_val + v_out[:, :, input_pos] = v_val + + return k_out, v_out + + +class Transformer(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.config = config + + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + # self.freqs_cis: Optional[Tensor] = None + # self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_heads + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + freqs_cis = precompute_freqs_cis( + self.config.dim // self.config.n_heads, + self.config.block_size * 2, + self.config.rope_base, + ) + self.register_buffer("freqs_cis", freqs_cis, persistent=True) + causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + self.register_buffer("causal_mask", causal_mask, persistent=True) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + # print ("*") + # print (f"* shape idx: {idx.shape}") + # print (f"* shape pos: {input_pos.shape}") + # print("@") + + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + # print(f"******** logits shape: {logits.shape}") + return logits + + @classmethod + def from_name(cls, name: str): + return cls(ModelArgs.from_name(name)) + + @classmethod + def from_table(cls, name: str): + return cls(ModelArgs.from_table(name)) + + @classmethod + def from_params(cls, params_path: str): + return cls(ModelArgs.from_params(params_path)) + + @classmethod + def from_gguf(cls, gguf_path: str): + from build.gguf_loader import load_llama_from_gguf_file + model = load_llama_from_gguf_file(gguf_path) + return model + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_heads == 0 + + # key, query, value projections for all heads, but in a batch + # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim + # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) + self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) + self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) + + self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.kv_cache = None + + self.n_heads = config.n_heads + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + def _unfuse_wqkv_state_dict( + state_dict: Dict[str, torch.Tensor], + dim: int, + ): + for key in list(state_dict): + if key.endswith("wqkv.weight"): + tensor = state_dict[key] + wq_key = key.replace("wqkv.weight", "wq.weight") + state_dict[wq_key] = tensor[: dim] + wk_key = key.replace("wqkv.weight", "wk.weight") + wv_key = key.replace("wqkv.weight", "wv.weight") + wk, wv = tensor[dim :].chunk(2, 0) + state_dict[wk_key] = wk + state_dict[wv_key] = wv + state_dict.pop(key) + else: + continue + _unfuse_wqkv_state_dict(state_dict, self.dim) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + + q = self.wq(x) + k = self.wk(x) + v = self.wv(x) + # kv_size = self.n_local_heads * self.head_dim + # q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_heads, self.head_dim) + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class FeedForward(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) + self.w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) + + def forward(self, x: Tensor) -> Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight + +# transpsoed first two arguments to align with model in ET +def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000, dtype=None) -> Tensor: + if not dtype: + dtype = get_precision() + freqs = 1.0 / ( + base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem) + ) + t = torch.arange(seq_len, device=freqs.device) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) + return cache.to(dtype=dtype) # bfloat16) + + +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/model_aoti.py b/build/model_aoti.py similarity index 100% rename from model_aoti.py rename to build/model_aoti.py diff --git a/model_et.py b/build/model_et.py similarity index 100% rename from model_et.py rename to build/model_et.py diff --git a/eval.py b/eval.py index 1b34aa9c1..90f2aa79c 100644 --- a/eval.py +++ b/eval.py @@ -21,9 +21,7 @@ from cli import cli_args from quantize import name_to_dtype, set_precision -from sentencepiece import SentencePieceProcessor - -from model import Transformer +from build.model import Transformer try: import lm_eval @@ -31,7 +29,7 @@ except: lm_eval_available = False -from builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs from generate import encode_tokens, model_forward if lm_eval_available: @@ -200,7 +198,7 @@ def eval( return eval_results -def eval_main(args) -> None: +def main(args) -> None: """Evaluates model on a task from the `lm-evaluation-harness` library. Args: @@ -234,7 +232,7 @@ def eval_main(args) -> None: print(f"Using device={device}") set_precision(buildeer_args.precision) - tokenizer = SentencePieceProcessor(model_file=str(tokenizer_path)) + tokenizer = _initialize_tokenizer(tokenizer_args) builder_args.setup_caches = False model = _initialize_model( buildeer_args, diff --git a/export.py b/export.py index db8046eff..1e5fb5d37 100644 --- a/export.py +++ b/export.py @@ -24,8 +24,8 @@ from export_aoti import export_model as export_model_aoti -from model import Transformer -from builder import _initialize_model, BuilderArgs, TokenizerArgs +from build.model import Transformer +from build.builder import _initialize_model, BuilderArgs, TokenizerArgs from generate import decode_one_token from quantize import quantize_model, name_to_dtype from torch._export import capture_pre_autograd_graph diff --git a/export_aoti.py b/export_aoti.py index c7d4d6d92..6501b9e98 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -17,7 +17,7 @@ from generate import decode_one_token from quantize import quantize_model -from model import Transformer +from build.model import Transformer default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' diff --git a/export_et.py b/export_et.py index bfdb8d087..030fd0b6c 100644 --- a/export_et.py +++ b/export_et.py @@ -10,26 +10,29 @@ import torch import torch.nn as nn from torch.export import Dim, export +from torch._export import capture_pre_autograd_graph from generate import decode_one_token -from quantize import quantize_model -from quantize import quantize_model, name_to_dtype, set_precision, get_precision +from quantize import ( + quantize_model, name_to_dtype, set_precision, get_precision, +) +from build.model import Transformer +from build.model import Transformer -from model import Transformer -# from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( -# XnnpackDynamicallyQuantizedPartitioner, -#) from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackPartitioner, ) -from executorch_portable_utils import export_to_edge # TODO: change back to executorch.examples.portable.utils when executorch installs correctly +# from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( +# XnnpackDynamicallyQuantizedPartitioner, +#) +from executorch_portable_utils import export_to_edge +# TODO: change back to executorch.examples.portable.utils +# when executorch installs correctly from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from model import Transformer -from torch._export import capture_pre_autograd_graph default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' diff --git a/generate.py b/generate.py index 99c75aeca..03f1233e5 100644 --- a/generate.py +++ b/generate.py @@ -8,13 +8,14 @@ import time from pathlib import Path from typing import Optional, Tuple -from builder import _load_model, _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs from dataclasses import dataclass import torch import torch._dynamo.config import torch._inductor.config +from build.builder import _load_model, _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.model import Transformer from quantize import quantize_model, name_to_dtype, set_precision, get_precision from cli import cli_args @@ -65,10 +66,6 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from sentencepiece import SentencePieceProcessor - -from model import Transformer - def multinomial_sample_one_no_sync( probs_sort, diff --git a/torchat.py b/torchat.py index d3514777d..4b720b8dd 100644 --- a/torchat.py +++ b/torchat.py @@ -14,7 +14,7 @@ from export import main as export_main from generate import main as generate_main -from eval import eval_main +from eval import main as eval_main from cli import cli_args, check_args default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu'