Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement download subcommand, optional positional model name argument #234

Merged
merged 14 commits into from
Apr 19, 2024

Conversation

GregoryComer
Copy link
Member

@GregoryComer GregoryComer commented Apr 17, 2024

Implementing download subcommand to download and convert model from HuggingFace. Add an optional positional argument to other torchchat subcommands to use a downloaded model. The model name can be either a known HF path, such as meta-llama/Llama-2-7b-chat-hf, or an alias, such as llama2. Per-model configuration, including the download channel and model aliases, are under config/models.json.

Example usage:

python torchchat.py generate llama2

# Can also explicitly download via
python torchchat.py download llama2

As a follow up, I intend to refactor the CLI model positional arg handling. It might also be nice to intelligently handle multiple file types with the positional arg, such as a gguf.

Test Plan:

python torchchat generate llama2

rm -rf .model-artifacts
python torchchat download llama2
python torchchat generate llama2

python torchchat generate meta-llama/Llama-2-7b-chat-hf
python torchchat.py generate llama2 --dtype fp16 --device cuda

python torchchat.py generate --checkpoint-path=.model-artifacts/meta-llama/Llama-2-7b-chat-hf/model.pth --dtype fp16 --device cuda

python torchchat.py generate stories15M
python torchchat.py generate stories110M
python torchchat.py generate mistral-7b-instruct

CI for model options are covered here:
--gguf-path:

- name: Run GGUF export + inference

--dso-path:
python3 -W ignore generate.py --dtype ${DTYPE} --checkpoint-path "$CHECKPOINT_PATH" --dso-path "$MODEL_DIR/${MODEL_NAME}.so" --prompt "$PROMPT" --device "$TARGET_DEVICE" > "$MODEL_DIR/output_aoti" || exit 1

--pte-path:
python3 -W ignore export.py --checkpoint-path "$CHECKPOINT_PATH" --output-pte-path "$MODEL_DIR/${MODEL_NAME}.pte" -d "fp32" || exit 1

Since there are many ways to load a model, I'm relying on CI to exercise many of the paths.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 17, 2024
@GregoryComer GregoryComer force-pushed the streamline-download branch 3 times, most recently from 9d595bd to 7a60392 Compare April 17, 2024 10:31
@GregoryComer GregoryComer changed the title Implement download subcommand (WIP) Implement download subcommand, positional model name argument (WIP) Apr 17, 2024
@mergennachin
Copy link
Contributor

Make sure to run linter

Setup:

pip install -r requirements-lintrunner.txt
lintrunnner init

lintrunner -a --all-files

@GregoryComer GregoryComer changed the title Implement download subcommand, positional model name argument (WIP) Implement download subcommand, positional model name argument Apr 17, 2024
@GregoryComer GregoryComer marked this pull request as ready for review April 17, 2024 19:46
@GregoryComer GregoryComer changed the title Implement download subcommand, positional model name argument Implement download subcommand, optional positional model name argument Apr 17, 2024
@mergennachin
Copy link
Contributor

Can you add a CI test to exercise the download path?

@GregoryComer GregoryComer force-pushed the streamline-download branch 5 times, most recently from 4f56a24 to 29516d5 Compare April 17, 2024 22:09
@GregoryComer
Copy link
Member Author

GregoryComer commented Apr 17, 2024

Can you add a CI test to exercise the download path?

I'm going to actually defer this because converting some of the larger models takes over an hour. We do need CI coverage, but I might need to experiment with runner size and the choice of model, and I want to land this to unblock others.

Tracking via T186104081.

@GregoryComer GregoryComer force-pushed the streamline-download branch 3 times, most recently from 88859fd to 3f6eb29 Compare April 17, 2024 23:37
@mikekgfb
Copy link
Contributor

Can you add a CI test to exercise the download path?

I'm going to actually defer this because converting some of the larger models takes over an hour. We do need CI coverage, but I might need to experiment with runner size and the choice of model, and I want to land this to unblock others.

Tracking via T186104081.

Sounds like you should download gguf file that's heavily quantized, and/or stories15M!

@GregoryComer
Copy link
Member Author

GregoryComer commented Apr 18, 2024

Can you add a CI test to exercise the download path?

