diff --git a/asynq/decorators.py b/asynq/decorators.py index 6cd6580..ef5d052 100644 --- a/asynq/decorators.py +++ b/asynq/decorators.py @@ -115,11 +115,12 @@ def is_pure_async_fn(self): class PureAsyncDecorator(qcore.decorators.DecoratorBase): binder_cls = PureAsyncDecoratorBinder - def __init__(self, fn, task_cls, kwargs={}): + def __init__(self, fn, task_cls, kwargs={}, asyncio_fn=None): qcore.decorators.DecoratorBase.__init__(self, fn) self.task_cls = task_cls self.needs_wrapper = core_inspection.is_cython_or_generator(fn) self.kwargs = kwargs + self.asyncio_fn = asyncio_fn def name(self): return "@asynq(pure=True)" @@ -132,10 +133,43 @@ def _fn_wrapper(self, args, kwargs): return yield + def asyncio(self, *args, **kwargs) -> Awaitable[Any]: + if self.asyncio_fn is None: + if inspect.isgeneratorfunction(self.fn): + + async def wrapped(*_args, **_kwargs): + task = asyncio.current_task() + with asynq_to_async.AsyncioMode(): + send = None + generator = self.fn(*_args, **_kwargs) + while True: + resume_contexts_asyncio(task) + try: + result = generator.send(send) + except StopIteration as exc: + return exc.value + + pause_contexts_asyncio(task) + send = await asynq_to_async.resolve_awaitables(result) + + self.asyncio_fn = wrapped + else: + + async def wrapped(*_args, **_kwargs): + with asynq_to_async.AsyncioMode(): + return self.fn(*_args, **_kwargs) + + self.asyncio_fn = wrapped + + return self.asyncio_fn(*args, **kwargs) + def __call__(self, *args, **kwargs): return self._call_pure(args, kwargs) def _call_pure(self, args, kwargs): + if asynq_to_async.is_asyncio_mode(): + return self.asyncio(*args, **kwargs) + if not self.needs_wrapper: result = self._fn_wrapper(args, kwargs) else: @@ -161,48 +195,14 @@ class AsyncDecorator(PureAsyncDecorator): binder_cls = AsyncDecoratorBinder def __init__(self, fn, cls, kwargs={}, asyncio_fn=None): - super().__init__(fn, cls, kwargs) - self.asyncio_fn = asyncio_fn + super().__init__(fn, cls, kwargs, asyncio_fn) def is_pure_async_fn(self): return False def asynq(self, *args, **kwargs): - if asynq_to_async.is_asyncio_mode(): - return self.asyncio(*args, **kwargs) - return self._call_pure(args, kwargs) - def asyncio(self, *args, **kwargs) -> Awaitable[Any]: - if self.asyncio_fn is None: - if inspect.isgeneratorfunction(self.fn): - - async def wrapped(*_args, **_kwargs): - task = asyncio.current_task() - with asynq_to_async.AsyncioMode(): - send = None - generator = self.fn(*_args, **_kwargs) - while True: - resume_contexts_asyncio(task) - try: - result = generator.send(send) - except StopIteration as exc: - return exc.value - - pause_contexts_asyncio(task) - send = await asynq_to_async.resolve_awaitables(result) - - self.asyncio_fn = wrapped - else: - - async def wrapped(*_args, **_kwargs): - with asynq_to_async.AsyncioMode(): - return self.fn(*_args, **_kwargs) - - self.asyncio_fn = wrapped - - return self.asyncio_fn(*args, **kwargs) - def name(self): return "@asynq()" diff --git a/asynq/tests/test_asynq_to_async.py b/asynq/tests/test_asynq_to_async.py index 22f676f..7aa6c68 100644 --- a/asynq/tests/test_asynq_to_async.py +++ b/asynq/tests/test_asynq_to_async.py @@ -105,3 +105,16 @@ def original(x): assert_eq(original(6), 116) assert_eq(asyncio.run(a.f.asyncio(7)), 127) + + +def test_pure(): + @asynq.asynq(pure=True) + def h(): + return 100 + + @asynq.asynq() + def i(): + return (yield h()) + + assert i() == 100 + assert asyncio.run(i.asyncio()) == 100