diff --git a/services/openai/agent.py b/services/openai/agent.py index 200413c7..203e7edb 100644 --- a/services/openai/agent.py +++ b/services/openai/agent.py @@ -2,6 +2,7 @@ import json import logging import time +from collections import defaultdict from typing import Any, Iterable # Third-party imports @@ -174,7 +175,6 @@ def wait_on_run( """ client: OpenAI = create_openai_client() input_data = "" - processed_calls = set() p = 40 while run.status not in OPENAI_FINAL_STATUSES: run = client.beta.threads.runs.retrieve( @@ -187,16 +187,17 @@ def wait_on_run( continue try: - tool_outputs = call_functions( - run=run, base_args=base_args, processed_calls=processed_calls, p=p - ) + tool_outputs = call_functions(run=run, base_args=base_args, p=p) if not tool_outputs: client.beta.threads.runs.cancel(thread_id=thread.id, run_id=run.id) return run, input_data # Serialize the tool outputs. Use default=str to serialize unterminated strings. (I encountered this error here: https://github.com/hiroshinishio/ClickHouse/blob/master/src/Functions/FunctionsStringHashFixedString.cpp) tool_outputs_json: list[ToolOutput] = [ - {"tool_call_id": tool_call.id, "output": json.dumps(result, default=str)} + { + "tool_call_id": tool_call.id, + "output": json.dumps(result, default=str), + } for tool_call, result in tool_outputs ] @@ -226,14 +227,13 @@ def wait_on_run( return run, input_data -def call_functions( - run: Run, base_args: BaseArgs, processed_calls: set, p: int -) -> list[Any]: +def call_functions(run: Run, base_args: BaseArgs, p: int) -> list[Any]: # Raise an error if there is no tool call in the run if run.required_action is None: raise ValueError("No tool call in the run.") # Get the tool calls results: list[Any] = [] + call_counts = defaultdict(int) for tool_call in run.required_action.submit_tool_outputs.tool_calls: name: str = tool_call.function.name args: dict[str, Any] = json.loads(s=tool_call.function.arguments) @@ -250,11 +250,11 @@ def call_functions( # Skip duplicate calls call_signature = (name, json.dumps(args, sort_keys=True)) - if call_signature in processed_calls: - msg = f"Skipping duplicate call: '{name}' with '{args}'\n" + call_counts[call_signature] += 1 + if call_counts[call_signature] >= 3: + msg = f"Cancelled run: {run.id} because call {call_signature} was repeated too many times (3 times)." logging.error(msg) return [] - processed_calls.add(call_signature) # Skip the function if it doesn't exist if name not in functions: