Skip to content

Commit

Permalink
Support resuming after SIGINT
Browse files Browse the repository at this point in the history
  • Loading branch information
jspahrsummers committed Aug 19, 2024
1 parent fcdab94 commit 7f68621
Showing 1 changed file with 39 additions and 6 deletions.
45 changes: 39 additions & 6 deletions script/claude.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#!.venv/bin/python

from collections.abc import Callable
import itertools
import sys
from signal import signal, SIGINT
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable

Expand Down Expand Up @@ -77,6 +80,27 @@ def load_context_from_file(path: Path) -> str:
def load_context_from_paths(paths: Iterable[Path]) -> str:
return "\n".join(load_context_from_file(path) for path in paths)

@contextmanager
def catch_interrupts(should_continue_fn: Callable[[], bool]):
in_handler = False
def handler(signum, frame):
nonlocal in_handler
if in_handler:
raise KeyboardInterrupt

in_handler = True
should_continue = should_continue_fn()
in_handler = False

if not should_continue:
raise KeyboardInterrupt

prev = signal(SIGINT, handler)
try:
yield
finally:
signal(SIGINT, prev)

def sample(
messages: list[MessageParam], append_to_system_prompt: str | None = None
) -> PromptCachingBetaMessage:
Expand Down Expand Up @@ -108,7 +132,20 @@ def sample(
}
]

with client.beta.prompt_caching.messages.stream(
current_tool: str | None = None
status: Status | None = None

def check_continue() -> bool:
if status:
status.stop()

result = Confirm.ask("\n\nInterrupted. Do you want to continue?", default=True)
if result and status:
status.start()

return result

with catch_interrupts(check_continue), client.beta.prompt_caching.messages.stream(
model=MODEL,
max_tokens=8192,
system=[system_block],
Expand All @@ -118,9 +155,6 @@ def sample(
"anthropic-beta": "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15"
},
) as stream:
current_tool: str | None = None
status: Status | None = None

try:
for event in stream:
match event.type:
Expand Down Expand Up @@ -274,8 +308,7 @@ def handle_command(command: str) -> None:
messages.append({"role": assistant_message.role, "content": assistant_message.content})
console.print()
except KeyboardInterrupt:
console.print("\n\nInterrupted. Discarding last turn.\n", style="info")
# TODO: Offer resumption
console.print("\n\nDiscarding last turn.\n", style="info")


if __name__ == "__main__":
Expand Down

0 comments on commit 7f68621

Please sign in to comment.