diff --git a/.zed/settings.json b/.zed/settings.json new file mode 100644 index 00000000..0c6f3718 --- /dev/null +++ b/.zed/settings.json @@ -0,0 +1,15 @@ +// Folder-specific settings +// +// For a full list of overridable settings, and general information on folder-specific settings, +// see the documentation: https://zed.dev/docs/configuring-zed#settings-files +{ + "lsp": { + "pyright": { + "settings": { + "python": { + "pythonPath": ".venv/bin/python" + } + } + } + } +} diff --git a/script/claude.py b/script/claude.py index b28d6d98..41d9c13f 100755 --- a/script/claude.py +++ b/script/claude.py @@ -1,7 +1,10 @@ #!.venv/bin/python +from collections.abc import Callable import itertools import sys +from signal import signal, SIGINT +from contextlib import contextmanager from pathlib import Path from typing import Iterable @@ -9,14 +12,21 @@ from anthropic.types.beta.prompt_caching import ( PromptCachingBetaMessageParam as MessageParam, ) -from anthropic.types.beta.prompt_caching import PromptCachingBetaTextBlockParam +from anthropic.types.beta.prompt_caching import ( + PromptCachingBetaTextBlockParam, + PromptCachingBetaToolParam, +) +from anthropic.types.beta.prompt_caching.prompt_caching_beta_message import PromptCachingBetaMessage +from anthropic.types.beta.prompt_caching.prompt_caching_beta_message_param import PromptCachingBetaMessageParam from dotenv import load_dotenv from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory from rich.console import Console from rich.prompt import Confirm from rich.syntax import Syntax +from rich.status import Status from rich.theme import Theme +from typing import cast load_dotenv() @@ -56,16 +66,9 @@ - There can only be one trading market for a whole star system. All planets within the system share the same trading market (meaning the same list of commodities and prices). - Planetary descriptions should be interesting and captivating, but limit them to approximately five or so paragraphs. -You have the ability to create new files or update existing files in-place. To write to files, include tags somewhere in your response, like this: - - -... file contents omitted for brevity ... - - -IMPORTANT: -- Wait to write code until specifically asked to do so. -- When writing files, ALWAYS include the FULL rewritten content, including ALL unchanged lines as well.""" +You have a tool to replace the contents of a file, but wait to write code until specifically asked to do so.""" +WRITE_FILE_TOOL = "write_file" def load_context_from_file(path: Path) -> str: contents = path.read_text() @@ -74,14 +77,33 @@ def load_context_from_file(path: Path) -> str: {contents} """ - def load_context_from_paths(paths: Iterable[Path]) -> str: return "\n".join(load_context_from_file(path) for path in paths) +@contextmanager +def catch_interrupts(should_continue_fn: Callable[[], bool]): + in_handler = False + def handler(signum, frame): + nonlocal in_handler + if in_handler: + raise KeyboardInterrupt + + in_handler = True + should_continue = should_continue_fn() + in_handler = False + + if not should_continue: + raise KeyboardInterrupt + + prev = signal(SIGINT, handler) + try: + yield + finally: + signal(SIGINT, prev) def sample( messages: list[MessageParam], append_to_system_prompt: str | None = None -) -> str: +) -> PromptCachingBetaMessage: system_prompt = SYSTEM_PROMPT if append_to_system_prompt is not None: system_prompt += f"\n\n{append_to_system_prompt}" @@ -92,95 +114,99 @@ def sample( "cache_control": {"type": "ephemeral"}, } - with client.beta.prompt_caching.messages.stream( + tools: list[PromptCachingBetaToolParam] = [ + { + "name": WRITE_FILE_TOOL, + "description": "Replace the contents of a file.", + "input_schema": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path to write to"}, + "content": { + "type": "string", + "description": "The full, new content for the file.", + }, + }, + "required": ["path", "content"], + }, + } + ] + + current_tool: str | None = None + status: Status | None = None + + def check_continue() -> bool: + if status: + status.stop() + + result = Confirm.ask("\n\nInterrupted. Do you want to continue?", default=True) + if result and status: + status.start() + + return result + + with catch_interrupts(check_continue), client.beta.prompt_caching.messages.stream( model=MODEL, - max_tokens=4096, + max_tokens=8192, system=[system_block], messages=messages, - stop_sequences=[''], - ) - assert path_message.content[0].type == "text" - file_path = path_message.content[0].text - assistant_turn += f'{file_path}">' - - contents_message = client.beta.prompt_caching.messages.create( - model=MODEL, - max_tokens=4096, - system=[system_block], - messages=[ - *messages, - {"role": "assistant", "content": assistant_turn}, - ], - stop_sequences=[""], - ) - - console.print(contents_message.usage.to_json(indent=None), style="info") - if contents_message.stop_reason == "max_tokens": - console.print("\nReached max tokens.\n", style="info") - - assert contents_message.content[0].type == "text" - file_contents = contents_message.content[0].text - assistant_turn += f"{file_contents}" - - syntax = Syntax( - code=file_contents, - lexer=Syntax.guess_lexer(path=file_path, code=file_contents), - theme="ansi_light", - ) - console.print(syntax) - if Confirm.ask(f"\nWrite to file {file_path}?", default=False): - Path(file_path).write_text(file_contents) - - console.print() - - messages = [*messages, {"role": "assistant", "content": assistant_turn}] - remaining_text = sample( - messages=messages, append_to_system_prompt=append_to_system_prompt - ) - return assistant_turn + remaining_text + if message.stop_reason == "max_tokens": + console.print("\nReached max tokens.\n", style="info") - case "end_turn": - pass - - case "max_tokens": - console.print("\nReached max tokens.\n", style="info") - - case other: - console.print("\nUnexpected stop reason:", other, style="error") - - return message.content[0].text + return message def main() -> None: @@ -277,13 +303,14 @@ def handle_command(command: str) -> None: messages + [user_message], append_to_system_prompt=f"Use these files from the project to help with your response:\n{context}", ) - messages += [ - user_message, - {"role": "assistant", "content": assistant_message}, - ] + + messages.append(user_message) + messages.append({"role": assistant_message.role, "content": assistant_message.content}) console.print() except KeyboardInterrupt: - console.print("\n\nInterrupted. Discarding last turn.\n", style="info") + console.print("\n\nDiscarding last turn.\n", style="info") + except Exception: + console.print_exception() if __name__ == "__main__":