diff --git a/.zed/settings.json b/.zed/settings.json
new file mode 100644
index 00000000..0c6f3718
--- /dev/null
+++ b/.zed/settings.json
@@ -0,0 +1,15 @@
+// Folder-specific settings
+//
+// For a full list of overridable settings, and general information on folder-specific settings,
+// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
+{
+ "lsp": {
+ "pyright": {
+ "settings": {
+ "python": {
+ "pythonPath": ".venv/bin/python"
+ }
+ }
+ }
+ }
+}
diff --git a/script/claude.py b/script/claude.py
index b28d6d98..41d9c13f 100755
--- a/script/claude.py
+++ b/script/claude.py
@@ -1,7 +1,10 @@
#!.venv/bin/python
+from collections.abc import Callable
import itertools
import sys
+from signal import signal, SIGINT
+from contextlib import contextmanager
from pathlib import Path
from typing import Iterable
@@ -9,14 +12,21 @@
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaMessageParam as MessageParam,
)
-from anthropic.types.beta.prompt_caching import PromptCachingBetaTextBlockParam
+from anthropic.types.beta.prompt_caching import (
+ PromptCachingBetaTextBlockParam,
+ PromptCachingBetaToolParam,
+)
+from anthropic.types.beta.prompt_caching.prompt_caching_beta_message import PromptCachingBetaMessage
+from anthropic.types.beta.prompt_caching.prompt_caching_beta_message_param import PromptCachingBetaMessageParam
from dotenv import load_dotenv
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from rich.console import Console
from rich.prompt import Confirm
from rich.syntax import Syntax
+from rich.status import Status
from rich.theme import Theme
+from typing import cast
load_dotenv()
@@ -56,16 +66,9 @@
- There can only be one trading market for a whole star system. All planets within the system share the same trading market (meaning the same list of commodities and prices).
- Planetary descriptions should be interesting and captivating, but limit them to approximately five or so paragraphs.
-You have the ability to create new files or update existing files in-place. To write to files, include tags somewhere in your response, like this:
-
-
-... file contents omitted for brevity ...
-
-
-IMPORTANT:
-- Wait to write code until specifically asked to do so.
-- When writing files, ALWAYS include the FULL rewritten content, including ALL unchanged lines as well."""
+You have a tool to replace the contents of a file, but wait to write code until specifically asked to do so."""
+WRITE_FILE_TOOL = "write_file"
def load_context_from_file(path: Path) -> str:
contents = path.read_text()
@@ -74,14 +77,33 @@ def load_context_from_file(path: Path) -> str:
{contents}
"""
-
def load_context_from_paths(paths: Iterable[Path]) -> str:
return "\n".join(load_context_from_file(path) for path in paths)
+@contextmanager
+def catch_interrupts(should_continue_fn: Callable[[], bool]):
+ in_handler = False
+ def handler(signum, frame):
+ nonlocal in_handler
+ if in_handler:
+ raise KeyboardInterrupt
+
+ in_handler = True
+ should_continue = should_continue_fn()
+ in_handler = False
+
+ if not should_continue:
+ raise KeyboardInterrupt
+
+ prev = signal(SIGINT, handler)
+ try:
+ yield
+ finally:
+ signal(SIGINT, prev)
def sample(
messages: list[MessageParam], append_to_system_prompt: str | None = None
-) -> str:
+) -> PromptCachingBetaMessage:
system_prompt = SYSTEM_PROMPT
if append_to_system_prompt is not None:
system_prompt += f"\n\n{append_to_system_prompt}"
@@ -92,95 +114,99 @@ def sample(
"cache_control": {"type": "ephemeral"},
}
- with client.beta.prompt_caching.messages.stream(
+ tools: list[PromptCachingBetaToolParam] = [
+ {
+ "name": WRITE_FILE_TOOL,
+ "description": "Replace the contents of a file.",
+ "input_schema": {
+ "type": "object",
+ "properties": {
+ "path": {"type": "string", "description": "File path to write to"},
+ "content": {
+ "type": "string",
+ "description": "The full, new content for the file.",
+ },
+ },
+ "required": ["path", "content"],
+ },
+ }
+ ]
+
+ current_tool: str | None = None
+ status: Status | None = None
+
+ def check_continue() -> bool:
+ if status:
+ status.stop()
+
+ result = Confirm.ask("\n\nInterrupted. Do you want to continue?", default=True)
+ if result and status:
+ status.start()
+
+ return result
+
+ with catch_interrupts(check_continue), client.beta.prompt_caching.messages.stream(
model=MODEL,
- max_tokens=4096,
+ max_tokens=8192,
system=[system_block],
messages=messages,
- stop_sequences=[''],
- )
- assert path_message.content[0].type == "text"
- file_path = path_message.content[0].text
- assistant_turn += f'{file_path}">'
-
- contents_message = client.beta.prompt_caching.messages.create(
- model=MODEL,
- max_tokens=4096,
- system=[system_block],
- messages=[
- *messages,
- {"role": "assistant", "content": assistant_turn},
- ],
- stop_sequences=[""],
- )
-
- console.print(contents_message.usage.to_json(indent=None), style="info")
- if contents_message.stop_reason == "max_tokens":
- console.print("\nReached max tokens.\n", style="info")
-
- assert contents_message.content[0].type == "text"
- file_contents = contents_message.content[0].text
- assistant_turn += f"{file_contents}"
-
- syntax = Syntax(
- code=file_contents,
- lexer=Syntax.guess_lexer(path=file_path, code=file_contents),
- theme="ansi_light",
- )
- console.print(syntax)
- if Confirm.ask(f"\nWrite to file {file_path}?", default=False):
- Path(file_path).write_text(file_contents)
-
- console.print()
-
- messages = [*messages, {"role": "assistant", "content": assistant_turn}]
- remaining_text = sample(
- messages=messages, append_to_system_prompt=append_to_system_prompt
- )
- return assistant_turn + remaining_text
+ if message.stop_reason == "max_tokens":
+ console.print("\nReached max tokens.\n", style="info")
- case "end_turn":
- pass
-
- case "max_tokens":
- console.print("\nReached max tokens.\n", style="info")
-
- case other:
- console.print("\nUnexpected stop reason:", other, style="error")
-
- return message.content[0].text
+ return message
def main() -> None:
@@ -277,13 +303,14 @@ def handle_command(command: str) -> None:
messages + [user_message],
append_to_system_prompt=f"Use these files from the project to help with your response:\n{context}",
)
- messages += [
- user_message,
- {"role": "assistant", "content": assistant_message},
- ]
+
+ messages.append(user_message)
+ messages.append({"role": assistant_message.role, "content": assistant_message.content})
console.print()
except KeyboardInterrupt:
- console.print("\n\nInterrupted. Discarding last turn.\n", style="info")
+ console.print("\n\nDiscarding last turn.\n", style="info")
+ except Exception:
+ console.print_exception()
if __name__ == "__main__":