From c6a38b8e35e811fbc11372871b496a4a961e7d00 Mon Sep 17 00:00:00 2001 From: Joungjin Lee Date: Sat, 2 Mar 2024 17:34:00 +0900 Subject: [PATCH 1/5] Propagate raised error to caller, cancel all remaining futures --- asynq/asynq_to_async.py | 16 +++++++++++++--- asynq/decorators.py | 20 +++++++++++++++++--- asynq/tests/test_asynq_to_async.py | 21 +++++++++++++++++++++ 3 files changed, 51 insertions(+), 6 deletions(-) diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 9885fa7..3aebd5e 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -27,6 +27,16 @@ def is_asyncio_mode(): return _asyncio_mode > 0 +async def _gather(*awaitables): + """Gather awaitables, but cancel all of them if one of them fails.""" + try: + futures = asyncio.gather(*awaitables) + return await futures + except Exception: + futures.cancel() + raise + + async def resolve_awaitables(x: Any): """ Resolve a possibly-nested collection of awaitables. @@ -38,11 +48,11 @@ async def resolve_awaitables(x: Any): if isinstance(x, BatchItemBase): raise RuntimeError("asynq BatchItem is not supported in asyncio mode") if isinstance(x, list): - return await asyncio.gather(*[resolve_awaitables(item) for item in x]) + return await _gather(*[resolve_awaitables(item) for item in x]) if isinstance(x, tuple): - return tuple(await asyncio.gather(*[resolve_awaitables(item) for item in x])) + return tuple(await _gather(*[resolve_awaitables(item) for item in x])) if isinstance(x, dict): - resolutions = await asyncio.gather( + resolutions = await _gather( *[resolve_awaitables(value) for value in x.values()] ) return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)} diff --git a/asynq/decorators.py b/asynq/decorators.py index f7233e6..e8f1fc9 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -15,6 +15,7 @@ import asyncio import inspect +import sys from typing import Any, Coroutine import qcore.decorators @@ -141,17 +142,30 @@ def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]: async def wrapped(*_args, **_kwargs): task = asyncio.current_task() with asynq_to_async.AsyncioMode(): - send = None + send, exception = None, None + generator = self.fn(*_args, **_kwargs) while True: resume_contexts_asyncio(task) try: - result = generator.send(send) + if exception is None: + result = generator.send(send) + else: + result = generator.throw( + type(exception), + exception, + exception.__traceback__, + ) except StopIteration as exc: return exc.value pause_contexts_asyncio(task) - send = await asynq_to_async.resolve_awaitables(result) + try: + send = await asynq_to_async.resolve_awaitables(result) + exception = None + except Exception as exc: + traceback = sys.exc_info()[2] + exception = exc self.asyncio_fn = wrapped else: diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index c6194f9..31b0b19 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -14,6 +14,7 @@ import asyncio +import pytest import time from qcore.asserts import assert_eq @@ -51,6 +52,26 @@ def g(x): assert asyncio.run(g.asyncio(5)) == {"a": [1, 2], "b": (3, 4), "c": 5, "d": 200} +def test_asyncio_exception(): + async def f2_async(): + assert False + + @asynq.asynq(asyncio_fn=f2_async) + def f2(): + assert False + + @asynq.asynq() + def f3(): + yield f2.asynq() + + @asynq.asynq() + def f(): + with pytest.raises(AssertionError): + yield [f3.asynq(), f3.asynq()] + + asyncio.run(f.asyncio()) + + def test_context(): async def blocking_op(): await asyncio.sleep(0.1) From a7dd3eecaf9018fec0fdebea084d08e563b2d924 Mon Sep 17 00:00:00 2001 From: Joungjin Lee Date: Sat, 2 Mar 2024 18:44:36 +0900 Subject: [PATCH 2/5] remove unused variable --- asynq/decorators.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/asynq/decorators.py b/asynq/decorators.py index e8f1fc9..a3b62c8 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -15,7 +15,6 @@ import asyncio import inspect -import sys from typing import Any, Coroutine import qcore.decorators @@ -164,7 +163,6 @@ async def wrapped(*_args, **_kwargs): send = await asynq_to_async.resolve_awaitables(result) exception = None except Exception as exc: - traceback = sys.exc_info()[2] exception = exc self.asyncio_fn = wrapped From 31a444c74ee06c512f75b5e05a8b5e180ee39341 Mon Sep 17 00:00:00 2001 From: Joungjin Lee Date: Sat, 2 Mar 2024 22:06:06 +0900 Subject: [PATCH 3/5] Fix: wait for other tasks, don't cancel --- asynq/asynq_to_async.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 3aebd5e..cd2a84b 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -27,14 +27,12 @@ def is_asyncio_mode(): return _asyncio_mode > 0 -async def _gather(*awaitables): - """Gather awaitables, but cancel all of them if one of them fails.""" - try: - futures = asyncio.gather(*awaitables) - return await futures - except Exception: - futures.cancel() - raise +async def _gather(awaitables): + """Gather awaitables, but wait all other awaitables to finish even if some of them fail.""" + + tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables] + done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + return [task.result() for task in done] async def resolve_awaitables(x: Any): @@ -48,13 +46,11 @@ async def resolve_awaitables(x: Any): if isinstance(x, BatchItemBase): raise RuntimeError("asynq BatchItem is not supported in asyncio mode") if isinstance(x, list): - return await _gather(*[resolve_awaitables(item) for item in x]) + return await _gather([resolve_awaitables(item) for item in x]) if isinstance(x, tuple): - return tuple(await _gather(*[resolve_awaitables(item) for item in x])) + return tuple(await _gather([resolve_awaitables(item) for item in x])) if isinstance(x, dict): - resolutions = await _gather( - *[resolve_awaitables(value) for value in x.values()] - ) + resolutions = await _gather([resolve_awaitables(value) for value in x.values()]) return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)} if x is None: return None From c4a8790555171b4e38f632e5424f1d66ad011c53 Mon Sep 17 00:00:00 2001 From: SoulTch Date: Tue, 5 Mar 2024 15:20:11 +0900 Subject: [PATCH 4/5] Update asynq_to_async.py Fix bug --- asynq/asynq_to_async.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index cd2a84b..24e7c05 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -31,8 +31,8 @@ async def _gather(awaitables): """Gather awaitables, but wait all other awaitables to finish even if some of them fail.""" tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables] - done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) - return [task.result() for task in done] + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + return [task.result() for task in tasks] async def resolve_awaitables(x: Any): From ab142ccb1dd0a14c7ac17b96e74e3cd17e56af01 Mon Sep 17 00:00:00 2001 From: Joungjin Lee Date: Tue, 5 Mar 2024 17:13:48 +0900 Subject: [PATCH 5/5] _gather: prevent printing "Task exception was never retrieved", add some comments and tests --- asynq/asynq_to_async.py | 7 +++++ asynq/tests/test_asynq_to_async.py | 44 ++++++++++++++++++++++-------- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 24e7c05..ce8aa37 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -31,7 +31,14 @@ async def _gather(awaitables): """Gather awaitables, but wait all other awaitables to finish even if some of them fail.""" tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables] + + # Wait for all tasks to finish, even if some of them fail. await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + # mark exceptions are retrieved. + for task in tasks: + task.exception() + return [task.result() for task in tasks] diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index 31b0b19..df2f32e 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -53,23 +53,45 @@ def g(x): def test_asyncio_exception(): - async def f2_async(): - assert False + call_count = 0 - @asynq.asynq(asyncio_fn=f2_async) - def f2(): + async def func_success_async(): + nonlocal call_count + await asyncio.sleep(0.25) + call_count += 1 + + @asynq.asynq(asyncio_fn=func_success_async) + def func_success(): + raise NotImplementedError() + + async def func_fail_async(): + nonlocal call_count + await asyncio.sleep(0.05) + call_count += 1 assert False - @asynq.asynq() - def f3(): - yield f2.asynq() + @asynq.asynq(asyncio_fn=func_fail_async) + def func_fail(): + raise NotImplementedError() @asynq.asynq() - def f(): + def func_main(): with pytest.raises(AssertionError): - yield [f3.asynq(), f3.asynq()] - - asyncio.run(f.asyncio()) + # func_fail will fail earlier than func_success + # but this statement should wait for all tasks to finish. + yield [ + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_success.asynq(), + func_success.asynq(), + func_fail.asynq(), + ] + + asyncio.run(func_main.asyncio()) + assert call_count == 8 def test_context():