Skip to content

Commit

Permalink
Further improve typing (#145)
Browse files Browse the repository at this point in the history
Co-authored-by: Jelle Zijlstra <[email protected]>
  • Loading branch information
dkang-quora and JelleZijlstra authored Jul 16, 2024
1 parent d04558f commit 4cdc7a6
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 35 deletions.
82 changes: 56 additions & 26 deletions asynq/decorators.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import (
Any,
Awaitable,
Callable,
Coroutine,
Generator,
Generic,
Mapping,
Optional,
Expand All @@ -12,14 +12,16 @@ from typing import (
)

import qcore.decorators
from typing_extensions import ParamSpec
from typing_extensions import Concatenate, Literal, ParamSpec

from . import async_task, futures

_P = ParamSpec("_P")
_P2 = ParamSpec("_P2")
_T = TypeVar("_T")
_Coroutine = Coroutine[Any, Any, Any]
_CoroutineFn = Callable[..., _Coroutine]
_T2 = TypeVar("_T2")
_G = Generator[Any, Any, _T] # Generator that returns _T
_Coroutine = Coroutine[Any, Any, _T]

def lazy(fn: Callable[_P, _T]) -> Callable[_P, futures.FutureBase[_T]]: ...
def has_async_fn(fn: object) -> bool: ...
Expand All @@ -30,17 +32,28 @@ def get_async_fn(
) -> Optional[Callable[..., futures.FutureBase[Any]]]: ...
def get_async_or_sync_fn(fn: object) -> Any: ...

class PureAsyncDecoratorBinder(qcore.decorators.DecoratorBinder):
class PureAsyncDecoratorBinder(qcore.decorators.DecoratorBinder[_T], Generic[_T, _P]):
def is_pure_async_fn(self) -> bool: ...
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...

class PureAsyncDecorator(qcore.decorators.DecoratorBase, Generic[_T, _P]):
binder_cls = PureAsyncDecoratorBinder

@overload
def __init__(
self,
fn: Callable[_P, Any], # TODO overloads for Generator[Any, Any, _T] and _T
task_cls: Optional[type[futures.FutureBase]],
kwargs: Mapping[str, Any] = ...,
asyncio_fn: Optional[Callable[_P, Awaitable[_T]]] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
) -> None: ...
@overload
def __init__(
self,
fn: Callable[_P, Generator[Any, Any, _T]],
task_cls: Optional[type[futures.FutureBase]],
kwargs: Mapping[str, Any] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
) -> None: ...
def name(self) -> str: ...
def is_pure_async_fn(self) -> bool: ...
Expand All @@ -50,27 +63,42 @@ class PureAsyncDecorator(qcore.decorators.DecoratorBase, Generic[_T, _P]):
def __call__(
self, *args: Any, **kwargs: Any
) -> Union[_T, futures.FutureBase[_T]]: ...
def __get__(self, owner: Any, cls: Any) -> PureAsyncDecorator[_T, _P]: ... # type: ignore[override]
def __get__(
self: PureAsyncDecorator[_T2, Concatenate[Any, _P2]], owner: Any, cls: Any
) -> PureAsyncDecoratorBinder[_T2, _P2]: ...

class AsyncDecoratorBinder(qcore.decorators.DecoratorBinder, Generic[_T]):
def asynq(self, *args: Any, **kwargs: Any) -> async_task.AsyncTask[_T]: ...
def asyncio(self, *args, **kwargs) -> _Coroutine: ...
class AsyncDecoratorBinder(qcore.decorators.DecoratorBinder, Generic[_T, _P]):
def asynq(
self, *args: _P.args, **kwargs: _P.kwargs
) -> async_task.AsyncTask[_T]: ...
def asyncio(
self, *args: _P.args, **kwargs: _P.kwargs
) -> Coroutine[Any, Any, _T]: ...

class AsyncDecorator(PureAsyncDecorator[_T, _P]):
binder_cls = AsyncDecoratorBinder # type: ignore
@overload
def __init__(
self,
fn: Callable[_P, Any], # TODO overloads for Generator[Any, Any, _T] and _T
cls: Optional[type[futures.FutureBase]],
kwargs: Mapping[str, Any] = ...,
asyncio_fn: Optional[Callable[_P, Awaitable[_T]]] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
): ...
@overload
def __init__(
self,
fn: Callable[_P, Generator[Any, Any, _T]],
cls: Optional[type[futures.FutureBase]],
kwargs: Mapping[str, Any] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
): ...
def is_pure_async_fn(self) -> bool: ...
def asynq(self, *args: Any, **kwargs: Any) -> async_task.AsyncTask[_T]: ...
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ...
def __get__(self, owner: Any, cls: Any) -> AsyncDecorator[_T, _P]: ... # type: ignore[override]
def __get__(self: PureAsyncDecorator[_T2, Concatenate[Any, _P2]], owner: Any, cls: Any) -> AsyncDecoratorBinder[_T2, _P2]: ... # type: ignore[override]

class AsyncAndSyncPairDecoratorBinder(AsyncDecoratorBinder[_T]): ...
class AsyncAndSyncPairDecoratorBinder(AsyncDecoratorBinder[_T, _P]): ...

class AsyncAndSyncPairDecorator(AsyncDecorator[_T, _P]):
binder_cls = AsyncAndSyncPairDecoratorBinder # type: ignore
Expand All @@ -88,52 +116,54 @@ class AsyncProxyDecorator(AsyncDecorator[_T, _P]):
def __init__(
self,
fn: Callable[..., futures.FutureBase[_T]],
asyncio_fn: Optional[_CoroutineFn] = ...,
asyncio_fn: Optional[Callable[..., Coroutine[Any, Any, _T]]] = ...,
) -> None: ...

class AsyncAndSyncPairProxyDecorator(AsyncProxyDecorator[_T, _P]):
def __init__(
self,
fn: Callable[..., futures.FutureBase[_T]],
sync_fn: Callable[..., _T],
asyncio_fn: Optional[_CoroutineFn] = ...,
asyncio_fn: Optional[Callable[..., Coroutine[Any, Any, _T]]] = ...,
) -> None: ...
def __call__(self, *args: Any, **kwargs: Any) -> _T: ...

class _MkAsyncDecorator:
def __call__(self, fn: Callable[_P, Any]) -> AsyncDecorator[Any, _P]: ...

class _MkPureAsyncDecorator:
def __call__(self, fn: Callable[_P, _T]) -> PureAsyncDecorator[_T, _P]: ...
def __call__(self, fn: Callable[_P, Any]) -> PureAsyncDecorator[Any, _P]: ...

# In reality these two can return other Decorator subclasses, but that doesn't matter for callers.
@overload
def asynq( # type: ignore
def asynq(
*,
sync_fn: Optional[Callable[..., Any]] = ...,
pure: Literal[False] = False,
sync_fn: Optional[Callable[_P, _T]] = ...,
cls: type[futures.FutureBase] = ...,
asyncio_fn: Optional[_CoroutineFn] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
**kwargs: Any,
) -> _MkAsyncDecorator: ...
@overload
def asynq(
pure: bool,
sync_fn: Optional[Callable[..., Any]] = ...,
*,
pure: Literal[True],
sync_fn: Optional[Callable[_P, _T]] = ...,
cls: type[futures.FutureBase] = ...,
asyncio_fn: Optional[_CoroutineFn] = ...,
asyncio_fn: Optional[Callable[_P, Coroutine[Any, Any, _T]]] = ...,
**kwargs: Any,
) -> _MkPureAsyncDecorator: ...
@overload
def async_proxy(
*,
sync_fn: Optional[Callable[..., Any]] = ...,
asyncio_fn: Optional[_CoroutineFn] = ...,
sync_fn: Optional[Callable[_P, Union[_T, Generator[Any, Any, _T]]]] = ...,
asyncio_fn: Optional[Callable[..., Coroutine[Any, Any, _T]]] = ...,
) -> _MkAsyncDecorator: ...
@overload
def async_proxy(
pure: bool,
sync_fn: Optional[Callable[..., Any]] = ...,
asyncio_fn: Optional[_CoroutineFn] = ...,
sync_fn: Optional[Callable[_P, Union[_T, Generator[Any, Any, _T]]]] = ...,
asyncio_fn: Optional[Callable[..., Coroutine[Any, Any, _T]]] = ...,
) -> _MkPureAsyncDecorator: ...
@asynq()
def async_call(
Expand Down
31 changes: 26 additions & 5 deletions asynq/tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Generator, TYPE_CHECKING
from typing import Any, Generator, TYPE_CHECKING, TypeVar
from typing_extensions import assert_type
from asynq.decorators import async_call, lazy, asynq
from asynq.futures import FutureBase
from asynq.tools import amax

_T = TypeVar("_T")


def test_lazy() -> None:
Expand All @@ -23,13 +26,25 @@ def generator(x: int) -> Generator[Any, Any, str]:
yield None
return str(x)

async def caller() -> None:
# This doesn't work, apparently due to a mypy bug
assert_type(await generator.asyncio(1), str) # type: ignore[assert-type]
assert_type(await non_generator.asyncio(1), str) # type: ignore[assert-type]
@asynq()
def generic(x: _T) -> _T:
return x

class Obj:
@asynq()
def method(self, x: str) -> int:
return int(x)

async def caller(x: str, obj: Obj) -> None:
assert_type(await generator.asyncio(1), Any) # TODO: str
assert_type(await non_generator.asyncio(1), Any) # TODO: str
assert_type(await generic.asyncio(x), Any) # TODO: str
assert_type(await obj.method.asyncio(x), Any) # TODO: int

await non_generator.asyncio() # type: ignore[call-arg]
await generator.asyncio() # type: ignore[call-arg]
await generic.asyncio(1, 2) # type: ignore[call-arg]
await obj.method.asyncio(1) # type: ignore[arg-type]


def test_async_call() -> None:
Expand All @@ -39,3 +54,9 @@ def f(x: int) -> str:
async_call(f, 1)
if TYPE_CHECKING:
async_call(f, 1, task_cls=FutureBase) # TODO: this should be an error


def test_amax(x: int = 1, y: int = 2) -> None:
if TYPE_CHECKING:
assert_type(amax(1, 2, key=len), Any) # TODO: int
assert_type(amax([1, 2], key=len), Any) # TODO: int
12 changes: 8 additions & 4 deletions asynq/tools.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,20 @@ def asorted(
) -> List[_T]: ...
@overload
@asynq()
def amax(__arg: Iterable[_T], key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
def amax(
arg1: _T, arg2: _T, /, *args: _T, key: Optional[Callable[[_T], Any]] = ...
) -> _T: ...
@overload
@asynq()
def amax(*args: _T, key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
def amax(arg: Iterable[_T], /, *, key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
@overload
@asynq()
def amin(__arg: Iterable[_T], key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
def amin(
arg1: _T, arg2: _T, /, *args: _T, key: Optional[Callable[[_T], Any]] = ...
) -> _T: ...
@overload
@asynq()
def amin(*args: _T, key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
def amin(__arg: Iterable[_T], key: Optional[Callable[[_T], Any]] = ...) -> _T: ...
@asynq()
def asift(
pred: Callable[[_T], bool], items: Iterable[_T]
Expand Down

0 comments on commit 4cdc7a6

Please sign in to comment.