From 87a34acc5e97b626ac0d77469ecaec21a96bdbe7 Mon Sep 17 00:00:00 2001 From: ajhai Date: Wed, 23 Oct 2024 16:29:46 -0700 Subject: [PATCH] Emit usage data from agent controller --- llmstack/apps/runner/agent_actor.py | 43 ++++++++++++++++++++++-- llmstack/apps/runner/agent_controller.py | 21 +++++++++++- 2 files changed, 60 insertions(+), 4 deletions(-) diff --git a/llmstack/apps/runner/agent_actor.py b/llmstack/apps/runner/agent_actor.py index 383e77f5ec..8f432415fe 100644 --- a/llmstack/apps/runner/agent_actor.py +++ b/llmstack/apps/runner/agent_actor.py @@ -16,9 +16,12 @@ ) from llmstack.apps.runner.output_actor import OutputActor from llmstack.common.utils.liquid import render_template +from llmstack.common.utils.provider_config import get_matched_provider_config from llmstack.play.messages import ContentData, Error, Message, MessageType from llmstack.play.output_stream import stitch_model_objects from llmstack.play.utils import run_coro_in_new_loop +from llmstack.processors.providers.config import ProviderConfigSource +from llmstack.processors.providers.metrics import MetricType logger = logging.getLogger(__name__) @@ -36,11 +39,18 @@ def __init__( self._process_output_task = None self._config = agent_config self._provider_configs = provider_configs + self._provider_slug = self._config.get("provider_slug", "openai") + self._model_slug = self._config.get("model", "gpt-4o-mini") + self._provider_config = get_matched_provider_config( + provider_configs=self._provider_configs, + provider_slug=self._provider_slug, + model_slug=self._model_slug, + ) self._controller_config = AgentControllerConfig( provider_configs=self._provider_configs, - provider_slug=self._config.get("provider_slug", "openai"), - model_slug=self._config.get("model", "gpt-4o-mini"), + provider_slug=self._provider_slug, + model_slug=self._model_slug, system_message=self._config.get("system_message", "You are a helpful assistant."), tools=tools, stream=True if self._config.get("stream") is None else self._config.get("stream"), @@ -127,7 +137,9 @@ async def _process_output(self): "chunks": self._stitched_data, } ) - self._bookkeeping_data_map["agent"] = self._stitched_data["agent"] + + self._bookkeeping_data_map["agent"]["config"] = self._config + self._bookkeeping_data_map["agent"]["output"] = self._stitched_data["agent"] self._bookkeeping_data_map["agent"]["timestamp"] = time.time() self._bookkeeping_data_future.set(self._bookkeeping_data_map) elif controller_output.type == AgentControllerDataType.TOOL_CALLS: @@ -195,6 +207,30 @@ async def _process_output(self): message_index += 1 elif controller_output.type == AgentControllerDataType.AGENT_OUTPUT_END: message_index = 0 + elif controller_output.type == AgentControllerDataType.USAGE_DATA: + self._bookkeeping_data_map["agent"]["usage_data"] = { + "usage_metrics": [ + [ + ("promptly/*/*/*", MetricType.INVOCATION, (ProviderConfigSource.PLATFORM_DEFAULT, 1)), + ( + f"{self._provider_slug}/*/{self._model_slug}/*", + MetricType.INPUT_TOKENS, + ( + self._provider_config.provider_config_source, + controller_output.data.prompt_tokens, + ), + ), + ( + f"{self._provider_slug}/*/{self._model_slug}/*", + MetricType.OUTPUT_TOKENS, + ( + self._provider_config.provider_config_source, + controller_output.data.completion_tokens, + ), + ), + ] + ] + } except asyncio.QueueEmpty: await asyncio.sleep(0.1) except Exception as e: @@ -307,6 +343,7 @@ def reset(self): super().reset() self._stitched_data = {"agent": {}} self._agent_outputs = {} + self._bookkeeping_data_map = {"agent": {}} if self._process_output_task: self._process_output_task.cancel() diff --git a/llmstack/apps/runner/agent_controller.py b/llmstack/apps/runner/agent_controller.py index 03b136c4d9..851045ab40 100644 --- a/llmstack/apps/runner/agent_controller.py +++ b/llmstack/apps/runner/agent_controller.py @@ -37,6 +37,13 @@ class AgentControllerDataType(StrEnum): AGENT_OUTPUT = "agent_output" AGENT_OUTPUT_END = "agent_output_end" ERROR = "error" + USAGE_DATA = "usage_data" + + +class AgentUsageData(BaseModel): + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 class AgentMessageRole(StrEnum): @@ -86,7 +93,7 @@ class AgentToolCallsMessage(BaseModel): class AgentControllerData(BaseModel): type: AgentControllerDataType - data: Optional[Union[AgentUserMessage, AgentAssistantMessage, AgentToolCallsMessage]] = None + data: Optional[Union[AgentUserMessage, AgentAssistantMessage, AgentToolCallsMessage, AgentUsageData]] = None class AgentController: @@ -200,6 +207,18 @@ def add_response_to_output_queue(self, response: Any): """ Add the response to the output queue as well as update _messages """ + if response.usage: + self._output_queue.put_nowait( + AgentControllerData( + type=AgentControllerDataType.USAGE_DATA, + data=AgentUsageData( + prompt_tokens=response.usage.input_tokens, + completion_tokens=response.usage.output_tokens, + total_tokens=response.usage.total_tokens, + ), + ) + ) + # For streaming responses, add the content to the output queue and messages if isinstance(response, ChatCompletionChunk) and response.choices[0].delta.content: self._output_queue.put_nowait(