I'm going to actually defer this because converting some of the larger models takes over an hour. We do need CI coverage, but I might need to experiment with runner size and the choice of model, and I want to land this to unblock others.
Tracking via T186104081.

Sounds like you should download gguf file that's heavily quantized, and/or stories15M!

GGUF has it's own conversion logic. Stories is also a little bit special because it has a unique format and I'll have to add special logic to handle the download. That being said, it would be nice to have, so I'll probably do that.

I want to look more into why it takes upwards of an hour to convert a 7B model on the runner, though. Seems like something is wrong. It shouldn't take that long to shuffle around the weights.

Edit:
I've done a bit more refactoring and added support for the stories models via positional argument. There is a new ModelConfig class and dict that encapsulated the differences in download and conversion.

@GregoryComer GregoryComer force-pushed the streamline-download branch 6 times, most recently from 8fbc926 to da23171 Compare April 18, 2024 10:28
Copy link
Contributor

@byjlw byjlw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve Michael's issue and merge in the changes to support Llama3 and merge.

cli.py Outdated
"--checkpoint-dir",
type=Path,
default=None,
help="Model checkpoint directory.",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean by developer-only option

.github/workflows/pull.yml Show resolved Hide resolved
cli.py Outdated
Comment on lines 101 to 106
parser.add_argument(
"--gguf-path",
type=Path,
default=None,
help="GGUF file path.",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that currently, specifying this with DSO or pte is only a warning; IMO we should hard error because it's easily fixed and a great way to waste a lot of time

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, though we should probably take this as a follow up.

Comment on lines +30 to +31
if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the default something that's not even in models.json?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mikekgfb Do we need this default value anymore?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but we should put this or a similarly situated chat model into the models.json.

BTW, I really think it's bad to have even the model name default to something (unless we're so excited about llama3 that we make it that.... but that will require users to have obtained a token)

Comment on lines +80 to +81
if model in model_aliases:
model = model_aliases[model]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: model = model_aliases.get(model, model) is shorter FWIW

Comment on lines +63 to +64
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to use progressbar or tqdm to show a progress bar since these downloads can be big; can leave for follow-up

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about this for checkpoint conversion, as well. My only concern was an additional dependency, but if that's not a worry, I can go ahead and add it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think makes sense. Let's make sure we lazily import it, maybe I don't want to wait if I am not downloading/converting?

@GregoryComer GregoryComer merged commit f08eb05 into pytorch:main Apr 19, 2024
19 checks passed
Copy link
Contributor

@mikekgfb mikekgfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please review and address comments (either why we should do something else, or do as suggested, either works... but we should document why we choose what we choose)

@@ -134,9 +144,12 @@ def from_args(cls, args): # -> TokenizerArgs:

if args.tokenizer_path:
tokenizer_path = args.tokenizer_path
elif args.model: # Using a named, well-known model
model_config = resolve_model_config(args.model)
tokenizer_path = Path(args.model_directory) / model_config.name / "tokenizer.model"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well known doesn't mean it's local. how do you know where the tokenizer is?

Comment on lines +30 to +31
if model_dir is None:
model_dir = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but we should put this or a similarly situated chat model into the models.json.

BTW, I really think it's bad to have even the model name default to something (unless we're so excited about llama3 that we make it that.... but that will require users to have obtained a token)

Comment on lines +63 to +64
print(f"Downloading {url}...")
urllib.request.urlretrieve(url, str(local_path.absolute()))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think makes sense. Let's make sure we lazily import it, maybe I don't want to wait if I am not downloading/converting?

@@ -546,8 +550,6 @@ def callback(x):


def main(args):
is_chat = args.subcommand == "chat"

# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should intercept this in a central place, like in cli() because all functions basically need to do the same? So we dupe it in a gazillion places?

@mikekgfb mikekgfb deleted the streamline-download branch April 19, 2024 21:34
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
malfet pushed a commit that referenced this pull request Jul 17, 2024
#234)

* Implement download option

* Add support for model aliases

* Support model name as a positional parameter

* Merge GenerateArgs changes

* Run lint

* Revert chat subcommand/arg changes

* Add mistral-7b-instruct alias, fix lints

* Add model config for known models

* Move known model config to config/models.json

* Make model names case-insensitive

* Move known model configuration from build/model.py to config/model_config.py

* Fix lints

* Fixing issues after rebasing

* Update README
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants