Skip to content

Commit

Permalink
(fix) CodeActAgent/LLM: react on should_exit flag (user cancellation) (
Browse files Browse the repository at this point in the history
  • Loading branch information
tobitege authored Sep 20, 2024
1 parent ebd9397 commit 01462e1
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 18 deletions.
4 changes: 4 additions & 0 deletions agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
from openhands.core.config import AgentConfig
from openhands.core.exceptions import OperationCancelled
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import ImageContent, Message, TextContent
from openhands.events.action import (
Expand Down Expand Up @@ -211,8 +212,11 @@ def step(self, state: State) -> Action:
'anthropic-beta': 'prompt-caching-2024-07-31',
}

# TODO: move exception handling to agent_controller
try:
response = self.llm.completion(**params)
except OperationCancelled as e:
raise e
except Exception as e:
logger.error(f'{e}')
error_message = '{}: {}'.format(type(e).__name__, str(e).split('\n')[0])
Expand Down
7 changes: 7 additions & 0 deletions openhands/core/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,10 @@ def __init__(self, message='User cancelled the request'):
class MicroAgentValidationError(Exception):
def __init__(self, message='Micro agent validation failed'):
super().__init__(message)


class OperationCancelled(Exception):
"""Exception raised when an operation is cancelled (e.g. by a keyboard interrupt)."""

def __init__(self, message='Operation was cancelled'):
super().__init__(message)
54 changes: 41 additions & 13 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,21 @@
from tenacity import (
retry,
retry_if_exception_type,
retry_if_not_exception_type,
stop_after_attempt,
wait_exponential,
)

from openhands.core.exceptions import LLMResponseError, UserCancelledError
from openhands.core.exceptions import (
LLMResponseError,
OperationCancelled,
UserCancelledError,
)
from openhands.core.logger import llm_prompt_logger, llm_response_logger
from openhands.core.logger import openhands_logger as logger
from openhands.core.message import Message
from openhands.core.metrics import Metrics
from openhands.runtime.utils.shutdown_listener import should_exit

__all__ = ['LLM']

Expand Down Expand Up @@ -169,13 +175,18 @@ def __init__(

completion_unwrapped = self._completion

def attempt_on_error(retry_state):
"""Custom attempt function for litellm completion."""
def log_retry_attempt(retry_state):
"""With before_sleep, this is called before `custom_completion_wait` and
ONLY if the retry is triggered by an exception."""
if should_exit():
raise OperationCancelled(
'Operation cancelled.'
) # exits the @retry loop
exception = retry_state.outcome.exception()
logger.error(
f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.',
exc_info=False,
)
return None

def custom_completion_wait(retry_state):
"""Custom wait function for litellm completion."""
Expand Down Expand Up @@ -211,10 +222,13 @@ def custom_completion_wait(retry_state):
return exponential_wait(retry_state)

@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
def wrapper(*args, **kwargs):
Expand Down Expand Up @@ -278,10 +292,13 @@ def wrapper(*args, **kwargs):
async_completion_unwrapped = self._async_completion

@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
async def async_completion_wrapper(*args, **kwargs):
Expand Down Expand Up @@ -351,10 +368,13 @@ async def check_stopped():
pass

@retry(
after=attempt_on_error,
before_sleep=log_retry_attempt,
stop=stop_after_attempt(self.config.num_retries),
reraise=True,
retry=retry_if_exception_type(self.retry_exceptions),
retry=(
retry_if_exception_type(self.retry_exceptions)
& retry_if_not_exception_type(OperationCancelled)
),
wait=custom_completion_wait,
)
async def async_acompletion_stream_wrapper(*args, **kwargs):
Expand Down Expand Up @@ -448,6 +468,9 @@ def _format_content_element(self, element):
return str(element)

async def _call_acompletion(self, *args, **kwargs):
"""This is a wrapper for the litellm acompletion function which
makes it mockable for testing.
"""
return await litellm.acompletion(*args, **kwargs)

@property
Expand Down Expand Up @@ -528,10 +551,15 @@ def _post_completion(self, response) -> None:
output_tokens = usage.get('completion_tokens')

if input_tokens:
stats += 'Input tokens: ' + str(input_tokens) + '\n'
stats += 'Input tokens: ' + str(input_tokens)

if output_tokens:
stats += 'Output tokens: ' + str(output_tokens) + '\n'
stats += (
(' | ' if input_tokens else '')
+ 'Output tokens: '
+ str(output_tokens)
+ '\n'
)

model_extra = usage.get('model_extra', {})

Expand Down
Loading

0 comments on commit 01462e1

Please sign in to comment.