Skip to content

Commit

Permalink
_gather: prevent printing "Task exception was never retrieved", add s…
Browse files Browse the repository at this point in the history
…ome comments and tests
  • Loading branch information
SoulTch committed Mar 5, 2024
1 parent c4a8790 commit ab142cc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
7 changes: 7 additions & 0 deletions asynq/asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ async def _gather(awaitables):
"""Gather awaitables, but wait all other awaitables to finish even if some of them fail."""

tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables]

# Wait for all tasks to finish, even if some of them fail.
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)

# mark exceptions are retrieved.
for task in tasks:
task.exception()

return [task.result() for task in tasks]


Expand Down
44 changes: 33 additions & 11 deletions asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,45 @@ def g(x):


def test_asyncio_exception():
async def f2_async():
assert False
call_count = 0

@asynq.asynq(asyncio_fn=f2_async)
def f2():
async def func_success_async():
nonlocal call_count
await asyncio.sleep(0.25)
call_count += 1

@asynq.asynq(asyncio_fn=func_success_async)
def func_success():
raise NotImplementedError()

async def func_fail_async():
nonlocal call_count
await asyncio.sleep(0.05)
call_count += 1
assert False

@asynq.asynq()
def f3():
yield f2.asynq()
@asynq.asynq(asyncio_fn=func_fail_async)
def func_fail():
raise NotImplementedError()

@asynq.asynq()
def f():
def func_main():
with pytest.raises(AssertionError):
yield [f3.asynq(), f3.asynq()]

asyncio.run(f.asyncio())
# func_fail will fail earlier than func_success
# but this statement should wait for all tasks to finish.
yield [
func_success.asynq(),
func_fail.asynq(),
func_success.asynq(),
func_fail.asynq(),
func_success.asynq(),
func_success.asynq(),
func_success.asynq(),
func_fail.asynq(),
]

asyncio.run(func_main.asyncio())
assert call_count == 8


def test_context():
Expand Down

0 comments on commit ab142cc

Please sign in to comment.