Skip to content

Commit

Permalink
Support model name as a positional parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 17, 2024
1 parent 422e23d commit c7ecf7c
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _add_arguments_common(parser):
)
parser.add_argument(
'--model-directory',
type=str,
type=Path,
default='.model-artifacts',
help='The directory to store downloaded model artifacts'
)
Expand Down
49 changes: 30 additions & 19 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,46 +6,57 @@
import os

from build.convert_hf_checkpoint import convert_hf_checkpoint
from build.model import model_aliases
from pathlib import Path
from typing import Optional

from requests.exceptions import HTTPError


def hf_download(
repo_id: Optional[str] = None,
model_dir: Optional[Path] = None,
def download_and_convert(
model: str,
models_dir: Path,
hf_token: Optional[str] = None) -> None:
from huggingface_hub import snapshot_download

if model_dir is None:
model_dir = Path(".model-artifacts/{repo_id}")
if model in model_aliases:
model = model_aliases[model]

model_dir = models_dir / model
os.makedirs(model_dir, exist_ok=True)

# Download and store the HF model artifacts.
print(f"Downloading {model} from HuggingFace...")
try:
snapshot_download(
repo_id,
model,
local_dir=model_dir,
local_dir_use_symlinks=False,
token=hf_token,
ignore_patterns="*safetensors*")
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
raise RuntimeError("You need to pass a valid `--hf_token=...` to download private checkpoints.")
else:
raise e


def main(args):
model_dir = Path(args.model_directory) / args.model
os.makedirs(model_dir, exist_ok=True)

# Download and store the HF model artifacts.
print(f"Downloading {args.model} from HuggingFace...")
hf_download(args.model, model_dir, args.hf_token)

# Convert the model to the torchchat format.
print(f"Converting {args.model} to torchchat format...")
print(f"Converting {model} to torchchat format...")
convert_hf_checkpoint(
model_dir=model_dir,
model_name=Path(args.model),
model_name=Path(model),
remove_bin_files=True)

def is_model_downloaded(
model: str,
models_dir: Path) -> bool:
if model in model_aliases:
model = model_aliases[model]

model_dir = models_dir / model

# TODO Can we be more thorough here?
return os.path.isdir(model_dir)


def main(args):
download_and_convert(args.model, args.model_directory, args.hf_token)
5 changes: 5 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from build.model import Transformer
from cli import add_arguments_for_eval, arg_init
from download import download_and_convert, is_model_downloaded
from generate import encode_tokens, model_forward

from quantize import set_precision
Expand Down Expand Up @@ -221,6 +222,10 @@ def main(args) -> None:
"""

# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
download_and_convert(args.model, args.model_directory, args.hf_token)

builder_args = BuilderArgs.from_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
quantize = args.quantize
Expand Down
5 changes: 5 additions & 0 deletions export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from build.builder import _initialize_model, BuilderArgs
from cli import add_arguments_for_export, arg_init, check_args
from download import download_and_convert, is_model_downloaded
from export_aoti import export_model as export_model_aoti

from quantize import set_precision
Expand All @@ -36,6 +37,10 @@ def device_sync(device):


def main(args):
# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
download_and_convert(args.model, args.model_directory, args.hf_token)

builder_args = BuilderArgs.from_args(args)
quantize = args.quantize

Expand Down
11 changes: 11 additions & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from build.model import Transformer
from cli import add_arguments_for_generate, arg_init, check_args
from download import download_and_convert, is_model_downloaded
from quantize import set_precision


Expand Down Expand Up @@ -402,6 +403,11 @@ def _main(
device_sync(device=builder_args.device)
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? ")

# DEBUG DO NOT COMMIT
B_INST = ""
E_INST = ""

if chat_mode:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
Expand Down Expand Up @@ -487,6 +493,11 @@ def callback(x):

def main(args):
is_chat = args.subcommand == "chat"

# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
download_and_convert(args.model, args.model_directory, args.hf_token)

builder_args = BuilderArgs.from_args(args)
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
Expand Down
2 changes: 1 addition & 1 deletion torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@

export_main(args)
else:
raise RuntimeError("Must specify valid subcommands: generate, export, eval")
raise RuntimeError("Must specify a valid subcommand: download, chat, generate, export, or eval.")

0 comments on commit c7ecf7c

Please sign in to comment.