Skip to content

Commit

Permalink
allow multiline prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed Apr 12, 2024
1 parent ca07e5e commit 2a7c397
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
from litgpt.utils import CLI, check_valid_checkpoint_dir, get_default_supported_precision, load_checkpoint


def read_multiline_input(prompt):
print(prompt, end="")
input_lines = []
while True:
line = input()
if line.strip() == "":
break
input_lines.append(line)
return "\n".join(input_lines)


@torch.inference_mode()
def generate(
model: GPT,
Expand Down Expand Up @@ -163,14 +174,18 @@ def main(
)
stop_tokens = prompt_style.stop_tokens(tokenizer)

print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")
print(
f"Now chatting with {config.name}.\nTo exit, type 'exit'"
" or 'quit' and press 'Enter'.\nAfter entering your prompt,"
" hit 'Enter' once to start writing on a new line. Hit 'Enter'"
" twice to submit the prompt.")
L.seed_everything(1234)
while True:
try:
prompt = input(">> Prompt: ")
prompt = read_multiline_input(">> Prompt: ")
except KeyboardInterrupt:
break
if prompt.lower().strip() in ("", "quit", "exit"):
if prompt.lower().strip() in ("quit", "exit"):
break
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
Expand Down

0 comments on commit 2a7c397

Please sign in to comment.