Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(agents-api): Refactor context.model_dump() to be more lightweight #824

Merged
merged 2 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ async def evaluate_step(
else context.current_step.evaluate
)

values = context.model_dump(include_remote=True) | additional_values
values = context.prepare_for_step(include_remote=True) | additional_values

output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)
Expand Down
6 changes: 2 additions & 4 deletions agents-api/agents_api/activities/task_steps/for_each_step.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

from beartype import beartype
from temporalio import activity

Expand All @@ -20,12 +18,12 @@ async def for_each_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ForeachStep)

output = await base_evaluate(
context.current_step.foreach.in_, context.model_dump()
context.current_step.foreach.in_, context.prepare_for_step()
)
return StepOutcome(output=output)

except BaseException as e:
logging.error(f"Error in for_each_step: {e}")
activity.logger.error(f"Error in for_each_step: {e}")
return StepOutcome(error=str(e))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def if_else_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, IfElseWorkflowStep)

expr: str = context.current_step.if_
output = await base_evaluate(expr, context.model_dump())
output = await base_evaluate(expr, context.prepare_for_step())
output: bool = bool(output)

result = StepOutcome(output=output)
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/log_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def log_step(context: StepContext) -> StepOutcome:
template: str = context.current_step.log
output = await render_template(
template,
context.model_dump(include_remote=True),
context.prepare_for_step(include_remote=True),
skip_vars=["developer_id"],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ async def map_reduce_step(context: StepContext) -> StepOutcome:
try:
assert isinstance(context.current_step, MapReduceStep)

output = await base_evaluate(context.current_step.over, context.model_dump())
output = await base_evaluate(
context.current_step.over, context.prepare_for_step()
)

return StepOutcome(output=output)

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/prompt_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def format_tool(tool: Tool) -> dict:
async def prompt_step(context: StepContext) -> StepOutcome:
# Get context data
prompt: str | list[dict] = context.current_step.model_dump()["prompt"]
context_data: dict = context.model_dump(include_remote=True)
context_data: dict = context.prepare_for_step(include_remote=True)

# If the prompt is a string and starts with $_ then we need to evaluate it
should_evaluate_prompt = isinstance(prompt, str) and prompt.startswith(
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/return_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async def return_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, ReturnStep)

exprs: dict[str, str] = context.current_step.return_
output = await base_evaluate(exprs, context.model_dump())
output = await base_evaluate(exprs, context.prepare_for_step())

result = StepOutcome(output=output)
return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def set_value_step(
try:
expr = override_expr if override_expr is not None else context.current_step.set

values = context.model_dump() | additional_values
values = context.prepare_for_step() | additional_values
output = simple_eval_dict(expr, values)
result = StepOutcome(output=output)

Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/switch_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def switch_step(context: StepContext) -> StepOutcome:
output: int = -1
cases: list[str] = [c.case for c in context.current_step.switch]

evaluator = get_evaluator(names=context.model_dump())
evaluator = get_evaluator(names=context.prepare_for_step())

for i, case in enumerate(cases):
result = evaluator.eval(case)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def tool_call_step(context: StepContext) -> StepOutcome:
raise ApplicationError(f"Tool {tool_name} not found in the toolset")

arguments = await base_evaluate(
context.current_step.arguments, context.model_dump()
context.current_step.arguments, context.prepare_for_step()
)

call_id = generate_call_id()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ async def wait_for_input_step(context: StepContext) -> StepOutcome:
assert isinstance(context.current_step, WaitForInputStep)

exprs = context.current_step.wait_for_input.info
output = await base_evaluate(exprs, context.model_dump())
output = await base_evaluate(exprs, context.prepare_for_step())

result = StepOutcome(output=output)
return result
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/yield_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def yield_step(context: StepContext) -> StepOutcome:
], f"Workflow {workflow} not found in task"

# Evaluate the expressions in the arguments
arguments = await base_evaluate(exprs, context.model_dump())
arguments = await base_evaluate(exprs, context.prepare_for_step())

# Transition to the first step of that workflow
transition_target = TransitionTarget(
Expand Down
9 changes: 7 additions & 2 deletions agents-api/agents_api/clients/temporal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import timedelta
from uuid import UUID

from beartype import beartype
from temporalio.client import Client, TLSConfig
from temporalio.common import (
SearchAttributeKey,
Expand Down Expand Up @@ -42,16 +43,20 @@ async def get_client(
)


@beartype
async def run_task_execution_workflow(
*,
execution_input: ExecutionInput,
job_id: UUID,
start: TransitionTarget = TransitionTarget(workflow="main", step=0),
previous_inputs: list[dict] = [],
start: TransitionTarget | None = None,
previous_inputs: list[dict] | None = None,
client: Client | None = None,
):
from ..workflows.task_execution import TaskExecutionWorkflow

start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)
previous_inputs: list[dict] = previous_inputs or []

client = client or (await get_client())
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")

Expand Down
23 changes: 14 additions & 9 deletions agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,21 +228,26 @@ def is_main(self) -> Annotated[bool, Field(exclude=True)]:
return self.cursor.workflow == "main"

def model_dump(self, *args, **kwargs) -> dict[str, Any]:
dump = super().model_dump(*args, **kwargs)
execution_input: dict = dump.pop("execution_input")

return dump | execution_input

def prepare_for_step(self, *args, **kwargs) -> dict[str, Any]:
current_input = self.current_input
if activity.in_activity():
current_input = load_from_blob_store_if_remote(current_input)

dump = super().model_dump(*args, **kwargs)

# Merge execution inputs into the dump dict
execution_input: dict = dump.pop("execution_input")
dump = {
**dump,
**execution_input,
"_": current_input,
}
dump = self.model_dump(*args, **kwargs)
prepared = dump | {"_": current_input}

for i, input in enumerate(self.inputs):
prepared = prepared | {f"_{i}": input}
if i >= 100:
break

return dump
return prepared


class StepOutcome(BaseModel):
Expand Down
Loading