From d5f6c4a1c6a6b44a3761842c539b757d3f79dc86 Mon Sep 17 00:00:00 2001 From: dkang-quora Date: Thu, 14 Mar 2024 10:09:23 +0900 Subject: [PATCH] Fix AsyncProxyDecorator.asyncio() Fixes a bug where @asynq_proxy()-decorated functions have to be awaited one more time in the asyncio mode. --- asynq/decorators.py | 103 ++++++++++++++++------------- asynq/tests/test_asynq_to_async.py | 14 +++- 2 files changed, 69 insertions(+), 48 deletions(-) diff --git a/asynq/decorators.py b/asynq/decorators.py index a3b62c8..3afe92a 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -21,13 +21,9 @@ import qcore.helpers as core_helpers import qcore.inspection as core_inspection -from . import async_task, asynq_to_async, futures -from .asynq_to_async import is_asyncio_mode -from .contexts import ( - ASYNCIO_CONTEXT_FIELD, - pause_contexts_asyncio, - resume_contexts_asyncio, -) +from . import async_task, futures +from .asynq_to_async import AsyncioMode, is_asyncio_mode, resolve_awaitables +from .contexts import pause_contexts_asyncio, resume_contexts_asyncio __traceback_hide__ = True @@ -108,6 +104,44 @@ def get_async_or_sync_fn(fn): return fn +def convert_asynq_to_async(fn): + if inspect.isgeneratorfunction(fn): + + async def wrapped(*_args, **_kwargs): + task = asyncio.current_task() + with AsyncioMode(): + send, exception = None, None + + generator = fn(*_args, **_kwargs) + while True: + resume_contexts_asyncio(task) + try: + if exception is None: + result = generator.send(send) + else: + result = generator.throw( + type(exception), exception, exception.__traceback__ + ) + except StopIteration as exc: + return exc.value + + pause_contexts_asyncio(task) + try: + send = await resolve_awaitables(result) + exception = None + except Exception as exc: + exception = exc + + return wrapped + else: + + async def wrapped(*_args, **_kwargs): + with AsyncioMode(): + return fn(*_args, **_kwargs) + + return wrapped + + class PureAsyncDecoratorBinder(qcore.decorators.DecoratorBinder): def is_pure_async_fn(self): return True @@ -136,43 +170,7 @@ def _fn_wrapper(self, args, kwargs): def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]: if self.asyncio_fn is None: - if inspect.isgeneratorfunction(self.fn): - - async def wrapped(*_args, **_kwargs): - task = asyncio.current_task() - with asynq_to_async.AsyncioMode(): - send, exception = None, None - - generator = self.fn(*_args, **_kwargs) - while True: - resume_contexts_asyncio(task) - try: - if exception is None: - result = generator.send(send) - else: - result = generator.throw( - type(exception), - exception, - exception.__traceback__, - ) - except StopIteration as exc: - return exc.value - - pause_contexts_asyncio(task) - try: - send = await asynq_to_async.resolve_awaitables(result) - exception = None - except Exception as exc: - exception = exc - - self.asyncio_fn = wrapped - else: - - async def wrapped(*_args, **_kwargs): - with asynq_to_async.AsyncioMode(): - return self.fn(*_args, **_kwargs) - - self.asyncio_fn = wrapped + self.asyncio_fn = convert_asynq_to_async(self.fn) return self.asyncio_fn(*args, **kwargs) @@ -180,7 +178,7 @@ def __call__(self, *args, **kwargs): return self._call_pure(args, kwargs) def _call_pure(self, args, kwargs): - if asynq_to_async.is_asyncio_mode(): + if is_asyncio_mode(): return self.asyncio(*args, **kwargs) if not self.needs_wrapper: @@ -220,7 +218,7 @@ def name(self): return "@asynq()" def __call__(self, *args, **kwargs): - if asynq_to_async.is_asyncio_mode(): + if is_asyncio_mode(): raise RuntimeError("asyncio mode does not support synchronous calls") return self._call_pure(args, kwargs).value() @@ -265,8 +263,19 @@ def __init__(self, fn, asyncio_fn=None): # we don't need the task class but still need to pass it to the superclass AsyncDecorator.__init__(self, fn, None, asyncio_fn=asyncio_fn) + def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]: + if self.asyncio_fn is None: + asyncio_fn = convert_asynq_to_async(self.fn) + + async def unwrap_coroutine(*args, **kwargs): + return await (await asyncio_fn(*args, **kwargs)) + + self.asyncio_fn = unwrap_coroutine + + return self.asyncio_fn(*args, **kwargs) + def _call_pure(self, args, kwargs): - if asynq_to_async.is_asyncio_mode(): + if is_asyncio_mode(): return self.asyncio(*args, **kwargs) return self.fn(*args, **kwargs) diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index c5b73d3..78ad528 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -14,9 +14,9 @@ import asyncio -import pytest import time +import pytest from qcore.asserts import assert_eq import asynq @@ -190,6 +190,18 @@ def jj(x): assert jj.asynq(0).value() == 888 +def test_proxy_passthrough(): + @asynq.asynq() + def f2(): + return 100 + + @asynq.async_proxy() + def f1(): + return f2.asynq() + + assert asyncio.run(f1.asyncio()) == 100 + + def test_proxy_and_bind(): async def async_g(self, x): return x + 20 + B.SELF