Skip to content

Commit

Permalink
Merge pull request #2035 from minrk/types-are-tedious
Browse files Browse the repository at this point in the history
more annotations for zmq.asyncio.Socket
  • Loading branch information
minrk authored Oct 4, 2024
2 parents 25168fc + 5f0b954 commit 60ee2c0
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 45 deletions.
53 changes: 53 additions & 0 deletions mypy_tests/test_socket_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import asyncio

import zmq
import zmq.asyncio


async def main() -> None:
ctx = zmq.asyncio.Context()

# shadow exercise
sync_ctx: zmq.Context = zmq.Context.shadow(ctx)
ctx2: zmq.asyncio.Context = zmq.asyncio.Context.shadow(sync_ctx)
ctx2 = zmq.asyncio.Context(sync_ctx)

url = "tcp://127.0.0.1:5555"
pub = ctx.socket(zmq.PUB)
sub = ctx.socket(zmq.SUB)
pub.bind(url)
sub.connect(url)
sub.subscribe(b"")
await asyncio.sleep(1)

# shadow exercise
sync_sock: zmq.Socket[bytes] = zmq.Socket.shadow(pub)
s2: zmq.asyncio.Socket = zmq.asyncio.Socket(sync_sock)
s2 = zmq.asyncio.Socket.from_socket(sync_sock)

print("sending")
await pub.send(b"plain")
await pub.send(b"plain")
await pub.send_multipart([b"topic", b"Message"])
await pub.send_multipart([b"topic", b"Message"])
await pub.send_string("asdf")
await pub.send_pyobj(123)
await pub.send_json({"a": "5"})

print("receiving")
msg_bytes: bytes = await sub.recv()
msg_frame: zmq.Frame = await sub.recv(copy=False)
msg_list: list[bytes] = await sub.recv_multipart()
msg_frames: list[zmq.Frame] = await sub.recv_multipart(copy=False)
s: str = await sub.recv_string()
obj = await sub.recv_pyobj()
d = await sub.recv_json()

pub.close()
sub.close()


if __name__ == "__main__":
asyncio.run(main())
53 changes: 8 additions & 45 deletions zmq/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from collections import deque
from functools import partial
from itertools import chain
from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload
from typing import (
Any,
Awaitable,
Callable,
NamedTuple,
TypeVar,
cast,
)

import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
from zmq._typing import Literal


class _FutureEvent(NamedTuple):
Expand Down Expand Up @@ -260,27 +266,6 @@ def get(self, key):

get.__doc__ = _zmq.Socket.get.__doc__

@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: # type: ignore
...

@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
Expand All @@ -292,19 +277,6 @@ def recv_multipart(
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)

@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...

def recv( # type: ignore
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]:
Expand Down Expand Up @@ -440,15 +412,6 @@ def cancel_poll(future):

return future

# overrides only necessary for updated types
def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
return super().recv_string(*args, **kwargs) # type: ignore

def send_string( # type: ignore
self, s: str, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[None]:
return super().send_string(s, flags=flags, encoding=encoding) # type: ignore

def _add_timeout(self, future, timeout):
"""Add a timeout for a send or recv Future"""

Expand Down
92 changes: 92 additions & 0 deletions zmq/_future.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""type annotations for async sockets"""

from __future__ import annotations

from asyncio import Future
from pickle import DEFAULT_PROTOCOL
from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload

import zmq as _zmq

class _AsyncPoller(_zmq.Poller):
_socket_class: type[_AsyncSocket]

def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore

T = TypeVar("T", bound="_AsyncSocket")

class _AsyncSocket(_zmq.Socket[Future]):
@classmethod
def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ...
def send( # type: ignore
self,
data: Any,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...
@overload
def recv(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]: ...
def send_multipart( # type: ignore
self,
msg_parts: Sequence,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: ...
@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

# serialization wrappers

def send_string( # type: ignore
self,
u: str,
flags: int = 0,
copy: bool = True,
*,
encoding: str = 'utf-8',
**kwargs,
) -> Awaitable[_zmq.Frame | None]: ...
def recv_string( # type: ignore
self, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[str]: ...
def send_pyobj( # type: ignore
self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore
def send_json( # type: ignore
self, obj: Any, flags: int = 0, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore
def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore

0 comments on commit 60ee2c0

Please sign in to comment.