Skip to content

Commit

Permalink
Label request_failure metric with error_code (#1862)
Browse files Browse the repository at this point in the history
---------

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Oct 2, 2023
1 parent 4b06628 commit 11e98fb
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 85 deletions.
4 changes: 4 additions & 0 deletions flytekit/exceptions/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,7 @@ def __init__(self, task_module, task_name=None, additional_msg=None):

class FlyteSystemAssertion(FlyteSystemException, AssertionError):
_ERROR_CODE = "SYSTEM:AssertionError"


class FlyteAgentNotFound(FlyteSystemException, AssertionError):
_ERROR_CODE = "SYSTEM:AgentNotFound"
161 changes: 79 additions & 82 deletions flytekit/extend/backend/agent_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import typing

import grpc
from flyteidl.admin.agent_pb2 import (
Expand All @@ -13,6 +14,7 @@
from prometheus_client import Counter, Summary

from flytekit import logger
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.extend.backend.base_agent import AgentRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
Expand All @@ -27,100 +29,95 @@
f"{metric_prefix}requests_success_total", "Total number of successful requests", ["task_type", "operation"]
)
request_failure_count = Counter(
f"{metric_prefix}requests_failure_total", "Total number of failed requests", ["task_type", "operation"]
f"{metric_prefix}requests_failure_total",
"Total number of failed requests",
["task_type", "operation", "error_code"],
)

request_latency = Summary(
f"{metric_prefix}request_latency_seconds", "Time spent processing agent request", ["task_type", "operation"]
)
input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])

input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])

class AsyncAgentService(AsyncAgentServiceServicer):
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
try:
with request_latency.labels(task_type=request.template.type, operation=create_operation).time():
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None

input_literal_size.labels(task_type=tmp.type).observe(request.inputs.ByteSize())
def agent_exception_handler(func):
async def wrapper(
self,
request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
context: grpc.ServicerContext,
*args,
**kwargs,
):
if isinstance(request, CreateTaskRequest):
task_type = request.template.type
operation = create_operation
if request.inputs:
input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
elif isinstance(request, GetTaskRequest):
task_type = request.task_type
operation = get_operation
elif isinstance(request, DeleteTaskRequest):
task_type = request.task_type
operation = delete_operation
else:
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details("Method not implemented!")
return

agent = AgentRegistry.get_agent(tmp.type)
logger.info(f"{tmp.type} agent start creating the job")
if agent.asynchronous:
try:
res = await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
)
request_success_count.labels(task_type=tmp.type, operation=create_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run async create with error {e}")
raise e
try:
res = await asyncio.to_thread(
agent.create,
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
)
request_success_count.labels(task_type=tmp.type, operation=create_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run sync create with error {e}")
raise
try:
with request_latency.labels(task_type=task_type, operation=operation).time():
res = await func(self, request, context, *args, **kwargs)
request_success_count.labels(task_type=task_type, operation=operation).inc()
return res
except FlyteAgentNotFound:
error_message = f"Cannot find agent for task type: {task_type}."
logger.error(error_message)
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc()
except Exception as e:
error_message = f"failed to {operation} {task_type} task with error {e}."
logger.error(error_message)
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")
request_failure_count.labels(task_type=tmp.type, operation=create_operation).inc()
context.set_details(error_message)
request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc()

return wrapper


class AsyncAgentService(AsyncAgentServiceServicer):
@agent_exception_handler
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
agent = AgentRegistry.get_agent(tmp.type)

logger.info(f"{tmp.type} agent start creating the job")
if agent.asynchronous:
return await agent.async_create(
context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp
)
return await asyncio.to_thread(
agent.create,
context=context,
inputs=inputs,
output_prefix=request.output_prefix,
task_template=tmp,
)

@agent_exception_handler
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
try:
with request_latency.labels(task_type=request.task_type, operation="get").time():
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent.asynchronous:
try:
res = await agent.async_get(context=context, resource_meta=request.resource_meta)
request_success_count.labels(task_type=request.task_type, operation=get_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run async get with error {e}")
raise
try:
res = await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta)
request_success_count.labels(task_type=request.task_type, operation=get_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run sync get with error {e}")
raise
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")
request_failure_count.labels(task_type=request.task_type, operation=get_operation).inc()
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start checking the status of the job")
if agent.asynchronous:
return await agent.async_get(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.get, context=context, resource_meta=request.resource_meta)

@agent_exception_handler
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
try:
with request_latency.labels(task_type=request.task_type, operation="delete").time():
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent.asynchronous:
try:
res = await agent.async_delete(context=context, resource_meta=request.resource_meta)
request_success_count.labels(task_type=request.task_type, operation=delete_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run async delete with error {e}")
raise
try:
res = asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta)
request_success_count.labels(task_type=request.task_type, operation=delete_operation).inc()
return res
except Exception as e:
logger.error(f"failed to run sync delete with error {e}")
raise
except Exception as e:
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
request_failure_count.labels(task_type=request.task_type, operation=delete_operation).inc()
agent = AgentRegistry.get_agent(request.task_type)
logger.info(f"{agent.task_type} agent start deleting the job")
if agent.asynchronous:
return await agent.async_delete(context=context, resource_meta=request.resource_meta)
return await asyncio.to_thread(agent.delete, context=context, resource_meta=request.resource_meta)
3 changes: 2 additions & 1 deletion flytekit/extend/backend/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import TypeEngine
from flytekit.exceptions.system import FlyteAgentNotFound
from flytekit.models.literals import LiteralMap


Expand Down Expand Up @@ -125,7 +126,7 @@ def register(agent: AgentBase):
@staticmethod
def get_agent(task_type: str) -> typing.Optional[AgentBase]:
if task_type not in AgentRegistry._REGISTRY:
raise ValueError(f"Unrecognized task type {task_type}")
raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.")
return AgentRegistry._REGISTRY[task_type]


Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-mmcloud/tests/test_mmcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def say_hello0(name: str) -> str:
return f"Hello, {name}."

task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello0)
agent = AgentRegistry.get_agent(context, task_spec.template.type)
agent = AgentRegistry.get_agent(task_spec.template.type)

assert isinstance(agent, MMCloudAgent)

Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(self, **kwargs):
t.execute()

t._task_type = "non-exist-type"
with pytest.raises(Exception, match="Unrecognized task type non-exist-type"):
with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."):
t.execute()


Expand Down

0 comments on commit 11e98fb

Please sign in to comment.