Skip to content

Commit

Permalink
Merge pull request #85 from jspahrsummers/more-script-improvements
Browse files Browse the repository at this point in the history
More script improvements
  • Loading branch information
jspahrsummers authored Aug 19, 2024
2 parents 29602f7 + b4ed60a commit bafce27
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 96 deletions.
15 changes: 15 additions & 0 deletions .zed/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Folder-specific settings
//
// For a full list of overridable settings, and general information on folder-specific settings,
// see the documentation: https://zed.dev/docs/configuring-zed#settings-files
{
"lsp": {
"pyright": {
"settings": {
"python": {
"pythonPath": ".venv/bin/python"
}
}
}
}
}
219 changes: 123 additions & 96 deletions script/claude.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
#!.venv/bin/python

from collections.abc import Callable
import itertools
import sys
from signal import signal, SIGINT
from contextlib import contextmanager
from pathlib import Path
from typing import Iterable

from anthropic import Anthropic
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaMessageParam as MessageParam,
)
from anthropic.types.beta.prompt_caching import PromptCachingBetaTextBlockParam
from anthropic.types.beta.prompt_caching import (
PromptCachingBetaTextBlockParam,
PromptCachingBetaToolParam,
)
from anthropic.types.beta.prompt_caching.prompt_caching_beta_message import PromptCachingBetaMessage
from anthropic.types.beta.prompt_caching.prompt_caching_beta_message_param import PromptCachingBetaMessageParam
from dotenv import load_dotenv
from prompt_toolkit import PromptSession
from prompt_toolkit.history import FileHistory
from rich.console import Console
from rich.prompt import Confirm
from rich.syntax import Syntax
from rich.status import Status
from rich.theme import Theme
from typing import cast

load_dotenv()

Expand Down Expand Up @@ -56,16 +66,9 @@
- There can only be one trading market for a whole star system. All planets within the system share the same trading market (meaning the same list of commodities and prices).
- Planetary descriptions should be interesting and captivating, but limit them to approximately five or so paragraphs.
You have the ability to create new files or update existing files in-place. To write to files, include <file> tags somewhere in your response, like this:
<file path="path/to/write.gd">
... file contents omitted for brevity ...
</file>
IMPORTANT:
- Wait to write code until specifically asked to do so.
- When writing files, ALWAYS include the FULL rewritten content, including ALL unchanged lines as well."""
You have a tool to replace the contents of a file, but wait to write code until specifically asked to do so."""

WRITE_FILE_TOOL = "write_file"

