diff --git a/asynq/decorators.py b/asynq/decorators.py index 390fe42..f7233e6 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -22,6 +22,7 @@ import qcore.inspection as core_inspection from . import async_task, asynq_to_async, futures +from .asynq_to_async import is_asyncio_mode from .contexts import ( ASYNCIO_CONTEXT_FIELD, pause_contexts_asyncio, @@ -317,7 +318,24 @@ def decorate(fn): return decorate -@async_proxy() +async def asyncio_call(fn, *args, **kwargs): + """An asyncio-version of async_call. + + Note that it is not always possible to detect whether an function is a coroutine function or not. + For example, a callable F that returns Coroutine is a coroutine function, but inspect.iscoroutinefunction(F) is False. + + """ + if is_pure_async_fn(fn): + return await fn.asyncio(*args, **kwargs) + elif hasattr(fn, "asynq"): + return await fn.asyncio(*args, **kwargs) + elif hasattr(fn, "async"): + return await getattr(fn, "async")(*args, **kwargs) + else: + return fn(*args, **kwargs) + + +@async_proxy(asyncio_fn=asyncio_call) def async_call(fn, *args, **kwargs): """Use this if you are not sure if fn is async or not. diff --git a/asynq/tests/test_decorators.py b/asynq/tests/test_decorators.py index 2d908a2..62fd726 100644 --- a/asynq/tests/test_decorators.py +++ b/asynq/tests/test_decorators.py @@ -12,16 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from qcore.asserts import assert_eq, assert_is, assert_is_instance, AssertRaises -from asynq import asynq, async_proxy, is_pure_async_fn, async_call, ConstFuture +import asyncio +import pickle + +from qcore.asserts import AssertRaises, assert_eq, assert_is, assert_is_instance + +from asynq import ConstFuture, async_call, async_proxy, asynq, is_pure_async_fn from asynq.decorators import ( - lazy, + AsyncDecorator, get_async_fn, get_async_or_sync_fn, + lazy, make_async_decorator, - AsyncDecorator, ) -import pickle def double_return_value(fun): @@ -259,6 +262,17 @@ def f3(arg, kw=1): assert_eq((10, 5), async_call.asynq(f, 10, 5).value()) assert_eq((10, 7), async_call.asynq(f, 10, kw=7).value()) + @asynq() + def g0(f, *args, **kwargs): + d = yield async_call.asynq(f, *args, **kwargs) + return d + + for f in [f1, f2, f3]: + assert_eq((10, 1), asyncio.run(async_call.asyncio(f, 10))) + assert_eq((10, 1), asyncio.run(g0.asyncio(f, 10))) + assert_eq((10, 5), asyncio.run(g0.asyncio(f, 10, 5))) + assert_eq((10, 7), asyncio.run(g0.asyncio(f, 10, kw=7))) + def test_make_async_decorator(): assert_eq(18, square(3))