Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDESK-7371] Support eager mode in async celery #2745

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 34 additions & 5 deletions superdesk/celery_app/context_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def run_async(self, *args: Any, **kwargs: Any) -> Any:
If the event loop is running, returns an asyncio.Task that represents the execution of the coroutine.
Otherwise it runs the tasks and returns the result of the task.
"""

loop = asyncio.get_event_loop()
is_always_eager = self._is_always_eager()

# We need a wrapper to handle exceptions inside the async function because asyncio
# does not propagate them in the same way as synchronous exceptions. This ensures that
Expand All @@ -67,10 +66,36 @@ async def wrapper():
self.handle_exception(e)
return None

if not loop.is_running():
return loop.run_until_complete(wrapper())
if is_always_eager:
return asyncio.create_task(wrapper())
else:
background_tasks = set()
loop = asyncio.get_event_loop()

# the loop might not be running even if `CELERY_TASK_ALWAYS_EAGER` is False
if not loop.is_running():
return loop.run_until_complete(wrapper())

# **Important** from asyncio documentation
# https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
# Save a reference to the result of this function, to avoid a task disappearing mid-execution.
# The event loop only keeps weak references to tasks. A task that isn’t referenced elsewhere may get
# garbage collected at any time, even before it’s done
task = asyncio.create_task(wrapper())
background_tasks.add(task)
task.add_done_callback(background_tasks.discard)
return task

async def apply_async(self, args: Tuple = (), kwargs: Dict = {}, **other_kwargs) -> Any:
"""
Schedules the task asynchronously. Awaits the result if `CELERY_TASK_ALWAYS_EAGER` is True.
"""
# directly run and await the task if eager
if self._is_always_eager():
async_result = super().apply_async(args=args, kwargs=kwargs, **other_kwargs)
return await async_result.get()

return asyncio.create_task(wrapper())
return super().apply_async(args=args, kwargs=kwargs, **other_kwargs)

def handle_exception(self, exc: Exception) -> None:
"""
Expand All @@ -85,3 +110,7 @@ def on_failure(self, exc: Exception, task_id: str, args: Tuple, kwargs: Dict, ei
# TODO-ASYNC: Support async with ``on_failure`` method
# async with self.get_current_app().app_context():
self.handle_exception(exc)

def _is_always_eager(self):
app = self.get_current_app()
return app.config.get("CELERY_TASK_ALWAYS_EAGER", False)
14 changes: 8 additions & 6 deletions tests/celery_app/context_task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

from superdesk.errors import SuperdeskError
from superdesk.celery_app import HybridAppContextTask
from superdesk.tests import AsyncFlaskTestCase, markers
from superdesk.tests import AsyncFlaskTestCase

# NOTE: all tasks below are in eager mode because of global
# tests settings. See `update_config` function in tests.__init__.py


@markers.requires_async_celery
class TestHybridAppContextTask(AsyncFlaskTestCase):
async def test_sync_task(self):
@self.app.celery.task(base=HybridAppContextTask)
def sync_task():
return "sync result"

result = sync_task.apply_async().get()
result = await sync_task.apply_async()
self.assertEqual(result, "sync result")

async def test_async_task(self):
Expand All @@ -22,7 +24,7 @@ async def async_task():
await asyncio.sleep(0.1)
return "async result"

result = await async_task.apply_async().get()
result = await async_task.apply_async()
self.assertEqual(result, "async result")

async def test_sync_task_exception(self):
Expand All @@ -31,7 +33,7 @@ def sync_task_exception():
raise SuperdeskError("Test exception")

with patch("superdesk.celery_app.context_task.logger") as mock_logger:
sync_task_exception.apply_async().get(propagate=True)
await sync_task_exception.apply_async()
expected_exc = SuperdeskError("Test exception")
expected_msg = f"Error handling task: {str(expected_exc)}"
mock_logger.exception.assert_called_once_with(expected_msg)
Expand All @@ -42,7 +44,7 @@ async def async_task_exception():
raise SuperdeskError("Async exception")

with patch("superdesk.celery_app.context_task.logger") as mock_logger:
await async_task_exception.apply_async().get()
await async_task_exception.apply_async()

expected_exc = SuperdeskError("Async exception")
expected_msg = f"Error handling task: {str(expected_exc)}"
Expand Down
Loading