Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the behavior consistent for asyncio and asynq: raised error. #136

Merged
merged 5 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions asynq/asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,21 @@ def is_asyncio_mode():
return _asyncio_mode > 0


async def _gather(awaitables):
"""Gather awaitables, but wait all other awaitables to finish even if some of them fail."""

tasks = [asyncio.ensure_future(awaitable) for awaitable in awaitables]

# Wait for all tasks to finish, even if some of them fail.
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)

# mark exceptions are retrieved.
for task in tasks:
task.exception()

return [task.result() for task in tasks]


async def resolve_awaitables(x: Any):
"""
Resolve a possibly-nested collection of awaitables.
Expand All @@ -38,13 +53,11 @@ async def resolve_awaitables(x: Any):
if isinstance(x, BatchItemBase):
raise RuntimeError("asynq BatchItem is not supported in asyncio mode")
if isinstance(x, list):
return await asyncio.gather(*[resolve_awaitables(item) for item in x])
return await _gather([resolve_awaitables(item) for item in x])
if isinstance(x, tuple):
return tuple(await asyncio.gather(*[resolve_awaitables(item) for item in x]))
return tuple(await _gather([resolve_awaitables(item) for item in x]))
if isinstance(x, dict):
resolutions = await asyncio.gather(
*[resolve_awaitables(value) for value in x.values()]
)
resolutions = await _gather([resolve_awaitables(value) for value in x.values()])
return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)}
if x is None:
return None
Expand Down
18 changes: 15 additions & 3 deletions asynq/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,29 @@ def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]:
async def wrapped(*_args, **_kwargs):
task = asyncio.current_task()
with asynq_to_async.AsyncioMode():
send = None
send, exception = None, None

generator = self.fn(*_args, **_kwargs)
while True:
resume_contexts_asyncio(task)
try:
result = generator.send(send)
if exception is None:
result = generator.send(send)
else:
result = generator.throw(
type(exception),
exception,
exception.__traceback__,
)
except StopIteration as exc:
return exc.value

pause_contexts_asyncio(task)
send = await asynq_to_async.resolve_awaitables(result)
try:
send = await asynq_to_async.resolve_awaitables(result)
exception = None
except Exception as exc:
exception = exc

self.asyncio_fn = wrapped
else:
Expand Down
43 changes: 43 additions & 0 deletions asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
import pytest
import time

from qcore.asserts import assert_eq
Expand Down Expand Up @@ -51,6 +52,48 @@ def g(x):
assert asyncio.run(g.asyncio(5)) == {"a": [1, 2], "b": (3, 4), "c": 5, "d": 200}


def test_asyncio_exception():
call_count = 0

async def func_success_async():
nonlocal call_count
await asyncio.sleep(0.25)
call_count += 1

@asynq.asynq(asyncio_fn=func_success_async)
def func_success():
raise NotImplementedError()

async def func_fail_async():
nonlocal call_count
await asyncio.sleep(0.05)
call_count += 1
assert False

@asynq.asynq(asyncio_fn=func_fail_async)
def func_fail():
raise NotImplementedError()

@asynq.asynq()
def func_main():
with pytest.raises(AssertionError):
# func_fail will fail earlier than func_success
# but this statement should wait for all tasks to finish.
yield [
func_success.asynq(),
func_fail.asynq(),
func_success.asynq(),
func_fail.asynq(),
func_success.asynq(),
func_success.asynq(),
func_success.asynq(),
func_fail.asynq(),
]

asyncio.run(func_main.asyncio())
assert call_count == 8


def test_context():
async def blocking_op():
await asyncio.sleep(0.1)
Expand Down
Loading