diff --git a/asynq/asynq_to_async.py b/asynq/asynq_to_async.py index 24e7c05..ce8aa37 100644 --- a/asynq/asynq_to_async.py +++ b/asynq/asynq_to_async.py @@ -31,7 +31,14 @@ async def _gather(awaitables): """Gather awaitables, but wait all other awaitables to finish even if some of them fail.""" tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables] + + # Wait for all tasks to finish, even if some of them fail. await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + # mark exceptions are retrieved. + for task in tasks: + task.exception() + return [task.result() for task in tasks] diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index 31b0b19..df2f32e 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -53,23 +53,45 @@ def g(x): def test_asyncio_exception(): - async def f2_async(): - assert False + call_count = 0 - @asynq.asynq(asyncio_fn=f2_async) - def f2(): + async def func_success_async(): + nonlocal call_count + await asyncio.sleep(0.25) + call_count += 1 + + @asynq.asynq(asyncio_fn=func_success_async) + def func_success(): + raise NotImplementedError() + + async def func_fail_async(): + nonlocal call_count + await asyncio.sleep(0.05) + call_count += 1 assert False - @asynq.asynq() - def f3(): - yield f2.asynq() + @asynq.asynq(asyncio_fn=func_fail_async) + def func_fail(): + raise NotImplementedError() @asynq.asynq() - def f(): + def func_main(): with pytest.raises(AssertionError): - yield [f3.asynq(), f3.asynq()] - - asyncio.run(f.asyncio()) + # func_fail will fail earlier than func_success + # but this statement should wait for all tasks to finish. + yield [ + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_fail.asynq(), + func_success.asynq(), + func_success.asynq(), + func_success.asynq(), + func_fail.asynq(), + ] + + asyncio.run(func_main.asyncio()) + assert call_count == 8 def test_context():