Skip to content

Commit

Permalink
fix: include a summary of function-tools registered in the Python REPL
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikBjare committed Mar 22, 2024
1 parent 2fcc245 commit 00a7bea
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 13 deletions.
2 changes: 1 addition & 1 deletion eval/filestore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def download(self) -> Files:
if path.is_file():
key = str(path.relative_to(self.working_dir))
try:
with open(path, "r") as f:
with open(path) as f:
files[key] = f.read()
except UnicodeDecodeError:
# file is binary
Expand Down
5 changes: 3 additions & 2 deletions eval/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass
from typing import Callable, Dict, TypedDict
from typing import TypedDict
from collections.abc import Callable

Files = Dict[str, str | bytes]
Files = dict[str, str | bytes]


@dataclass
Expand Down
5 changes: 4 additions & 1 deletion gptme/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def chat(
"""
Run the chat loop.
prompt_msgs: list of messages to execute in sequence.
initial_msgs: list of history messages.
Callable from other modules.
"""
# init
Expand Down Expand Up @@ -233,7 +236,7 @@ def chat(
codeblock = log.get_last_code_block("assistant", history=1, content=False)
if not (codeblock and is_supported_codeblock(codeblock)):
logger.info("Non-interactive and exhausted prompts, exiting")
exit(0)
break

# ask for input if no prompt, generate reply, and run tools
for msg in step(log, no_confirm, model, stream=stream): # pragma: no cover
Expand Down
5 changes: 4 additions & 1 deletion gptme/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .config import get_config
from .message import Message
from .tools import browser, patch
from .tools import browser, patch, python

PromptType = Literal["full", "short"]

Expand Down Expand Up @@ -153,6 +153,9 @@ def prompt_tools() -> Generator[Message, None, None]:
The following libraries are available:
{python_libraries_str}
The following functions are available in the REPL:
{python.get_functions_prompt()}
## bash
When you send a message containing bash code, it will be executed in a stateful bash shell.
Expand Down
2 changes: 2 additions & 0 deletions gptme/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from .python import execute_python, init_python
from .save import execute_save
from .shell import execute_shell
from .subagent import noop
from .summarize import summarize

noop() # just to make sure the import isn't automatically removed
logger = logging.getLogger(__name__)


Expand Down
2 changes: 1 addition & 1 deletion gptme/tools/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Callable
from collections.abc import Callable


@dataclass
Expand Down
38 changes: 36 additions & 2 deletions gptme/tools/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
import re
from collections.abc import Generator
from logging import getLogger
from typing import Callable, TypeVar
from typing import (
Literal,
TypeVar,
get_origin,
)
from collections.abc import Callable

from IPython.terminal.embed import InteractiveShellEmbed
from IPython.utils.capture import capture_output
Expand All @@ -49,7 +54,7 @@ def init_python():
check_available_packages()


registered_functions = {}
registered_functions: dict[str, Callable] = {}

T = TypeVar("T", bound=Callable)

Expand All @@ -67,6 +72,35 @@ def register_function_if(condition: bool):
return lambda x: x


def derive_type(t) -> str:
print(t, get_origin(t))
if get_origin(t) == Literal:
v = ", ".join(f'"{a}"' for a in t.__args__)
return f"Literal[{v}]"
else:
return t.__name__


def callable_signature(func: Callable) -> str:
# returns a signature f(arg1: type1, arg2: type2, ...) -> return_type
args = ", ".join(
f"{k}: {derive_type(v)}"
for k, v in func.__annotations__.items()
if k != "return"
)
ret_type = func.__annotations__.get("return")
ret = f" -> {derive_type(ret_type)}" if ret_type else ""
return f"{func.__name__}({args}){ret}"


def get_functions_prompt() -> str:
# return a prompt with a brief description of the available functions
return "\n".join(
f"- {callable_signature(func)}: {func.__doc__ or 'No description'}"
for func in registered_functions.values()
)


def _get_ipython():
global _ipython
if _ipython is None:
Expand Down
3 changes: 1 addition & 2 deletions gptme/tools/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import subprocess
import sys
from collections.abc import Generator
from typing import List

import bashlex

Expand Down Expand Up @@ -262,7 +261,7 @@ def _shorten_stdout(stdout: str, pre_lines=None, post_lines=None) -> str:
return "\n".join(lines)


def split_commands(script: str) -> List[str]:
def split_commands(script: str) -> list[str]:
# TODO: write proper tests
parts = bashlex.parse(script)
commands = []
Expand Down
100 changes: 98 additions & 2 deletions gptme/tools/subagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,106 @@
Lets gptme break down a task into smaller parts, and delegate them to subagents.
"""

import json
import threading
from dataclasses import dataclass
from typing import TYPE_CHECKING, Literal, TypedDict

def subagent(prompt: str):
from ..message import Message
from .python import register_function

if TYPE_CHECKING:
# noreorder
from ..logmanager import LogManager # fmt: skip

Status = Literal["running", "success", "failure"]

_subagents = []


class ReturnType(TypedDict):
description: str
result: Literal["success", "failure"]


@dataclass
class Subagent:
prompt: str
agent_id: str
thread: threading.Thread

def get_log(self) -> "LogManager":
# noreorder
from gptme.cli import get_logfile # fmt: skip

from ..logmanager import LogManager # fmt: skip

name = f"subagent-{self.agent_id}"
logfile = get_logfile(name, interactive=False)
return LogManager.load(logfile)

def status(self) -> tuple[Status, ReturnType | None]:
# check if the last message contains the return JSON
last_msg = self.get_log().log[-1]
if last_msg.content.startswith("{"):
print("Subagent has returned a JSON response:")
print(last_msg.content)
result = ReturnType(**json.loads(last_msg.content)) # type: ignore
return result["result"], result
else:
return "running", None


@register_function
def subagent(prompt: str, agent_id: str):
"""Runs a subagent and returns the resulting JSON output."""
# noreorder
from gptme import chat # fmt: skip
from ..prompts import get_prompt # fmt: skip

name = f"subagent-{agent_id}"

def run_subagent():
prompt_msgs = [Message("user", prompt)]
initial_msgs = [get_prompt()]

# add the return prompt
return_prompt = """When done with the task, please return a JSON response of this format:
{
description: 'A description of the task result',
result: 'success' | 'failure',
}"""
initial_msgs[0].content += "\n\n" + return_prompt

chat(
prompt_msgs,
initial_msgs,
name=name,
llm="openai",
model="gpt-4-1106-preview",
stream=False,
no_confirm=True,
interactive=False,
show_hidden=False,
)

# start a thread with a subagent
t = threading.Thread(
target=run_subagent,
daemon=True,
)
t.start()
_subagents.append(Subagent(prompt, agent_id, t))


@register_function
def subagent_status(agent_id: str):
"""Returns the status of a subagent."""
for subagent in _subagents:
if subagent.agent_id == agent_id:
return subagent.status()
raise ValueError(f"Subagent with ID {agent_id} not found.")


chat("Hello! I am a subagent.")
def noop():
pass
24 changes: 23 additions & 1 deletion tests/test_tools_python.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from gptme.tools.python import execute_python
from typing import Literal, TypeAlias

from gptme.tools.python import callable_signature, execute_python


def run(code):
Expand All @@ -13,3 +15,23 @@ def test_execute_python():
# test that vars are preserved between executions
assert run("a = 2")
assert "2\n" in run("print(a)")


TestType: TypeAlias = Literal["a", "b"]


def test_callable_signature():
def f():
pass

assert callable_signature(f) == "f()"

def g(a: int) -> str:
return str(a)

assert callable_signature(g) == "g(a: int) -> str"

def h(a: TestType) -> str:
return str(a)

assert callable_signature(h) == 'h(a: Literal["a", "b"]) -> str'

0 comments on commit 00a7bea

Please sign in to comment.