Skip to content

Commit

Permalink
Create subtask results
Browse files Browse the repository at this point in the history
  • Loading branch information
dewmal committed Aug 21, 2024
1 parent c1f3f22 commit 9a57d0a
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 34 deletions.
34 changes: 25 additions & 9 deletions bindings/ceylon/ceylon/llm/llm_task_coordinator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from textwrap import dedent
from typing import Dict, List, Set

Expand All @@ -9,7 +10,7 @@

from ceylon.llm.llm_task_operator import LLMTaskOperator
from ceylon.static_val import DEFAULT_WORKSPACE_ID, DEFAULT_ADMIN_PORT
from ceylon.task import TaskResult, Task, SubTask
from ceylon.task import SubTaskResult, Task, SubTask
from ceylon.task.task_coordinator import TaskCoordinator
from ceylon.task.task_operation import TaskDeliverable

Expand Down Expand Up @@ -59,7 +60,7 @@ class SubTaskList(pydantic.v1.BaseModel):
class LLMTaskCoordinator(TaskCoordinator):
tasks: List[Task] = []
agents: List[LLMTaskOperator] = []
results: Dict[str, List[TaskResult]] = {}
results: Dict[str, List[SubTaskResult]] = {}
team_network: nx.Graph = nx.Graph()

def __init__(self, tasks: List[Task], agents: List[LLMTaskOperator],
Expand All @@ -71,8 +72,8 @@ def __init__(self, tasks: List[Task], agents: List[LLMTaskOperator],
port=DEFAULT_ADMIN_PORT):
self.context = context
self.team_goal = team_goal
self.llm = llm
self.tool_llm = tool_llm
self.llm = copy.copy(llm)
self.tool_llm = copy.copy(tool_llm) if tool_llm is not None else copy.copy(llm)
self.tasks = tasks
self.agents = agents
self.initialize_team_network()
Expand Down Expand Up @@ -194,8 +195,8 @@ def get_valid_agent_name(max_attempts=3):

if response in agent_names:
return response
print(response)
print(f"Attempt {attempt + 1}: Invalid agent name received: {response} {subtask}. Retrying...")
# print(response)
logger.info(f"Attempt {attempt + 1}: Invalid agent name received: {response} {subtask}. Retrying...")

raise Exception(f"Failed to get a valid agent name after {max_attempts} attempts.")

Expand Down Expand Up @@ -234,7 +235,7 @@ async def generate_final_sub_task_from_description(self, task: Task) -> SubTask:
"task_deliverable": task.task_deliverable,
"existing_subtasks": "\n".join([f"{t.name}- {t.description}" for t in task.subtasks.values()])
})
print(sub_task)
# print(sub_task)
return sub_task.to_v2(task.id)

