Skip to content

Commit

Permalink
more annotations for zmq.asyncio.Socket
Browse files Browse the repository at this point in the history
missing for rarely used pyobj, json methods

done via type stub instead of in implementation
  • Loading branch information
minrk committed Sep 23, 2024
1 parent 25168fc commit 636365f
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 45 deletions.
53 changes: 53 additions & 0 deletions mypy_tests/test_socket_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from __future__ import annotations

import asyncio

import zmq
import zmq.asyncio


async def main() -> None:
ctx = zmq.asyncio.Context()

# shadow exercise
sync_ctx: zmq.Context = zmq.Context.shadow(ctx)
ctx2: zmq.asyncio.Context = zmq.asyncio.Context.shadow(sync_ctx)
ctx2 = zmq.asyncio.Context(sync_ctx)

url = "tcp://127.0.0.1:5555"
pub = ctx.socket(zmq.PUB)
sub = ctx.socket(zmq.SUB)
pub.bind(url)
sub.connect(url)
sub.subscribe(b"")
await asyncio.sleep(1)

# shadow exercise
sync_sock: zmq.Socket[bytes] = zmq.Socket.shadow(pub)
s2: zmq.asyncio.Socket = zmq.asyncio.Socket(sync_sock)
s2 = zmq.asyncio.Socket.from_socket(sync_sock)

print("sending")
await pub.send(b"plain")
await pub.send(b"plain")
await pub.send_multipart([b"topic", b"Message"])
await pub.send_multipart([b"topic", b"Message"])
await pub.send_string("asdf")
await pub.send_pyobj(123)
await pub.send_json({"a": "5"})

print("receiving")
msg_bytes: bytes = await sub.recv()
msg_frame: zmq.Frame = await sub.recv(copy=False)
msg_list: list[bytes] = await sub.recv_multipart()
msg_frames: list[zmq.Frame] = await sub.recv_multipart(copy=False)
s: str = await sub.recv_string()
obj = await sub.recv_pyobj()
d = await sub.recv_json()

pub.close()
sub.close()


if __name__ == "__main__":
asyncio.run(main())
53 changes: 8 additions & 45 deletions zmq/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from collections import deque
from functools import partial
from itertools import chain
from typing import Any, Awaitable, Callable, NamedTuple, TypeVar, cast, overload
from typing import (
Any,
Awaitable,
Callable,
NamedTuple,
TypeVar,
cast,
)

import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
from zmq._typing import Literal


class _FutureEvent(NamedTuple):
Expand Down Expand Up @@ -260,27 +266,6 @@ def get(self, key):

get.__doc__ = _zmq.Socket.get.__doc__

@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...

@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: # type: ignore
...

@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]:
Expand All @@ -292,19 +277,6 @@ def recv_multipart(
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)

@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...

@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...

def recv( # type: ignore
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]:
Expand Down Expand Up @@ -440,15 +412,6 @@ def cancel_poll(future):

return future

# overrides only necessary for updated types
def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
return super().recv_string(*args, **kwargs) # type: ignore

def send_string( # type: ignore
self, s: str, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[None]:
return super().send_string(s, flags=flags, encoding=encoding) # type: ignore

def _add_timeout(self, future, timeout):
"""Add a timeout for a send or recv Future"""

Expand Down
91 changes: 91 additions & 0 deletions zmq/_future.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""type annotations for async sockets"""

from __future__ import annotations

from asyncio import Future
from pickle import DEFAULT_PROTOCOL
from typing import Any, Awaitable, Literal, Sequence, TypeVar, overload

import zmq as _zmq

class _AsyncPoller(_zmq.Poller):
_socket_class: type[_AsyncSocket]

def poll(self, timeout=-1) -> Awaitable[list[tuple[Any, int]]]: ... # type: ignore

T = TypeVar("T", bound="_AsyncSocket")

class _AsyncSocket(_zmq.Socket[Future]):
@classmethod
def from_socket(cls: type[T], socket: _zmq.Socket, io_loop: Any = None) -> T: ...
def send( # type: ignore
self,
data: Any,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv(self, flags: int = 0, *, track: bool = False) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[bytes]: ...
@overload
def recv(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[_zmq.Frame]: ...
@overload
def recv(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[bytes | _zmq.Frame]: ...
def send_multipart( # type: ignore
self,
msg_parts: Sequence,
flags: int = 0,
copy: bool = True,
track: bool = False,
routing_id: int | None = None,
group: str | None = None,
) -> Awaitable[_zmq.MessageTracker | None]: ...
@overload # type: ignore
def recv_multipart(
self, flags: int = 0, *, track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[True], track: bool = False
) -> Awaitable[list[bytes]]: ...
@overload
def recv_multipart(
self, flags: int = 0, *, copy: Literal[False], track: bool = False
) -> Awaitable[list[_zmq.Frame]]: ...
@overload
def recv_multipart(
self, flags: int = 0, copy: bool = True, track: bool = False
) -> Awaitable[list[bytes] | list[_zmq.Frame]]: ...

# serialization wrappers

def send_string( # type: ignore
self,
u: str,
flags: int = 0,
copy: bool = True,
*,
encoding: str = 'utf-8',
**kwargs,
) -> Awaitable[_zmq.Frame | None]: ...
def recv_string( # type: ignore
self, flags: int = 0, encoding: str = 'utf-8'
) -> Awaitable[str]: ...
def send_pyobj( # type: ignore
self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_pyobj(self, flags: int = 0) -> Awaitable[Any]: ... # type: ignore
def send_json( # type: ignore
self, obj: Any, flags: int = 0, **kwargs
) -> Awaitable[_zmq.Frame | None]: ...
def recv_json(self, flags: int = 0, **kwargs) -> Awaitable[Any]: ... # type: ignore

0 comments on commit 636365f

Please sign in to comment.