From 9bc0f3d5ce636e3afcbfed64340f1ccff15f17d9 Mon Sep 17 00:00:00 2001 From: Diwank Singh Tomer Date: Sun, 10 Nov 2024 14:54:16 +0530 Subject: [PATCH] refactor(agents-api): Refactor context.model_dump() to be more lightweight Signed-off-by: Diwank Singh Tomer --- .../activities/task_steps/evaluate_step.py | 2 +- .../activities/task_steps/for_each_step.py | 6 ++--- .../activities/task_steps/if_else_step.py | 2 +- .../activities/task_steps/log_step.py | 2 +- .../activities/task_steps/map_reduce_step.py | 4 +++- .../activities/task_steps/prompt_step.py | 2 +- .../activities/task_steps/return_step.py | 2 +- .../activities/task_steps/set_value_step.py | 2 +- .../activities/task_steps/switch_step.py | 2 +- .../activities/task_steps/tool_call_step.py | 2 +- .../task_steps/wait_for_input_step.py | 2 +- .../activities/task_steps/yield_step.py | 2 +- agents-api/agents_api/clients/temporal.py | 9 ++++++-- .../agents_api/common/protocol/tasks.py | 23 +++++++++++-------- 14 files changed, 36 insertions(+), 26 deletions(-) diff --git a/agents-api/agents_api/activities/task_steps/evaluate_step.py b/agents-api/agents_api/activities/task_steps/evaluate_step.py index 76031c8e3..4458bbd2d 100644 --- a/agents-api/agents_api/activities/task_steps/evaluate_step.py +++ b/agents-api/agents_api/activities/task_steps/evaluate_step.py @@ -23,7 +23,7 @@ async def evaluate_step( else context.current_step.evaluate ) - values = context.model_dump(include_remote=True) | additional_values + values = context.prepare_for_step(include_remote=True) | additional_values output = simple_eval_dict(expr, values) result = StepOutcome(output=output) diff --git a/agents-api/agents_api/activities/task_steps/for_each_step.py b/agents-api/agents_api/activities/task_steps/for_each_step.py index c9a48e44d..df01c1ca8 100644 --- a/agents-api/agents_api/activities/task_steps/for_each_step.py +++ b/agents-api/agents_api/activities/task_steps/for_each_step.py @@ -1,5 +1,3 @@ -import logging - from beartype import beartype from temporalio import activity @@ -20,12 +18,12 @@ async def for_each_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ForeachStep) output = await base_evaluate( - context.current_step.foreach.in_, context.model_dump() + context.current_step.foreach.in_, context.prepare_for_step() ) return StepOutcome(output=output) except BaseException as e: - logging.error(f"Error in for_each_step: {e}") + activity.logger.error(f"Error in for_each_step: {e}") return StepOutcome(error=str(e)) diff --git a/agents-api/agents_api/activities/task_steps/if_else_step.py b/agents-api/agents_api/activities/task_steps/if_else_step.py index 1b6aeb60e..9b90647de 100644 --- a/agents-api/agents_api/activities/task_steps/if_else_step.py +++ b/agents-api/agents_api/activities/task_steps/if_else_step.py @@ -20,7 +20,7 @@ async def if_else_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, IfElseWorkflowStep) expr: str = context.current_step.if_ - output = await base_evaluate(expr, context.model_dump()) + output = await base_evaluate(expr, context.prepare_for_step()) output: bool = bool(output) result = StepOutcome(output=output) diff --git a/agents-api/agents_api/activities/task_steps/log_step.py b/agents-api/agents_api/activities/task_steps/log_step.py index 80a61089f..4c5158279 100644 --- a/agents-api/agents_api/activities/task_steps/log_step.py +++ b/agents-api/agents_api/activities/task_steps/log_step.py @@ -22,7 +22,7 @@ async def log_step(context: StepContext) -> StepOutcome: template: str = context.current_step.log output = await render_template( template, - context.model_dump(include_remote=True), + context.prepare_for_step(include_remote=True), skip_vars=["developer_id"], ) diff --git a/agents-api/agents_api/activities/task_steps/map_reduce_step.py b/agents-api/agents_api/activities/task_steps/map_reduce_step.py index 43cd13690..904a7082a 100644 --- a/agents-api/agents_api/activities/task_steps/map_reduce_step.py +++ b/agents-api/agents_api/activities/task_steps/map_reduce_step.py @@ -19,7 +19,9 @@ async def map_reduce_step(context: StepContext) -> StepOutcome: try: assert isinstance(context.current_step, MapReduceStep) - output = await base_evaluate(context.current_step.over, context.model_dump()) + output = await base_evaluate( + context.current_step.over, context.prepare_for_step() + ) return StepOutcome(output=output) diff --git a/agents-api/agents_api/activities/task_steps/prompt_step.py b/agents-api/agents_api/activities/task_steps/prompt_step.py index 038830511..ad16ed6bd 100644 --- a/agents-api/agents_api/activities/task_steps/prompt_step.py +++ b/agents-api/agents_api/activities/task_steps/prompt_step.py @@ -90,7 +90,7 @@ def format_tool(tool: Tool) -> dict: async def prompt_step(context: StepContext) -> StepOutcome: # Get context data prompt: str | list[dict] = context.current_step.model_dump()["prompt"] - context_data: dict = context.model_dump(include_remote=True) + context_data: dict = context.prepare_for_step(include_remote=True) # If the prompt is a string and starts with $_ then we need to evaluate it should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith( diff --git a/agents-api/agents_api/activities/task_steps/return_step.py b/agents-api/agents_api/activities/task_steps/return_step.py index cb85c38be..e58c4d7e7 100644 --- a/agents-api/agents_api/activities/task_steps/return_step.py +++ b/agents-api/agents_api/activities/task_steps/return_step.py @@ -18,7 +18,7 @@ async def return_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, ReturnStep) exprs: dict[str, str] = context.current_step.return_ - output = await base_evaluate(exprs, context.model_dump()) + output = await base_evaluate(exprs, context.prepare_for_step()) result = StepOutcome(output=output) return result diff --git a/agents-api/agents_api/activities/task_steps/set_value_step.py b/agents-api/agents_api/activities/task_steps/set_value_step.py index 5a89fb87b..707dbdeb0 100644 --- a/agents-api/agents_api/activities/task_steps/set_value_step.py +++ b/agents-api/agents_api/activities/task_steps/set_value_step.py @@ -21,7 +21,7 @@ async def set_value_step( try: expr = override_expr if override_expr is not None else context.current_step.set - values = context.model_dump() | additional_values + values = context.prepare_for_step() | additional_values output = simple_eval_dict(expr, values) result = StepOutcome(output=output) diff --git a/agents-api/agents_api/activities/task_steps/switch_step.py b/agents-api/agents_api/activities/task_steps/switch_step.py index 29ad8ea65..413611e27 100644 --- a/agents-api/agents_api/activities/task_steps/switch_step.py +++ b/agents-api/agents_api/activities/task_steps/switch_step.py @@ -21,7 +21,7 @@ async def switch_step(context: StepContext) -> StepOutcome: output: int = -1 cases: list[str] = [c.case for c in context.current_step.switch] - evaluator = get_evaluator(names=context.model_dump()) + evaluator = get_evaluator(names=context.prepare_for_step()) for i, case in enumerate(cases): result = evaluator.eval(case) diff --git a/agents-api/agents_api/activities/task_steps/tool_call_step.py b/agents-api/agents_api/activities/task_steps/tool_call_step.py index a0ea7e7bf..03525e5ed 100644 --- a/agents-api/agents_api/activities/task_steps/tool_call_step.py +++ b/agents-api/agents_api/activities/task_steps/tool_call_step.py @@ -61,7 +61,7 @@ async def tool_call_step(context: StepContext) -> StepOutcome: raise ApplicationError(f"Tool {tool_name} not found in the toolset") arguments = await base_evaluate( - context.current_step.arguments, context.model_dump() + context.current_step.arguments, context.prepare_for_step() ) call_id = generate_call_id() diff --git a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py index 59a75ea00..d9839bc8e 100644 --- a/agents-api/agents_api/activities/task_steps/wait_for_input_step.py +++ b/agents-api/agents_api/activities/task_steps/wait_for_input_step.py @@ -15,7 +15,7 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome: assert isinstance(context.current_step, WaitForInputStep) exprs = context.current_step.wait_for_input.info - output = await base_evaluate(exprs, context.model_dump()) + output = await base_evaluate(exprs, context.prepare_for_step()) result = StepOutcome(output=output) return result diff --git a/agents-api/agents_api/activities/task_steps/yield_step.py b/agents-api/agents_api/activities/task_steps/yield_step.py index 6a97b5a07..ec6c08353 100644 --- a/agents-api/agents_api/activities/task_steps/yield_step.py +++ b/agents-api/agents_api/activities/task_steps/yield_step.py @@ -25,7 +25,7 @@ async def yield_step(context: StepContext) -> StepOutcome: ], f"Workflow {workflow} not found in task" # Evaluate the expressions in the arguments - arguments = await base_evaluate(exprs, context.model_dump()) + arguments = await base_evaluate(exprs, context.prepare_for_step()) # Transition to the first step of that workflow transition_target = TransitionTarget( diff --git a/agents-api/agents_api/clients/temporal.py b/agents-api/agents_api/clients/temporal.py index deb4809f1..5737bd97e 100644 --- a/agents-api/agents_api/clients/temporal.py +++ b/agents-api/agents_api/clients/temporal.py @@ -1,6 +1,7 @@ from datetime import timedelta from uuid import UUID +from beartype import beartype from temporalio.client import Client, TLSConfig from temporalio.common import ( SearchAttributeKey, @@ -42,16 +43,20 @@ async def get_client( ) +@beartype async def run_task_execution_workflow( *, execution_input: ExecutionInput, job_id: UUID, - start: TransitionTarget = TransitionTarget(workflow="main", step=0), - previous_inputs: list[dict] = [], + start: TransitionTarget | None = None, + previous_inputs: list[dict] | None = None, client: Client | None = None, ): from ..workflows.task_execution import TaskExecutionWorkflow + start: TransitionTarget = start or TransitionTarget(workflow="main", step=0) + previous_inputs: list[dict] = previous_inputs or [] + client = client or (await get_client()) execution_id_key = SearchAttributeKey.for_keyword("CustomStringField") diff --git a/agents-api/agents_api/common/protocol/tasks.py b/agents-api/agents_api/common/protocol/tasks.py index 06044d4a0..6a379b6a4 100644 --- a/agents-api/agents_api/common/protocol/tasks.py +++ b/agents-api/agents_api/common/protocol/tasks.py @@ -228,21 +228,26 @@ def is_main(self) -> Annotated[bool, Field(exclude=True)]: return self.cursor.workflow == "main" def model_dump(self, *args, **kwargs) -> dict[str, Any]: + dump = super().model_dump(*args, **kwargs) + execution_input: dict = dump.pop("execution_input") + + return dump | execution_input + + def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]: current_input = self.current_input if activity.in_activity(): current_input = load_from_blob_store_if_remote(current_input) - dump = super().model_dump(*args, **kwargs) - # Merge execution inputs into the dump dict - execution_input: dict = dump.pop("execution_input") - dump = { - **dump, - **execution_input, - "_": current_input, - } + dump = self.model_dump(*args, **kwargs) + prepared = dump | {"_": current_input} + + for i, input in enumerate(self.inputs): + prepared = prepared | {f"_{i}": input} + if i >= 100: + break - return dump + return prepared class StepOutcome(BaseModel):