Skip to content

Commit

Permalink
Run lint
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 17, 2024
1 parent c1660a3 commit 42e69e7
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 41 deletions.
20 changes: 12 additions & 8 deletions build/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
import torch
import torch._dynamo.config
import torch._inductor.config

from build.model import model_aliases
from quantize import name_to_dtype, quantize_model

from sentencepiece import SentencePieceProcessor

from build.model import Transformer
from build.model import model_aliases, Transformer


@dataclass
Expand Down Expand Up @@ -69,9 +67,11 @@ def __post_init__(self):
@classmethod
def from_args(cls, args): # -> BuilderArgs:
model = resolve_model_name(args.model) if args.model else None
checkpoint_path = Path(args.model_directory) / model / "model.pth" \
if model and not args.checkpoint_path \
checkpoint_path = (
Path(args.model_directory) / model / "model.pth"
if model and not args.checkpoint_path
else args.checkpoint_path
)

is_chat_model = False
if args.is_chat_model:
Expand Down Expand Up @@ -130,9 +130,11 @@ def from_args(cls, args): # -> TokenizerArgs:
is_TikToken = False

model = resolve_model_name(args.model) if args.model else None
checkpoint_dir = Path(args.model_directory) / model \
if not args.checkpoint_dir and args.model \
checkpoint_dir = (
Path(args.model_directory) / model
if not args.checkpoint_dir and args.model
else args.checkpoint_dir
)

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
Expand Down Expand Up @@ -254,6 +256,7 @@ def _load_model(builder_args):

if builder_args.use_tp:
from tp import apply_tp

print("Applying tensor parallel to model ...")
apply_tp(model)

Expand Down Expand Up @@ -320,9 +323,10 @@ def _initialize_model(

return model


def resolve_model_name(model: str) -> str:
# If the provided model name is an alias, retrieve the full path.
if model in model_aliases:
return model_aliases[model]
else:
return model
return model
2 changes: 1 addition & 1 deletion build/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def convert_hf_checkpoint(
*,
model_dir: Optional[Path] = None,
model_name: Optional[str] = None,
remove_bin_files: bool = False
remove_bin_files: bool = False,
) -> None:
if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
Expand Down
10 changes: 5 additions & 5 deletions build/gguf_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
sys.path.append(str(wd))

from gguf import GGUFValueType, ReaderTensor
from model import ModelArgs, Transformer
from quantize import (
group_dequantize_tensor_from_qparams,
pack_scales_and_zeros,
WeightOnlyInt4Linear,
)

from build.gguf_util import F16, F32, Q4_0, Q6_K, to_float
from model import ModelArgs, Transformer

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -116,9 +116,7 @@ def load_model(gguf_file: str) -> torch.nn.Module:
metadata = _get_metadata(reader)

arch = metadata["general.architecture"]
assert (
arch == "llama"
), "Only LLaMa models are supported by this converter."
assert arch == "llama", "Only LLaMa models are supported by this converter."

model_args = ModelArgs(
dim=metadata[f"{arch}.embedding_length"],
Expand All @@ -139,7 +137,9 @@ def load_model(gguf_file: str) -> torch.nn.Module:
return model


def load_model_and_state_dict(gguf_file: str, load_as_quantized: bool, *, inner_k_tiles = 8) -> torch.nn.Module:
def load_model_and_state_dict(
gguf_file: str, load_as_quantized: bool, *, inner_k_tiles=8
) -> torch.nn.Module:
"""
Parses the GGUF file and returns an nn.Module on meta device along with a state_dict
that can be loaded into it.
Expand Down
10 changes: 6 additions & 4 deletions build/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ def from_name(cls, name: str):

return cls(**transformer_configs[config[0]])


# Aliases for well-known models. Maps a short name to a HuggingFace path. These
# can be used from the CLI in-place of the full model path.
model_aliases = {
"llama2": "meta-llama/Llama-2-7b-chat-hf"
}
model_aliases = {"llama2": "meta-llama/Llama-2-7b-chat-hf"}

transformer_configs = {
"CodeLlama-7b-Python-hf": {
Expand Down Expand Up @@ -253,7 +252,10 @@ def from_params(cls, params_path: str):
@classmethod
def from_gguf(cls, gguf_path: str):
from build.gguf_loader import load_model_and_state_dict
model, state_dict = load_model_and_state_dict(gguf_path, load_as_quantized=True, inner_k_tiles=8)

model, state_dict = load_model_and_state_dict(
gguf_path, load_as_quantized=True, inner_k_tiles=8
)
model.load_state_dict(state_dict, assign=True)
return model

Expand Down
20 changes: 11 additions & 9 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def check_args(args, command_name: str):
disallowed_args = []
elif command_name == "download":
if not args.model:
raise RuntimeError(f"Download requires a valid model name or HuggingFace model path.")
raise RuntimeError(
f"Download requires a valid model name or HuggingFace model path."
)

# TBD
disallowed_args = []
Expand Down Expand Up @@ -77,7 +79,7 @@ def _add_arguments_common(parser):
type=str,
nargs="?",
default=None,
help="Model name or HuggingFace model path."
help="Model name or HuggingFace model path.",
)
parser.add_argument(
"--checkpoint-path",
Expand Down Expand Up @@ -207,21 +209,21 @@ def _add_arguments_common(parser):
"--max-seq-length",
type=int,
default=None,
help='maximum length sequence to evaluate'
help="maximum length sequence to evaluate",
)
parser.add_argument(
'--hf-token',
"--hf-token",
type=str,
default=None,
help='A HuggingFace API token to use when downloading model artifacts'
help="A HuggingFace API token to use when downloading model artifacts",
)
parser.add_argument(
'--model-directory',
"--model-directory",
type=Path,
default='.model-artifacts',
help='The directory to store downloaded model artifacts'
default=".model-artifacts",
help="The directory to store downloaded model artifacts",
)


def arg_init(args):

Expand Down
27 changes: 14 additions & 13 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,18 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from pathlib import Path
from typing import Optional

from build.convert_hf_checkpoint import convert_hf_checkpoint
from build.model import model_aliases
from pathlib import Path
from typing import Optional

from requests.exceptions import HTTPError


def download_and_convert(
model: str,
models_dir: Path,
hf_token: Optional[str] = None) -> None:
model: str, models_dir: Path, hf_token: Optional[str] = None
) -> None:
from huggingface_hub import snapshot_download

if model in model_aliases:
Expand All @@ -32,23 +32,24 @@ def download_and_convert(
local_dir=model_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*")
ignore_patterns="*safetensors*",
)
except HTTPError as e:
if e.response.status_code == 401:
raise RuntimeError("You need to pass a valid `--hf_token=...` to download private checkpoints.")
raise RuntimeError(
"You need to pass a valid `--hf_token=...` to download private checkpoints."
)
else:
raise e

# Convert the model to the torchchat format.
print(f"Converting {model} to torchchat format...")
convert_hf_checkpoint(
model_dir=model_dir,
model_name=Path(model),
remove_bin_files=True)
model_dir=model_dir, model_name=Path(model), remove_bin_files=True
)


def is_model_downloaded(
model: str,
models_dir: Path) -> bool:
def is_model_downloaded(model: str, models_dir: Path) -> bool:
if model in model_aliases:
model = model_aliases[model]

Expand Down
4 changes: 3 additions & 1 deletion torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,6 @@

export_main(args)
else:
raise RuntimeError("Must specify a valid subcommand: download, chat, generate, export, or eval.")
raise RuntimeError(
"Must specify a valid subcommand: download, chat, generate, export, or eval."
)

0 comments on commit 42e69e7

Please sign in to comment.