diff --git a/agenthub/codeact_agent/codeact_agent.py b/agenthub/codeact_agent/codeact_agent.py index c8d31ad04670..2471fd820665 100644 --- a/agenthub/codeact_agent/codeact_agent.py +++ b/agenthub/codeact_agent/codeact_agent.py @@ -5,6 +5,7 @@ from openhands.controller.agent import Agent from openhands.controller.state.state import State from openhands.core.config import AgentConfig +from openhands.core.exceptions import OperationCancelled from openhands.core.logger import openhands_logger as logger from openhands.core.message import ImageContent, Message, TextContent from openhands.events.action import ( @@ -211,8 +212,11 @@ def step(self, state: State) -> Action: 'anthropic-beta': 'prompt-caching-2024-07-31', } + # TODO: move exception handling to agent_controller try: response = self.llm.completion(**params) + except OperationCancelled as e: + raise e except Exception as e: logger.error(f'{e}') error_message = '{}: {}'.format(type(e).__name__, str(e).split('\n')[0]) diff --git a/openhands/core/exceptions.py b/openhands/core/exceptions.py index 3d28938ae14e..1a8f7aeb9ace 100644 --- a/openhands/core/exceptions.py +++ b/openhands/core/exceptions.py @@ -77,3 +77,10 @@ def __init__(self, message='User cancelled the request'): class MicroAgentValidationError(Exception): def __init__(self, message='Micro agent validation failed'): super().__init__(message) + + +class OperationCancelled(Exception): + """Exception raised when an operation is cancelled (e.g. by a keyboard interrupt).""" + + def __init__(self, message='Operation was cancelled'): + super().__init__(message) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 1d5e59422304..b2dc0e7e573c 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -24,15 +24,21 @@ from tenacity import ( retry, retry_if_exception_type, + retry_if_not_exception_type, stop_after_attempt, wait_exponential, ) -from openhands.core.exceptions import LLMResponseError, UserCancelledError +from openhands.core.exceptions import ( + LLMResponseError, + OperationCancelled, + UserCancelledError, +) from openhands.core.logger import llm_prompt_logger, llm_response_logger from openhands.core.logger import openhands_logger as logger from openhands.core.message import Message from openhands.core.metrics import Metrics +from openhands.runtime.utils.shutdown_listener import should_exit __all__ = ['LLM'] @@ -169,13 +175,18 @@ def __init__( completion_unwrapped = self._completion - def attempt_on_error(retry_state): - """Custom attempt function for litellm completion.""" + def log_retry_attempt(retry_state): + """With before_sleep, this is called before `custom_completion_wait` and + ONLY if the retry is triggered by an exception.""" + if should_exit(): + raise OperationCancelled( + 'Operation cancelled.' + ) # exits the @retry loop + exception = retry_state.outcome.exception() logger.error( - f'{retry_state.outcome.exception()}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.', + f'{exception}. Attempt #{retry_state.attempt_number} | You can customize retry values in the configuration.', exc_info=False, ) - return None def custom_completion_wait(retry_state): """Custom wait function for litellm completion.""" @@ -211,10 +222,13 @@ def custom_completion_wait(retry_state): return exponential_wait(retry_state) @retry( - after=attempt_on_error, + before_sleep=log_retry_attempt, stop=stop_after_attempt(self.config.num_retries), reraise=True, - retry=retry_if_exception_type(self.retry_exceptions), + retry=( + retry_if_exception_type(self.retry_exceptions) + & retry_if_not_exception_type(OperationCancelled) + ), wait=custom_completion_wait, ) def wrapper(*args, **kwargs): @@ -278,10 +292,13 @@ def wrapper(*args, **kwargs): async_completion_unwrapped = self._async_completion @retry( - after=attempt_on_error, + before_sleep=log_retry_attempt, stop=stop_after_attempt(self.config.num_retries), reraise=True, - retry=retry_if_exception_type(self.retry_exceptions), + retry=( + retry_if_exception_type(self.retry_exceptions) + & retry_if_not_exception_type(OperationCancelled) + ), wait=custom_completion_wait, ) async def async_completion_wrapper(*args, **kwargs): @@ -351,10 +368,13 @@ async def check_stopped(): pass @retry( - after=attempt_on_error, + before_sleep=log_retry_attempt, stop=stop_after_attempt(self.config.num_retries), reraise=True, - retry=retry_if_exception_type(self.retry_exceptions), + retry=( + retry_if_exception_type(self.retry_exceptions) + & retry_if_not_exception_type(OperationCancelled) + ), wait=custom_completion_wait, ) async def async_acompletion_stream_wrapper(*args, **kwargs): @@ -448,6 +468,9 @@ def _format_content_element(self, element): return str(element) async def _call_acompletion(self, *args, **kwargs): + """This is a wrapper for the litellm acompletion function which + makes it mockable for testing. + """ return await litellm.acompletion(*args, **kwargs) @property @@ -528,10 +551,15 @@ def _post_completion(self, response) -> None: output_tokens = usage.get('completion_tokens') if input_tokens: - stats += 'Input tokens: ' + str(input_tokens) + '\n' + stats += 'Input tokens: ' + str(input_tokens) if output_tokens: - stats += 'Output tokens: ' + str(output_tokens) + '\n' + stats += ( + (' | ' if input_tokens else '') + + 'Output tokens: ' + + str(output_tokens) + + '\n' + ) model_extra = usage.get('model_extra', {}) diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index a08c6a84974f..827e27676a73 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -1,15 +1,38 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from litellm.exceptions import ( + APIConnectionError, + ContentPolicyViolationError, + InternalServerError, + OpenAIError, + RateLimitError, +) from openhands.core.config import LLMConfig +from openhands.core.exceptions import OperationCancelled from openhands.core.metrics import Metrics from openhands.llm.llm import LLM +@pytest.fixture(autouse=True) +def mock_logger(monkeypatch): + # suppress logging of completion data to file + mock_logger = MagicMock() + monkeypatch.setattr('openhands.llm.llm.llm_prompt_logger', mock_logger) + monkeypatch.setattr('openhands.llm.llm.llm_response_logger', mock_logger) + return mock_logger + + @pytest.fixture def default_config(): - return LLMConfig(model='gpt-4o', api_key='test_key') + return LLMConfig( + model='gpt-4o', + api_key='test_key', + num_retries=2, + retry_min_wait=1, + retry_max_wait=2, + ) def test_llm_init_with_default_config(default_config): @@ -64,7 +87,7 @@ def test_llm_init_with_metrics(): def test_llm_reset(): - llm = LLM(LLMConfig(model='gpt-3.5-turbo', api_key='test_key')) + llm = LLM(LLMConfig(model='gpt-4o-mini', api_key='test_key')) initial_metrics = llm.metrics llm.reset() assert llm.metrics is not initial_metrics @@ -73,7 +96,7 @@ def test_llm_reset(): @patch('openhands.llm.llm.litellm.get_model_info') def test_llm_init_with_openrouter_model(mock_get_model_info, default_config): - default_config.model = 'openrouter:gpt-3.5-turbo' + default_config.model = 'openrouter:gpt-4o-mini' mock_get_model_info.return_value = { 'max_input_tokens': 7000, 'max_output_tokens': 1500, @@ -81,4 +104,197 @@ def test_llm_init_with_openrouter_model(mock_get_model_info, default_config): llm = LLM(default_config) assert llm.config.max_input_tokens == 7000 assert llm.config.max_output_tokens == 1500 - mock_get_model_info.assert_called_once_with('openrouter:gpt-3.5-turbo') + mock_get_model_info.assert_called_once_with('openrouter:gpt-4o-mini') + + +# Tests involving completion and retries + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_with_mocked_logger( + mock_litellm_completion, default_config, mock_logger +): + mock_litellm_completion.return_value = { + 'choices': [{'message': {'content': 'Test response'}}] + } + + llm = LLM(config=default_config) + response = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert response['choices'][0]['message']['content'] == 'Test response' + assert mock_litellm_completion.call_count == 1 + + mock_logger.debug.assert_called() + + +@pytest.mark.parametrize( + 'exception_class,extra_args,expected_retries', + [ + ( + APIConnectionError, + {'llm_provider': 'test_provider', 'model': 'test_model'}, + 2, + ), + ( + ContentPolicyViolationError, + {'model': 'test_model', 'llm_provider': 'test_provider'}, + 2, + ), + ( + InternalServerError, + {'llm_provider': 'test_provider', 'model': 'test_model'}, + 2, + ), + (OpenAIError, {}, 2), + (RateLimitError, {'llm_provider': 'test_provider', 'model': 'test_model'}, 2), + ], +) +@patch('openhands.llm.llm.litellm_completion') +def test_completion_retries( + mock_litellm_completion, + default_config, + exception_class, + extra_args, + expected_retries, +): + mock_litellm_completion.side_effect = [ + exception_class('Test error message', **extra_args), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + + llm = LLM(config=default_config) + response = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert response['choices'][0]['message']['content'] == 'Retry successful' + assert mock_litellm_completion.call_count == expected_retries + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_rate_limit_wait_time(mock_litellm_completion, default_config): + with patch('time.sleep') as mock_sleep: + mock_litellm_completion.side_effect = [ + RateLimitError( + 'Rate limit exceeded', llm_provider='test_provider', model='test_model' + ), + {'choices': [{'message': {'content': 'Retry successful'}}]}, + ] + + llm = LLM(config=default_config) + response = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert response['choices'][0]['message']['content'] == 'Retry successful' + assert mock_litellm_completion.call_count == 2 + + mock_sleep.assert_called_once() + wait_time = mock_sleep.call_args[0][0] + assert ( + 60 <= wait_time <= 240 + ), f'Expected wait time between 60 and 240 seconds, but got {wait_time}' + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_exhausts_retries(mock_litellm_completion, default_config): + mock_litellm_completion.side_effect = APIConnectionError( + 'Persistent error', llm_provider='test_provider', model='test_model' + ) + + llm = LLM(config=default_config) + with pytest.raises(APIConnectionError): + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert mock_litellm_completion.call_count == llm.config.num_retries + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_operation_cancelled(mock_litellm_completion, default_config): + mock_litellm_completion.side_effect = OperationCancelled('Operation cancelled') + + llm = LLM(config=default_config) + with pytest.raises(OperationCancelled): + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert mock_litellm_completion.call_count == 1 + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_keyboard_interrupt(mock_litellm_completion, default_config): + def side_effect(*args, **kwargs): + raise KeyboardInterrupt('Simulated KeyboardInterrupt') + + mock_litellm_completion.side_effect = side_effect + + llm = LLM(config=default_config) + with pytest.raises(OperationCancelled): + try: + llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + except KeyboardInterrupt: + raise OperationCancelled('Operation cancelled due to KeyboardInterrupt') + + assert mock_litellm_completion.call_count == 1 + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_keyboard_interrupt_handler(mock_litellm_completion, default_config): + global _should_exit + + def side_effect(*args, **kwargs): + global _should_exit + _should_exit = True + return {'choices': [{'message': {'content': 'Simulated interrupt response'}}]} + + mock_litellm_completion.side_effect = side_effect + + llm = LLM(config=default_config) + result = llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + ) + + assert mock_litellm_completion.call_count == 1 + assert result['choices'][0]['message']['content'] == 'Simulated interrupt response' + assert _should_exit + + _should_exit = False + + +@patch('openhands.llm.llm.litellm_completion') +def test_completion_with_litellm_mock(mock_litellm_completion, default_config): + mock_response = { + 'choices': [{'message': {'content': 'This is a mocked response.'}}] + } + mock_litellm_completion.return_value = mock_response + + test_llm = LLM(config=default_config) + response = test_llm.completion( + messages=[{'role': 'user', 'content': 'Hello!'}], + stream=False, + drop_params=True, + ) + + # Assertions + assert response['choices'][0]['message']['content'] == 'This is a mocked response.' + mock_litellm_completion.assert_called_once() + + # Check if the correct arguments were passed to litellm_completion + call_args = mock_litellm_completion.call_args[1] # Get keyword arguments + assert call_args['model'] == default_config.model + assert call_args['messages'] == [{'role': 'user', 'content': 'Hello!'}] + assert not call_args['stream']