diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 9885fa7..3aebd5e 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -27,6 +27,16 @@ def is_asyncio_mode(): return _asyncio_mode > 0 +async def _gather(*awaitables): + """Gather awaitables, but cancel all of them if one of them fails.""" + try: + futures = asyncio.gather(*awaitables) + return await futures + except Exception: + futures.cancel() + raise + + async def resolve_awaitables(x: Any): """ Resolve a possibly-nested collection of awaitables. @@ -38,11 +48,11 @@ async def resolve_awaitables(x: Any): if isinstance(x, BatchItemBase): raise RuntimeError("asynq BatchItem is not supported in asyncio mode") if isinstance(x, list): - return await asyncio.gather(*[resolve_awaitables(item) for item in x]) + return await _gather(*[resolve_awaitables(item) for item in x]) if isinstance(x, tuple): - return tuple(await asyncio.gather(*[resolve_awaitables(item) for item in x])) + return tuple(await _gather(*[resolve_awaitables(item) for item in x])) if isinstance(x, dict): - resolutions = await asyncio.gather( + resolutions = await _gather( *[resolve_awaitables(value) for value in x.values()] ) return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)} diff --git a/asynq/decorators.py b/asynq/decorators.py index f7233e6..e8f1fc9 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -15,6 +15,7 @@ import asyncio import inspect +import sys from typing import Any, Coroutine import qcore.decorators @@ -141,17 +142,30 @@ def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]: async def wrapped(*_args, **_kwargs): task = asyncio.current_task() with asynq_to_async.AsyncioMode(): - send = None + send, exception = None, None + generator = self.fn(*_args, **_kwargs) while True: resume_contexts_asyncio(task) try: - result = generator.send(send) + if exception is None: + result = generator.send(send) + else: + result = generator.throw( + type(exception), + exception, + exception.__traceback__, + ) except StopIteration as exc: return exc.value pause_contexts_asyncio(task) - send = await asynq_to_async.resolve_awaitables(result) + try: + send = await asynq_to_async.resolve_awaitables(result) + exception = None + except Exception as exc: + traceback = sys.exc_info()[2] + exception = exc self.asyncio_fn = wrapped else: diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index c6194f9..31b0b19 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -14,6 +14,7 @@ import asyncio +import pytest import time from qcore.asserts import assert_eq @@ -51,6 +52,26 @@ def g(x): assert asyncio.run(g.asyncio(5)) == {"a": [1, 2], "b": (3, 4), "c": 5, "d": 200} +def test_asyncio_exception(): + async def f2_async(): + assert False + + @asynq.asynq(asyncio_fn=f2_async) + def f2(): + assert False + + @asynq.asynq() + def f3(): + yield f2.asynq() + + @asynq.asynq() + def f(): + with pytest.raises(AssertionError): + yield [f3.asynq(), f3.asynq()] + + asyncio.run(f.asyncio()) + + def test_context(): async def blocking_op(): await asyncio.sleep(0.1)