diff --git a/superdesk/celery_app/context_task.py b/superdesk/celery_app/context_task.py index ee6fe927b8..c023e57661 100644 --- a/superdesk/celery_app/context_task.py +++ b/superdesk/celery_app/context_task.py @@ -52,8 +52,7 @@ def run_async(self, *args: Any, **kwargs: Any) -> Any: If the event loop is running, returns an asyncio.Task that represents the execution of the coroutine. Otherwise it runs the tasks and returns the result of the task. """ - - loop = asyncio.get_event_loop() + is_always_eager = self._is_always_eager() # We need a wrapper to handle exceptions inside the async function because asyncio # does not propagate them in the same way as synchronous exceptions. This ensures that @@ -67,10 +66,36 @@ async def wrapper(): self.handle_exception(e) return None - if not loop.is_running(): - return loop.run_until_complete(wrapper()) + if is_always_eager: + return asyncio.create_task(wrapper()) + else: + background_tasks = set() + loop = asyncio.get_event_loop() + + # the loop might not be running even if `CELERY_TASK_ALWAYS_EAGER` is False + if not loop.is_running(): + return loop.run_until_complete(wrapper()) + + # **Important** from asyncio documentation + # https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + # Save a reference to the result of this function, to avoid a task disappearing mid-execution. + # The event loop only keeps weak references to tasks. A task that isn’t referenced elsewhere may get + # garbage collected at any time, even before it’s done + task = asyncio.create_task(wrapper()) + background_tasks.add(task) + task.add_done_callback(background_tasks.discard) + return task + + async def apply_async(self, args: Tuple = (), kwargs: Dict = {}, **other_kwargs) -> Any: + """ + Schedules the task asynchronously. Awaits the result if `CELERY_TASK_ALWAYS_EAGER` is True. + """ + # directly run and await the task if eager + if self._is_always_eager(): + async_result = super().apply_async(args=args, kwargs=kwargs, **other_kwargs) + return await async_result.get() - return asyncio.create_task(wrapper()) + return super().apply_async(args=args, kwargs=kwargs, **other_kwargs) def handle_exception(self, exc: Exception) -> None: """ @@ -85,3 +110,7 @@ def on_failure(self, exc: Exception, task_id: str, args: Tuple, kwargs: Dict, ei # TODO-ASYNC: Support async with ``on_failure`` method # async with self.get_current_app().app_context(): self.handle_exception(exc) + + def _is_always_eager(self): + app = self.get_current_app() + return app.config.get("CELERY_TASK_ALWAYS_EAGER", False) diff --git a/tests/celery_app/context_task_test.py b/tests/celery_app/context_task_test.py index d50d8d0953..2de25c81c0 100644 --- a/tests/celery_app/context_task_test.py +++ b/tests/celery_app/context_task_test.py @@ -3,17 +3,19 @@ from superdesk.errors import SuperdeskError from superdesk.celery_app import HybridAppContextTask -from superdesk.tests import AsyncFlaskTestCase, markers +from superdesk.tests import AsyncFlaskTestCase + +# NOTE: all tasks below are in eager mode because of global +# tests settings. See `update_config` function in tests.__init__.py -@markers.requires_async_celery class TestHybridAppContextTask(AsyncFlaskTestCase): async def test_sync_task(self): @self.app.celery.task(base=HybridAppContextTask) def sync_task(): return "sync result" - result = sync_task.apply_async().get() + result = await sync_task.apply_async() self.assertEqual(result, "sync result") async def test_async_task(self): @@ -22,7 +24,7 @@ async def async_task(): await asyncio.sleep(0.1) return "async result" - result = await async_task.apply_async().get() + result = await async_task.apply_async() self.assertEqual(result, "async result") async def test_sync_task_exception(self): @@ -31,7 +33,7 @@ def sync_task_exception(): raise SuperdeskError("Test exception") with patch("superdesk.celery_app.context_task.logger") as mock_logger: - sync_task_exception.apply_async().get(propagate=True) + await sync_task_exception.apply_async() expected_exc = SuperdeskError("Test exception") expected_msg = f"Error handling task: {str(expected_exc)}" mock_logger.exception.assert_called_once_with(expected_msg) @@ -42,7 +44,7 @@ async def async_task_exception(): raise SuperdeskError("Async exception") with patch("superdesk.celery_app.context_task.logger") as mock_logger: - await async_task_exception.apply_async().get() + await async_task_exception.apply_async() expected_exc = SuperdeskError("Async exception") expected_msg = f"Error handling task: {str(expected_exc)}"