Skip to content

Commit

Permalink
Revert chat subcommand/arg changes
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 17, 2024
1 parent 42e69e7 commit 81b74a5
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 12 deletions.
8 changes: 6 additions & 2 deletions cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
def check_args(args, command_name: str):
global strict

# chat and generate support the same options
if command_name in ["chat", "generate", "gui"]:
if command_name in ["generate", "gui"]:
# examples, can add more. Note that attributes convert dash to _
disallowed_args = ["output_pte_path", "output_dso_path"]
elif command_name == "export":
Expand Down Expand Up @@ -223,6 +222,11 @@ def _add_arguments_common(parser):
default=".model-artifacts",
help="The directory to store downloaded model artifacts",
)
parser.add_argument(
"--chat",
action="store_true",
help="Use torchchat to for an interactive chat session.",
)


def arg_init(args):
Expand Down
4 changes: 0 additions & 4 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,8 @@ class GeneratorArgs:
def from_args(cls, args): # -> GeneratorArgs:
return cls(
prompt=args.prompt,
<<<<<<< HEAD
encoded_prompt=None,
chat_mode=args.chat,
=======
chat_mode=hasattr(args, "subcommand") and args.subcommand == "chat",
>>>>>>> 0dc48b0 (Merge GenerateArgs changes)
gui_mode=args.gui,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
Expand Down
9 changes: 3 additions & 6 deletions torchchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@
parser = argparse.ArgumentParser(description="Top-level command")
subparsers = parser.add_subparsers(
dest="subcommand",
help="Use `chat`, `download`, `generate`, `eval` or `export` followed by subcommand specific options.",
help="Use `download`, `generate`, `eval` or `export` followed by subcommand specific options.",
)

parser_chat = subparsers.add_parser("chat")
add_arguments_for_generate(parser_chat)

parser_download = subparsers.add_parser("download")
add_arguments_for_download(parser_download)

Expand All @@ -48,8 +45,8 @@
from download import main as download_main

download_main(args)
elif args.subcommand == "generate" or args.subcommand == "chat":
check_args(args, args.subcommand)
elif args.subcommand == "generate":
check_args(args, "generate")
from generate import main as generate_main

generate_main(args)
Expand Down

0 comments on commit 81b74a5

Please sign in to comment.