From 42e69e70081b4f4f0d33428f826e46dce108e62a Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Wed, 17 Apr 2024 12:36:04 -0700 Subject: [PATCH] Run lint --- build/builder.py | 20 ++++++++++++-------- build/convert_hf_checkpoint.py | 2 +- build/gguf_loader.py | 10 +++++----- build/model.py | 10 ++++++---- cli.py | 20 +++++++++++--------- download.py | 27 ++++++++++++++------------- torchchat.py | 4 +++- 7 files changed, 52 insertions(+), 41 deletions(-) diff --git a/build/builder.py b/build/builder.py index 2ddcf6b65..5201beea1 100644 --- a/build/builder.py +++ b/build/builder.py @@ -14,13 +14,11 @@ import torch import torch._dynamo.config import torch._inductor.config - -from build.model import model_aliases from quantize import name_to_dtype, quantize_model from sentencepiece import SentencePieceProcessor -from build.model import Transformer +from build.model import model_aliases, Transformer @dataclass @@ -69,9 +67,11 @@ def __post_init__(self): @classmethod def from_args(cls, args): # -> BuilderArgs: model = resolve_model_name(args.model) if args.model else None - checkpoint_path = Path(args.model_directory) / model / "model.pth" \ - if model and not args.checkpoint_path \ + checkpoint_path = ( + Path(args.model_directory) / model / "model.pth" + if model and not args.checkpoint_path else args.checkpoint_path + ) is_chat_model = False if args.is_chat_model: @@ -130,9 +130,11 @@ def from_args(cls, args): # -> TokenizerArgs: is_TikToken = False model = resolve_model_name(args.model) if args.model else None - checkpoint_dir = Path(args.model_directory) / model \ - if not args.checkpoint_dir and args.model \ + checkpoint_dir = ( + Path(args.model_directory) / model + if not args.checkpoint_dir and args.model else args.checkpoint_dir + ) if args.tokenizer_path: tokenizer_path = args.tokenizer_path @@ -254,6 +256,7 @@ def _load_model(builder_args): if builder_args.use_tp: from tp import apply_tp + print("Applying tensor parallel to model ...") apply_tp(model) @@ -320,9 +323,10 @@ def _initialize_model( return model + def resolve_model_name(model: str) -> str: # If the provided model name is an alias, retrieve the full path. if model in model_aliases: return model_aliases[model] else: - return model \ No newline at end of file + return model diff --git a/build/convert_hf_checkpoint.py b/build/convert_hf_checkpoint.py index cf4928feb..cc2cc14f5 100644 --- a/build/convert_hf_checkpoint.py +++ b/build/convert_hf_checkpoint.py @@ -25,7 +25,7 @@ def convert_hf_checkpoint( *, model_dir: Optional[Path] = None, model_name: Optional[str] = None, - remove_bin_files: bool = False + remove_bin_files: bool = False, ) -> None: if model_dir is None: model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf") diff --git a/build/gguf_loader.py b/build/gguf_loader.py index f98e326da..93fdf66bd 100644 --- a/build/gguf_loader.py +++ b/build/gguf_loader.py @@ -21,6 +21,7 @@ sys.path.append(str(wd)) from gguf import GGUFValueType, ReaderTensor +from model import ModelArgs, Transformer from quantize import ( group_dequantize_tensor_from_qparams, pack_scales_and_zeros, @@ -28,7 +29,6 @@ ) from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float -from model import ModelArgs, Transformer logger: logging.Logger = logging.getLogger(__name__) @@ -116,9 +116,7 @@ def load_model(gguf_file: str) -> torch.nn.Module: metadata = _get_metadata(reader) arch = metadata["general.architecture"] - assert ( - arch == "llama" - ), "Only LLaMa models are supported by this converter." + assert arch == "llama", "Only LLaMa models are supported by this converter." model_args = ModelArgs( dim=metadata[f"{arch}.embedding_length"], @@ -139,7 +137,9 @@ def load_model(gguf_file: str) -> torch.nn.Module: return model -def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module: +def load_model_and_state_dict( + gguf_file: str, load_as_quantized: bool, *, inner_k_tiles=8 +) -> torch.nn.Module: """ Parses the GGUF file and returns an nn.Module on meta device along with a state_dict that can be loaded into it. diff --git a/build/model.py b/build/model.py index bff132ca0..d233ad150 100644 --- a/build/model.py +++ b/build/model.py @@ -93,11 +93,10 @@ def from_name(cls, name: str): return cls(**transformer_configs[config[0]]) + # Aliases for well-known models. Maps a short name to a HuggingFace path. These # can be used from the CLI in-place of the full model path. -model_aliases = { - "llama2": "meta-llama/Llama-2-7b-chat-hf" -} +model_aliases = {"llama2": "meta-llama/Llama-2-7b-chat-hf"} transformer_configs = { "CodeLlama-7b-Python-hf": { @@ -253,7 +252,10 @@ def from_params(cls, params_path: str): @classmethod def from_gguf(cls, gguf_path: str): from build.gguf_loader import load_model_and_state_dict - model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8) + + model, state_dict = load_model_and_state_dict( + gguf_path, load_as_quantized=True, inner_k_tiles=8 + ) model.load_state_dict(state_dict, assign=True) return model diff --git a/cli.py b/cli.py index c5060e5db..b1e580dda 100644 --- a/cli.py +++ b/cli.py @@ -30,7 +30,9 @@ def check_args(args, command_name: str): disallowed_args = [] elif command_name == "download": if not args.model: - raise RuntimeError(f"Download requires a valid model name or HuggingFace model path.") + raise RuntimeError( + f"Download requires a valid model name or HuggingFace model path." + ) # TBD disallowed_args = [] @@ -77,7 +79,7 @@ def _add_arguments_common(parser): type=str, nargs="?", default=None, - help="Model name or HuggingFace model path." + help="Model name or HuggingFace model path.", ) parser.add_argument( "--checkpoint-path", @@ -207,21 +209,21 @@ def _add_arguments_common(parser): "--max-seq-length", type=int, default=None, - help='maximum length sequence to evaluate' + help="maximum length sequence to evaluate", ) parser.add_argument( - '--hf-token', + "--hf-token", type=str, default=None, - help='A HuggingFace API token to use when downloading model artifacts' + help="A HuggingFace API token to use when downloading model artifacts", ) parser.add_argument( - '--model-directory', + "--model-directory", type=Path, - default='.model-artifacts', - help='The directory to store downloaded model artifacts' + default=".model-artifacts", + help="The directory to store downloaded model artifacts", ) - + def arg_init(args): diff --git a/download.py b/download.py index dd243434a..49749c942 100644 --- a/download.py +++ b/download.py @@ -4,18 +4,18 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import os +from pathlib import Path +from typing import Optional from build.convert_hf_checkpoint import convert_hf_checkpoint from build.model import model_aliases -from pathlib import Path -from typing import Optional from requests.exceptions import HTTPError + def download_and_convert( - model: str, - models_dir: Path, - hf_token: Optional[str] = None) -> None: + model: str, models_dir: Path, hf_token: Optional[str] = None +) -> None: from huggingface_hub import snapshot_download if model in model_aliases: @@ -32,23 +32,24 @@ def download_and_convert( local_dir=model_dir, local_dir_use_symlinks=False, token=hf_token, - ignore_patterns="*safetensors*") + ignore_patterns="*safetensors*", + ) except HTTPError as e: if e.response.status_code == 401: - raise RuntimeError("You need to pass a valid `--hf_token=...` to download private checkpoints.") + raise RuntimeError( + "You need to pass a valid `--hf_token=...` to download private checkpoints." + ) else: raise e # Convert the model to the torchchat format. print(f"Converting {model} to torchchat format...") convert_hf_checkpoint( - model_dir=model_dir, - model_name=Path(model), - remove_bin_files=True) + model_dir=model_dir, model_name=Path(model), remove_bin_files=True + ) + -def is_model_downloaded( - model: str, - models_dir: Path) -> bool: +def is_model_downloaded(model: str, models_dir: Path) -> bool: if model in model_aliases: model = model_aliases[model] diff --git a/torchchat.py b/torchchat.py index c9879773b..0f9b4d022 100644 --- a/torchchat.py +++ b/torchchat.py @@ -63,4 +63,6 @@ export_main(args) else: - raise RuntimeError("Must specify a valid subcommand: download, chat, generate, export, or eval.") + raise RuntimeError( + "Must specify a valid subcommand: download, chat, generate, export, or eval." + )