Skip to content

Commit

Permalink
Merge GenerateArgs changes
Browse files Browse the repository at this point in the history
  • Loading branch information
GregoryComer committed Apr 17, 2024
1 parent c7ecf7c commit 7a60392
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class GeneratorArgs:
def from_args(cls, args): # -> GeneratorArgs:
return cls(
prompt=args.prompt,
chat_mode=args.subcommand == "chat",
chat_mode=hasattr(args, "subcommand") and args.subcommand == "chat",
gui_mode=args.gui,
num_samples=args.num_samples,
max_new_tokens=args.max_new_tokens,
Expand Down Expand Up @@ -404,12 +404,11 @@ def _main(
if i >= 0 and generator_args.chat_mode:
prompt = input("What is your prompt? ")

# DEBUG DO NOT COMMIT
B_INST = ""
E_INST = ""
if generator_args.chat_mode:
# TODO Where should B_INST, E_INST come from? Model args?
#prompt = f"{B_INST} {prompt.strip()} {E_INST}"
prompt = prompt.strip()

if chat_mode:
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
encoded = encode_tokens(
tokenizer, prompt, bos=True, device=builder_args.device
)
Expand Down Expand Up @@ -492,8 +491,6 @@ def callback(x):


def main(args):
is_chat = args.subcommand == "chat"

# If a named model was provided and not downloaded, download it.
if args.model and not is_model_downloaded(args.model, args.model_directory):
download_and_convert(args.model, args.model_directory, args.hf_token)
Expand Down

0 comments on commit 7a60392

Please sign in to comment.