def load_context_from_file(path: Path) -> str:
contents = path.read_text()
Expand All @@ -74,14 +77,33 @@ def load_context_from_file(path: Path) -> str:
{contents}
</file>"""


def load_context_from_paths(paths: Iterable[Path]) -> str:
return "\n".join(load_context_from_file(path) for path in paths)

@contextmanager
def catch_interrupts(should_continue_fn: Callable[[], bool]):
in_handler = False
def handler(signum, frame):
nonlocal in_handler
if in_handler:
raise KeyboardInterrupt

in_handler = True
should_continue = should_continue_fn()
in_handler = False

if not should_continue:
raise KeyboardInterrupt

prev = signal(SIGINT, handler)
try:
yield
finally:
signal(SIGINT, prev)

def sample(
messages: list[MessageParam], append_to_system_prompt: str | None = None
) -> str:
) -> PromptCachingBetaMessage:
system_prompt = SYSTEM_PROMPT
if append_to_system_prompt is not None:
system_prompt += f"\n\n{append_to_system_prompt}"
Expand All @@ -92,95 +114,99 @@ def sample(
"cache_control": {"type": "ephemeral"},
}

with client.beta.prompt_caching.messages.stream(
tools: list[PromptCachingBetaToolParam] = [
{
"name": WRITE_FILE_TOOL,
"description": "Replace the contents of a file.",
"input_schema": {
"type": "object",
"properties": {
"path": {"type": "string", "description": "File path to write to"},
"content": {
"type": "string",
"description": "The full, new content for the file.",
},
},
"required": ["path", "content"],
},
}
]

current_tool: str | None = None
status: Status | None = None

def check_continue() -> bool:
if status:
status.stop()

result = Confirm.ask("\n\nInterrupted. Do you want to continue?", default=True)
if result and status:
status.start()

return result

with catch_interrupts(check_continue), client.beta.prompt_caching.messages.stream(
model=MODEL,
max_tokens=4096,
max_tokens=8192,
system=[system_block],
messages=messages,
stop_sequences=['<file path="'],
tools=tools,
extra_headers={
"anthropic-beta": "prompt-caching-2024-07-31,max-tokens-3-5-sonnet-2024-07-15"
},
) as stream:
for text in stream.text_stream:
console.out(text, end="", style="assistant")
try:
for event in stream:
match event.type:
case "text":
console.out(event.text, end="", style="assistant")

case "content_block_start":
if event.content_block.type == "tool_use":
current_tool = event.content_block.name
status = Status(f"{current_tool}…")
status.start()
else:
current_tool = None

case "content_block_stop":
if status:
status.stop()
status = None

if event.content_block.type == "tool_use":
assert event.content_block.name == current_tool

# Type of this is wrong in the SDK, for some reason
input = cast(dict, event.content_block.input)
path: str = input["path"]
code: str = input["content"]

syntax = Syntax(
code=code,
lexer=Syntax.guess_lexer(
path=path, code=code
),
theme="ansi_light",
)
console.print(syntax)

if Confirm.ask(f"\nWrite to file {path}?", default=False):
Path(path).write_text(code)

current_tool = None
finally:
if status:
status.stop()

console.out()

message = stream.get_final_message()
console.print(message.usage.to_json(indent=None), style="info")
assert message.content[0].type == "text"

match message.stop_reason:
case "stop_sequence":
assert message.stop_sequence == '<file path="'

assistant_turn = f'{message.content[0].text}<file path="'
if messages[-1]["role"] == "assistant":
last_assistant_message = messages[-1]
messages = messages[:-1]
assistant_turn = (
f"{last_assistant_message["content"]}{assistant_turn}"
)

with console.status("Writing a file…"):
path_message = client.beta.prompt_caching.messages.create(
model=MODEL,
max_tokens=100,
system=[system_block],
messages=[
*messages,
{"role": "assistant", "content": assistant_turn},
],
stop_sequences=['">'],
)
assert path_message.content[0].type == "text"
file_path = path_message.content[0].text
assistant_turn += f'{file_path}">'

contents_message = client.beta.prompt_caching.messages.create(
model=MODEL,
max_tokens=4096,
system=[system_block],
messages=[
*messages,
{"role": "assistant", "content": assistant_turn},
],
stop_sequences=["</file>"],
)

console.print(contents_message.usage.to_json(indent=None), style="info")
if contents_message.stop_reason == "max_tokens":
console.print("\nReached max tokens.\n", style="info")

assert contents_message.content[0].type == "text"
file_contents = contents_message.content[0].text
assistant_turn += f"{file_contents}</file>"

syntax = Syntax(
code=file_contents,
lexer=Syntax.guess_lexer(path=file_path, code=file_contents),
theme="ansi_light",
)
console.print(syntax)
if Confirm.ask(f"\nWrite to file {file_path}?", default=False):
Path(file_path).write_text(file_contents)

console.print()

messages = [*messages, {"role": "assistant", "content": assistant_turn}]
remaining_text = sample(
messages=messages, append_to_system_prompt=append_to_system_prompt
)
return assistant_turn + remaining_text
if message.stop_reason == "max_tokens":
console.print("\nReached max tokens.\n", style="info")

case "end_turn":
pass

case "max_tokens":
console.print("\nReached max tokens.\n", style="info")

case other:
console.print("\nUnexpected stop reason:", other, style="error")

return message.content[0].text
return message


def main() -> None:
Expand Down Expand Up @@ -277,13 +303,14 @@ def handle_command(command: str) -> None:
messages + [user_message],
append_to_system_prompt=f"Use these files from the project to help with your response:\n{context}",
)
messages += [
user_message,
{"role": "assistant", "content": assistant_message},
]

messages.append(user_message)
messages.append({"role": assistant_message.role, "content": assistant_message.content})
console.print()
except KeyboardInterrupt:
console.print("\n\nInterrupted. Discarding last turn.\n", style="info")
console.print("\n\nDiscarding last turn.\n", style="info")
except Exception:
console.print_exception()


if __name__ == "__main__":
Expand Down

0 comments on commit bafce27

Please sign in to comment.