Skip to content

Commit

Permalink
Propagate raised error to caller, cancel all remaining futures
Browse files Browse the repository at this point in the history
  • Loading branch information
SoulTch committed Mar 2, 2024
1 parent ebbc65f commit c6a38b8
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 6 deletions.
16 changes: 13 additions & 3 deletions asynq/asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ def is_asyncio_mode():
return _asyncio_mode > 0


async def _gather(*awaitables):
"""Gather awaitables, but cancel all of them if one of them fails."""
try:
futures = asyncio.gather(*awaitables)
return await futures
except Exception:
futures.cancel()
raise


async def resolve_awaitables(x: Any):
"""
Resolve a possibly-nested collection of awaitables.
Expand All @@ -38,11 +48,11 @@ async def resolve_awaitables(x: Any):
if isinstance(x, BatchItemBase):
raise RuntimeError("asynq BatchItem is not supported in asyncio mode")
if isinstance(x, list):
return await asyncio.gather(*[resolve_awaitables(item) for item in x])
return await _gather(*[resolve_awaitables(item) for item in x])
if isinstance(x, tuple):
return tuple(await asyncio.gather(*[resolve_awaitables(item) for item in x]))
return tuple(await _gather(*[resolve_awaitables(item) for item in x]))
if isinstance(x, dict):
resolutions = await asyncio.gather(
resolutions = await _gather(
*[resolve_awaitables(value) for value in x.values()]
)
return {key: resolution for (key, resolution) in zip(x.keys(), resolutions)}
Expand Down
20 changes: 17 additions & 3 deletions asynq/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import asyncio
import inspect
import sys
from typing import Any, Coroutine

import qcore.decorators
Expand Down Expand Up @@ -141,17 +142,30 @@ def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]:
async def wrapped(*_args, **_kwargs):
task = asyncio.current_task()
with asynq_to_async.AsyncioMode():
send = None
send, exception = None, None

generator = self.fn(*_args, **_kwargs)
while True:
resume_contexts_asyncio(task)
try:
result = generator.send(send)
if exception is None:
result = generator.send(send)
else:
result = generator.throw(
type(exception),
exception,
exception.__traceback__,
)
except StopIteration as exc:
return exc.value

pause_contexts_asyncio(task)
send = await asynq_to_async.resolve_awaitables(result)
try:
send = await asynq_to_async.resolve_awaitables(result)
exception = None
except Exception as exc:
traceback = sys.exc_info()[2]
exception = exc

self.asyncio_fn = wrapped
else:
Expand Down
21 changes: 21 additions & 0 deletions asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


import asyncio
import pytest
import time

from qcore.asserts import assert_eq
Expand Down Expand Up @@ -51,6 +52,26 @@ def g(x):
assert asyncio.run(g.asyncio(5)) == {"a": [1, 2], "b": (3, 4), "c": 5, "d": 200}


def test_asyncio_exception():
async def f2_async():
assert False

@asynq.asynq(asyncio_fn=f2_async)
def f2():
assert False

@asynq.asynq()
def f3():
yield f2.asynq()

@asynq.asynq()
def f():
with pytest.raises(AssertionError):
yield [f3.asynq(), f3.asynq()]

asyncio.run(f.asyncio())


def test_context():
async def blocking_op():
await asyncio.sleep(0.1)
Expand Down

0 comments on commit c6a38b8

Please sign in to comment.