diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 9885fa7..ce8aa37 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -27,6 +27,21 @@ def is_asyncio_mode(): return _asyncio_mode > 0 +async def _gather(awaitables): + """Gather awaitables, but wait all other awaitables to finish even if some of them fail.""" + + tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables] + + # Wait for all tasks to finish, even if some of them fail. + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + # mark exceptions are retrieved. + for task in tasks: + task.exception() + + return [task.result() for task in tasks] + + async def resolve_awaitables(x: Any): """ Resolve a possibly-nested collection of awaitables. @@ -38,13 +53,11 @@ async def resolve_awaitables(x: Any): if isinstance(x, BatchItemBase): raise RuntimeError("asynq BatchItem is not supported in asyncio mode") if isinstance(x, list): - return await asyncio.gather(*[resolve_awaitables(item) for item in x]) + return await _gather([resolve_awaitables(item) for item in x]) if isinstance(x, tuple): - return tuple(await asyncio.gather(*[resolve_awaitables(item) for item in x])) + return tuple(await _gather([resolve_awaitables(item) for item in x])) if isinstance(x, dict): - resolutions = await asyncio.gather( - *[resolve_awaitables(value) for value in x.values()] - ) + resolutions = await _gather([resolve_awaitables(value) for value in x.values()]) return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)} if x is None: return None diff --git a/asynq/decorators.py b/asynq/decorators.py index f7233e6..a3b62c8 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -141,17 +141,29 @@ def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]: async def wrapped(*_args, **_kwargs): task = asyncio.current_task() with asynq_to_async.AsyncioMode(): - send = None + send, exception = None, None + generator = self.fn(*_args, **_kwargs) while True: resume_contexts_asyncio(task) try: - result = generator.send(send) + if exception is None: + result = generator.send(send) + else: + result = generator.throw( + type(exception), + exception, + exception.__traceback__, + ) except StopIteration as exc: return exc.value pause_contexts_asyncio(task) - send = await asynq_to_async.resolve_awaitables(result) + try: + send = await asynq_to_async.resolve_awaitables(result) + exception = None + except Exception as exc: + exception = exc self.asyncio_fn = wrapped else: diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index c6194f9..df2f32e 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -14,6 +14,7 @@ import asyncio +import pytest import time from qcore.asserts import assert_eq @@ -51,6 +52,48 @@ def g(x): assert asyncio.run(g.asyncio(5)) == {"a": [1, 2], "b": (3, 4), "c": 5, "d": 200} +def test_asyncio_exception(): + call_count = 0 + + async def func_success_async(): + nonlocal call_count + await asyncio.sleep(0.25) + call_count += 1 + + @asynq.asynq(asyncio_fn=func_success_async) + def func_success(): + raise NotImplementedError() + + async def func_fail_async(): + nonlocal call_count + await asyncio.sleep(0.05) + call_count += 1 + assert False + + @asynq.asynq(asyncio_fn=func_fail_async) + def func_fail(): + raise NotImplementedError() + + @asynq.asynq() + def func_main(): + with pytest.raises(AssertionError): + # func_fail will fail earlier than func_success + # but this statement should wait for all tasks to finish. + yield [ + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_success.asynq(), + func_success.asynq(), + func_fail.asynq(), + ] + + asyncio.run(func_main.asyncio()) + assert call_count == 8 + + def test_context(): async def blocking_op(): await asyncio.sleep(0.1)