Skip to content

Commit

Permalink
Fix AsyncProxyDecorator.asyncio() (#138)
Browse files Browse the repository at this point in the history
Fixes a bug where @asynq_proxy()-decorated functions have to be awaited one more time in the asyncio mode.
  • Loading branch information
dkang-quora authored Mar 14, 2024
1 parent 169ee54 commit 9c1b2d2
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 48 deletions.
103 changes: 56 additions & 47 deletions asynq/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,9 @@
import qcore.helpers as core_helpers
import qcore.inspection as core_inspection

from . import async_task, asynq_to_async, futures
from .asynq_to_async import is_asyncio_mode
from .contexts import (
ASYNCIO_CONTEXT_FIELD,
pause_contexts_asyncio,
resume_contexts_asyncio,
)
from . import async_task, futures
from .asynq_to_async import AsyncioMode, is_asyncio_mode, resolve_awaitables
from .contexts import pause_contexts_asyncio, resume_contexts_asyncio

__traceback_hide__ = True

Expand Down Expand Up @@ -108,6 +104,44 @@ def get_async_or_sync_fn(fn):
return fn


def convert_asynq_to_async(fn):
if inspect.isgeneratorfunction(fn):

async def wrapped(*_args, **_kwargs):
task = asyncio.current_task()
with AsyncioMode():
send, exception = None, None

generator = fn(*_args, **_kwargs)
while True:
resume_contexts_asyncio(task)
try:
if exception is None:
result = generator.send(send)
else:
result = generator.throw(
type(exception), exception, exception.__traceback__
)
except StopIteration as exc:
return exc.value

pause_contexts_asyncio(task)
try:
send = await resolve_awaitables(result)
exception = None
except Exception as exc:
exception = exc

return wrapped
else:

async def wrapped(*_args, **_kwargs):
with AsyncioMode():
return fn(*_args, **_kwargs)

return wrapped


class PureAsyncDecoratorBinder(qcore.decorators.DecoratorBinder):
def is_pure_async_fn(self):
return True
Expand Down Expand Up @@ -136,51 +170,15 @@ def _fn_wrapper(self, args, kwargs):

def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]:
if self.asyncio_fn is None:
if inspect.isgeneratorfunction(self.fn):

async def wrapped(*_args, **_kwargs):
task = asyncio.current_task()
with asynq_to_async.AsyncioMode():
send, exception = None, None

generator = self.fn(*_args, **_kwargs)
while True:
resume_contexts_asyncio(task)
try:
if exception is None:
result = generator.send(send)
else:
result = generator.throw(
type(exception),
exception,
exception.__traceback__,
)
except StopIteration as exc:
return exc.value

pause_contexts_asyncio(task)
try:
send = await asynq_to_async.resolve_awaitables(result)
exception = None
except Exception as exc:
exception = exc

self.asyncio_fn = wrapped
else:

async def wrapped(*_args, **_kwargs):
with asynq_to_async.AsyncioMode():
return self.fn(*_args, **_kwargs)

self.asyncio_fn = wrapped
self.asyncio_fn = convert_asynq_to_async(self.fn)

return self.asyncio_fn(*args, **kwargs)

def __call__(self, *args, **kwargs):
return self._call_pure(args, kwargs)

def _call_pure(self, args, kwargs):
if asynq_to_async.is_asyncio_mode():
if is_asyncio_mode():
return self.asyncio(*args, **kwargs)

if not self.needs_wrapper:
Expand Down Expand Up @@ -220,7 +218,7 @@ def name(self):
return "@asynq()"

def __call__(self, *args, **kwargs):
if asynq_to_async.is_asyncio_mode():
if is_asyncio_mode():
raise RuntimeError("asyncio mode does not support synchronous calls")

return self._call_pure(args, kwargs).value()
Expand Down Expand Up @@ -265,8 +263,19 @@ def __init__(self, fn, asyncio_fn=None):
# we don't need the task class but still need to pass it to the superclass
AsyncDecorator.__init__(self, fn, None, asyncio_fn=asyncio_fn)

def asyncio(self, *args, **kwargs) -> Coroutine[Any, Any, Any]:
if self.asyncio_fn is None:
asyncio_fn = convert_asynq_to_async(self.fn)

async def unwrap_coroutine(*args, **kwargs):
return await (await asyncio_fn(*args, **kwargs))

self.asyncio_fn = unwrap_coroutine

return self.asyncio_fn(*args, **kwargs)

def _call_pure(self, args, kwargs):
if asynq_to_async.is_asyncio_mode():
if is_asyncio_mode():
return self.asyncio(*args, **kwargs)

return self.fn(*args, **kwargs)
Expand Down
14 changes: 13 additions & 1 deletion asynq/tests/test_asynq_to_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@


import asyncio
import pytest
import time

import pytest
from qcore.asserts import assert_eq

import asynq
Expand Down Expand Up @@ -190,6 +190,18 @@ def jj(x):
assert jj.asynq(0).value() == 888


def test_proxy_passthrough():
@asynq.asynq()
def f2():
return 100

@asynq.async_proxy()
def f1():
return f2.asynq()

assert asyncio.run(f1.asyncio()) == 100


def test_proxy_and_bind():
async def async_g(self, x):
return x + 20 + B.SELF
Expand Down

0 comments on commit 9c1b2d2

Please sign in to comment.