async def generate_tasks_from_description(self, task: Task) -> List[SubTask]:
Expand Down Expand Up @@ -330,8 +331,15 @@ async def build_task_deliverable(self, task: Task):
2. A list of specific deliverables
3. Key features to be implemented
4. Any considerations or constraints based on the team's objectives
Please format your response in a structured manner, using bullet points or numbered lists where appropriate.
Please format your response as a JSON object with the following structure:
{{
"objectives": ["objective1", "objective2", ...],
"final_output": "Description of the final output",
"final_output_type": "Type of the final output"
}}
"""))

structured_llm = self.tool_llm.with_structured_output(TaskDeliverableModel)
Expand All @@ -341,4 +349,12 @@ async def build_task_deliverable(self, task: Task):
"objectives": self.team_goal,
"task_description": task.description
})
if task_deliverable is None:
return TaskDeliverable(
objectives=[
"Finish the required deliverable",
],
final_output=f"{task.description}",
final_output_type="text"
)
return task_deliverable.to_v2()
4 changes: 2 additions & 2 deletions bindings/ceylon/ceylon/llm/llm_task_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from loguru import logger

from ceylon.task import TaskResult
from ceylon.task import SubTaskResult
from ceylon.task.task_operation import SubTask
from ceylon.task.task_operator import TaskOperator

Expand All @@ -27,7 +27,7 @@ def __init__(self, name: str, role: str, context: str, skills: List[str],
self.skills = skills
self.llm = copy.copy(llm)
self.tool_llm = copy.copy(tool_llm)
self.history: Dict[str, List[TaskResult]] = {}
self.history: Dict[str, List[SubTaskResult]] = {}
super().__init__(name=name, role=role, workspace_id=workspace_id, admin_port=admin_port)

async def get_result(self, task):
Expand Down
2 changes: 1 addition & 1 deletion bindings/ceylon/ceylon/task/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .task_operation import SubTask, TaskResult, TaskAssignment, Task
from .task_operation import SubTask, SubTaskResult, TaskAssignment, Task
20 changes: 13 additions & 7 deletions bindings/ceylon/ceylon/task/task_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

from ceylon import CoreAdmin, on_message
from ceylon.ceylon import AgentDetail
from ceylon.task import Task, TaskResult, TaskAssignment, SubTask
from ceylon.task.task_operation import TaskResultStatus
from ceylon.task import Task, SubTaskResult, TaskAssignment, SubTask
from ceylon.task.task_operation import TaskResultStatus, TaskResult
from ceylon.task.task_operator import TaskOperator


class TaskCoordinator(CoreAdmin):
tasks: List[Task] = []
agents: List[TaskOperator] = []
results: Dict[str, List[TaskResult]] = {}
results: Dict[str, List[SubTaskResult]] = {}

def __init__(self, tasks: List[Task], agents: List[TaskOperator], name="ceylon_agent_stack", port=8888, *args,
**kwargs):
Expand All @@ -29,7 +29,7 @@ async def get_task_executor(self, task: SubTask) -> str:
async def run(self, inputs: bytes):
for idx, task in enumerate(self.tasks):
task = await self.update_task(idx, task)
print(task)
logger.info(f"Validating task {task.name}")
if task.validate_sub_tasks():
logger.info(f"Task {task.name} is valid")
else:
Expand All @@ -54,9 +54,9 @@ async def run_tasks(self):
await self.broadcast_data(
TaskAssignment(task=subtask_, assigned_agent=subtask_.executor))

@on_message(type=TaskResult)
async def on_task_result(self, result: TaskResult):
print(f"Received task result: {result}")
@on_message(type=SubTaskResult)
async def on_task_result(self, result: SubTaskResult):
logger.info(f"Received task result: {result}")
if result.status == TaskResultStatus.COMPLETED:
for idx, task in enumerate(self.tasks):
sub_task = task.get_next_subtask()
Expand All @@ -67,6 +67,12 @@ async def on_task_result(self, result: TaskResult):
task.update_subtask_status(sub_task[1].name, result.result)
break

# Task is completed
for task in self.tasks:
if task.is_completed():
if task.id == result.task_id:
await self.broadcast_data(TaskResult(task_id=task.id, result=result.result))

if self.all_tasks_completed():
await self.end_task_management()

Expand Down
15 changes: 12 additions & 3 deletions bindings/ceylon/ceylon/task/task_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from uuid import uuid4

import networkx as nx
from loguru import logger
from pydantic import BaseModel
from pydantic import Field

Expand Down Expand Up @@ -76,6 +77,9 @@ def _create_dependency_graph(self) -> nx.DiGraph:
graph = nx.DiGraph()
for subtask in self.subtasks.values():
graph.add_node(subtask.name)
for dep in subtask.depends_on:
if dep in self.subtasks:
graph.add_edge(dep, subtask.name)
return graph

def validate_sub_tasks(self) -> bool:
Expand All @@ -85,7 +89,7 @@ def validate_sub_tasks(self) -> bool:
for subtask in self.subtasks.values():
if not subtask.depends_on.issubset(subtask_names):
missing_deps = subtask.depends_on - subtask_names
print(f"Subtask '{subtask.name}' has missing dependencies: {missing_deps}")
logger.info(f"Subtask '{subtask.name}' has missing dependencies: {missing_deps}")
return False

# Check for circular dependencies
Expand Down Expand Up @@ -170,7 +174,7 @@ class TaskResultStatus(enum.Enum):
FAILED = "FAILED"


class TaskResult(BaseModel):
class SubTaskResult(BaseModel):
task_id: str
parent_task_id: str
agent: str
Expand All @@ -180,6 +184,11 @@ class TaskResult(BaseModel):
status: TaskResultStatus


class TaskResult(BaseModel):
task_id: str
final_answer: str


if __name__ == "__main__":
def execute_task(task: Task) -> None:
while True:
Expand All @@ -198,7 +207,7 @@ def execute_task(task: Task) -> None:
print(f"Simulating execution of {subtask_name}")

# Simulate a result (in a real scenario, this would be the outcome of the subtask execution)
result = True
result = "Success"

# Update the subtask status
task.update_subtask_status(subtask_name, result)
Expand Down
24 changes: 12 additions & 12 deletions bindings/ceylon/ceylon/task/task_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ceylon import Agent, on_message
from ceylon.static_val import DEFAULT_WORKSPACE_ID, DEFAULT_ADMIN_PORT
from ceylon.task import TaskAssignment, TaskResult
from ceylon.task import TaskAssignment, SubTaskResult
from ceylon.task.task_operation import TaskResultStatus


Expand All @@ -17,7 +17,7 @@ def __init__(self, name: str, role: str, workspace_id: str = DEFAULT_WORKSPACE_I
**kwargs):
self.task_history = []
self.exeuction_history = []
self.history: Dict[str, List[TaskResult]] = {}
self.history: Dict[str, List[SubTaskResult]] = {}
super().__init__(name=name, role=role, workspace_id=workspace_id, admin_port=admin_port, *args, **kwargs)

@on_message(type=TaskAssignment)
Expand All @@ -41,24 +41,24 @@ async def on_task_assignment(self, data: TaskAssignment):
result = str(e)
status = TaskResultStatus.FAILED

result_task = TaskResult(task_id=data.task.id,
name=data.task.name,
description=data.task.description,
agent=self.details().name,
parent_task_id=data.task.parent_task_id,
result=result,
status=status)
result_task = SubTaskResult(task_id=data.task.id,
name=data.task.name,
description=data.task.description,
agent=self.details().name,
parent_task_id=data.task.parent_task_id,
result=result,
status=status)
# Update task history
if status == TaskResultStatus.COMPLETED:
await self.add_result_to_history(result_task)
await self.broadcast_data(result_task)
logger.info(f"{self.details().name} sent subtask result: {data.task.description}")

@on_message(type=TaskResult)
async def on_task_result(self, data: TaskResult):
@on_message(type=SubTaskResult)
async def on_task_result(self, data: SubTaskResult):
await self.add_result_to_history(data)

async def add_result_to_history(self, data: TaskResult):
async def add_result_to_history(self, data: SubTaskResult):
if data.parent_task_id in self.history:
# If the task result already exists, replace it
for idx, result in enumerate(self.history[data.parent_task_id]):
Expand Down

0 comments on commit 9a57d0a

Please sign in to comment.