diff --git a/generate.py b/generate.py index 7183213e7..c1c02b448 100644 --- a/generate.py +++ b/generate.py @@ -46,7 +46,7 @@ class GeneratorArgs: def from_args(cls, args): # -> GeneratorArgs: return cls( prompt=args.prompt, - chat_mode=args.subcommand == "chat", + chat_mode=hasattr(args, "subcommand") and args.subcommand == "chat", gui_mode=args.gui, num_samples=args.num_samples, max_new_tokens=args.max_new_tokens, @@ -404,12 +404,11 @@ def _main( if i >= 0 and generator_args.chat_mode: prompt = input("What is your prompt? ") - # DEBUG DO NOT COMMIT - B_INST = "" - E_INST = "" + if generator_args.chat_mode: + # TODO Where should B_INST, E_INST come from? Model args? + #prompt = f"{B_INST} {prompt.strip()} {E_INST}" + prompt = prompt.strip() - if chat_mode: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = encode_tokens( tokenizer, prompt, bos=True, device=builder_args.device ) @@ -492,8 +491,6 @@ def callback(x): def main(args): - is_chat = args.subcommand == "chat" - # If a named model was provided and not downloaded, download it. if args.model and not is_model_downloaded(args.model, args.model_directory): download_and_convert(args.model, args.model_directory, args.hf_token)