Skip to content

Commit

Permalink
Allow GitAuto to call the same function with the same args up to 3 ti…
Browse files Browse the repository at this point in the history
…mes (previously not allowed)
  • Loading branch information
hiroshinishio committed Aug 19, 2024
1 parent 0a24bff commit bc76f78
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions services/openai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import logging
import time
from collections import defaultdict
from typing import Any, Iterable

# Third-party imports
Expand Down Expand Up @@ -174,7 +175,6 @@ def wait_on_run(
"""
client: OpenAI = create_openai_client()
input_data = ""
processed_calls = set()
p = 40
while run.status not in OPENAI_FINAL_STATUSES:
run = client.beta.threads.runs.retrieve(
Expand All @@ -187,16 +187,17 @@ def wait_on_run(
continue

try:
tool_outputs = call_functions(
run=run, base_args=base_args, processed_calls=processed_calls, p=p
)
tool_outputs = call_functions(run=run, base_args=base_args, p=p)
if not tool_outputs:
client.beta.threads.runs.cancel(thread_id=thread.id, run_id=run.id)
return run, input_data

# Serialize the tool outputs. Use default=str to serialize unterminated strings. (I encountered this error here: https://github.com/hiroshinishio/ClickHouse/blob/master/src/Functions/FunctionsStringHashFixedString.cpp)
tool_outputs_json: list[ToolOutput] = [
{"tool_call_id": tool_call.id, "output": json.dumps(result, default=str)}
{
"tool_call_id": tool_call.id,
"output": json.dumps(result, default=str),
}
for tool_call, result in tool_outputs
]

Expand Down Expand Up @@ -226,14 +227,13 @@ def wait_on_run(
return run, input_data


def call_functions(
run: Run, base_args: BaseArgs, processed_calls: set, p: int
) -> list[Any]:
def call_functions(run: Run, base_args: BaseArgs, p: int) -> list[Any]:
# Raise an error if there is no tool call in the run
if run.required_action is None:
raise ValueError("No tool call in the run.")
# Get the tool calls
results: list[Any] = []
call_counts = defaultdict(int)
for tool_call in run.required_action.submit_tool_outputs.tool_calls:
name: str = tool_call.function.name
args: dict[str, Any] = json.loads(s=tool_call.function.arguments)
Expand All @@ -250,11 +250,11 @@ def call_functions(

# Skip duplicate calls
call_signature = (name, json.dumps(args, sort_keys=True))
if call_signature in processed_calls:
msg = f"Skipping duplicate call: '{name}' with '{args}'\n"
call_counts[call_signature] += 1
if call_counts[call_signature] >= 3:
msg = f"Cancelled run: {run.id} because call {call_signature} was repeated too many times (3 times)."
logging.error(msg)
return []
processed_calls.add(call_signature)

# Skip the function if it doesn't exist
if name not in functions:
Expand Down

0 comments on commit bc76f78

Please sign in to comment.