From 4985c080f3bf4debfcb0bfd175a9a0c4e089645d Mon Sep 17 00:00:00 2001 From: Mergen Nachin Date: Tue, 16 Apr 2024 12:00:49 -0400 Subject: [PATCH] Use lintrunner across the project (#216) --- .lintrunner.toml | 51 +++++ build/builder.py | 119 +++++----- build/gguf_loader.py | 54 +++-- build/gguf_util.py | 123 ++++++---- build/model.py | 41 ++-- build/model_aoti.py | 10 +- build/model_et.py | 3 +- cli.py | 120 +++------- eval.py | 66 +++--- export.py | 20 +- export_aoti.py | 12 +- export_et.py | 37 ++- generate.py | 79 ++++--- quantize.py | 372 ++++++++++++++++++++----------- quantized_ops.py | 83 ++++--- requirements-lintrunner.txt | 18 ++ scripts/convert_hf_checkpoint.py | 30 ++- scripts/download.py | 25 ++- torchat.py | 14 +- utils/tokenizer.py | 24 +- 20 files changed, 786 insertions(+), 515 deletions(-) create mode 100644 .lintrunner.toml create mode 100644 requirements-lintrunner.txt diff --git a/.lintrunner.toml b/.lintrunner.toml new file mode 100644 index 000000000..97de2b6a2 --- /dev/null +++ b/.lintrunner.toml @@ -0,0 +1,51 @@ +merge_base_with = "origin/main" + +[[linter]] +code = 'FLAKE8' +include_patterns = ['**/*.py'] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'flake8_linter', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'pip_init', + '--dry-run={{DRYRUN}}', + '--requirement=requirements-lintrunner.txt', +] + +# Black + usort +[[linter]] +code = 'UFMT' +include_patterns = [ + '**/*.py', + '**/*.pyi', +] +command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'ufmt_linter', + '--', + '@{{PATHSFILE}}' +] +init_command = [ + 'python', + '-m', + 'lintrunner_adapters', + 'run', + 'pip_init', + '--dry-run={{DRYRUN}}', + '--no-black-binary', + '--requirement=requirements-lintrunner.txt', +] +is_formatter = true \ No newline at end of file diff --git a/build/builder.py b/build/builder.py index f62c42131..c8b1fe78c 100644 --- a/build/builder.py +++ b/build/builder.py @@ -6,21 +6,19 @@ import itertools import sys import time +from dataclasses import dataclass from pathlib import Path -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch._dynamo.config import torch._inductor.config - -from quantize import ( - quantize_model, name_to_dtype, set_precision, get_precision -) from cli import cli_args -from dataclasses import dataclass -from typing import Union, Optional + +from quantize import get_precision, name_to_dtype, quantize_model, set_precision from sentencepiece import SentencePieceProcessor + from build.model import Transformer @@ -40,43 +38,50 @@ class BuilderArgs: def __post_init__(self): if not ( - (self.checkpoint_path and self.checkpoint_path.is_file()) or - (self.checkpoint_dir and self.checkpoint_path.is_dir()) or - (self.gguf_path and self.gguf_path.is_file()) or - (self.dso_path and Path(self.dso_path).is_file()) or - (self.pte_path and Path(self.pte_path).is_file()) + (self.checkpoint_path and self.checkpoint_path.is_file()) + or (self.checkpoint_dir and self.checkpoint_path.is_dir()) + or (self.gguf_path and self.gguf_path.is_file()) + or (self.dso_path and Path(self.dso_path).is_file()) + or (self.pte_path and Path(self.pte_path).is_file()) ): - raise RuntimeError("need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path") + raise RuntimeError( + "need to specified a valid checkpoint path, checkpoint dir, gguf path, DSO path, or PTE path" + ) - if (self.dso_path and self.pte_path): + if self.dso_path and self.pte_path: raise RuntimeError("specify either DSO path or PTE path, but not both") - if (self.checkpoint_path and (self.dso_path or self.pte_path)): - print("Warning: checkpoint path ignored because an exported DSO or PTE path specified") - if (self.checkpoint_dir and (self.dso_path or self.pte_path)): - print("Warning: checkpoint dir ignored because an exported DSO or PTE path specified") - if (self.gguf_path and (self.dso_path or self.pte_path)): - print("Warning: GGUF path ignored because an exported DSO or PTE path specified") - + if self.checkpoint_path and (self.dso_path or self.pte_path): + print( + "Warning: checkpoint path ignored because an exported DSO or PTE path specified" + ) + if self.checkpoint_dir and (self.dso_path or self.pte_path): + print( + "Warning: checkpoint dir ignored because an exported DSO or PTE path specified" + ) + if self.gguf_path and (self.dso_path or self.pte_path): + print( + "Warning: GGUF path ignored because an exported DSO or PTE path specified" + ) @classmethod - def from_args(cls, args): # -> BuilderArgs: + def from_args(cls, args): # -> BuilderArgs: return cls( - checkpoint_path = args.checkpoint_path, - checkpoint_dir = args.checkpoint_dir, - params_path = args.params_path, - params_table = args.params_table, - gguf_path = args.gguf_path, - dso_path = args.dso_path, - pte_path = args.pte_path, - device = args.device, - precision = name_to_dtype(args.dtype), - setup_caches = (args.output_dso_path or args.output_pte_path), - use_tp = False, + checkpoint_path=args.checkpoint_path, + checkpoint_dir=args.checkpoint_dir, + params_path=args.params_path, + params_table=args.params_table, + gguf_path=args.gguf_path, + dso_path=args.dso_path, + pte_path=args.pte_path, + device=args.device, + precision=name_to_dtype(args.dtype), + setup_caches=(args.output_dso_path or args.output_pte_path), + use_tp=False, ) @classmethod - def from_speculative_args(cls, args): # -> BuilderArgs: + def from_speculative_args(cls, args): # -> BuilderArgs: speculative_builder_args = BuilderArgs.from_args(args) # let's limit multi-checkpoint to checker speculative_builder_args.checkpoint_dir = None @@ -94,7 +99,7 @@ class TokenizerArgs: is_TikToken: bool = False @classmethod - def from_args(cls, args): # -> TokenizerArgs: + def from_args(cls, args): # -> TokenizerArgs: is_SentencePiece = True is_TikToken = False @@ -108,7 +113,7 @@ def from_args(cls, args): # -> TokenizerArgs: raise RuntimeError(f"cannot find tokenizer model") if not tokenizer_path.is_file(): - raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") + raise RuntimeError(f"did not find tokenizer at {tokenizer_path}") if args.tiktoken: is_SentencePiece = False @@ -117,9 +122,10 @@ def from_args(cls, args): # -> TokenizerArgs: return cls( tokenizer_path=tokenizer_path, is_SentencePiece=is_SentencePiece, - is_TikToken=is_TikToken + is_TikToken=is_TikToken, ) + def _initialize_tokenizer(tokenizer_args: TokenizerArgs): if tokenizer_args.is_SentencePiece: return SentencePieceProcessor(model_file=str(tokenizer_args.tokenizer_path)) @@ -147,6 +153,7 @@ def device_sync(device): wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) + def _load_model(builder_args): if builder_args.gguf_path: model = Transformer.from_gguf(builder_args.gguf_path) @@ -160,9 +167,8 @@ def _load_model(builder_args): else: return _load_model_not_gguf(builder_args) -def _load_model_not_gguf( - builder_args -): + +def _load_model_not_gguf(builder_args): assert not builder_args.gguf_path with torch.device("meta"): @@ -200,7 +206,12 @@ def _load_model_not_gguf( else: checkpoint[key] = cps[0][key] else: - checkpoint = torch.load(builder_args.checkpoint_path, map_location=builder_args.device, mmap=True, weights_only=True) + checkpoint = torch.load( + builder_args.checkpoint_path, + map_location=builder_args.device, + mmap=True, + weights_only=True, + ) if "model" in checkpoint and "stories" in str(builder_args.checkpoint_path): checkpoint = checkpoint["model"] @@ -218,21 +229,21 @@ def _load_model_not_gguf( def _initialize_model( - builder_args, - quantize, + builder_args, + quantize, ): print("Loading model ...") t0 = time.time() - model_ = _load_model( - builder_args - ) + model_ = _load_model(builder_args) device_sync(device=builder_args.device) print(f"Time to load model: {time.time() - t0:.02f} seconds") if builder_args.dso_path: # make sure user did not try to set dtype # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported DSO model. Specify quantization during export." + assert ( + quantize is None or quantize == "{ }" + ), f"quantize not valid for exported DSO model. Specify quantization during export." try: model = model_ # Replace model forward with the AOT-compiled forward @@ -241,15 +252,20 @@ def _initialize_model( # attributes will NOT be seen on by AOTI-compiled forward # function, e.g. calling model.setup_cache will NOT touch # AOTI compiled and maintained model buffers such as kv_cache. - model.forward = torch._export.aot_load(str(builder_args.dso_path.absolute()), builder_args.device) + model.forward = torch._export.aot_load( + str(builder_args.dso_path.absolute()), builder_args.device + ) except: raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}") elif builder_args.pte_path: # make sure user did not try to set dtype # assert model_dtype == "float32", f"dtype setting not valid for a DSO model. Specify dtype during export." - assert quantize is None or quantize == "{ }", f"quantize not valid for exported PTE model. Specify quantization during export." + assert ( + quantize is None or quantize == "{ }" + ), f"quantize not valid for exported PTE model. Specify quantization during export." try: from build.model_et import PTEModel + model = PTEModel(model_.config, builder_args.pte_path) except Exception as e: raise RuntimeError(f"Failed to load ET compiled {builder_args.pte_path}") @@ -265,10 +281,7 @@ def _initialize_model( if builder_args.setup_caches: max_seq_length = 350 with torch.device(builder_args.device): - model.setup_caches( - max_batch_size=1, - max_seq_length=max_seq_length - ) + model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) model.to(dtype=builder_args.precision) diff --git a/build/gguf_loader.py b/build/gguf_loader.py index a52c15274..3655e804a 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -8,26 +8,33 @@ import argparse import copy +import logging import sys from dataclasses import dataclass from pathlib import Path -from typing import Any, Mapping, Dict -import logging -from quantize import WeightOnlyInt4Linear, pack_scales_and_zeros, group_dequantize_tensor_from_qparams -from build.gguf_util import F16, F32, Q4_0, Q6_K +from typing import Any, Dict, Mapping + import gguf import torch import torch.nn as nn from gguf import GGUFValueType, ReaderTensor +from quantize import ( + group_dequantize_tensor_from_qparams, + pack_scales_and_zeros, + WeightOnlyInt4Linear, +) + +from build.gguf_util import F16, F32, Q4_0, Q6_K wd = Path(__file__).parent.resolve() sys.path.append(str(wd)) -from model import ModelArgs, Transformer from typing import Set +from model import ModelArgs, Transformer + logger: logging.Logger = logging.getLogger(__name__) @@ -101,7 +108,9 @@ def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs: arch = metadata["general.architecture"] - assert arch == "llama", f"Only general.architecture=llama is supported, but got general.architecture={arch}" + assert ( + arch == "llama" + ), f"Only general.architecture=llama is supported, but got general.architecture={arch}" return GGUFModelArgs( arch=arch, embedding_length=metadata[f"{arch}.embedding_length"], @@ -119,6 +128,7 @@ def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs: ), ) + def _fqn_lookup(fqn: str, module: torch.nn.Module) -> Any: if fqn == "": return module @@ -147,7 +157,9 @@ def _fqn_last(fqn: str) -> str: return atoms[-1] -def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles = 8) -> None: +def load_weights( + pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], inner_k_tiles=8 +) -> None: fqns = [] for fqn in pt_model.state_dict(): assert _fqn_last(fqn) == "weight" @@ -159,7 +171,10 @@ def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], t = weight_map[f"{fqn}.weight"] - if isinstance(mod, torch.nn.Linear) and t.tensor_type == gguf.GGMLQuantizationType.Q4_0: + if ( + isinstance(mod, torch.nn.Linear) + and t.tensor_type == gguf.GGMLQuantizationType.Q4_0 + ): assert not mod.bias out_features = mod.out_features in_features = mod.in_features @@ -167,30 +182,36 @@ def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], q, s, z = Q4_0.unpack(t) scales_and_zeros = pack_scales_and_zeros(s, z) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(q, inner_k_tiles) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + q, inner_k_tiles + ) - state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu') - state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu') + state_dict[f"{fqn}.weight"] = weight_int4pack.to("cpu") + state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cpu") parent = _fqn_lookup(_fqn_up(fqn), pt_model) setattr( parent, _fqn_last(fqn), WeightOnlyInt4Linear( - "cpu", # TODO: should --device work for gguf load? (yes?!) + "cpu", # TODO: should --device work for gguf load? (yes?!) in_features, out_features, bias=False, groupsize=Q4_0.groupsize, inner_k_tiles=inner_k_tiles, - ) + ), ) else: # All other weights are dequantized to float if t.tensor_type == gguf.GGMLQuantizationType.Q4_0: - as_float = group_dequantize_tensor_from_qparams(*Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize) + as_float = group_dequantize_tensor_from_qparams( + *Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize + ) elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K: - as_float = group_dequantize_tensor_from_qparams(*Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize) + as_float = group_dequantize_tensor_from_qparams( + *Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize + ) elif t.tensor_type == gguf.GGMLQuantizationType.F16: as_float = F16.unpack(t) elif t.tensor_type == gguf.GGMLQuantizationType.F32: @@ -198,7 +219,7 @@ def load_weights(pt_model: torch.nn.Module, weight_map: Dict[str, ReaderTensor], else: raise ValueError(f"Unsupported tensor type {t.tensor_type}") - state_dict[f"{fqn}.weight"] = as_float.to('cpu') + state_dict[f"{fqn}.weight"] = as_float.to("cpu") pt_model.load_state_dict(state_dict) return pt_model @@ -245,7 +266,6 @@ def load_llama_from_gguf_file(gguf_file: str) -> torch.nn.Module: logger.info("Creating initial PT model.") pt_model = _create_pt_model(model_args) - logger.info("Reading GGUF weights.") gguf_weights = GGUFWeights(tensors=reader.tensors) diff --git a/build/gguf_util.py b/build/gguf_util.py index 3160af0dc..9f8a07661 100644 --- a/build/gguf_util.py +++ b/build/gguf_util.py @@ -4,10 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torch import gguf +import torch from quantize import group_dequantize_tensor_from_qparams + def to_float(t: gguf.gguf_reader.ReaderTensor): """ Unpack and dequantize GGUF tensor to torch tensor of type torch.float32. @@ -15,9 +16,13 @@ def to_float(t: gguf.gguf_reader.ReaderTensor): # All other weights are dequantized to float if t.tensor_type == gguf.GGMLQuantizationType.Q4_0: - return group_dequantize_tensor_from_qparams(*Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize).to(torch.float32) + return group_dequantize_tensor_from_qparams( + *Q4_0.unpack(t), Q4_0.n_bit, Q4_0.groupsize + ).to(torch.float32) elif t.tensor_type == gguf.GGMLQuantizationType.Q6_K: - return group_dequantize_tensor_from_qparams(*Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize).to(torch.float32) + return group_dequantize_tensor_from_qparams( + *Q6_K.unpack(t), Q6_K.n_bit, Q6_K.groupsize + ).to(torch.float32) elif t.tensor_type == gguf.GGMLQuantizationType.F16: return F16.unpack(t).to(torch.float32) elif t.tensor_type == gguf.GGMLQuantizationType.F32: @@ -41,15 +46,21 @@ def test_by_to_float(source_file: str, target_file: str) -> None: gguf_targets = {t.name: t for t in gguf.GGUFReader(target_file, "r").tensors} for t in gguf_targets.values(): - assert t.tensor_type == gguf.GGMLQuantizationType.F32, f"target_file must only contain F32 tensors, but found tensor {t.name} with type {repr(t.tensor_type)}." - assert gguf_sources.keys() == gguf_targets.keys(), "source_file and target_file should have the same tensors (by name)" + assert ( + t.tensor_type == gguf.GGMLQuantizationType.F32 + ), f"target_file must only contain F32 tensors, but found tensor {t.name} with type {repr(t.tensor_type)}." + assert ( + gguf_sources.keys() == gguf_targets.keys() + ), "source_file and target_file should have the same tensors (by name)" for k in gguf_sources: source = to_float(gguf_sources[k]) target = to_float(gguf_targets[k]) if not torch.allclose(source, target): - print(f"After calling to_float on source tensor {k} of type {repr(gguf_sources[k].tensor_type)} it does not match its target.") + print( + f"After calling to_float on source tensor {k} of type {repr(gguf_sources[k].tensor_type)} it does not match its target." + ) print("First 5 elements of converted source: ", source.reshape(-1)[0:5]) print("First 5 elements of target: ", target.reshape(-1)[0:5]) assert False, "found mismatch" @@ -68,6 +79,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): new_tensor = gguf_tensor.data.reshape(reversed_shape) return torch.from_numpy(new_tensor).to(torch.float16) + class F32: @staticmethod def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): @@ -79,6 +91,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): new_tensor = gguf_tensor.data.reshape(reversed_shape) return torch.from_numpy(new_tensor).to(torch.float32) + class Q4_0: groupsize = 32 n_bit = 4 @@ -111,24 +124,24 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): assert gguf_tensor.tensor_type == gguf.GGMLQuantizationType.Q4_0 assert len(gguf_tensor.shape) == 2 - nc, nr = gguf_tensor.shape # GGUF tensor has reversed shape + nc, nr = gguf_tensor.shape # GGUF tensor has reversed shape - QK4_0 = 32 # groupsize + QK4_0 = 32 # groupsize # Parse block_q4_0 block_q4_0_size = int(2 + QK4_0 / 2) packed = torch.from_numpy(gguf_tensor.data.reshape(-1, block_q4_0_size)) assert packed.dtype == torch.uint8 - ng = packed.shape[0] # number of groups/blocks + ng = packed.shape[0] # number of groups/blocks curr = 0 - size = 2 # half size - d = packed[:,curr:(curr+size)].contiguous() + size = 2 # half size + d = packed[:, curr : (curr + size)].contiguous() d = torch.tensor(d.untyped_storage(), dtype=torch.float16).reshape(ng, 1) curr += size size = int(QK4_0 / 2) - qs = packed[:,curr:(curr+size)].contiguous() + qs = packed[:, curr : (curr + size)].contiguous() curr += size # Check we finished parsing @@ -141,7 +154,7 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): int32_data = torch.cat([x0, x1], dim=1).to(torch.int32).reshape(ng, QK4_0) assert int32_data.dtype == torch.int32 assert int32_data.min().item() >= 0 - assert int32_data.max().item() <= 2**4-1 + assert int32_data.max().item() <= 2**4 - 1 assert int32_data.shape == (ng, QK4_0) # Prepare for return @@ -194,62 +207,86 @@ def unpack(gguf_tensor: gguf.gguf_reader.ReaderTensor): QK_K = 256 # Parse block_q6_K - block_q6_K_size = int(QK_K/2 + QK_K/4 + QK_K/16 + 2) + block_q6_K_size = int(QK_K / 2 + QK_K / 4 + QK_K / 16 + 2) packed = torch.from_numpy(gguf_tensor.data.reshape(-1, block_q6_K_size)) assert packed.dtype == torch.uint8 - ng = packed.shape[0] # number of groups/blocks + ng = packed.shape[0] # number of groups/blocks curr = 0 - size = int(QK_K/2) - ql = packed[:,curr:(curr+size)].contiguous() + size = int(QK_K / 2) + ql = packed[:, curr : (curr + size)].contiguous() assert ql.shape == (ng, 128) curr += size - size = int(QK_K/4) - qh = packed[:,curr:(curr+size)].contiguous() + size = int(QK_K / 4) + qh = packed[:, curr : (curr + size)].contiguous() assert qh.shape == (ng, 64) curr += size - size = int(QK_K/16) - scales = packed[:,curr:(curr+size)].contiguous() - scales = torch.tensor(scales.untyped_storage(), dtype=torch.int8).reshape(ng, int(QK_K/16)).to(torch.float32) + size = int(QK_K / 16) + scales = packed[:, curr : (curr + size)].contiguous() + scales = ( + torch.tensor(scales.untyped_storage(), dtype=torch.int8) + .reshape(ng, int(QK_K / 16)) + .to(torch.float32) + ) curr += size - size = 2 # half size - d = packed[:,curr:(curr+size)].contiguous() - d = torch.tensor(d.untyped_storage(), dtype=torch.float16).reshape(ng, 1).to(torch.float32) + size = 2 # half size + d = packed[:, curr : (curr + size)].contiguous() + d = ( + torch.tensor(d.untyped_storage(), dtype=torch.float16) + .reshape(ng, 1) + .to(torch.float32) + ) curr += size # Check we finished parsing assert curr == block_q6_K_size # Unpack quantized values. Unlike the code in ggml-quants.c, we do not subtract 32 - q1 = ((ql[:,0:32] & 0xF) | (((qh[:,0:32] >> 0) & 3) << 4)) - q2 = ((ql[:,32:64] & 0xF) | (((qh[:,0:32] >> 2) & 3) << 4)) - q3 = ((ql[:,0:32] >> 4) | (((qh[:,0:32] >> 4) & 3) << 4)) - q4 = ((ql[:,32:64] >> 4) | (((qh[:,0:32] >> 6) & 3) << 4)) + q1 = (ql[:, 0:32] & 0xF) | (((qh[:, 0:32] >> 0) & 3) << 4) + q2 = (ql[:, 32:64] & 0xF) | (((qh[:, 0:32] >> 2) & 3) << 4) + q3 = (ql[:, 0:32] >> 4) | (((qh[:, 0:32] >> 4) & 3) << 4) + q4 = (ql[:, 32:64] >> 4) | (((qh[:, 0:32] >> 6) & 3) << 4) - q5 = ((ql[:,64:96] & 0xF) | (((qh[:,32:64] >> 0) & 3) << 4)) - q6 = ((ql[:,96:128] & 0xF) | (((qh[:,32:64] >> 2) & 3) << 4)) - q7 = ((ql[:,64:96] >> 4) | (((qh[:,32:64] >> 4) & 3) << 4)) - q8 = ((ql[:,96:128] >> 4) | (((qh[:,32:64] >> 6) & 3) << 4)) + q5 = (ql[:, 64:96] & 0xF) | (((qh[:, 32:64] >> 0) & 3) << 4) + q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:64] >> 2) & 3) << 4) + q7 = (ql[:, 64:96] >> 4) | (((qh[:, 32:64] >> 4) & 3) << 4) + q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:64] >> 6) & 3) << 4) q = torch.cat([q1, q2, q3, q4, q5, q6, q7, q8], dim=1).to(torch.int32) assert q.shape == (ng, QK_K) assert q.min().item() >= 0 - assert q.max().item() <= 2**6-1 + assert q.max().item() <= 2**6 - 1 # Unpack scales - s1 = d * torch.cat([scales[:,0].reshape(-1,1), scales[:,1].reshape(-1,1)], dim=1) - s2 = d * torch.cat([scales[:,2].reshape(-1,1), scales[:,3].reshape(-1,1)], dim=1) - s3 = d * torch.cat([scales[:,4].reshape(-1,1), scales[:,5].reshape(-1,1)], dim=1) - s4 = d * torch.cat([scales[:,6].reshape(-1,1), scales[:,7].reshape(-1,1)], dim=1) - - s5 = d * torch.cat([scales[:,8].reshape(-1,1), scales[:,9].reshape(-1,1)], dim=1) - s6 = d * torch.cat([scales[:,10].reshape(-1,1), scales[:,11].reshape(-1,1)], dim=1) - s7 = d * torch.cat([scales[:,12].reshape(-1,1), scales[:,13].reshape(-1,1)], dim=1) - s8 = d * torch.cat([scales[:,14].reshape(-1,1), scales[:,15].reshape(-1,1)], dim=1) + s1 = d * torch.cat( + [scales[:, 0].reshape(-1, 1), scales[:, 1].reshape(-1, 1)], dim=1 + ) + s2 = d * torch.cat( + [scales[:, 2].reshape(-1, 1), scales[:, 3].reshape(-1, 1)], dim=1 + ) + s3 = d * torch.cat( + [scales[:, 4].reshape(-1, 1), scales[:, 5].reshape(-1, 1)], dim=1 + ) + s4 = d * torch.cat( + [scales[:, 6].reshape(-1, 1), scales[:, 7].reshape(-1, 1)], dim=1 + ) + + s5 = d * torch.cat( + [scales[:, 8].reshape(-1, 1), scales[:, 9].reshape(-1, 1)], dim=1 + ) + s6 = d * torch.cat( + [scales[:, 10].reshape(-1, 1), scales[:, 11].reshape(-1, 1)], dim=1 + ) + s7 = d * torch.cat( + [scales[:, 12].reshape(-1, 1), scales[:, 13].reshape(-1, 1)], dim=1 + ) + s8 = d * torch.cat( + [scales[:, 14].reshape(-1, 1), scales[:, 15].reshape(-1, 1)], dim=1 + ) s = torch.cat([s1, s2, s3, s4, s5, s6, s7, s8], dim=1) assert s.shape == (ng, 16) diff --git a/build/model.py b/build/model.py index bc1a783b1..9e3cc1e71 100644 --- a/build/model.py +++ b/build/model.py @@ -9,10 +9,11 @@ import torch import torch.nn as nn + +from quantize import get_precision from torch import Tensor from torch.nn import functional as F -from quantize import get_precision def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -86,7 +87,9 @@ def from_name(cls, name: str): config[1] ), name # make sure only one 'best' match elif len(config) == 0: - raise ValueError(f"Unknown model directory name {name}. Must be one of {list(transformer_configs.keys())}.") + raise ValueError( + f"Unknown model directory name {name}. Must be one of {list(transformer_configs.keys())}." + ) return cls(**transformer_configs[config[0]]) @@ -107,9 +110,7 @@ def from_name(cls, name: str): hidden_dim=22016, rope_base=1000000, ), # CodeLlama-34B-Python-hf - "70B": dict( - n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672 - ), + "70B": dict(n_layer=80, n_heads=64, dim=8192, n_local_heads=8, hidden_dim=28672), "Mistral-7B": dict( n_layer=32, n_heads=32, @@ -140,12 +141,11 @@ def from_name(cls, name: str): class KVCache(nn.Module): - def __init__( - self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=None): # torch.float): # bfloat16 ): super().__init__() if not dtype: - dtype=get_precision() + dtype = get_precision() cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) self.register_buffer("k_cache", torch.zeros(cache_shape, dtype=dtype)) self.register_buffer("v_cache", torch.zeros(cache_shape, dtype=dtype)) @@ -238,9 +238,10 @@ def from_params(cls, params_path: str): @classmethod def from_gguf(cls, gguf_path: str): from build.gguf_loader import load_llama_from_gguf_file + model = load_llama_from_gguf_file(gguf_path) return model - + class TransformerBlock(nn.Module): def __init__(self, config: ModelArgs) -> None: @@ -267,8 +268,12 @@ def __init__(self, config: ModelArgs): # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) self.wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) - self.wk = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) - self.wv = nn.Linear(config.dim, config.n_local_heads * config.head_dim, bias=False) + self.wk = nn.Linear( + config.dim, config.n_local_heads * config.head_dim, bias=False + ) + self.wv = nn.Linear( + config.dim, config.n_local_heads * config.head_dim, bias=False + ) self.wo = nn.Linear(config.dim, config.dim, bias=False) self.kv_cache = None @@ -297,7 +302,6 @@ def load_hook(self, state_dict, prefix, *args): return - def _unfuse_wqkv_state_dict( state_dict: Dict[str, torch.Tensor], dim: int, @@ -306,15 +310,16 @@ def _unfuse_wqkv_state_dict( if key.endswith("wqkv.weight"): tensor = state_dict[key] wq_key = key.replace("wqkv.weight", "wq.weight") - state_dict[wq_key] = tensor[: dim] + state_dict[wq_key] = tensor[:dim] wk_key = key.replace("wqkv.weight", "wk.weight") wv_key = key.replace("wqkv.weight", "wv.weight") - wk, wv = tensor[dim :].chunk(2, 0) + wk, wv = tensor[dim:].chunk(2, 0) state_dict[wk_key] = wk state_dict[wv_key] = wv state_dict.pop(key) else: continue + _unfuse_wqkv_state_dict(state_dict, self.dim) def forward( @@ -326,7 +331,6 @@ def forward( ) -> Tensor: bsz, seqlen, _ = x.shape - q = self.wq(x) k = self.wk(x) v = self.wv(x) @@ -379,8 +383,11 @@ def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight + # transpsoed first two arguments to align with model in ET -def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000, dtype=None) -> Tensor: +def precompute_freqs_cis( + n_elem: int, seq_len: int, base: int = 10000, dtype=None +) -> Tensor: if not dtype: dtype = get_precision() freqs = 1.0 / ( @@ -390,7 +397,7 @@ def precompute_freqs_cis(n_elem: int, seq_len: int, base: int = 10000, dtype=Non freqs = torch.outer(t, freqs) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) - return cache.to(dtype=dtype) # bfloat16) + return cache.to(dtype=dtype) # bfloat16) def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: diff --git a/build/model_aoti.py b/build/model_aoti.py index 4832c248d..50fa2b939 100644 --- a/build/model_aoti.py +++ b/build/model_aoti.py @@ -11,7 +11,7 @@ # with open("./dso_model.h", "rb") as f: # dso_src = f.read().decode("utf-8") -dso_src ="" +dso_src = "" src = """ #include @@ -36,7 +36,6 @@ """ - class DSOModel(nn.Module): def __init__(self, config, dso_path) -> None: super().__init__() @@ -44,8 +43,8 @@ def __init__(self, config, dso_path) -> None: # build transformer model global src, dso_src - - src = src.replace('***my_model.so***', str(dso_path)) + + src = src.replace("***my_model.so***", str(dso_path)) async_compile = AsyncCompile() self.transformer_model = async_compile.cpp_pybinding( ["long *", "long *", "float *"], dso_src + src @@ -53,9 +52,8 @@ def __init__(self, config, dso_path) -> None: async_compile.wait(globals()) del async_compile - def forward(self, x, input_pos): - vocab_size = self.config.vocab_size # 32000 + vocab_size = self.config.vocab_size # 32000 assert x.dim() == 2 and x.size(0) == 1 and x.size(1) == 1 logits = torch.empty(1, 1, vocab_size) x = x.to(torch.long) diff --git a/build/model_et.py b/build/model_et.py index cee0d5c3d..f7bd02194 100644 --- a/build/model_et.py +++ b/build/model_et.py @@ -2,8 +2,9 @@ import torch import torch.nn as nn -from torch import empty from executorch.extension.pybindings import portable_lib as exec_lib +from torch import empty + class PTEModel(nn.Module): def __init__(self, config, path) -> None: diff --git a/cli.py b/cli.py index b14f0a944..9bdca3776 100644 --- a/cli.py +++ b/cli.py @@ -4,8 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import time import os +import time from pathlib import Path import torch @@ -15,22 +15,23 @@ strict = False + def check_args(args, command_name: str): global strict # chat and generate support the same options - if command_name in ["generate", "chat", "gui"]: + if command_name in ["generate", "chat", "gui"]: # examples, can add more. Note that attributes convert dash to _ - disallowed_args = ['output_pte_path', 'output_dso_path' ] + disallowed_args = ["output_pte_path", "output_dso_path"] elif command_name == "export": # examples, can add more. Note that attributes convert dash to _ - disallowed_args = ['pte_path', 'dso_path' ] + disallowed_args = ["pte_path", "dso_path"] elif command_name == "eval": # TBD disallowed_args = [] else: raise RuntimeError(f"{command_name} is not a valid command") - + for disallowed in disallowed_args: if hasattr(args, disallowed): text = f"command {command_name} does not support option {disallowed.replace('_', '-')}" @@ -39,7 +40,7 @@ def check_args(args, command_name: str): else: print(f"Warning: {text}") - + def cli_args(): import argparse @@ -48,8 +49,8 @@ def cli_args(): parser.add_argument( "--seed", type=int, - default=1234, # set None for release - help="Initialize torch seed" + default=1234, # set None for release + help="Initialize torch seed", ) parser.add_argument( "--prompt", type=str, default="Hello, my name is", help="Input prompt." @@ -78,55 +79,31 @@ def cli_args(): "--chat", action="store_true", help="Use torchat to for an interactive chat session.", - ) + ) parser.add_argument( "--gui", action="store_true", help="Use torchat to for an interactive gui-chat session.", - ) - parser.add_argument( - "--num-samples", - type=int, - default=1, - help="Number of samples.") - parser.add_argument( - "--max-new-tokens", - type=int, - default=200, - help="Maximum number of new tokens." ) + parser.add_argument("--num-samples", type=int, default=1, help="Number of samples.") parser.add_argument( - "--top-k", - type=int, - default=200, - help="Top-k for sampling.") + "--max-new-tokens", type=int, default=200, help="Maximum number of new tokens." + ) + parser.add_argument("--top-k", type=int, default=200, help="Top-k for sampling.") parser.add_argument( - "--temperature", - type=float, - default=0.8, - help="Temperature for sampling." + "--temperature", type=float, default=0.8, help="Temperature for sampling." ) parser.add_argument( - "--compile", - action="store_true", - help="Whether to compile the model." + "--compile", action="store_true", help="Whether to compile the model." ) parser.add_argument( "--compile-prefill", action="store_true", help="Whether to compile the prefill (improves prefill perf, but higher compile times)", ) + parser.add_argument("--profile", type=Path, default=None, help="Profile path.") parser.add_argument( - "--profile", - type=Path, - default=None, - help="Profile path." - ) - parser.add_argument( - "--speculate-k", - type=int, - default=5, - help="Speculative execution depth." + "--speculate-k", type=int, default=5, help="Speculative execution depth." ) parser.add_argument( "--draft-checkpoint-path", @@ -163,31 +140,18 @@ def cli_args(): type=Path, default=None, help="Model checkpoint path.", - ) - parser.add_argument( - "--output-pte-path", - type=str, - default=None, - help="Filename" ) + parser.add_argument("--output-pte-path", type=str, default=None, help="Filename") + parser.add_argument("--output-dso-path", type=str, default=None, help="Filename") parser.add_argument( - "--output-dso-path", - type=str, - default=None, - help="Filename" - ) - parser.add_argument( - "--dso-path", - type=Path, - default=None, - help="Use the specified AOTI DSO model." + "--dso-path", type=Path, default=None, help="Use the specified AOTI DSO model." ) parser.add_argument( "--pte-path", type=Path, default=None, - help="Use the specified Executorch PTE model." - ) + help="Use the specified Executorch PTE model.", + ) parser.add_argument( "-d", "--dtype", @@ -196,48 +160,36 @@ def cli_args(): ) parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( - "--quantize", - type=str, - default="{ }", - help="Quantization options." + "--quantize", type=str, default="{ }", help="Quantization options." ) parser.add_argument( - "--device", - type=str, - default=default_device, - help="Device to use" + "--device", type=str, default=default_device, help="Device to use" ) + parser.add_argument("--params-table", type=str, default=None, help="Device to use") parser.add_argument( - "--params-table", - type=str, - default=None, - help="Device to use" - ) - parser.add_argument( - '--tasks', - nargs='+', + "--tasks", + nargs="+", type=str, default=["hellaswag"], - help='list of lm-eluther tasks to evaluate usage: --tasks task1 task2' + help="list of lm-eluther tasks to evaluate usage: --tasks task1 task2", ) parser.add_argument( - '--limit', type=int, - default=None, - help='number of samples to evaluate' + "--limit", type=int, default=None, help="number of samples to evaluate" ) parser.add_argument( - '--max-seq-length', + "--max-seq-length", type=int, default=None, - help='maximum length sequence to evaluate') - + help="maximum length sequence to evaluate", + ) + args = parser.parse_args() - if (Path(args.quantize).is_file()): + if Path(args.quantize).is_file(): with open(args.quantize, "r") as f: args.quantize = json.loads(f.read()) if args.seed: - torch.manual_seed(args.seed) + torch.manual_seed(args.seed) return args diff --git a/eval.py b/eval.py index 8ac8c457f..9ebd5337d 100644 --- a/eval.py +++ b/eval.py @@ -18,32 +18,36 @@ torch._inductor.config.triton.cudagraphs = True torch._dynamo.config.cache_size_limit = 100000 +from build.model import Transformer from cli import cli_args from quantize import name_to_dtype, set_precision -from build.model import Transformer - try: import lm_eval + lm_eval_available = True except: lm_eval_available = False -from build.builder import _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _initialize_model, + _initialize_tokenizer, + BuilderArgs, + TokenizerArgs, +) from generate import encode_tokens, model_forward if lm_eval_available: - try: # lm_eval version 0.4 + try: # lm_eval version 0.4 + from lm_eval.evaluator import evaluate from lm_eval.models.huggingface import HFLM as eval_wrapper from lm_eval.tasks import get_task_dict - from lm_eval.evaluator import evaluate - except: #lm_eval version 0.3 - from lm_eval import base - from lm_eval import tasks - from lm_eval import evaluator - eval_wrapper=base.BaseLM - get_task_dict=tasks.get_task_dict - evaluate=evaluator.evaluate + except: # lm_eval version 0.3 + from lm_eval import base, evaluator, tasks + + eval_wrapper = base.BaseLM + get_task_dict = tasks.get_task_dict + evaluate = evaluator.evaluate def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( @@ -84,20 +88,22 @@ def setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( return seq, input_pos, max_seq_length + class GPTFastEvalWrapper(eval_wrapper): """ A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library. """ + def __init__( self, model: Transformer, tokenizer, - max_seq_length: Optional[int]=None, + max_seq_length: Optional[int] = None, ): super().__init__() self._model = model self._tokenizer = tokenizer - self._device = torch.device('cuda') + self._device = torch.device("cuda") self._max_seq_length = 2048 if max_seq_length is None else max_seq_length @property @@ -121,8 +127,7 @@ def device(self): return self._device def tok_encode(self, string: str, **kwargs): - encoded = encode_tokens(self._tokenizer, - string, bos=True, device=self._device) + encoded = encode_tokens(self._tokenizer, string, bos=True, device=self._device) # encoded is a pytorch tensor, but some internal logic in the # eval harness expects it to be a list instead # TODO: verify this for multi-batch as well @@ -138,19 +143,20 @@ def _model_call(self, inps): inps = inps.squeeze(0) max_new_tokens = 1 - seq, input_pos, max_seq_length = \ + seq, input_pos, max_seq_length = ( setup_cache_padded_seq_input_pos_max_seq_length_for_prefill( self._model, inps, max_new_tokens, self.max_length, ) + ) x = seq.index_select(0, input_pos).view(1, -1) logits = model_forward(self._model, x, input_pos) return logits def _model_generate(self, context, max_length, eos_token_id): - raise Exception('unimplemented') + raise Exception("unimplemented") @torch.no_grad() @@ -185,8 +191,8 @@ def eval( except: pass - if 'hendrycks_test' in tasks: - tasks.remove('hendrycks_test') + if "hendrycks_test" in tasks: + tasks.remove("hendrycks_test") tasks += [x for x in lm_eval.tasks.hendrycks_test.create_all_tasks().keys()] task_dict = get_task_dict(tasks) @@ -212,7 +218,7 @@ def main(args) -> None: builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) - + checkpoint_path = args.checkpoint_path checkpoint_dir = args.checkpoint_dir params_path = args.params_path @@ -223,12 +229,12 @@ def main(args) -> None: pte_path = args.pte_path quantize = args.quantize device = args.device - model_dtype = args.dtype + model_dtype = args.dtype tasks = args.tasks limit = args.limit max_seq_length = args.max_seq_length use_tiktoken = args.tiktoken - + print(f"Using device={device}") set_precision(buildeer_args.precision) @@ -240,9 +246,13 @@ def main(args) -> None: ) if compile: - assert not (builder_args.dso_path or builder_args.pte_path), "cannot compile exported model" + assert not ( + builder_args.dso_path or builder_args.pte_path + ), "cannot compile exported model" global model_forward - model_forward = torch.compile(model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True) + model_forward = torch.compile( + model_forward, mode="reduce-overhead", dynamic=True, fullgraph=True + ) torch._inductor.config.coordinate_descent_tuning = True t1 = time.time() @@ -268,11 +278,13 @@ def main(args) -> None: for task, res in result["results"].items(): print(f"{task}: {res}") -if __name__ == '__main__': + +if __name__ == "__main__": + def cli(): args = cli_args() main(args) if __name__ == "__main__": - cli() + cli() diff --git a/export.py b/export.py index 1e5fb5d37..f31af803a 100644 --- a/export.py +++ b/export.py @@ -4,17 +4,17 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import time import os +import time from pathlib import Path import torch import torch.nn as nn -from torch.export import Dim, export - -from quantize import quantize_model, name_to_dtype, set_precision, get_precision from cli import cli_args +from quantize import get_precision, name_to_dtype, quantize_model, set_precision +from torch.export import Dim, export + try: executorch_export_available = True from export_et import export_model as export_model_et @@ -22,12 +22,12 @@ executorch_exception = f"ET EXPORT EXCEPTION: {e}" executorch_export_available = False -from export_aoti import export_model as export_model_aoti +from build.builder import _initialize_model, BuilderArgs, TokenizerArgs from build.model import Transformer -from build.builder import _initialize_model, BuilderArgs, TokenizerArgs +from export_aoti import export_model as export_model_aoti from generate import decode_one_token -from quantize import quantize_model, name_to_dtype +from quantize import name_to_dtype, quantize_model from torch._export import capture_pre_autograd_graph default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -42,7 +42,6 @@ def device_sync(device): print(f"device={device} is not yet suppported") - def main(args): builder_args = BuilderArgs.from_args(args) tokenizer_args = TokenizerArgs.from_args(args) @@ -70,7 +69,9 @@ def main(args): print(f"Exporting model using Executorch to {output_pte_path}") export_model_et(model, builder_args.device, args.output_pte_path, args) else: - print(f"Export with executorch requested but Executorch could not be loaded") + print( + f"Export with executorch requested but Executorch could not be loaded" + ) print(executorch_exception) if output_dso_path: output_dso_path = str(os.path.abspath(output_dso_path)) @@ -82,5 +83,6 @@ def cli(): args = cli_args() main(args) + if __name__ == "__main__": cli() diff --git a/export_aoti.py b/export_aoti.py index 6501b9e98..b9a59c3bb 100644 --- a/export_aoti.py +++ b/export_aoti.py @@ -12,12 +12,12 @@ import torch import torch.nn as nn -from torch.export import Dim, export + +from build.model import Transformer from generate import decode_one_token from quantize import quantize_model - -from build.model import Transformer +from torch.export import Dim, export default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -33,11 +33,11 @@ def device_sync(device): def export_model(model: nn.Module, device, output_path, args=None): max_seq_length = 350 -# with torch.device(device): -# model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) + # with torch.device(device): + # model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) input = ( - torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), + torch.tensor([[1, 9038, 2501, 263, 931]], dtype=torch.int, device=device), torch.tensor([0, 1, 2, 3, 4], dtype=torch.int, device=device), ) diff --git a/export_et.py b/export_et.py index 030fd0b6c..6e7de441b 100644 --- a/export_et.py +++ b/export_et.py @@ -9,23 +9,10 @@ import torch import torch.nn as nn -from torch.export import Dim, export -from torch._export import capture_pre_autograd_graph - -from generate import decode_one_token -from quantize import ( - quantize_model, name_to_dtype, set_precision, get_precision, -) -from build.model import Transformer from build.model import Transformer -from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( - XnnpackPartitioner, -) -# from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( -# XnnpackDynamicallyQuantizedPartitioner, -#) -from executorch_portable_utils import export_to_edge +from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner + # TODO: change back to executorch.examples.portable.utils # when executorch installs correctly @@ -33,6 +20,16 @@ from executorch.exir.passes.quant_fusion_pass import QuantFusionPass from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +# from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( +# XnnpackDynamicallyQuantizedPartitioner, +# ) +from executorch_portable_utils import export_to_edge + +from generate import decode_one_token +from quantize import get_precision, name_to_dtype, quantize_model, set_precision +from torch._export import capture_pre_autograd_graph +from torch.export import Dim, export + default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' @@ -99,7 +96,7 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901 _skip_type_promotion=bool(target_precision == torch.float16), ) - if target_precision == torch.float16: # or args.quantization_mode=="int4": + if target_precision == torch.float16: # or args.quantization_mode=="int4": if state_dict_dtype != torch.float16: print("model.to torch.float16") model = model.to(dtype=torch.float16) @@ -111,11 +108,11 @@ def export_model(model, device, output_path, args=None) -> str: # noqa: C901 else: raise ValueError(f"Unsupported dtype for ET export: {target_precision}") - with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]), torch.no_grad(): + with torch.nn.attention.sdpa_kernel( + [torch.nn.attention.SDPBackend.MATH] + ), torch.no_grad(): m = capture_pre_autograd_graph( - export_model, - input, - dynamic_shapes=dynamic_shapes + export_model, input, dynamic_shapes=dynamic_shapes ) edge_manager = export_to_edge( diff --git a/generate.py b/generate.py index ad6085582..18583c2f9 100644 --- a/generate.py +++ b/generate.py @@ -4,48 +4,55 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import itertools -import sys import os +import sys import time +from dataclasses import dataclass from pathlib import Path from typing import Optional, Tuple -from dataclasses import dataclass import torch import torch._dynamo.config import torch._inductor.config -from build.builder import _load_model, _initialize_model, _initialize_tokenizer, BuilderArgs, TokenizerArgs +from build.builder import ( + _initialize_model, + _initialize_tokenizer, + _load_model, + BuilderArgs, + TokenizerArgs, +) from build.model import Transformer -from quantize import quantize_model, name_to_dtype, set_precision, get_precision from cli import cli_args +from quantize import get_precision, name_to_dtype, quantize_model, set_precision + @dataclass class GeneratorArgs: prompt: str = "torchat is pronounced torch-chat and is so cool because" - chat: bool = False, - gui: bool = False, - num_samples: int =1, - max_new_tokens: int = 200, - top_k: int = 200, - temperature: int = 0, # deterministic argmax - compile: bool = False, - compile_prefill: bool = False, - speculate_k: int = 5, + chat: bool = (False,) + gui: bool = (False,) + num_samples: int = (1,) + max_new_tokens: int = (200,) + top_k: int = (200,) + temperature: int = (0,) # deterministic argmax + compile: bool = (False,) + compile_prefill: bool = (False,) + speculate_k: int = (5,) @classmethod - def from_args(cls, args): # -> GeneratorArgs: + def from_args(cls, args): # -> GeneratorArgs: return cls( - prompt = args.prompt, - chat = args.chat, - gui = args.gui, - num_samples = args.num_samples, - max_new_tokens = args.max_new_tokens, - top_k = args.top_k, - temperature = args.temperature, - compile = args.compile, - compile_prefill = args.compile_prefill, - speculate_k = args.speculate_k, + prompt=args.prompt, + chat=args.chat, + gui=args.gui, + num_samples=args.num_samples, + max_new_tokens=args.max_new_tokens, + top_k=args.top_k, + temperature=args.temperature, + compile=args.compile, + compile_prefill=args.compile_prefill, + speculate_k=args.speculate_k, ) @@ -152,6 +159,7 @@ def decode_n_tokens( # except: # print("compiled model load not successful, running eager model") + def model_forward(model, x, input_pos): return model(x, input_pos) @@ -336,20 +344,20 @@ def _main( is_chat = "chat" in str(os.path.basename(builder_args.checkpoint_path)) if is_chat: - raise RuntimeError("need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!") + raise RuntimeError( + "need to stop filename based kludgery, at a minimum need to look at all pathnames. in particular, this now fails because chat is part of the pathname, yuck!" + ) tokenizer = _initialize_tokenizer(tokenizer_args) builder_args.setup_caches = False - model = _initialize_model( - builder_args, - quantize - ) + model = _initialize_model(builder_args, quantize) # will add a version of _initialize_model in future # (need additional args) if is_speculative: from builder import _load_model + speculative_builder_args = builder_args draft_model = _load_model( @@ -359,7 +367,7 @@ def _main( draft_model = None encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) - print (encoded) + print(encoded) prompt_length = encoded.size(0) model_size = sum( @@ -369,7 +377,9 @@ def _main( ] ) if compile: - if is_speculative and builder_args.use_tp: # and ("cuda" in builder_args.device): + if ( + is_speculative and builder_args.use_tp + ): # and ("cuda" in builder_args.device): torch._inductor.config.triton.cudagraph_trees = ( False # Bug with cudagraph trees in this case ) @@ -401,7 +411,9 @@ def _main( prompt = input("What is your prompt? ") if is_chat: prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = encode_tokens(tokenizer, prompt, bos=True, device=builder_args.device) + encoded = encode_tokens( + tokenizer, prompt, bos=True, device=builder_args.device + ) if chat_mode and i >= 0: buffer = [] @@ -503,10 +515,11 @@ def main(args): args.quantize, ) + def cli(): args = cli_args() main(args) if __name__ == "__main__": - cli() + cli() diff --git a/quantize.py b/quantize.py index 4890c4b42..766efac04 100644 --- a/quantize.py +++ b/quantize.py @@ -4,21 +4,21 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import json from functools import reduce from math import gcd from typing import Dict, Optional, Tuple -import json + +import quantized_ops import torch import torch.nn as nn import torch.nn.functional as F -import quantized_ops - try: + from eval import evaluate, get_task_dict, lm_eval from GPTQ import GenericGPTQRunner, InputRecorder - from eval import get_task_dict, evaluate, lm_eval except: pass @@ -27,34 +27,39 @@ precision = torch.float + def set_precision(dtype): global precision precision = dtype + def get_precision(): global precision return precision + def name_to_dtype(name): if name in name_to_dtype_dict: return name_to_dtype_dict[name] else: raise RuntimeError(f"unsupported dtype name {name} specified") + name_to_dtype_dict = { - "fp32" : torch.float, - "fp16" : torch.float16, - "bf16" : torch.bfloat16, - "float" : torch.float, - "half" : torch.float16, - "float32" : torch.float, - "float16" : torch.float16, - "bfloat16" : torch.bfloat16, + "fp32": torch.float, + "fp16": torch.float16, + "bf16": torch.bfloat16, + "float": torch.float, + "half": torch.float16, + "float32": torch.float, + "float16": torch.float16, + "bfloat16": torch.bfloat16, } ########################################################################## ### process quantization dictionary ### + def quantize_model(model: nn.Module, device, quantize_options): """ Quantize the specified model using the quantizers described by @@ -73,46 +78,34 @@ def quantize_model(model: nn.Module, device, quantize_options): for quantizer, q_kwargs in quantize_options.items(): if quantizer == "embedding": model = EmbeddingOnlyInt8QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif linears_quantized: - assert 0==1, "can only specify one linear quantizer" + assert 0 == 1, "can only specify one linear quantizer" elif quantizer == "linear:int8": linears_quantized = True model = WeightOnlyInt8QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:int4": linears_quantized = True model = WeightOnlyInt4QuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:a8w4dq": linears_quantized = True model = Int8DynActInt4WeightQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:gptq": linears_quantized = True model = WeightOnlyInt4GPTQQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "linear:hqq": linears_quantized = True model = WeightOnlyInt4HqqQuantHandler( - model, - device, - **q_kwargs + model, device, **q_kwargs ).quantized_model() elif quantizer == "precision": model.to(**q_kwargs) @@ -123,6 +116,7 @@ def quantize_model(model: nn.Module, device, quantize_options): ######################################################################### ##### Quantization Primitives ###### + def dynamically_quantize_per_channel( x, quant_min, @@ -217,8 +211,7 @@ def dynamically_quantize_per_channel( return quant, scales, zero_points - -def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype= torch.float): +def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype=torch.float): # needed for GPTQ with padding if groupsize > w.shape[-1]: groupsize = w.shape[-1] @@ -324,6 +317,7 @@ def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128): w_int32, scales, zeros, n_bit, groupsize ) + ######################################################################### ### QuantHandler API definition ### @@ -349,9 +343,11 @@ def quantized_model(self) -> nn.Module: ##### Weight-only int8 per-channel quantized code ###### -def replace_linear_weight_only_int8_per_channel(module, device, node_type, groupsize=None): +def replace_linear_weight_only_int8_per_channel( + module, device, node_type, groupsize=None +): if groupsize is not None and groupsize != 0: - pass # groupsize = 2 ** groupsize + pass # groupsize = 2 ** groupsize for name, child in module.named_children(): # print(f"name: {name}") @@ -367,10 +363,14 @@ def replace_linear_weight_only_int8_per_channel(module, device, node_type, group setattr( module, name, - WeightOnlyInt8Linear(device, child.in_features, child.out_features, groupsize), + WeightOnlyInt8Linear( + device, child.in_features, child.out_features, groupsize + ), ) else: - replace_linear_weight_only_int8_per_channel(child, device, node_type, groupsize) + replace_linear_weight_only_int8_per_channel( + child, device, node_type, groupsize + ) class WeightOnlyInt8QuantHandler(QuantHandler): @@ -443,7 +443,9 @@ def create_quantized_state_dict(self) -> Dict: return cur_state_dict def convert_for_runtime(self) -> nn.Module: - replace_linear_weight_only_int8_per_channel(self.mod, self.device, self.node_type, self.groupsize) + replace_linear_weight_only_int8_per_channel( + self.mod, self.device, self.node_type, self.groupsize + ) return self.mod def quantized_model(self) -> nn.Module: @@ -474,14 +476,19 @@ def __init__( self.in_features = in_features self.out_features = out_features self.register_buffer( - "weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device) + "weight", + torch.empty((out_features, in_features), dtype=torch.int8, device=device), ) - dtype=get_precision() + dtype = get_precision() if groupsize is None or (groupsize == 0): - self.register_buffer("scales", torch.ones(out_features, dtype=dtype, device=device)) + self.register_buffer( + "scales", torch.ones(out_features, dtype=dtype, device=device) + ) else: groups = (in_features + groupsize - 1) // groupsize - self.register_buffer("scales", torch.ones(out_features, groups, dtype=dtype, device=device)) + self.register_buffer( + "scales", torch.ones(out_features, groups, dtype=dtype, device=device) + ) def forward(self, input: torch.Tensor) -> torch.Tensor: scales = self.scales @@ -496,7 +503,13 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: if scales.shape[1] == 1: return F.linear(input, weight.to(dtype=input.dtype)) * self.scales else: - return F.linear(input, (weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) * scales.view(weight.shape[0], no_groups, -1)).view(weight.shape[0], -1)) + return F.linear( + input, + ( + weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) + * scales.view(weight.shape[0], no_groups, -1) + ).view(weight.shape[0], -1), + ) ######################################################################### @@ -504,7 +517,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: def replace_embedding_weight_only_grouped_int8_per_channel( - module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False + module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False ): for name, child in module.named_children(): # print(f"name: {name}") @@ -529,9 +542,17 @@ def replace_embedding_weight_only_grouped_int8_per_channel( class EmbeddingOnlyInt8QuantHandler(QuantHandler): - def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = None, packed = False): + def __init__( + self, + mod, + device, + *, + bitwidth: int = 8, + groupsize: Optional[int] = None, + packed=False, + ): if isinstance(packed, str): - packed = (packed == "True") + packed = packed == "True" self.mod = mod self.device = device self.groupsize = groupsize @@ -540,7 +561,6 @@ def __init__(self, mod, device, *, bitwidth: int = 8, groupsize: Optional[int] = if (bitwidth != 4) and packed: raise RuntimeError("pack only works with bitsize 4") - @torch.no_grad() def create_quantized_state_dict(self, packed=False) -> Dict: cur_state_dict = self.mod.state_dict() @@ -555,9 +575,7 @@ def create_quantized_state_dict(self, packed=False) -> Dict: raise ValueError(f"Unsupported bitwidth {self.bitwidth}") for fqn, mod in self.mod.named_modules(): - if ( - isinstance(mod, nn.Embedding) - ): + if isinstance(mod, nn.Embedding): # print("****") # print(f"Embedding identified: {fqn, mod}") # print(f"weights size: {mod.weight.size()}") @@ -576,17 +594,15 @@ def create_quantized_state_dict(self, packed=False) -> Dict: ) if packed: - if weight.shape[-1] %2 != 0: + if weight.shape[-1] % 2 != 0: raise RuntimeError("automatic padding not implemented yet") weight_range_shifted = weight.add(8).view(torch.uint8) weight_view = weight_range_shifted.view( - weight.shape[0], - weight.shape[1] //2, - 2 - ) - weight_even = weight_view[:,:,0] * 16 # left shift 4 - weight_odd = weight_view[:,:,1] + weight.shape[0], weight.shape[1] // 2, 2 + ) + weight_even = weight_view[:, :, 0] * 16 # left shift 4 + weight_odd = weight_view[:, :, 1] weight_packed = weight_even + weight_odd weight = weight_packed @@ -630,16 +646,25 @@ def __init__( self.packed = packed if not packed: self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim), dtype=torch.int8, device=device) + "weight", + torch.empty( + (vocab_size, embedding_dim), dtype=torch.int8, device=device + ), ) - else: # packed + else: # packed self.register_buffer( - "weight", torch.empty((vocab_size, embedding_dim//2), dtype=torch.uint8, device=device) + "weight", + torch.empty( + (vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device + ), ) groups_per_row = (embedding_dim + groupsize - 1) // groupsize if groups_per_row > 1: self.register_buffer( - "scales", torch.ones((vocab_size, groups_per_row), dtype=torch.float16, device=device) + "scales", + torch.ones( + (vocab_size, groups_per_row), dtype=torch.float16, device=device + ), ) else: self.register_buffer( @@ -648,17 +673,16 @@ def __init__( @torch.no_grad() def forward(self, indices: torch.Tensor) -> torch.Tensor: - if False: # Used for Executorch + if False: # Used for Executorch return torch.ops.llama_quantized.embedding_byte.dtype( self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype ) - # result_weights = self.weight.index_select(0, indices.view(-1)) # result_scales = self.scales.index_select(0, indices.view(-1)) if self.packed: - weight_even = self.weight.div(16, rounding_mode='trunc') + weight_even = self.weight.div(16, rounding_mode="trunc") weight_odd = self.weight.remainder(16) weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) weight = weight_unpacked.view(self.weight.shape[0], -1) @@ -671,8 +695,22 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: result_weights = F.embedding(indices, weight) result_scales = F.embedding(indices, scales) - rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1] + (scales.shape[1], -1, ))) - rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, )) + rw_view = result_weights.to(dtype=result_scales.dtype).view( + tuple( + result_weights.shape[:-1] + + ( + scales.shape[1], + -1, + ) + ) + ) + rs_view = result_scales.view( + tuple(result_scales.shape[:-1]) + + ( + scales.shape[1], + 1, + ) + ) # print(f"rw_view {rw_view.shape}") # print(f"rs_view {rs_view.shape}") @@ -685,17 +723,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: ######################################################################### ##### weight only int4 per channel groupwise quantized code ###### -def _int4_prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_tiles): + +def _int4_prepare_int4_weight_and_scales_and_zeros( + weight_bf16, groupsize, inner_k_tiles +): weight_int32, scales_and_zeros = group_quantize_tensor( weight_bf16, n_bit=4, groupsize=groupsize ) - weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + weight_int32, inner_k_tiles + ) return weight_int4pack, scales_and_zeros + def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1): from build.model import find_multiple + return find_multiple(k, 1024) + def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) @@ -705,31 +751,41 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou if "mps" in str(x.device): new_shape = origin_x_size[:-1] + (out_features,) return torch.zeros(new_shape, dtype=x.dtype, device=x.device) - + c = torch.ops.aten._weight_int4pack_mm( - x.to(torch.bfloat16), # TODO: should probably make a warning if x is not already bfloat16 + x.to( + torch.bfloat16 + ), # TODO: should probably make a warning if x is not already bfloat16 weight_int4pack, groupsize, - scales_and_zeros.to(torch.bfloat16), # TODO: should probably make a warning if not already bfloat16 - ).to(x.dtype) # cast back to x.dtype + scales_and_zeros.to( + torch.bfloat16 + ), # TODO: should probably make a warning if not already bfloat16 + ).to( + x.dtype + ) # cast back to x.dtype new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) return c -def _int4_check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1): +def _int4_check_linear_int4_k(k, groupsize=1, inner_k_tiles=1): return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0 + def replace_linear_int4( - module, - device, - groupsize, - inner_k_tiles, - padding_allowed, + module, + device, + groupsize, + inner_k_tiles, + padding_allowed, ): for name, child in module.named_children(): if isinstance(child, nn.Linear): - if _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: + if ( + _int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) + or padding_allowed + ): setattr( module, name, @@ -740,7 +796,8 @@ def replace_linear_int4( bias=False, groupsize=groupsize, inner_k_tiles=inner_k_tiles, - )) + ), + ) else: replace_linear_int4( child, device, groupsize, inner_k_tiles, padding_allowed @@ -748,7 +805,9 @@ def replace_linear_int4( class WeightOnlyInt4QuantHandler(QuantHandler): - def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True): + def __init__( + self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True + ): self.mod = mod self.device = device self.groupsize = groupsize @@ -769,19 +828,30 @@ def create_quantized_state_dict(self): print(f"linear: {fqn}, in={in_features}, out={out_features}") weight = mod.weight.data - if not _int4_check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles): + if not _int4_check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): if self.padding_allowed: - from build.model import find_multiple import torch.nn.functional as F - print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + from build.model import find_multiple + + print( + f"warning: {fqn} is padded to satisfy in_features % 1024 == 0" + ) padded_in_features = find_multiple(in_features, 1024) - weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + weight = F.pad( + weight, pad=(0, padded_in_features - in_features) + ) else: - print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + - "and that groupsize and inner_k_tiles*16 evenly divide into it") + print( + f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it" + ) continue - weight_int4pack, scales_and_zeros = _int4_prepare_int4_weight_and_scales_and_zeros( - weight.to(torch.float), self.groupsize, self.inner_k_tiles + weight_int4pack, scales_and_zeros = ( + _int4_prepare_int4_weight_and_scales_and_zeros( + weight.to(torch.float), self.groupsize, self.inner_k_tiles + ) ) weight_int4pack = weight_int4pack.to(device=self.device) scales_and_zeros = scales_and_zeros.to(device=self.device) @@ -790,9 +860,14 @@ def create_quantized_state_dict(self): return cur_state_dict - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding_allowed) + replace_linear_int4( + self.mod, + self.device, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) return self.mod def quantized_model(self) -> nn.Module: @@ -803,25 +878,28 @@ def quantized_model(self) -> nn.Module: class WeightOnlyInt4Linear(torch.nn.Module): - __constants__ = ['in_features', 'out_features'] + __constants__ = ["in_features", "out_features"] in_features: int out_features: int weight: torch.Tensor def __init__( - self, - device: str, - in_features: int, - out_features: int, - bias=True, - dtype=None, - groupsize: int = 128, - inner_k_tiles: int = 8, + self, + device: str, + in_features: int, + out_features: int, + bias=True, + dtype=None, + groupsize: int = 128, + inner_k_tiles: int = 8, ) -> None: super().__init__() - self.padding = not _int4_check_linear_int4_k(in_features, groupsize, inner_k_tiles) + self.padding = not _int4_check_linear_int4_k( + in_features, groupsize, inner_k_tiles + ) if self.padding: from build.model import find_multiple + self.origin_in_features = in_features in_features = find_multiple(in_features, 1024) @@ -832,14 +910,21 @@ def __init__( self.inner_k_tiles = inner_k_tiles assert out_features % 8 == 0, "require out_features % 8 == 0" - assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" + assert ( + in_features % (inner_k_tiles * 16) == 0 + ), "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", torch.empty( - (out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), + ( + out_features // 8, + in_features // (inner_k_tiles * 16), + 32, + inner_k_tiles // 2, + ), dtype=torch.int32, device=device, - ) + ), ) # MKG: torch.float self.register_buffer( @@ -848,7 +933,7 @@ def __init__( (in_features // groupsize, out_features, 2), dtype=get_precision(), device=device, - ) + ), ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -856,18 +941,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: # input = input.to(torch.float) if self.padding: import torch.nn.functional as F + input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( - input, - self.weight, - self.scales_and_zeros, - self.out_features, - self.groupsize + input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) + ######################################################################### ##### Int8 Dynamic Activations 4 Bit Weights ##### + def prepare_int4_weight_and_scales_and_zeros(weight, groupsize, precision): weight_int8, scales, zeros = group_quantize_tensor_symmetric( weight, @@ -924,6 +1008,7 @@ def find_multiple(n: int, *args: Tuple[int]) -> int: def _check_linear_int4_k(k, groupsize=1): return k % groupsize == 0 + def _calc_padded_size_linear_int4(k, groupsize=1): return find_multiple(k, groupsize) @@ -965,7 +1050,7 @@ def __init__( self, mod, device, - * , + *, groupsize=256, padding_allowed=False, precision=torch.float32, @@ -1130,6 +1215,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ######################################################################### ##### GPTQ ##### + class GPTQQuantHandler(QuantHandler): """ This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class. @@ -1192,6 +1278,7 @@ class GPTQQuantHandler(QuantHandler): names_and_values_dict: a dictionary mapping the name of the parameters of the quantized module to the corresponding quantized weights and qparams. """ + def __init__(self): assert self.mod is not None assert self.get_qparams_func is not None @@ -1201,7 +1288,14 @@ def __init__(self): assert self.make_names_and_values_dict_func is not None @staticmethod - def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) -> "MultiInput": + def get_inputs( + model, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) -> "MultiInput": input_recorder = InputRecorder( model, tokenizer, @@ -1223,9 +1317,9 @@ def get_inputs(model, tokenizer, calibration_tasks, calibration_limit, calibrati ) inputs = input_recorder.get_recorded_inputs() assert inputs is not None, ( - f"No inputs were collected, use a task other than {calibration_tasks}, "+ - f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently "+ - f"{calibration_seq_length})" + f"No inputs were collected, use a task other than {calibration_tasks}, " + + f"use option pad_calibration_inputs, or decrease calibration_sequence_length (currently " + + f"{calibration_seq_length})" ) print(f"Obtained {len(inputs[0].values)} calibration samples") return inputs @@ -1242,7 +1336,14 @@ def create_quantized_state_dict( calibration_seq_length, pad_calibration_inputs, ) -> "StateDict": - inputs = GPTQQuantHandler.get_inputs(self.mod, tokenizer, calibration_tasks, calibration_limit, calibration_seq_length, pad_calibration_inputs) + inputs = GPTQQuantHandler.get_inputs( + self.mod, + tokenizer, + calibration_tasks, + calibration_limit, + calibration_seq_length, + pad_calibration_inputs, + ) print("Tracing model for GPTQ") GPTQ_runner = GenericGPTQRunner( self.mod, @@ -1256,7 +1357,7 @@ def create_quantized_state_dict( self.dequantize_func, self.combine_qparams_list_func, self.make_names_and_values_dict_func, - self.skip_layer_func + self.skip_layer_func, ) print("Applying GPTQ to weights") @@ -1270,40 +1371,52 @@ def convert_for_runtime(self) -> "nn.Module": class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler): def __init__(self, mod, device, *, groupsize=128, inner_k_tiles=8, padding=True): from build.model import find_multiple + self.mod = mod self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles self.padding = padding self.get_qparams_func = lambda w: get_group_qparams(w, 4, groupsize) - self.quantize_func = lambda w, qparams: \ - group_quantize_tensor_from_qparams(w, qparams[0], qparams[1], 4, groupsize) - self.dequantize_func = lambda q, qparams: \ - group_dequantize_tensor_from_qparams(q, qparams[0], qparams[1], 4, groupsize).float() - self.combine_qparams_list_func = lambda qparams_list: \ - [torch.cat(x, dim=1) for x in zip(*qparams_list)] + self.quantize_func = lambda w, qparams: group_quantize_tensor_from_qparams( + w, qparams[0], qparams[1], 4, groupsize + ) + self.dequantize_func = lambda q, qparams: group_dequantize_tensor_from_qparams( + q, qparams[0], qparams[1], 4, groupsize + ).float() + self.combine_qparams_list_func = lambda qparams_list: [ + torch.cat(x, dim=1) for x in zip(*qparams_list) + ] # skip unless padding=True or its correctly sized self.skip_layer_func = lambda linear_weight: not ( - _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) or padding + _check_linear_int4_k(linear_weight.shape[-1], groupsize, inner_k_tiles) + or padding ) + # we need to do the padding here, both for q and the qparams if necessary def make_names_and_values_dict_func(q, qparams): k = q.shape[1] new_k = find_multiple(k, 1024) # how much we need to pad the weight delta_k = new_k - q.shape[1] - final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles) + final_q = torch.ops.aten._convert_weight_to_int4pack( + F.pad(q, pad=(0, delta_k)), inner_k_tiles + ) scales_and_zeros = pack_scales_and_zeros(*qparams) # how many new groups we need for padded weight delta_groups = new_k // groupsize - scales_and_zeros.shape[0] - final_s_and_z = F.pad(scales_and_zeros, pad=(0,0,0,0,0, delta_groups), value=1) + final_s_and_z = F.pad( + scales_and_zeros, pad=(0, 0, 0, 0, 0, delta_groups), value=1 + ) return {"weight": final_q, "scales_and_zeros": final_s_and_z} + self.make_names_and_values_dict_func = make_names_and_values_dict_func super().__init__() - def convert_for_runtime(self): - replace_linear_int4(self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding) + replace_linear_int4( + self.mod, self.device, self.groupsize, self.inner_k_tiles, self.padding + ) return self.mod def quantized_model(self) -> nn.Module: @@ -1313,7 +1426,6 @@ def quantized_model(self) -> nn.Module: return self.mod - # class Int8DynActInt4WeightGPTQQuantHandler(GPTQQuantHandler): # def __init__( # self, @@ -1388,6 +1500,7 @@ def quantized_model(self) -> nn.Module: ################################################################## ### WIP: HQQ ### + class WeightOnlyInt4HqqQuantHandler: def __init__(self, mod, device, *, groupsize): self.mod = mod @@ -1397,7 +1510,6 @@ def __init__(self, mod, device, *, groupsize): def create_quantized_state_dict(self): from hqq.core.quantize import Quantizer # TODO maybe torchao - for m in self.mod.modules(): for name, child in m.named_children(): if isinstance(child, torch.nn.Linear): diff --git a/quantized_ops.py b/quantized_ops.py index 7ac39b85e..e01cdf3af 100644 --- a/quantized_ops.py +++ b/quantized_ops.py @@ -10,15 +10,13 @@ import torch.nn.functional as F from torch.library import impl, impl_abstract -torchat_lib = torch.library.Library( - "torchat", "DEF" -) +torchat_lib = torch.library.Library("torchat", "DEF") torchat_lib.define( - "embedding_int8(Tensor input, Tensor weight, " - "Tensor scales) -> Tensor", + "embedding_int8(Tensor input, Tensor weight, " "Tensor scales) -> Tensor", ) + @impl(torchat_lib, "embedding_int8", "CompositeExplicitAutograd") def embedding_int8( input: torch.Tensor, @@ -27,9 +25,7 @@ def embedding_int8( ) -> torch.Tensor: indices = input # embedding_byte_weight_checks(weight, weight_scales, weight_zero_points) - groupsize = weight.size(1) // ( - scales.size(1) if scales.dim() == 2 else 1 - ) + groupsize = weight.size(1) // (scales.size(1) if scales.dim() == 2 else 1) # ET definition if False: weight_zero_points = None @@ -45,68 +41,82 @@ def embedding_int8( ) return torch.ops.aten.embedding.default(weight, indices) - scales = scales.view(weight.shape[0], -1) + scales = scales.view(weight.shape[0], -1) result_weights = F.embedding(indices, weight) result_scales = F.embedding(indices, scales) - rw_view = result_weights.to(dtype=result_scales.dtype).view(tuple(result_weights.shape[:-1]) + (scales.shape[1], -1, )) - rs_view = result_scales.view(tuple(result_scales.shape[:-1]) + (scales.shape[1], 1, )) + rw_view = result_weights.to(dtype=result_scales.dtype).view( + tuple(result_weights.shape[:-1]) + + ( + scales.shape[1], + -1, + ) + ) + rs_view = result_scales.view( + tuple(result_scales.shape[:-1]) + + ( + scales.shape[1], + 1, + ) + ) # print(f"rw_view {rw_view.shape}") # print(f"rs_view {rs_view.shape}") r = rw_view * rs_view return r.view(indices.size() + (-1,)) - - + + torchat_lib.define( "linear_int8(Tensor input, Tensor weight, Tensor scales, " "Tensor bias = None) -> Tensor", ) + @impl(torchat_lib, "linear_int8", "CompositeExplicitAutograd") def linear_int8( - input: torch.Tensor, - weight: torch.Tensor, - scales: torch.Tensor, - bias: Optional[torch.Tensor] = None, + input: torch.Tensor, + weight: torch.Tensor, + scales: torch.Tensor, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert bias is None, "bias != None not implemented" - + scales = scales.view(scales.shape[0], -1) no_groups = scales.shape[1] # for now, we special-case channel-wise, because we know how to - # make that fast with Triton + # make that fast with Triton if scales.shape[1] == 1: return F.linear(input, weight.to(dtype=input.dtype)) * scales else: return F.linear( input, - (weight.to(dtype=input.dtype).view(weight.shape[0],no_groups, -1) - * scales.view(weight.shape[0], no_groups, -1) - ).view(weight.shape[0], -1) + ( + weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1) + * scales.view(weight.shape[0], no_groups, -1) + ).view(weight.shape[0], -1), ) - torchat_lib.define( "linear_int4(Tensor input, Tensor weight, Tensor scales_and_zeros, " "Tensor bias=None, *, int groupsize, int origin_in_features, " "int int_features, int out_features, bool padding = True) -> Tensor", ) + @impl(torchat_lib, "linear_int4", "CompositeExplicitAutograd") def linear_int4( - input: torch.Tensor, - weight: torch.Tensor, - scales_and_zeros: torch.Tensor, - bias: torch.Tensor, - *, - groupsize: int, - origin_in_features: int, - in_features: int, - out_features: int, - padding: bool = True, + input: torch.Tensor, + weight: torch.Tensor, + scales_and_zeros: torch.Tensor, + bias: torch.Tensor, + *, + groupsize: int, + origin_in_features: int, + in_features: int, + out_features: int, + padding: bool = True, ) -> torch.Tensor: assert bias is None, "bias != None not implemented" @@ -116,7 +126,7 @@ def linear_int4( # the weight is in int4pack format # rename to remind ourselves of that weight_int4pack = weight - + origin_input_size = input.size() input = input.reshape(-1, origin_input_size[-1]) c = torch.ops.aten._weight_int4pack_mm( @@ -136,10 +146,9 @@ def linear_int4( "dtype precision) -> Tensor", ) + @impl(torchat_lib, "linear_a8w4dq", "CompositeExplicitAutograd") -def linear_a8w4dq( - input, weight, scales, zeros, out_features, groupsize, precision -): +def linear_a8w4dq(input, weight, scales, zeros, out_features, groupsize, precision): x = per_token_dynamic_quant(input) weight_int8 = weight # TODO: verify and remove following reshape code diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt new file mode 100644 index 000000000..45bf4d81b --- /dev/null +++ b/requirements-lintrunner.txt @@ -0,0 +1,18 @@ +# Lintrunner itself +lintrunner==0.11.0 +lintrunner-adapters==0.11.0 + +# Flake 8 and its dependencies +flake8==6.0.0 +flake8-breakpoint==1.1.0 +flake8-bugbear==23.6.5 +flake8-comprehensions==3.12.0 +flake8-pyi==23.5.0 +mccabe==0.7.0 +pycodestyle==2.10.0 +torchfix==0.1.1 + +# UFMT +black==24.2.0 +ufmt==2.5.1 +usort==1.0.5 diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 428c4a733..6ac1dc3e1 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -22,7 +22,9 @@ @torch.inference_mode() def convert_hf_checkpoint( *, - checkpoint_dir: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf"), + checkpoint_dir: Path = Path( + "checkpoints/meta-Transformer/Transformer-2-7b-chat-hf" + ), model_name: Optional[str] = None, ) -> None: if model_name is None: @@ -45,8 +47,8 @@ def convert_hf_checkpoint( "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", - 'model.layers.{}.self_attn.rotary_emb.inv_freq': None, - 'model.layers.{}.mlp.gate_proj.weight': 'layers.{}.feed_forward.w1.weight', + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", @@ -66,13 +68,15 @@ def permute(w, n_heads): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + state_dict = torch.load( + str(file), map_location="cpu", mmap=True, weights_only=True + ) merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): if "layers" in key: - abstract_key = re.sub(r'(\d+)', '{}', key) - layer_num = re.search(r'\d+', key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key) + layer_num = re.search(r"\d+", key).group(0) new_key = weight_map[abstract_key] if new_key is None: continue @@ -96,11 +100,17 @@ def permute(w, n_heads): print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Convert HuggingFace checkpoint.') - parser.add_argument('--checkpoint-dir', type=Path, default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf")) - parser.add_argument('--model-name', type=str, default=None) + + parser = argparse.ArgumentParser(description="Convert HuggingFace checkpoint.") + parser.add_argument( + "--checkpoint-dir", + type=Path, + default=Path("checkpoints/meta-llama/llama-2-7b-chat-hf"), + ) + parser.add_argument("--model-name", type=str, default=None) args = parser.parse_args() convert_hf_checkpoint( diff --git a/scripts/download.py b/scripts/download.py index 849095ddf..387c41243 100644 --- a/scripts/download.py +++ b/scripts/download.py @@ -11,6 +11,7 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) -> None: from huggingface_hub import snapshot_download + os.makedirs(f"checkpoints/{repo_id}", exist_ok=True) try: snapshot_download( @@ -18,18 +19,30 @@ def hf_download(repo_id: Optional[str] = None, hf_token: Optional[str] = None) - local_dir=f"checkpoints/{repo_id}", local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*") + ignore_patterns="*safetensors*", + ) except HTTPError as e: if e.response.status_code == 401: - print("You need to pass a valid `--hf_token=...` to download private checkpoints.") + print( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e -if __name__ == '__main__': + +if __name__ == "__main__": import argparse - parser = argparse.ArgumentParser(description='Download data from HuggingFace Hub.') - parser.add_argument('--repo-id', type=str, default="checkpoints/meta-llama/llama-2-7b-chat-hf", help='Repository ID to download from.') - parser.add_argument('--hf-token', type=str, default=None, help='HuggingFace API token.') + + parser = argparse.ArgumentParser(description="Download data from HuggingFace Hub.") + parser.add_argument( + "--repo-id", + type=str, + default="checkpoints/meta-llama/llama-2-7b-chat-hf", + help="Repository ID to download from.", + ) + parser.add_argument( + "--hf-token", type=str, default=None, help="HuggingFace API token." + ) args = parser.parse_args() hf_download(args.repo_id, args.hf_token) diff --git a/torchat.py b/torchat.py index 4b720b8dd..d01b12f88 100644 --- a/torchat.py +++ b/torchat.py @@ -4,24 +4,25 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import time import os +import time from pathlib import Path import torch import torch.nn as nn -from torch.export import Dim, export +from cli import check_args, cli_args +from eval import main as eval_main from export import main as export_main from generate import main as generate_main -from eval import main as eval_main -from cli import cli_args, check_args +from torch.export import Dim, export default_device = "cpu" # 'cuda' if torch.cuda.is_available() else 'cpu' + def cli(): args = cli_args() - + if args.generate or args.chat: check_args(args, "generate") generate_main(args) @@ -32,6 +33,7 @@ def cli(): export_main(args) else: raise RuntimeError("must specify either --generate or --export") - + + if __name__ == "__main__": cli() diff --git a/utils/tokenizer.py b/utils/tokenizer.py index f3c0cc324..b20dccf1d 100644 --- a/utils/tokenizer.py +++ b/utils/tokenizer.py @@ -2,14 +2,15 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import argparse import os import struct -import argparse from typing import List from sentencepiece import SentencePieceProcessor -TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model +TOKENIZER_MODEL = "tokenizer.model" # the llama sentencepiece tokenizer model + class Tokenizer: def __init__(self, tokenizer_model=None): @@ -23,7 +24,7 @@ def __init__(self, tokenizer_model=None): self.bos_id: int = self.sp_model.bos_id() self.eos_id: int = self.sp_model.eos_id() self.pad_id: int = self.sp_model.pad_id() - #print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") + # print(f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}") assert self.sp_model.vocab_size() == self.sp_model.get_piece_size() def encode(self, s: str, bos: bool, eos: bool) -> List[int]: @@ -48,11 +49,11 @@ def export(self): t = self.sp_model.id_to_piece(i) s = self.sp_model.get_score(i) if i == self.bos_id: - t = '\n\n' + t = "\n\n" elif i == self.eos_id: - t = '\n\n' - t = t.replace('▁', ' ') # sentencepiece uses this character as whitespace - b = t.encode('utf-8') # bytes of this token, utf-8 encoded + t = "\n\n" + t = t.replace("▁", " ") # sentencepiece uses this character as whitespace + b = t.encode("utf-8") # bytes of this token, utf-8 encoded tokens.append(b) scores.append(s) @@ -62,16 +63,19 @@ def export(self): # write to a binary file # the tokenizer.bin file is the same as .model file, but .bin - tokenizer_bin = self.model_path.replace('.model', '.bin') - with open(tokenizer_bin, 'wb') as f: + tokenizer_bin = self.model_path.replace(".model", ".bin") + with open(tokenizer_bin, "wb") as f: f.write(struct.pack("I", max_token_length)) for bytes, score in zip(tokens, scores): f.write(struct.pack("fI", score, len(bytes))) f.write(bytes) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ") + parser.add_argument( + "-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer " + ) args = parser.parse_args() t = Tokenizer(args.tokenizer_model)