From b5947aa8dce8705f60474b544026748d19ccc04c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 00:07:05 -0700 Subject: [PATCH 01/17] remove skips --- tests/unit/server/api/test_subscriptions.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index a17feebf99..78a6558ad9 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -1,10 +1,8 @@ import json -import sys from datetime import datetime from pathlib import Path from typing import Any, Dict -import pytest from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, @@ -53,14 +51,6 @@ def test_openai() -> None: return response -@pytest.mark.skipif( - sys.platform - in ( - "win32", - "linux", - ), # todo: support windows and linux https://github.com/Arize-ai/phoenix/issues/5126 - reason="subscriptions are not currently supported on windows or linux", -) class TestChatCompletionSubscription: QUERY = """ subscription ChatCompletionSubscription($input: ChatCompletionInput!) { From 427fac51489f5f8210c9d2472fee3a2c016a9138 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 00:15:04 -0700 Subject: [PATCH 02/17] fail fast false --- .github/workflows/python-CI.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml index 8dca479b4e..a34110d195 100644 --- a/.github/workflows/python-CI.yml +++ b/.github/workflows/python-CI.yml @@ -336,6 +336,7 @@ jobs: needs: changes if: ${{ needs.changes.outputs.phoenix == 'true' }} strategy: + fail-fast: false matrix: py: [3.9, 3.12] os: [ubuntu-latest, windows-latest, macos-13] From 1b588ae0635ce7ed21cc9109bdcb25235da0f39e Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 00:30:32 -0700 Subject: [PATCH 03/17] skip postgres --- tests/unit/server/api/test_subscriptions.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index 78a6558ad9..9964c99fa5 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Any, Dict +import pytest from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, @@ -127,9 +128,12 @@ class TestChatCompletionSubscription: async def test_openai_text_response_emits_expected_payloads_and_records_expected_span( self, + dialect: str, gql_client: Any, openai_api_key: str, ) -> None: + if dialect == "postgresql": + pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ @@ -256,9 +260,12 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected async def test_openai_emits_expected_payloads_and_records_expected_span_on_error( self, + dialect: str, gql_client: Any, openai_api_key: str, ) -> None: + if dialect == "postgresql": + pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ @@ -376,9 +383,12 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error async def test_openai_tool_call_response_emits_expected_payloads_and_records_expected_span( self, + dialect: str, gql_client: Any, openai_api_key: str, ) -> None: + if dialect == "postgresql": + pytest.skip("fails on postgres for unknown reason") get_current_weather_tool_schema = { "type": "function", "function": { @@ -529,9 +539,12 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp async def test_openai_tool_call_messages_emits_expected_payloads_and_records_expected_span( self, + dialect: str, gql_client: Any, openai_api_key: str, ) -> None: + if dialect == "postgresql": + pytest.skip("fails on postgres for unknown reason") tool_call_id = "call_zz1hkqH3IakqnHfVhrrUemlQ" tool_calls = [ { @@ -689,9 +702,12 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp async def test_anthropic_text_response_emits_expected_payloads_and_records_expected_span( self, + dialect: str, gql_client: Any, anthropic_api_key: str, ) -> None: + if dialect == "postgresql": + pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ From 7dc862d6066e9235616935e17927ee36836a0942 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 10:12:42 -0700 Subject: [PATCH 04/17] vendor httpx-ws transport --- src/phoenix/server/api/subscriptions.py | 12 ++ tests/unit/conftest.py | 2 +- tests/unit/ws_transport.py | 267 ++++++++++++++++++++++++ 3 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 tests/unit/ws_transport.py diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 80eb55b9d3..48ad5148d5 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -438,6 +438,18 @@ class Subscription: async def chat_completion( self, info: Info[Context, None], input: ChatCompletionInput ) -> AsyncIterator[ChatCompletionSubscriptionPayload]: + async with info.context.db() as session: + if ( + playground_project_id := ( + await session.scalar( + select(models.Project.id).where( + models.Project.name == PLAYGROUND_PROJECT_NAME + ) + ) + ) + ) is None: + print("Creating playground project") + print("Creating playground project") # Determine which LLM client to use based on provider_key provider_key = input.model.provider_key if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4cc6d2ac87..7dcb8acfd9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -31,7 +31,6 @@ from faker import Faker from httpx import AsyncByteStream, Request, Response from httpx_ws import AsyncWebSocketSession, aconnect_ws -from httpx_ws.transport import ASGIWebSocketTransport from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import URL, make_url @@ -53,6 +52,7 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span +from tests.unit.ws_transport import ASGIWebSocketTransport def pytest_terminal_summary( diff --git a/tests/unit/ws_transport.py b/tests/unit/ws_transport.py new file mode 100644 index 0000000000..35889a6db7 --- /dev/null +++ b/tests/unit/ws_transport.py @@ -0,0 +1,267 @@ +""" +This code is copied and adapted from https://github.com/frankie567/httpx-ws/tree/main +""" + +import contextlib +import queue +import typing +from concurrent.futures import Future + +import anyio +import httpx +import wsproto +from httpcore import AsyncNetworkStream +from httpx import ASGITransport, AsyncByteStream, Request, Response +from wsproto.frame_protocol import CloseReason + + +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + pass + + +class WebSocketUpgradeError(HTTPXWSException): + """ + Raised when the initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(HTTPXWSException): + """ + Raised when a event is not of the expected type. + """ + + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class WebSocketNetworkError(HTTPXWSException): + """ + Raised when a network error occured, + typically if the underlying stream has closed or timeout. + """ + + pass + + +Scope = dict[str, typing.Any] +Message = dict[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): + def __init__(self, app: ASGIApp, scope: Scope) -> None: + self.app = app + self.scope = scope + self._receive_queue: queue.Queue[Message] = queue.Queue() + self._send_queue: queue.Queue[Message] = queue.Queue() + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + + async def __aenter__( + self, + ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.from_thread.start_blocking_portal("asyncio") + ) + _: Future[None] = self.portal.start_task_soon(self._run) + + await self.send({"type": "websocket.connect"}) + message = await self.receive() + + if message["type"] == "websocket.close": + await self.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + assert message["type"] == "websocket.accept" + return self, self._build_accept_response(message) + + async def __aexit__(self, *args: typing.Any) -> None: + await self.aclose() + + async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + message: Message = await self.receive(timeout=timeout) + type = message["type"] + + if type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + event: wsproto.events.Event + if type == "websocket.send": + data_str: typing.Optional[str] = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + data_bytes: typing.Optional[bytes] = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + elif type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message["reason"]) + + return self.connection.send(event) + + async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Request): + pass + elif isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.close", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise UnhandledWebSocketEvent(event) + + async def aclose(self) -> None: + await self.send({"type": "websocket.close"}) + self.exit_stack.close() + + async def send(self, message: Message) -> None: + self._receive_queue.put(message) + + async def receive(self, timeout: typing.Optional[float] = None) -> Message: + while self._send_queue.empty(): + await anyio.sleep(0) + return self._send_queue.get(timeout=timeout) + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except Exception as e: + message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + await self._asgi_send(message) + + async def _asgi_receive(self) -> Message: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + self._send_queue.put(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + + +class ASGIWebSocketTransport(ASGITransport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _handle_ws_request( + self, + request: Request, + scope: Scope, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + self.exit_stack = contextlib.AsyncExitStack() + stream, accept_response = await self.exit_stack.enter_async_context( + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] + ) + + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) + for line in accept_response_lines[1:] + if line.strip() != "" + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) + + async def aclose(self) -> None: + if self.exit_stack: + await self.exit_stack.aclose() From 168e6bc3f1ceef33ce8feb90bf2367d0c7014b57 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 12:39:14 -0700 Subject: [PATCH 05/17] vendor the whole shebang --- tests/unit/conftest.py | 4 +- tests/unit/httpx_ws/__init__.py | 29 + tests/unit/httpx_ws/_api.py | 1292 ++++++++++++++++++++++++++++ tests/unit/httpx_ws/_exceptions.py | 55 ++ tests/unit/httpx_ws/_ping.py | 40 + tests/unit/httpx_ws/transport.py | 212 +++++ 6 files changed, 1630 insertions(+), 2 deletions(-) create mode 100644 tests/unit/httpx_ws/__init__.py create mode 100644 tests/unit/httpx_ws/_api.py create mode 100644 tests/unit/httpx_ws/_exceptions.py create mode 100644 tests/unit/httpx_ws/_ping.py create mode 100644 tests/unit/httpx_ws/transport.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 7dcb8acfd9..d9fa53cf8e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -30,7 +30,6 @@ from asgi_lifespan import LifespanManager from faker import Faker from httpx import AsyncByteStream, Request, Response -from httpx_ws import AsyncWebSocketSession, aconnect_ws from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import URL, make_url @@ -52,7 +51,8 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span -from tests.unit.ws_transport import ASGIWebSocketTransport +from tests.unit.httpx_ws import AsyncWebSocketSession, aconnect_ws +from tests.unit.httpx_ws.transport import ASGIWebSocketTransport def pytest_terminal_summary( diff --git a/tests/unit/httpx_ws/__init__.py b/tests/unit/httpx_ws/__init__.py new file mode 100644 index 0000000000..2ae6b1b843 --- /dev/null +++ b/tests/unit/httpx_ws/__init__.py @@ -0,0 +1,29 @@ +__version__ = "0.6.2" + +from ._api import ( + AsyncWebSocketSession, + JSONMode, + WebSocketSession, + aconnect_ws, + connect_ws, +) +from ._exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) + +__all__ = [ + "AsyncWebSocketSession", + "HTTPXWSException", + "JSONMode", + "WebSocketDisconnect", + "WebSocketInvalidTypeReceived", + "WebSocketNetworkError", + "WebSocketSession", + "WebSocketUpgradeError", + "aconnect_ws", + "connect_ws", +] diff --git a/tests/unit/httpx_ws/_api.py b/tests/unit/httpx_ws/_api.py new file mode 100644 index 0000000000..f1e967e6fd --- /dev/null +++ b/tests/unit/httpx_ws/_api.py @@ -0,0 +1,1292 @@ +import base64 +import concurrent.futures +import contextlib +import json +import queue +import secrets +import threading +import typing +from types import TracebackType + +import anyio +import httpcore +import httpx +import wsproto +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from httpcore import AsyncNetworkStream, NetworkStream +from wsproto.frame_protocol import CloseReason + +from ._exceptions import ( + HTTPXWSException, + WebSocketDisconnect, + WebSocketInvalidTypeReceived, + WebSocketNetworkError, + WebSocketUpgradeError, +) +from ._ping import AsyncPingManager, PingManager +from .transport import ASGIWebSocketAsyncNetworkStream + +JSONMode = typing.Literal["text", "binary"] +TaskFunction = typing.TypeVar("TaskFunction") +TaskResult = typing.TypeVar("TaskResult") + +DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 +DEFAULT_QUEUE_SIZE = 512 +DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 +DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 + + +class ShouldClose(Exception): + pass + + +class WebSocketSession: + """ + Sync context manager representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. + response (typing.Optional[httpx.Response]): + The webSocket handshake response. + """ + + subprotocol: typing.Optional[str] + response: typing.Optional[httpx.Response] + + def __init__( + self, + stream: NetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: typing.Optional[httpx.Response] = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._events: queue.Queue[typing.Union[wsproto.events.Event, HTTPXWSException]] = ( + queue.Queue(queue_size) + ) + + self._ping_manager = PingManager() + self._should_close = threading.Event() + self._should_close_task: typing.Optional[concurrent.futures.Future[bool]] = None + self._executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + def _get_executor_should_close_task( + self, + ) -> tuple[concurrent.futures.ThreadPoolExecutor, "concurrent.futures.Future[bool]"]: + if self._should_close_task is None: + self._executor = concurrent.futures.ThreadPoolExecutor() + self._should_close_task = self._executor.submit(self._should_close.wait) + assert self._executor is not None + return self._executor, self._should_close_task + + def __enter__(self) -> "WebSocketSession": + self._background_receive_task = threading.Thread( + target=self._background_receive, args=(self._max_message_size_bytes,) + ) + self._background_receive_task.start() + + self._background_keepalive_ping_task: typing.Optional[threading.Thread] = None + if self._keepalive_ping_interval_seconds is not None: + self._background_keepalive_ping_task = threading.Thread( + target=self._background_keepalive_ping, + args=( + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ), + ) + self._background_keepalive_ping_task.start() + + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + self._background_receive_task.join() + if self._background_keepalive_ping_task is not None: + self._background_keepalive_ping_task.join() + + def ping(self, payload: bytes = b"") -> threading.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = ws.ping() + # Will block until the corresponding Pong is received. + pong_callback.wait() + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + self.send(event) + return callback + + def send(self, event: wsproto.events.Event) -> None: + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.WebSocketSession.ping], + [send_text()][httpx_ws.WebSocketSession.send_text], + [send_bytes()][httpx_ws.WebSocketSession.send_bytes] + and [send_json()][httpx_ws.WebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send an event. + + event = wsproto.events.Message(b"Hello!") + ws.send(event) + """ + try: + data = self.connection.send(event) + self.stream.write(data) + except httpcore.WriteError as e: + self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a text message. + + ws.send_text("Hello!") + """ + event = wsproto.events.TextMessage(data=data) + self.send(event) + + def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a bytes message. + + ws.send_bytes(b"Hello!") + """ + event = wsproto.events.BytesMessage(data=data) + self.send(event) + + def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + ws.send_json(data) + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + self.send_text(serialized_data) + else: + self.send_bytes(serialized_data.encode("utf-8")) + + def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.WebSocketSession.receive_text], + [receive_bytes()][httpx_ws.WebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.WebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + + Examples: + Wait for an event until one is available. + + try: + event = ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = ws.receive(timeout=2.) + except queue.Empty: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self._events.get(block=True, timeout=timeout) + if isinstance(event, HTTPXWSException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + def receive_text(self, timeout: typing.Optional[float] = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = ws.receive_text(timeout=2.) + except queue.Empty: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = ws.receive_bytes(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + def receive_json( + self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" + ) -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + queue.Empty: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = ws.receive_json(timeout=2.) + except queue.Empty: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + assert mode in ["text", "binary"] + data: typing.Union[str, bytes] + if mode == "text": + data = self.receive_text(timeout) + elif mode == "binary": + data = self.receive_bytes(timeout) + return json.loads(data) + + def close(self, code: int = 1000, reason: typing.Optional[str] = None): + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + *This method is automatically called when exiting the context manager.* + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + ws.close() + """ + self._should_close.set() + if self._executor is not None: + self._executor.shutdown(False) + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + self.stream.write(data) + except httpcore.WriteError: + pass + self.stream.close() + + def _background_receive(self, max_bytes: int) -> None: + """ + Background thread listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ + partial_message_buffer: typing.Union[str, bytes, None] = None + try: + while not self._should_close.is_set(): + data = self._wait_until_closed(self.stream.read, max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + self._events.put(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + self._events.put(full_message_event) + continue + self._events.put(event) + except (httpcore.ReadError, httpcore.WriteError): + self.close(CloseReason.INTERNAL_ERROR, "Stream error") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _background_keepalive_ping( + self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None + ) -> None: + try: + while not self._should_close.is_set(): + should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) + if should_close: + raise ShouldClose() + pong_callback = self.ping() + if timeout_seconds is not None: + acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) + if not acknowledged: + self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + self._events.put(WebSocketNetworkError()) + except ShouldClose: + pass + + def _wait_until_closed( + self, callable: typing.Callable[..., TaskResult], *args, **kwargs + ) -> TaskResult: + try: + executor, should_close_task = self._get_executor_should_close_task() + todo_task = executor.submit(callable, *args, **kwargs) + except RuntimeError as e: + raise ShouldClose() from e + else: + done, _ = concurrent.futures.wait( + (todo_task, should_close_task), # type: ignore[misc] + return_when=concurrent.futures.FIRST_COMPLETED, + ) + if should_close_task in done: + raise ShouldClose() + assert todo_task in done + result = todo_task.result() + return result + + +class AsyncWebSocketSession: + """ + Async context manager representing an opened WebSocket session. + + Attributes: + subprotocol (typing.Optional[str]): + Optional protocol that has been accepted by the server. + response (typing.Optional[httpx.Response]): + The webSocket handshake response. + """ + + subprotocol: typing.Optional[str] + response: typing.Optional[httpx.Response] + _send_event: MemoryObjectSendStream[typing.Union[wsproto.events.Event, HTTPXWSException]] + _receive_event: MemoryObjectReceiveStream[typing.Union[wsproto.events.Event, HTTPXWSException]] + + def __init__( + self, + stream: AsyncNetworkStream, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + response: typing.Optional[httpx.Response] = None, + ) -> None: + self.stream = stream + self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + self.response = response + if self.response is not None: + self.subprotocol = self.response.headers.get("sec-websocket-protocol") + else: + self.subprotocol = None + + self._ping_manager = AsyncPingManager() + self._should_close = anyio.Event() + + self._max_message_size_bytes = max_message_size_bytes + self._queue_size = queue_size + + # Always disable keepalive ping when emulating ASGI + if isinstance(stream, ASGIWebSocketAsyncNetworkStream): + self._keepalive_ping_interval_seconds = None + self._keepalive_ping_timeout_seconds = None + else: + self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds + self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds + + async def __aenter__(self) -> "AsyncWebSocketSession": + async with contextlib.AsyncExitStack() as exit_stack: + self._send_event, self._receive_event = anyio.create_memory_object_stream[ + typing.Union[wsproto.events.Event, HTTPXWSException] + ]() + exit_stack.enter_context(self._send_event) + exit_stack.enter_context(self._receive_event) + + self._background_task_group = anyio.create_task_group() + await exit_stack.enter_async_context(self._background_task_group) + + self._background_task_group.start_soon( + self._background_receive, self._max_message_size_bytes + ) + if self._keepalive_ping_interval_seconds is not None: + self._background_task_group.start_soon( + self._background_keepalive_ping, + self._keepalive_ping_interval_seconds, + self._keepalive_ping_timeout_seconds, + ) + + exit_stack.callback(self._background_task_group.cancel_scope.cancel) + exit_stack.push_async_callback(self.close) + self._exit_stack = exit_stack.pop_all() + + return self + + async def __aexit__( + self, + exc_type: typing.Optional[type[BaseException]], + exc: typing.Optional[BaseException], + tb: typing.Optional[TracebackType], + ) -> None: + await self._exit_stack.aclose() + + async def ping(self, payload: bytes = b"") -> anyio.Event: + """ + Send a Ping message. + + Args: + payload: + Payload to attach to the Ping event. + Internally, it's used to track this specific event. + If left empty, a random one will be generated. + + Returns: + An event that can be used to wait for the corresponding Pong response. + + Examples: + Send a Ping and wait for the Pong + + pong_callback = await ws.ping() + # Will block until the corresponding Pong is received. + await pong_callback.wait() + """ + ping_id, callback = self._ping_manager.create(payload) + event = wsproto.events.Ping(ping_id) + await self.send(event) + return callback + + async def send(self, event: wsproto.events.Event) -> None: + """ + Send an Event message. + + Mainly useful to send events that are not supported by the library. + Most of the time, [ping()][httpx_ws.AsyncWebSocketSession.ping], + [send_text()][httpx_ws.AsyncWebSocketSession.send_text], + [send_bytes()][httpx_ws.AsyncWebSocketSession.send_bytes] + and [send_json()][httpx_ws.AsyncWebSocketSession.send_json] are preferred. + + Args: + event: The event to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send an event. + + event = await wsproto.events.Message(b"Hello!") + ws.send(event) + """ + try: + data = self.connection.send(event) + await self.stream.write(data) + except httpcore.WriteError as e: + await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") + raise WebSocketNetworkError() from e + + async def send_text(self, data: str) -> None: + """ + Send a text message. + + Args: + data: The text to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a text message. + + await ws.send_text("Hello!") + """ + event = wsproto.events.TextMessage(data=data) + await self.send(event) + + async def send_bytes(self, data: bytes) -> None: + """ + Send a bytes message. + + Args: + data: The data to send. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send a bytes message. + + await ws.send_bytes(b"Hello!") + """ + event = wsproto.events.BytesMessage(data=data) + await self.send(event) + + async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: + """ + Send JSON data. + + Args: + data: + The data to send. Must be serializable by [json.dumps][json.dumps]. + mode: + The sending mode. Should either be `'text'` or `'bytes'`. + + Raises: + WebSocketNetworkError: A network error occured. + + Examples: + Send JSON data. + + data = {"message": "Hello!"} + await ws.send_json(data) + """ + assert mode in ["text", "binary"] + serialized_data = json.dumps(data) + if mode == "text": + await self.send_text(serialized_data) + else: + await self.send_bytes(serialized_data.encode("utf-8")) + + async def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: + """ + Receive an event from the server. + + Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. + Most of the time, [receive_text()][httpx_ws.AsyncWebSocketSession.receive_text], + [receive_bytes()][httpx_ws.AsyncWebSocketSession.receive_bytes], + and [receive_json()][httpx_ws.AsyncWebSocketSession.receive_json] are preferred. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + A raw [wsproto.events.Event][wsproto.events.Event]. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + + Examples: + Wait for an event until one is available. + + try: + event = await ws.receive() + except WebSocketDisconnect: + print("Connection closed") + + Wait for an event for 2 seconds. + + try: + event = await ws.receive(timeout=2.) + except TimeoutError: + print("No event received.") + except WebSocketDisconnect: + print("Connection closed") + """ + with anyio.fail_after(timeout): + event = await self._receive_event.receive() + if isinstance(event, HTTPXWSException): + raise event + if isinstance(event, wsproto.events.CloseConnection): + raise WebSocketDisconnect(event.code, event.reason) + return event + + async def receive_text(self, timeout: typing.Optional[float] = None) -> str: + """ + Receive text from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Text data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a text message. + + Examples: + Wait for text until available. + + try: + text = await ws.receive_text() + except WebSocketDisconnect: + print("Connection closed") + + Wait for text for 2 seconds. + + try: + event = await ws.receive_text(timeout=2.) + except TimeoutError: + print("No text received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.TextMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: + """ + Receive bytes from the server. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + + Returns: + Bytes data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event was not a bytes message. + + Examples: + Wait for bytes until available. + + try: + data = await ws.receive_bytes() + except WebSocketDisconnect: + print("Connection closed") + + Wait for bytes for 2 seconds. + + try: + data = await ws.receive_bytes(timeout=2.) + except TimeoutError: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + event = await self.receive(timeout) + if isinstance(event, wsproto.events.BytesMessage): + return event.data + raise WebSocketInvalidTypeReceived(event) + + async def receive_json( + self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" + ) -> typing.Any: + """ + Receive JSON data from the server. + + The received data should be parseable by [json.loads][json.loads]. + + Args: + timeout: + Number of seconds to wait for an event. + If `None`, will block until an event is available. + mode: + Receive mode. Should either be `'text'` or `'bytes'`. + + Returns: + Parsed JSON data. + + Raises: + TimeoutError: No event was received before the timeout delay. + WebSocketDisconnect: The server closed the websocket. + WebSocketNetworkError: A network error occured. + WebSocketInvalidTypeReceived: The received event + didn't correspond to the specified mode. + + Examples: + Wait for data until available. + + try: + data = await ws.receive_json() + except WebSocketDisconnect: + print("Connection closed") + + Wait for data for 2 seconds. + + try: + data = await ws.receive_json(timeout=2.) + except TimeoutError: + print("No data received.") + except WebSocketDisconnect: + print("Connection closed") + """ + assert mode in ["text", "binary"] + data: typing.Union[str, bytes] + if mode == "text": + data = await self.receive_text(timeout) + elif mode == "binary": + data = await self.receive_bytes(timeout) + return json.loads(data) + + async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + """ + Close the WebSocket session. + + Internally, it'll send the + [CloseConnection][wsproto.events.CloseConnection] event. + + *This method is automatically called when exiting the context manager.* + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + + Examples: + Close the WebSocket session. + + await ws.close() + """ + self._should_close.set() + if self.connection.state not in { + wsproto.connection.ConnectionState.LOCAL_CLOSING, + wsproto.connection.ConnectionState.CLOSED, + }: + event = wsproto.events.CloseConnection(code, reason) + data = self.connection.send(event) + try: + await self.stream.write(data) + except httpcore.WriteError: + pass + await self.stream.aclose() + + async def _background_receive(self, max_bytes: int) -> None: + """ + Background task listening for data from the server. + + Internally, it'll: + + * Answer to Ping events. + * Acknowledge Pong events. + * Put other events in the [_events][_events] + queue that'll eventually be consumed by the user. + + Args: + max_bytes: The maximum chunk size to read at each iteration. + """ + partial_message_buffer: typing.Union[str, bytes, None] = None + try: + while not self._should_close.is_set(): + data = await self.stream.read(max_bytes=max_bytes) + self.connection.receive_data(data) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Ping): + data = self.connection.send(event.response()) + await self.stream.write(data) + continue + if isinstance(event, wsproto.events.Pong): + self._ping_manager.ack(event.payload) + continue + if isinstance(event, wsproto.events.CloseConnection): + self._should_close.set() + if isinstance(event, wsproto.events.Message): + # Unfinished message: bufferize + if not event.message_finished: + if partial_message_buffer is None: + partial_message_buffer = event.data + else: + partial_message_buffer += event.data + # Finished message but no buffer: just emit the event + elif partial_message_buffer is None: + await self._send_event.send(event) + # Finished message with buffer: emit the full event + else: + event_type = type(event) + full_message_event = event_type(partial_message_buffer + event.data) + partial_message_buffer = None + await self._send_event.send(full_message_event) + continue + await self._send_event.send(event) + except (httpcore.ReadError, httpcore.WriteError): + await self.close(CloseReason.INTERNAL_ERROR, "Stream error") + await self._send_event.send(WebSocketNetworkError()) + + async def _background_keepalive_ping( + self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None + ) -> None: + while not self._should_close.is_set(): + await anyio.sleep(interval_seconds) + pong_callback = await self.ping() + if timeout_seconds is not None: + try: + with anyio.fail_after(timeout_seconds): + await pong_callback.wait() + except TimeoutError: + await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") + await self._send_event.send(WebSocketNetworkError()) + + +def _get_headers( + subprotocols: typing.Optional[list[str]], +) -> dict[str, typing.Any]: + headers = { + "connection": "upgrade", + "upgrade": "websocket", + "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), + "sec-websocket-version": "13", + } + if subprotocols is not None: + headers["sec-websocket-protocol"] = ", ".join(subprotocols) + return headers + + +@contextlib.contextmanager +def _connect_ws( + url: str, + client: httpx.Client, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[list[str]] = None, + **kwargs: typing.Any, +) -> typing.Generator[WebSocketSession, None, None]: + headers = kwargs.pop("headers", {}) + headers.update(_get_headers(subprotocols)) + + with client.stream("GET", url, headers=headers, **kwargs) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + with WebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) as session: + yield session + + +@contextlib.contextmanager +def connect_ws( + url: str, + client: typing.Optional[httpx.Client] = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[list[str]] = None, + **kwargs: typing.Any, +) -> typing.Generator[WebSocketSession, None, None]: + """ + Start a sync WebSocket session. + + It returns a context manager that'll automatically + call [close()][httpx_ws.WebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. + **kwargs: + Additional keyword arguments that will be passed to + the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + + Returns: + A [context manager][contextlib.AbstractContextManager] + for [WebSocketSession][httpx_ws.WebSocketSession]. + + Examples: + Without explicit HTTPX client. + + with connect_ws("http://localhost:8000/ws") as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + + With explicit HTTPX client. + + with httpx.Client() as client: + with connect_ws("http://localhost:8000/ws", client) as ws: + message = ws.receive_text() + print(message) + ws.send_text("Hello!") + """ + if client is None: + with httpx.Client() as client: + with _connect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as websocket: + yield websocket + else: + with _connect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as websocket: + yield websocket + + +@contextlib.asynccontextmanager +async def _aconnect_ws( + url: str, + client: httpx.AsyncClient, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[list[str]] = None, + **kwargs: typing.Any, +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + headers = kwargs.pop("headers", {}) + headers.update(_get_headers(subprotocols)) + + async with client.stream("GET", url, headers=headers, **kwargs) as response: + if response.status_code != 101: + raise WebSocketUpgradeError(response) + + async with AsyncWebSocketSession( + response.extensions["network_stream"], + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + response=response, + ) as session: + yield session + + +@contextlib.asynccontextmanager +async def aconnect_ws( + url: str, + client: typing.Optional[httpx.AsyncClient] = None, + *, + max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, + queue_size: int = DEFAULT_QUEUE_SIZE, + keepalive_ping_interval_seconds: typing.Optional[ + float + ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, + keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, + subprotocols: typing.Optional[list[str]] = None, + **kwargs: typing.Any, +) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: + """ + Start an async WebSocket session. + + It returns an async context manager that'll automatically + call [close()][httpx_ws.AsyncWebSocketSession.close] when exiting. + + Args: + url: The WebSocket URL. + client: + HTTPX client to use. + If not provided, a default one will be initialized. + max_message_size_bytes: + Message size in bytes to receive from the server. + Defaults to 65 KiB. + queue_size: + Size of the queue where the received messages will be held + until they are consumed. + If the queue is full, the client will stop receive messages + from the server until the queue has room available. + Defaults to 512. + keepalive_ping_interval_seconds: + Interval at which the client will automatically send a Ping event + to keep the connection alive. Set it to `None` to disable this mechanism. + Defaults to 20 seconds. + keepalive_ping_timeout_seconds: + Maximum delay the client will wait for an answer to its Ping event. + If the delay is exceeded, + [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] + will be raised and the connection closed. + Defaults to 20 seconds. + subprotocols: + Optional list of suprotocols to negotiate with the server. + **kwargs: + Additional keyword arguments that will be passed to + the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. + + Returns: + An [async context manager][contextlib.AbstractAsyncContextManager] + for [AsyncWebSocketSession][httpx_ws.AsyncWebSocketSession]. + + Examples: + Without explicit HTTPX client. + + async with aconnect_ws("http://localhost:8000/ws") as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + + With explicit HTTPX client. + + async with httpx.AsyncClient() as client: + async with aconnect_ws("http://localhost:8000/ws", client) as ws: + message = await ws.receive_text() + print(message) + await ws.send_text("Hello!") + """ + if client is None: + async with httpx.AsyncClient() as client: + async with _aconnect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as websocket: + yield websocket + else: + async with _aconnect_ws( + url, + client=client, + max_message_size_bytes=max_message_size_bytes, + queue_size=queue_size, + keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, + keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, + subprotocols=subprotocols, + **kwargs, + ) as websocket: + yield websocket diff --git a/tests/unit/httpx_ws/_exceptions.py b/tests/unit/httpx_ws/_exceptions.py new file mode 100644 index 0000000000..0facbf82aa --- /dev/null +++ b/tests/unit/httpx_ws/_exceptions.py @@ -0,0 +1,55 @@ +import typing + +import httpx +import wsproto + + +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + pass + + +class WebSocketUpgradeError(HTTPXWSException): + """ + Raised when the initial connection didn't correctly upgrade to a WebSocket session. + """ + + def __init__(self, response: httpx.Response) -> None: + self.response = response + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + +class WebSocketInvalidTypeReceived(HTTPXWSException): + """ + Raised when a event is not of the expected type. + """ + + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class WebSocketNetworkError(HTTPXWSException): + """ + Raised when a network error occured, + typically if the underlying stream has closed or timeout. + """ + + pass diff --git a/tests/unit/httpx_ws/_ping.py b/tests/unit/httpx_ws/_ping.py new file mode 100644 index 0000000000..2b4e7f24db --- /dev/null +++ b/tests/unit/httpx_ws/_ping.py @@ -0,0 +1,40 @@ +import secrets +import threading +import typing + +import anyio + + +class PingManagerBase: + def _generate_id(self) -> bytes: + return secrets.token_bytes() + + +class PingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: dict[bytes, threading.Event] = {} + + def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, threading.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = threading.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: typing.Union[bytes, bytearray]): + event = self._pings.pop(bytes(ping_id)) + event.set() + + +class AsyncPingManager(PingManagerBase): + def __init__(self) -> None: + self._pings: dict[bytes, anyio.Event] = {} + + def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, anyio.Event]: + ping_id = self._generate_id() if not ping_id else ping_id + event = anyio.Event() + self._pings[ping_id] = event + return ping_id, event + + def ack(self, ping_id: typing.Union[bytes, bytearray]): + event = self._pings.pop(bytes(ping_id)) + event.set() diff --git a/tests/unit/httpx_ws/transport.py b/tests/unit/httpx_ws/transport.py new file mode 100644 index 0000000000..13a8e752a6 --- /dev/null +++ b/tests/unit/httpx_ws/transport.py @@ -0,0 +1,212 @@ +import contextlib +import queue +import typing +from concurrent.futures import Future + +import anyio +import wsproto +from httpcore import AsyncNetworkStream +from httpx import ASGITransport, AsyncByteStream, Request, Response +from wsproto.frame_protocol import CloseReason + +from ._exceptions import WebSocketDisconnect + +Scope = dict[str, typing.Any] +Message = dict[str, typing.Any] +Receive = typing.Callable[[], typing.Awaitable[Message]] +Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] +ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] + + +class ASGIWebSocketTransportError(Exception): + pass + + +class UnhandledASGIMessageType(ASGIWebSocketTransportError): + def __init__(self, message: Message) -> None: + self.message = message + + +class UnhandledWebSocketEvent(ASGIWebSocketTransportError): + def __init__(self, event: wsproto.events.Event) -> None: + self.event = event + + +class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): + def __init__(self, app: ASGIApp, scope: Scope) -> None: + self.app = app + self.scope = scope + self._receive_queue: queue.Queue[Message] = queue.Queue() + self._send_queue: queue.Queue[Message] = queue.Queue() + self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) + self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + + async def __aenter__( + self, + ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: + self.exit_stack = contextlib.ExitStack() + self.portal = self.exit_stack.enter_context( + anyio.from_thread.start_blocking_portal("asyncio") + ) + _: Future[None] = self.portal.start_task_soon(self._run) + + await self.send({"type": "websocket.connect"}) + message = await self.receive() + + if message["type"] == "websocket.close": + await self.aclose() + raise WebSocketDisconnect(message["code"], message.get("reason")) + + assert message["type"] == "websocket.accept" + return self, self._build_accept_response(message) + + async def __aexit__(self, *args: typing.Any) -> None: + await self.aclose() + + async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: + message: Message = await self.receive(timeout=timeout) + type = message["type"] + + if type not in {"websocket.send", "websocket.close"}: + raise UnhandledASGIMessageType(message) + + event: wsproto.events.Event + if type == "websocket.send": + data_str: typing.Optional[str] = message.get("text") + if data_str is not None: + event = wsproto.events.TextMessage(data_str) + data_bytes: typing.Optional[bytes] = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + elif type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message["reason"]) + + return self.connection.send(event) + + async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: + self.connection.receive_data(buffer) + for event in self.connection.events(): + if isinstance(event, wsproto.events.Request): + pass + elif isinstance(event, wsproto.events.CloseConnection): + await self.send( + { + "type": "websocket.close", + "code": event.code, + "reason": event.reason, + } + ) + elif isinstance(event, wsproto.events.TextMessage): + await self.send({"type": "websocket.receive", "text": event.data}) + elif isinstance(event, wsproto.events.BytesMessage): + await self.send({"type": "websocket.receive", "bytes": event.data}) + else: + raise UnhandledWebSocketEvent(event) + + async def aclose(self) -> None: + await self.send({"type": "websocket.close"}) + self.exit_stack.close() + + async def send(self, message: Message) -> None: + self._receive_queue.put(message) + + async def receive(self, timeout: typing.Optional[float] = None) -> Message: + while self._send_queue.empty(): + await anyio.sleep(0) + return self._send_queue.get(timeout=timeout) + + async def _run(self) -> None: + """ + The sub-thread in which the websocket session runs. + """ + scope = self.scope + receive = self._asgi_receive + send = self._asgi_send + try: + await self.app(scope, receive, send) + except Exception as e: + message = { + "type": "websocket.close", + "code": CloseReason.INTERNAL_ERROR, + "reason": str(e), + } + await self._asgi_send(message) + + async def _asgi_receive(self) -> Message: + while self._receive_queue.empty(): + await anyio.sleep(0) + return self._receive_queue.get() + + async def _asgi_send(self, message: Message) -> None: + self._send_queue.put(message) + + def _build_accept_response(self, message: Message) -> bytes: + subprotocol = message.get("subprotocol", None) + headers = message.get("headers", []) + return self.connection.send( + wsproto.events.AcceptConnection( + subprotocol=subprotocol, + extra_headers=headers, + ) + ) + + +class ASGIWebSocketTransport(ASGITransport): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None + + async def handle_async_request(self, request: Request) -> Response: + scheme = request.url.scheme + headers = request.headers + + if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": + subprotocols: list[str] = [] + if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: + subprotocols = subprotocols_header.split(",") + + scope = { + "type": "websocket", + "path": request.url.path, + "raw_path": request.url.raw_path, + "root_path": self.root_path, + "scheme": scheme, + "query_string": request.url.query, + "headers": [(k.lower(), v) for (k, v) in request.headers.raw], + "client": self.client, + "server": (request.url.host, request.url.port), + "subprotocols": subprotocols, + } + return await self._handle_ws_request(request, scope) + + return await super().handle_async_request(request) + + async def _handle_ws_request( + self, + request: Request, + scope: Scope, + ) -> Response: + assert isinstance(request.stream, AsyncByteStream) + + self.scope = scope + self.exit_stack = contextlib.AsyncExitStack() + stream, accept_response = await self.exit_stack.enter_async_context( + ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] + ) + + accept_response_lines = accept_response.decode("utf-8").splitlines() + headers = [ + typing.cast(tuple[str, str], line.split(": ", 1)) + for line in accept_response_lines[1:] + if line.strip() != "" + ] + + return Response( + status_code=101, + headers=headers, + extensions={"network_stream": stream}, + ) + + async def aclose(self) -> None: + if self.exit_stack: + await self.exit_stack.aclose() From 709d51b8405ae35f2d68445d090c6329912c86ef Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 12:55:25 -0700 Subject: [PATCH 06/17] pass postgres --- tests/conftest.py | 2 +- tests/unit/httpx_ws/transport.py | 74 ++++++++++++--------- tests/unit/server/api/test_subscriptions.py | 8 +-- 3 files changed, 49 insertions(+), 35 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index e31c426f39..c0d9342510 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ def pytest_addoption(parser: Parser) -> None: parser.addoption( "--run-postgres", - action="store_true", + action="store_false", help="Run tests that require Postgres", ) parser.addoption( diff --git a/tests/unit/httpx_ws/transport.py b/tests/unit/httpx_ws/transport.py index 13a8e752a6..9783b1f20d 100644 --- a/tests/unit/httpx_ws/transport.py +++ b/tests/unit/httpx_ws/transport.py @@ -1,9 +1,7 @@ +import asyncio import contextlib -import queue import typing -from concurrent.futures import Future -import anyio import wsproto from httpcore import AsyncNetworkStream from httpx import ASGITransport, AsyncByteStream, Request, Response @@ -36,19 +34,22 @@ class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): def __init__(self, app: ASGIApp, scope: Scope) -> None: self.app = app self.scope = scope - self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message] = queue.Queue() + self._receive_queue: asyncio.Queue[Message] = asyncio.Queue() + self._send_queue: asyncio.Queue[Message] = asyncio.Queue() self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) + self.tasks: list[asyncio.Task] = [] async def __aenter__( self, ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.from_thread.start_blocking_portal("asyncio") - ) - _: Future[None] = self.portal.start_task_soon(self._run) + self.exit_stack = contextlib.AsyncExitStack() + await self.exit_stack.__aenter__() + + # Start the _run coroutine as a task + self._run_task = asyncio.create_task(self._run()) + self.tasks.append(self._run_task) + self.exit_stack.push_async_callback(self._cancel_tasks) await self.send({"type": "websocket.connect"}) message = await self.receive() @@ -62,24 +63,36 @@ async def __aenter__( async def __aexit__(self, *args: typing.Any) -> None: await self.aclose() + await self.exit_stack.aclose() + + async def _cancel_tasks(self): + # Cancel all running tasks + for task in self.tasks: + task.cancel() + # Wait for tasks to be cancelled + await asyncio.gather(*self.tasks, return_exceptions=True) async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: - message: Message = await self.receive(timeout=timeout) - type = message["type"] + message: Message = await self.receive() + message_type = message["type"] - if type not in {"websocket.send", "websocket.close"}: + if message_type not in {"websocket.send", "websocket.close"}: raise UnhandledASGIMessageType(message) event: wsproto.events.Event - if type == "websocket.send": + if message_type == "websocket.send": data_str: typing.Optional[str] = message.get("text") if data_str is not None: event = wsproto.events.TextMessage(data_str) - data_bytes: typing.Optional[bytes] = message.get("bytes") - if data_bytes is not None: - event = wsproto.events.BytesMessage(data_bytes) - elif type == "websocket.close": - event = wsproto.events.CloseConnection(message["code"], message["reason"]) + else: + data_bytes: typing.Optional[bytes] = message.get("bytes") + if data_bytes is not None: + event = wsproto.events.BytesMessage(data_bytes) + else: + # If neither text nor bytes are provided, raise an error + raise ValueError("websocket.send message missing 'text' or 'bytes'") + elif message_type == "websocket.close": + event = wsproto.events.CloseConnection(message["code"], message.get("reason")) return self.connection.send(event) @@ -87,7 +100,7 @@ async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> self.connection.receive_data(buffer) for event in self.connection.events(): if isinstance(event, wsproto.events.Request): - pass + pass # Already handled in __init__ elif isinstance(event, wsproto.events.CloseConnection): await self.send( { @@ -105,19 +118,22 @@ async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> async def aclose(self) -> None: await self.send({"type": "websocket.close"}) - self.exit_stack.close() + # Ensure tasks are cancelled and cleaned up + await self._cancel_tasks() async def send(self, message: Message) -> None: - self._receive_queue.put(message) + await self._receive_queue.put(message) async def receive(self, timeout: typing.Optional[float] = None) -> Message: - while self._send_queue.empty(): - await anyio.sleep(0) - return self._send_queue.get(timeout=timeout) + try: + message = await asyncio.wait_for(self._send_queue.get(), timeout) + return message + except asyncio.TimeoutError: + raise TimeoutError("Timed out waiting for message") async def _run(self) -> None: """ - The sub-thread in which the websocket session runs. + The coroutine in which the websocket session runs. """ scope = self.scope receive = self._asgi_receive @@ -133,12 +149,10 @@ async def _run(self) -> None: await self._asgi_send(message) async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - await anyio.sleep(0) - return self._receive_queue.get() + return await self._receive_queue.get() async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) + await self._send_queue.put(message) def _build_accept_response(self, message: Message) -> bytes: subprotocol = message.get("subprotocol", None) diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index 9964c99fa5..611a0724ef 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -128,12 +128,9 @@ class TestChatCompletionSubscription: async def test_openai_text_response_emits_expected_payloads_and_records_expected_span( self, - dialect: str, gql_client: Any, openai_api_key: str, ) -> None: - if dialect == "postgresql": - pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ @@ -182,6 +179,10 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -195,7 +196,6 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None From 09825940c4191d29082fa8f795c48d99e7803753 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 12:58:31 -0700 Subject: [PATCH 07/17] change back to not run postgres --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index c0d9342510..e31c426f39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,7 @@ def pytest_addoption(parser: Parser) -> None: parser.addoption( "--run-postgres", - action="store_false", + action="store_true", help="Run tests that require Postgres", ) parser.addoption( From 3d113ad884d8ab6a2ab667b7a8e5540f9bc518f0 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:27:34 -0700 Subject: [PATCH 08/17] remove file --- tests/unit/ws_transport.py | 267 ------------------------------------- 1 file changed, 267 deletions(-) delete mode 100644 tests/unit/ws_transport.py diff --git a/tests/unit/ws_transport.py b/tests/unit/ws_transport.py deleted file mode 100644 index 35889a6db7..0000000000 --- a/tests/unit/ws_transport.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -This code is copied and adapted from https://github.com/frankie567/httpx-ws/tree/main -""" - -import contextlib -import queue -import typing -from concurrent.futures import Future - -import anyio -import httpx -import wsproto -from httpcore import AsyncNetworkStream -from httpx import ASGITransport, AsyncByteStream, Request, Response -from wsproto.frame_protocol import CloseReason - - -class HTTPXWSException(Exception): - """ - Base exception class for HTTPX WS. - """ - - pass - - -class WebSocketUpgradeError(HTTPXWSException): - """ - Raised when the initial connection didn't correctly upgrade to a WebSocket session. - """ - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - -class WebSocketDisconnect(HTTPXWSException): - """ - Raised when the server closed the WebSocket session. - - Args: - code: - The integer close code to indicate why the connection has closed. - reason: - Additional reasoning for why the connection has closed. - """ - - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: - self.code = code - self.reason = reason or "" - - -class WebSocketInvalidTypeReceived(HTTPXWSException): - """ - Raised when a event is not of the expected type. - """ - - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event - - -class WebSocketNetworkError(HTTPXWSException): - """ - Raised when a network error occured, - typically if the underlying stream has closed or timeout. - """ - - pass - - -Scope = dict[str, typing.Any] -Message = dict[str, typing.Any] -Receive = typing.Callable[[], typing.Awaitable[Message]] -Send = typing.Callable[[Scope], typing.Coroutine[None, None, None]] -ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] - - -class ASGIWebSocketTransportError(Exception): - pass - - -class UnhandledASGIMessageType(ASGIWebSocketTransportError): - def __init__(self, message: Message) -> None: - self.message = message - - -class UnhandledWebSocketEvent(ASGIWebSocketTransportError): - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event - - -class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream): - def __init__(self, app: ASGIApp, scope: Scope) -> None: - self.app = app - self.scope = scope - self._receive_queue: queue.Queue[Message] = queue.Queue() - self._send_queue: queue.Queue[Message] = queue.Queue() - self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) - self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) - - async def __aenter__( - self, - ) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]: - self.exit_stack = contextlib.ExitStack() - self.portal = self.exit_stack.enter_context( - anyio.from_thread.start_blocking_portal("asyncio") - ) - _: Future[None] = self.portal.start_task_soon(self._run) - - await self.send({"type": "websocket.connect"}) - message = await self.receive() - - if message["type"] == "websocket.close": - await self.aclose() - raise WebSocketDisconnect(message["code"], message.get("reason")) - - assert message["type"] == "websocket.accept" - return self, self._build_accept_response(message) - - async def __aexit__(self, *args: typing.Any) -> None: - await self.aclose() - - async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: - message: Message = await self.receive(timeout=timeout) - type = message["type"] - - if type not in {"websocket.send", "websocket.close"}: - raise UnhandledASGIMessageType(message) - - event: wsproto.events.Event - if type == "websocket.send": - data_str: typing.Optional[str] = message.get("text") - if data_str is not None: - event = wsproto.events.TextMessage(data_str) - data_bytes: typing.Optional[bytes] = message.get("bytes") - if data_bytes is not None: - event = wsproto.events.BytesMessage(data_bytes) - elif type == "websocket.close": - event = wsproto.events.CloseConnection(message["code"], message["reason"]) - - return self.connection.send(event) - - async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: - self.connection.receive_data(buffer) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Request): - pass - elif isinstance(event, wsproto.events.CloseConnection): - await self.send( - { - "type": "websocket.close", - "code": event.code, - "reason": event.reason, - } - ) - elif isinstance(event, wsproto.events.TextMessage): - await self.send({"type": "websocket.receive", "text": event.data}) - elif isinstance(event, wsproto.events.BytesMessage): - await self.send({"type": "websocket.receive", "bytes": event.data}) - else: - raise UnhandledWebSocketEvent(event) - - async def aclose(self) -> None: - await self.send({"type": "websocket.close"}) - self.exit_stack.close() - - async def send(self, message: Message) -> None: - self._receive_queue.put(message) - - async def receive(self, timeout: typing.Optional[float] = None) -> Message: - while self._send_queue.empty(): - await anyio.sleep(0) - return self._send_queue.get(timeout=timeout) - - async def _run(self) -> None: - """ - The sub-thread in which the websocket session runs. - """ - scope = self.scope - receive = self._asgi_receive - send = self._asgi_send - try: - await self.app(scope, receive, send) - except Exception as e: - message = { - "type": "websocket.close", - "code": CloseReason.INTERNAL_ERROR, - "reason": str(e), - } - await self._asgi_send(message) - - async def _asgi_receive(self) -> Message: - while self._receive_queue.empty(): - await anyio.sleep(0) - return self._receive_queue.get() - - async def _asgi_send(self, message: Message) -> None: - self._send_queue.put(message) - - def _build_accept_response(self, message: Message) -> bytes: - subprotocol = message.get("subprotocol", None) - headers = message.get("headers", []) - return self.connection.send( - wsproto.events.AcceptConnection( - subprotocol=subprotocol, - extra_headers=headers, - ) - ) - - -class ASGIWebSocketTransport(ASGITransport): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None - - async def handle_async_request(self, request: Request) -> Response: - scheme = request.url.scheme - headers = request.headers - - if scheme in {"ws", "wss"} or headers.get("upgrade") == "websocket": - subprotocols: list[str] = [] - if (subprotocols_header := headers.get("sec-websocket-protocol")) is not None: - subprotocols = subprotocols_header.split(",") - - scope = { - "type": "websocket", - "path": request.url.path, - "raw_path": request.url.raw_path, - "root_path": self.root_path, - "scheme": scheme, - "query_string": request.url.query, - "headers": [(k.lower(), v) for (k, v) in request.headers.raw], - "client": self.client, - "server": (request.url.host, request.url.port), - "subprotocols": subprotocols, - } - return await self._handle_ws_request(request, scope) - - return await super().handle_async_request(request) - - async def _handle_ws_request( - self, - request: Request, - scope: Scope, - ) -> Response: - assert isinstance(request.stream, AsyncByteStream) - - self.scope = scope - self.exit_stack = contextlib.AsyncExitStack() - stream, accept_response = await self.exit_stack.enter_async_context( - ASGIWebSocketAsyncNetworkStream(self.app, self.scope) # type: ignore[arg-type] - ) - - accept_response_lines = accept_response.decode("utf-8").splitlines() - headers = [ - typing.cast(tuple[str, str], line.split(": ", 1)) - for line in accept_response_lines[1:] - if line.strip() != "" - ] - - return Response( - status_code=101, - headers=headers, - extensions={"network_stream": stream}, - ) - - async def aclose(self) -> None: - if self.exit_stack: - await self.exit_stack.aclose() From e2e71bd55597ba35856e27615879e9d1a6cf32d9 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:31:56 -0700 Subject: [PATCH 09/17] pass postgres tests --- tests/unit/server/api/test_subscriptions.py | 33 ++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/unit/server/api/test_subscriptions.py b/tests/unit/server/api/test_subscriptions.py index 611a0724ef..8288c2df00 100644 --- a/tests/unit/server/api/test_subscriptions.py +++ b/tests/unit/server/api/test_subscriptions.py @@ -3,7 +3,6 @@ from pathlib import Path from typing import Any, Dict -import pytest from openinference.semconv.trace import ( OpenInferenceMimeTypeValues, OpenInferenceSpanKindValues, @@ -260,12 +259,9 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected async def test_openai_emits_expected_payloads_and_records_expected_span_on_error( self, - dialect: str, gql_client: Any, openai_api_key: str, ) -> None: - if dialect == "postgresql": - pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ @@ -314,6 +310,10 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -327,7 +327,6 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -383,12 +382,9 @@ async def test_openai_emits_expected_payloads_and_records_expected_span_on_error async def test_openai_tool_call_response_emits_expected_payloads_and_records_expected_span( self, - dialect: str, gql_client: Any, openai_api_key: str, ) -> None: - if dialect == "postgresql": - pytest.skip("fails on postgres for unknown reason") get_current_weather_tool_schema = { "type": "function", "function": { @@ -457,6 +453,10 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -470,7 +470,6 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -539,12 +538,9 @@ async def test_openai_tool_call_response_emits_expected_payloads_and_records_exp async def test_openai_tool_call_messages_emits_expected_payloads_and_records_expected_span( self, - dialect: str, gql_client: Any, openai_api_key: str, ) -> None: - if dialect == "postgresql": - pytest.skip("fails on postgres for unknown reason") tool_call_id = "call_zz1hkqH3IakqnHfVhrrUemlQ" tool_calls = [ { @@ -610,6 +606,10 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -623,7 +623,6 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None @@ -702,12 +701,9 @@ async def test_openai_tool_call_messages_emits_expected_payloads_and_records_exp async def test_anthropic_text_response_emits_expected_payloads_and_records_expected_span( self, - dialect: str, gql_client: Any, anthropic_api_key: str, ) -> None: - if dialect == "postgresql": - pytest.skip("fails on postgres for unknown reason") variables = { "input": { "messages": [ @@ -756,6 +752,10 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery" ) span = data["span"] + assert json.loads(attributes := span.pop("attributes")) == json.loads( + subscription_span.pop("attributes") + ) + attributes = dict(flatten(json.loads(attributes))) assert span == subscription_span # check attributes @@ -769,7 +769,6 @@ async def test_anthropic_text_response_emits_expected_payloads_and_records_expec assert span.pop("parentId") is None assert span.pop("spanKind") == "llm" assert (context := span.pop("context")).pop("spanId") - assert (attributes := dict(flatten(json.loads(span.pop("attributes"))))) assert context.pop("traceId") assert not context assert span.pop("metadata") is None From 843035caec63cd9349f59852384e70bbb7814b5d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:32:45 -0700 Subject: [PATCH 10/17] remove httpx-ws as a unit test dep --- requirements/unit-tests.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements/unit-tests.txt b/requirements/unit-tests.txt index a9abadf7b5..d0451624af 100644 --- a/requirements/unit-tests.txt +++ b/requirements/unit-tests.txt @@ -6,7 +6,6 @@ asgi-lifespan asyncpg grpc-interceptor[testing] httpx # For OpenAI testing -httpx-ws litellm>=1.0.3 nest-asyncio # for executor testing numpy From b28de37f25870c9e7077e2c2dfa42220ad6534bd Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:33:10 -0700 Subject: [PATCH 11/17] remove fail fast --- .github/workflows/python-CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/python-CI.yml b/.github/workflows/python-CI.yml index a34110d195..8dca479b4e 100644 --- a/.github/workflows/python-CI.yml +++ b/.github/workflows/python-CI.yml @@ -336,7 +336,6 @@ jobs: needs: changes if: ${{ needs.changes.outputs.phoenix == 'true' }} strategy: - fail-fast: false matrix: py: [3.9, 3.12] os: [ubuntu-latest, windows-latest, macos-13] From db1cb6536f4c2159711eb266eccb685236eb931f Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:33:45 -0700 Subject: [PATCH 12/17] remove experimental query --- src/phoenix/server/api/subscriptions.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/phoenix/server/api/subscriptions.py b/src/phoenix/server/api/subscriptions.py index 48ad5148d5..80eb55b9d3 100644 --- a/src/phoenix/server/api/subscriptions.py +++ b/src/phoenix/server/api/subscriptions.py @@ -438,18 +438,6 @@ class Subscription: async def chat_completion( self, info: Info[Context, None], input: ChatCompletionInput ) -> AsyncIterator[ChatCompletionSubscriptionPayload]: - async with info.context.db() as session: - if ( - playground_project_id := ( - await session.scalar( - select(models.Project.id).where( - models.Project.name == PLAYGROUND_PROJECT_NAME - ) - ) - ) - ) is None: - print("Creating playground project") - print("Creating playground project") # Determine which LLM client to use based on provider_key provider_key = input.model.provider_key if (llm_client_class := PLAYGROUND_STREAMING_CLIENT_REGISTRY.get(provider_key)) is None: From ce7eb6f1d72a25d78ac0d94189959ab215cd4853 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:53:30 -0700 Subject: [PATCH 13/17] add wsproto to unit tests deps --- requirements/unit-tests.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements/unit-tests.txt b/requirements/unit-tests.txt index d0451624af..2936849fa7 100644 --- a/requirements/unit-tests.txt +++ b/requirements/unit-tests.txt @@ -23,3 +23,4 @@ tiktoken types-pytz typing-extensions==4.7.0 vcrpy +wsproto From e187118dd57e5bc7e55e396ac9cdf8e289b12615 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 14:53:49 -0700 Subject: [PATCH 14/17] move into vendor directory --- tests/unit/conftest.py | 4 ++-- tests/unit/vendor/README.md | 3 +++ tests/unit/vendor/__init__.py | 0 tests/unit/vendor/httpx_ws/README.md | 3 +++ tests/unit/{ => vendor}/httpx_ws/__init__.py | 0 tests/unit/{ => vendor}/httpx_ws/_api.py | 0 tests/unit/{ => vendor}/httpx_ws/_exceptions.py | 0 tests/unit/{ => vendor}/httpx_ws/_ping.py | 0 tests/unit/{ => vendor}/httpx_ws/transport.py | 0 9 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 tests/unit/vendor/README.md create mode 100644 tests/unit/vendor/__init__.py create mode 100644 tests/unit/vendor/httpx_ws/README.md rename tests/unit/{ => vendor}/httpx_ws/__init__.py (100%) rename tests/unit/{ => vendor}/httpx_ws/_api.py (100%) rename tests/unit/{ => vendor}/httpx_ws/_exceptions.py (100%) rename tests/unit/{ => vendor}/httpx_ws/_ping.py (100%) rename tests/unit/{ => vendor}/httpx_ws/transport.py (100%) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d9fa53cf8e..fe817d5e22 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -51,8 +51,8 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span -from tests.unit.httpx_ws import AsyncWebSocketSession, aconnect_ws -from tests.unit.httpx_ws.transport import ASGIWebSocketTransport +from tests.unit.vendor.httpx_ws import AsyncWebSocketSession, aconnect_ws +from tests.unit.vendor.httpx_ws.transport import ASGIWebSocketTransport def pytest_terminal_summary( diff --git a/tests/unit/vendor/README.md b/tests/unit/vendor/README.md new file mode 100644 index 0000000000..10eea305f4 --- /dev/null +++ b/tests/unit/vendor/README.md @@ -0,0 +1,3 @@ +# Unit Test Vendored Dependencies + +This directory contains vendored dependencies used for unit testing. diff --git a/tests/unit/vendor/__init__.py b/tests/unit/vendor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/vendor/httpx_ws/README.md b/tests/unit/vendor/httpx_ws/README.md new file mode 100644 index 0000000000..2be5cc0363 --- /dev/null +++ b/tests/unit/vendor/httpx_ws/README.md @@ -0,0 +1,3 @@ +# HTTPX-WS + +This directory contains a copy of [httpx-ws](https://github.com/frankie567/httpx-ws), which is published under an [MIT license](https://github.com/frankie567/httpx-ws/blob/main/LICENSE). Modifications have been made to better support the concurrency paradigm used in our unit test suite. diff --git a/tests/unit/httpx_ws/__init__.py b/tests/unit/vendor/httpx_ws/__init__.py similarity index 100% rename from tests/unit/httpx_ws/__init__.py rename to tests/unit/vendor/httpx_ws/__init__.py diff --git a/tests/unit/httpx_ws/_api.py b/tests/unit/vendor/httpx_ws/_api.py similarity index 100% rename from tests/unit/httpx_ws/_api.py rename to tests/unit/vendor/httpx_ws/_api.py diff --git a/tests/unit/httpx_ws/_exceptions.py b/tests/unit/vendor/httpx_ws/_exceptions.py similarity index 100% rename from tests/unit/httpx_ws/_exceptions.py rename to tests/unit/vendor/httpx_ws/_exceptions.py diff --git a/tests/unit/httpx_ws/_ping.py b/tests/unit/vendor/httpx_ws/_ping.py similarity index 100% rename from tests/unit/httpx_ws/_ping.py rename to tests/unit/vendor/httpx_ws/_ping.py diff --git a/tests/unit/httpx_ws/transport.py b/tests/unit/vendor/httpx_ws/transport.py similarity index 100% rename from tests/unit/httpx_ws/transport.py rename to tests/unit/vendor/httpx_ws/transport.py From 0dfb5ccde9caf67116a4476b8bb35a550c49052a Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 15:12:53 -0700 Subject: [PATCH 15/17] fix types --- tests/unit/vendor/httpx_ws/_api.py | 13 +++++++++---- tests/unit/vendor/httpx_ws/_ping.py | 4 ++-- tests/unit/vendor/httpx_ws/transport.py | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/tests/unit/vendor/httpx_ws/_api.py b/tests/unit/vendor/httpx_ws/_api.py index f1e967e6fd..cd4385ee21 100644 --- a/tests/unit/vendor/httpx_ws/_api.py +++ b/tests/unit/vendor/httpx_ws/_api.py @@ -118,7 +118,12 @@ def __enter__(self) -> "WebSocketSession": return self - def __exit__(self, exc_type, exc, tb): + def __exit__( + self, + exc_type: typing.Optional[type[BaseException]], + exc: typing.Optional[BaseException], + tb: typing.Optional[TracebackType], + ) -> None: self.close() self._background_receive_task.join() if self._background_keepalive_ping_task is not None: @@ -416,7 +421,7 @@ def receive_json( data = self.receive_bytes(timeout) return json.loads(data) - def close(self, code: int = 1000, reason: typing.Optional[str] = None): + def close(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: """ Close the WebSocket session. @@ -522,7 +527,7 @@ def _background_keepalive_ping( pass def _wait_until_closed( - self, callable: typing.Callable[..., TaskResult], *args, **kwargs + self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any ) -> TaskResult: try: executor, should_close_task = self._get_executor_should_close_task() @@ -921,7 +926,7 @@ async def receive_json( data = await self.receive_bytes(timeout) return json.loads(data) - async def close(self, code: int = 1000, reason: typing.Optional[str] = None): + async def close(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: """ Close the WebSocket session. diff --git a/tests/unit/vendor/httpx_ws/_ping.py b/tests/unit/vendor/httpx_ws/_ping.py index 2b4e7f24db..434e78e03c 100644 --- a/tests/unit/vendor/httpx_ws/_ping.py +++ b/tests/unit/vendor/httpx_ws/_ping.py @@ -20,7 +20,7 @@ def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, threadi self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: typing.Union[bytes, bytearray]): + def ack(self, ping_id: typing.Union[bytes, bytearray]) -> None: event = self._pings.pop(bytes(ping_id)) event.set() @@ -35,6 +35,6 @@ def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, anyio.E self._pings[ping_id] = event return ping_id, event - def ack(self, ping_id: typing.Union[bytes, bytearray]): + def ack(self, ping_id: typing.Union[bytes, bytearray]) -> None: event = self._pings.pop(bytes(ping_id)) event.set() diff --git a/tests/unit/vendor/httpx_ws/transport.py b/tests/unit/vendor/httpx_ws/transport.py index 9783b1f20d..86813a6b73 100644 --- a/tests/unit/vendor/httpx_ws/transport.py +++ b/tests/unit/vendor/httpx_ws/transport.py @@ -38,7 +38,7 @@ def __init__(self, app: ASGIApp, scope: Scope) -> None: self._send_queue: asyncio.Queue[Message] = asyncio.Queue() self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER) self.connection.initiate_upgrade_connection(scope["headers"], scope["path"]) - self.tasks: list[asyncio.Task] = [] + self.tasks: list[asyncio.Task[None]] = [] async def __aenter__( self, @@ -65,7 +65,7 @@ async def __aexit__(self, *args: typing.Any) -> None: await self.aclose() await self.exit_stack.aclose() - async def _cancel_tasks(self): + async def _cancel_tasks(self) -> None: # Cancel all running tasks for task in self.tasks: task.cancel() @@ -166,7 +166,7 @@ def _build_accept_response(self, message: Message) -> bytes: class ASGIWebSocketTransport(ASGITransport): - def __init__(self, *args, **kwargs) -> None: + def __init__(self, *args: typing.Any, **kwargs: typing.Any) -> None: super().__init__(*args, **kwargs) self.exit_stack: typing.Optional[contextlib.AsyncExitStack] = None From 6d1b43f89646520a23146c88d9ccdf942f6e8b9c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 15:43:04 -0700 Subject: [PATCH 16/17] condense to a single file --- requirements/unit-tests.txt | 2 +- tests/unit/conftest.py | 2 +- tests/unit/vendor/httpx_ws/transport.py | 26 +++++++++++++++++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/requirements/unit-tests.txt b/requirements/unit-tests.txt index 2936849fa7..a9abadf7b5 100644 --- a/requirements/unit-tests.txt +++ b/requirements/unit-tests.txt @@ -6,6 +6,7 @@ asgi-lifespan asyncpg grpc-interceptor[testing] httpx # For OpenAI testing +httpx-ws litellm>=1.0.3 nest-asyncio # for executor testing numpy @@ -23,4 +24,3 @@ tiktoken types-pytz typing-extensions==4.7.0 vcrpy -wsproto diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index fe817d5e22..931d72fa34 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -30,6 +30,7 @@ from asgi_lifespan import LifespanManager from faker import Faker from httpx import AsyncByteStream, Request, Response +from httpx_ws import AsyncWebSocketSession, aconnect_ws from psycopg import Connection from pytest_postgresql import factories from sqlalchemy import URL, make_url @@ -51,7 +52,6 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span -from tests.unit.vendor.httpx_ws import AsyncWebSocketSession, aconnect_ws from tests.unit.vendor.httpx_ws.transport import ASGIWebSocketTransport diff --git a/tests/unit/vendor/httpx_ws/transport.py b/tests/unit/vendor/httpx_ws/transport.py index 86813a6b73..abc97aa1f0 100644 --- a/tests/unit/vendor/httpx_ws/transport.py +++ b/tests/unit/vendor/httpx_ws/transport.py @@ -7,8 +7,6 @@ from httpx import ASGITransport, AsyncByteStream, Request, Response from wsproto.frame_protocol import CloseReason -from ._exceptions import WebSocketDisconnect - Scope = dict[str, typing.Any] Message = dict[str, typing.Any] Receive = typing.Callable[[], typing.Awaitable[Message]] @@ -16,6 +14,30 @@ ASGIApp = typing.Callable[[Scope, Receive, Send], typing.Coroutine[None, None, None]] +class HTTPXWSException(Exception): + """ + Base exception class for HTTPX WS. + """ + + pass + + +class WebSocketDisconnect(HTTPXWSException): + """ + Raised when the server closed the WebSocket session. + + Args: + code: + The integer close code to indicate why the connection has closed. + reason: + Additional reasoning for why the connection has closed. + """ + + def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: + self.code = code + self.reason = reason or "" + + class ASGIWebSocketTransportError(Exception): pass From 6f23672f6bc99ded968f4d363d4e1142f399a528 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 28 Oct 2024 15:45:39 -0700 Subject: [PATCH 17/17] remove other vendored code --- tests/unit/conftest.py | 2 +- tests/unit/{vendor/httpx_ws => }/transport.py | 8 + tests/unit/vendor/README.md | 3 - tests/unit/vendor/__init__.py | 0 tests/unit/vendor/httpx_ws/README.md | 3 - tests/unit/vendor/httpx_ws/__init__.py | 29 - tests/unit/vendor/httpx_ws/_api.py | 1297 ----------------- tests/unit/vendor/httpx_ws/_exceptions.py | 55 - tests/unit/vendor/httpx_ws/_ping.py | 40 - 9 files changed, 9 insertions(+), 1428 deletions(-) rename tests/unit/{vendor/httpx_ws => }/transport.py (96%) delete mode 100644 tests/unit/vendor/README.md delete mode 100644 tests/unit/vendor/__init__.py delete mode 100644 tests/unit/vendor/httpx_ws/README.md delete mode 100644 tests/unit/vendor/httpx_ws/__init__.py delete mode 100644 tests/unit/vendor/httpx_ws/_api.py delete mode 100644 tests/unit/vendor/httpx_ws/_exceptions.py delete mode 100644 tests/unit/vendor/httpx_ws/_ping.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 931d72fa34..2f8881d647 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -52,7 +52,7 @@ from phoenix.server.types import BatchedCaller, DbSessionFactory from phoenix.session.client import Client from phoenix.trace.schemas import Span -from tests.unit.vendor.httpx_ws.transport import ASGIWebSocketTransport +from tests.unit.transport import ASGIWebSocketTransport def pytest_terminal_summary( diff --git a/tests/unit/vendor/httpx_ws/transport.py b/tests/unit/transport.py similarity index 96% rename from tests/unit/vendor/httpx_ws/transport.py rename to tests/unit/transport.py index abc97aa1f0..b93a998928 100644 --- a/tests/unit/vendor/httpx_ws/transport.py +++ b/tests/unit/transport.py @@ -1,3 +1,11 @@ +""" +This file contains a copy of [httpx-ws](https://github.com/frankie567/httpx-ws), +which is published under an [MIT +license](https://github.com/frankie567/httpx-ws/blob/main/LICENSE). +Modifications have been made to better support the concurrency paradigm used in +our unit test suite. +""" + import asyncio import contextlib import typing diff --git a/tests/unit/vendor/README.md b/tests/unit/vendor/README.md deleted file mode 100644 index 10eea305f4..0000000000 --- a/tests/unit/vendor/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# Unit Test Vendored Dependencies - -This directory contains vendored dependencies used for unit testing. diff --git a/tests/unit/vendor/__init__.py b/tests/unit/vendor/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/tests/unit/vendor/httpx_ws/README.md b/tests/unit/vendor/httpx_ws/README.md deleted file mode 100644 index 2be5cc0363..0000000000 --- a/tests/unit/vendor/httpx_ws/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# HTTPX-WS - -This directory contains a copy of [httpx-ws](https://github.com/frankie567/httpx-ws), which is published under an [MIT license](https://github.com/frankie567/httpx-ws/blob/main/LICENSE). Modifications have been made to better support the concurrency paradigm used in our unit test suite. diff --git a/tests/unit/vendor/httpx_ws/__init__.py b/tests/unit/vendor/httpx_ws/__init__.py deleted file mode 100644 index 2ae6b1b843..0000000000 --- a/tests/unit/vendor/httpx_ws/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -__version__ = "0.6.2" - -from ._api import ( - AsyncWebSocketSession, - JSONMode, - WebSocketSession, - aconnect_ws, - connect_ws, -) -from ._exceptions import ( - HTTPXWSException, - WebSocketDisconnect, - WebSocketInvalidTypeReceived, - WebSocketNetworkError, - WebSocketUpgradeError, -) - -__all__ = [ - "AsyncWebSocketSession", - "HTTPXWSException", - "JSONMode", - "WebSocketDisconnect", - "WebSocketInvalidTypeReceived", - "WebSocketNetworkError", - "WebSocketSession", - "WebSocketUpgradeError", - "aconnect_ws", - "connect_ws", -] diff --git a/tests/unit/vendor/httpx_ws/_api.py b/tests/unit/vendor/httpx_ws/_api.py deleted file mode 100644 index cd4385ee21..0000000000 --- a/tests/unit/vendor/httpx_ws/_api.py +++ /dev/null @@ -1,1297 +0,0 @@ -import base64 -import concurrent.futures -import contextlib -import json -import queue -import secrets -import threading -import typing -from types import TracebackType - -import anyio -import httpcore -import httpx -import wsproto -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from httpcore import AsyncNetworkStream, NetworkStream -from wsproto.frame_protocol import CloseReason - -from ._exceptions import ( - HTTPXWSException, - WebSocketDisconnect, - WebSocketInvalidTypeReceived, - WebSocketNetworkError, - WebSocketUpgradeError, -) -from ._ping import AsyncPingManager, PingManager -from .transport import ASGIWebSocketAsyncNetworkStream - -JSONMode = typing.Literal["text", "binary"] -TaskFunction = typing.TypeVar("TaskFunction") -TaskResult = typing.TypeVar("TaskResult") - -DEFAULT_MAX_MESSAGE_SIZE_BYTES = 65_536 -DEFAULT_QUEUE_SIZE = 512 -DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS = 20.0 -DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = 20.0 - - -class ShouldClose(Exception): - pass - - -class WebSocketSession: - """ - Sync context manager representing an opened WebSocket session. - - Attributes: - subprotocol (typing.Optional[str]): - Optional protocol that has been accepted by the server. - response (typing.Optional[httpx.Response]): - The webSocket handshake response. - """ - - subprotocol: typing.Optional[str] - response: typing.Optional[httpx.Response] - - def __init__( - self, - stream: NetworkStream, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - response: typing.Optional[httpx.Response] = None, - ) -> None: - self.stream = stream - self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) - self.response = response - if self.response is not None: - self.subprotocol = self.response.headers.get("sec-websocket-protocol") - else: - self.subprotocol = None - - self._events: queue.Queue[typing.Union[wsproto.events.Event, HTTPXWSException]] = ( - queue.Queue(queue_size) - ) - - self._ping_manager = PingManager() - self._should_close = threading.Event() - self._should_close_task: typing.Optional[concurrent.futures.Future[bool]] = None - self._executor: typing.Optional[concurrent.futures.ThreadPoolExecutor] = None - - self._max_message_size_bytes = max_message_size_bytes - self._queue_size = queue_size - self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds - self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds - - def _get_executor_should_close_task( - self, - ) -> tuple[concurrent.futures.ThreadPoolExecutor, "concurrent.futures.Future[bool]"]: - if self._should_close_task is None: - self._executor = concurrent.futures.ThreadPoolExecutor() - self._should_close_task = self._executor.submit(self._should_close.wait) - assert self._executor is not None - return self._executor, self._should_close_task - - def __enter__(self) -> "WebSocketSession": - self._background_receive_task = threading.Thread( - target=self._background_receive, args=(self._max_message_size_bytes,) - ) - self._background_receive_task.start() - - self._background_keepalive_ping_task: typing.Optional[threading.Thread] = None - if self._keepalive_ping_interval_seconds is not None: - self._background_keepalive_ping_task = threading.Thread( - target=self._background_keepalive_ping, - args=( - self._keepalive_ping_interval_seconds, - self._keepalive_ping_timeout_seconds, - ), - ) - self._background_keepalive_ping_task.start() - - return self - - def __exit__( - self, - exc_type: typing.Optional[type[BaseException]], - exc: typing.Optional[BaseException], - tb: typing.Optional[TracebackType], - ) -> None: - self.close() - self._background_receive_task.join() - if self._background_keepalive_ping_task is not None: - self._background_keepalive_ping_task.join() - - def ping(self, payload: bytes = b"") -> threading.Event: - """ - Send a Ping message. - - Args: - payload: - Payload to attach to the Ping event. - Internally, it's used to track this specific event. - If left empty, a random one will be generated. - - Returns: - An event that can be used to wait for the corresponding Pong response. - - Examples: - Send a Ping and wait for the Pong - - pong_callback = ws.ping() - # Will block until the corresponding Pong is received. - pong_callback.wait() - """ - ping_id, callback = self._ping_manager.create(payload) - event = wsproto.events.Ping(ping_id) - self.send(event) - return callback - - def send(self, event: wsproto.events.Event) -> None: - """ - Send an Event message. - - Mainly useful to send events that are not supported by the library. - Most of the time, [ping()][httpx_ws.WebSocketSession.ping], - [send_text()][httpx_ws.WebSocketSession.send_text], - [send_bytes()][httpx_ws.WebSocketSession.send_bytes] - and [send_json()][httpx_ws.WebSocketSession.send_json] are preferred. - - Args: - event: The event to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send an event. - - event = wsproto.events.Message(b"Hello!") - ws.send(event) - """ - try: - data = self.connection.send(event) - self.stream.write(data) - except httpcore.WriteError as e: - self.close(CloseReason.INTERNAL_ERROR, "Stream write error") - raise WebSocketNetworkError() from e - - def send_text(self, data: str) -> None: - """ - Send a text message. - - Args: - data: The text to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send a text message. - - ws.send_text("Hello!") - """ - event = wsproto.events.TextMessage(data=data) - self.send(event) - - def send_bytes(self, data: bytes) -> None: - """ - Send a bytes message. - - Args: - data: The data to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send a bytes message. - - ws.send_bytes(b"Hello!") - """ - event = wsproto.events.BytesMessage(data=data) - self.send(event) - - def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: - """ - Send JSON data. - - Args: - data: - The data to send. Must be serializable by [json.dumps][json.dumps]. - mode: - The sending mode. Should either be `'text'` or `'bytes'`. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send JSON data. - - data = {"message": "Hello!"} - ws.send_json(data) - """ - assert mode in ["text", "binary"] - serialized_data = json.dumps(data) - if mode == "text": - self.send_text(serialized_data) - else: - self.send_bytes(serialized_data.encode("utf-8")) - - def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: - """ - Receive an event from the server. - - Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. - Most of the time, [receive_text()][httpx_ws.WebSocketSession.receive_text], - [receive_bytes()][httpx_ws.WebSocketSession.receive_bytes], - and [receive_json()][httpx_ws.WebSocketSession.receive_json] are preferred. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - A raw [wsproto.events.Event][wsproto.events.Event]. - - Raises: - queue.Empty: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - - Examples: - Wait for an event until one is available. - - try: - event = ws.receive() - except WebSocketDisconnect: - print("Connection closed") - - Wait for an event for 2 seconds. - - try: - event = ws.receive(timeout=2.) - except queue.Empty: - print("No event received.") - except WebSocketDisconnect: - print("Connection closed") - """ - event = self._events.get(block=True, timeout=timeout) - if isinstance(event, HTTPXWSException): - raise event - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) - return event - - def receive_text(self, timeout: typing.Optional[float] = None) -> str: - """ - Receive text from the server. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - Text data. - - Raises: - queue.Empty: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event was not a text message. - - Examples: - Wait for text until available. - - try: - text = ws.receive_text() - except WebSocketDisconnect: - print("Connection closed") - - Wait for text for 2 seconds. - - try: - event = ws.receive_text(timeout=2.) - except queue.Empty: - print("No text received.") - except WebSocketDisconnect: - print("Connection closed") - """ - event = self.receive(timeout) - if isinstance(event, wsproto.events.TextMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) - - def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: - """ - Receive bytes from the server. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - Bytes data. - - Raises: - queue.Empty: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event was not a bytes message. - - Examples: - Wait for bytes until available. - - try: - data = ws.receive_bytes() - except WebSocketDisconnect: - print("Connection closed") - - Wait for bytes for 2 seconds. - - try: - data = ws.receive_bytes(timeout=2.) - except queue.Empty: - print("No data received.") - except WebSocketDisconnect: - print("Connection closed") - """ - event = self.receive(timeout) - if isinstance(event, wsproto.events.BytesMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) - - def receive_json( - self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" - ) -> typing.Any: - """ - Receive JSON data from the server. - - The received data should be parseable by [json.loads][json.loads]. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - mode: - Receive mode. Should either be `'text'` or `'bytes'`. - - Returns: - Parsed JSON data. - - Raises: - queue.Empty: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event - didn't correspond to the specified mode. - - Examples: - Wait for data until available. - - try: - data = ws.receive_json() - except WebSocketDisconnect: - print("Connection closed") - - Wait for data for 2 seconds. - - try: - data = ws.receive_json(timeout=2.) - except queue.Empty: - print("No data received.") - except WebSocketDisconnect: - print("Connection closed") - """ - assert mode in ["text", "binary"] - data: typing.Union[str, bytes] - if mode == "text": - data = self.receive_text(timeout) - elif mode == "binary": - data = self.receive_bytes(timeout) - return json.loads(data) - - def close(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: - """ - Close the WebSocket session. - - Internally, it'll send the - [CloseConnection][wsproto.events.CloseConnection] event. - - *This method is automatically called when exiting the context manager.* - - Args: - code: - The integer close code to indicate why the connection has closed. - reason: - Additional reasoning for why the connection has closed. - - Examples: - Close the WebSocket session. - - ws.close() - """ - self._should_close.set() - if self._executor is not None: - self._executor.shutdown(False) - if self.connection.state not in { - wsproto.connection.ConnectionState.LOCAL_CLOSING, - wsproto.connection.ConnectionState.CLOSED, - }: - event = wsproto.events.CloseConnection(code, reason) - data = self.connection.send(event) - try: - self.stream.write(data) - except httpcore.WriteError: - pass - self.stream.close() - - def _background_receive(self, max_bytes: int) -> None: - """ - Background thread listening for data from the server. - - Internally, it'll: - - * Answer to Ping events. - * Acknowledge Pong events. - * Put other events in the [_events][_events] - queue that'll eventually be consumed by the user. - - Args: - max_bytes: The maximum chunk size to read at each iteration. - """ - partial_message_buffer: typing.Union[str, bytes, None] = None - try: - while not self._should_close.is_set(): - data = self._wait_until_closed(self.stream.read, max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): - data = self.connection.send(event.response()) - self.stream.write(data) - continue - if isinstance(event, wsproto.events.Pong): - self._ping_manager.ack(event.payload) - continue - if isinstance(event, wsproto.events.CloseConnection): - self._should_close.set() - if isinstance(event, wsproto.events.Message): - # Unfinished message: bufferize - if not event.message_finished: - if partial_message_buffer is None: - partial_message_buffer = event.data - else: - partial_message_buffer += event.data - # Finished message but no buffer: just emit the event - elif partial_message_buffer is None: - self._events.put(event) - # Finished message with buffer: emit the full event - else: - event_type = type(event) - full_message_event = event_type(partial_message_buffer + event.data) - partial_message_buffer = None - self._events.put(full_message_event) - continue - self._events.put(event) - except (httpcore.ReadError, httpcore.WriteError): - self.close(CloseReason.INTERNAL_ERROR, "Stream error") - self._events.put(WebSocketNetworkError()) - except ShouldClose: - pass - - def _background_keepalive_ping( - self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None - ) -> None: - try: - while not self._should_close.is_set(): - should_close = self._wait_until_closed(self._should_close.wait, interval_seconds) - if should_close: - raise ShouldClose() - pong_callback = self.ping() - if timeout_seconds is not None: - acknowledged = self._wait_until_closed(pong_callback.wait, timeout_seconds) - if not acknowledged: - self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") - self._events.put(WebSocketNetworkError()) - except ShouldClose: - pass - - def _wait_until_closed( - self, callable: typing.Callable[..., TaskResult], *args: typing.Any, **kwargs: typing.Any - ) -> TaskResult: - try: - executor, should_close_task = self._get_executor_should_close_task() - todo_task = executor.submit(callable, *args, **kwargs) - except RuntimeError as e: - raise ShouldClose() from e - else: - done, _ = concurrent.futures.wait( - (todo_task, should_close_task), # type: ignore[misc] - return_when=concurrent.futures.FIRST_COMPLETED, - ) - if should_close_task in done: - raise ShouldClose() - assert todo_task in done - result = todo_task.result() - return result - - -class AsyncWebSocketSession: - """ - Async context manager representing an opened WebSocket session. - - Attributes: - subprotocol (typing.Optional[str]): - Optional protocol that has been accepted by the server. - response (typing.Optional[httpx.Response]): - The webSocket handshake response. - """ - - subprotocol: typing.Optional[str] - response: typing.Optional[httpx.Response] - _send_event: MemoryObjectSendStream[typing.Union[wsproto.events.Event, HTTPXWSException]] - _receive_event: MemoryObjectReceiveStream[typing.Union[wsproto.events.Event, HTTPXWSException]] - - def __init__( - self, - stream: AsyncNetworkStream, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - response: typing.Optional[httpx.Response] = None, - ) -> None: - self.stream = stream - self.connection = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) - self.response = response - if self.response is not None: - self.subprotocol = self.response.headers.get("sec-websocket-protocol") - else: - self.subprotocol = None - - self._ping_manager = AsyncPingManager() - self._should_close = anyio.Event() - - self._max_message_size_bytes = max_message_size_bytes - self._queue_size = queue_size - - # Always disable keepalive ping when emulating ASGI - if isinstance(stream, ASGIWebSocketAsyncNetworkStream): - self._keepalive_ping_interval_seconds = None - self._keepalive_ping_timeout_seconds = None - else: - self._keepalive_ping_interval_seconds = keepalive_ping_interval_seconds - self._keepalive_ping_timeout_seconds = keepalive_ping_timeout_seconds - - async def __aenter__(self) -> "AsyncWebSocketSession": - async with contextlib.AsyncExitStack() as exit_stack: - self._send_event, self._receive_event = anyio.create_memory_object_stream[ - typing.Union[wsproto.events.Event, HTTPXWSException] - ]() - exit_stack.enter_context(self._send_event) - exit_stack.enter_context(self._receive_event) - - self._background_task_group = anyio.create_task_group() - await exit_stack.enter_async_context(self._background_task_group) - - self._background_task_group.start_soon( - self._background_receive, self._max_message_size_bytes - ) - if self._keepalive_ping_interval_seconds is not None: - self._background_task_group.start_soon( - self._background_keepalive_ping, - self._keepalive_ping_interval_seconds, - self._keepalive_ping_timeout_seconds, - ) - - exit_stack.callback(self._background_task_group.cancel_scope.cancel) - exit_stack.push_async_callback(self.close) - self._exit_stack = exit_stack.pop_all() - - return self - - async def __aexit__( - self, - exc_type: typing.Optional[type[BaseException]], - exc: typing.Optional[BaseException], - tb: typing.Optional[TracebackType], - ) -> None: - await self._exit_stack.aclose() - - async def ping(self, payload: bytes = b"") -> anyio.Event: - """ - Send a Ping message. - - Args: - payload: - Payload to attach to the Ping event. - Internally, it's used to track this specific event. - If left empty, a random one will be generated. - - Returns: - An event that can be used to wait for the corresponding Pong response. - - Examples: - Send a Ping and wait for the Pong - - pong_callback = await ws.ping() - # Will block until the corresponding Pong is received. - await pong_callback.wait() - """ - ping_id, callback = self._ping_manager.create(payload) - event = wsproto.events.Ping(ping_id) - await self.send(event) - return callback - - async def send(self, event: wsproto.events.Event) -> None: - """ - Send an Event message. - - Mainly useful to send events that are not supported by the library. - Most of the time, [ping()][httpx_ws.AsyncWebSocketSession.ping], - [send_text()][httpx_ws.AsyncWebSocketSession.send_text], - [send_bytes()][httpx_ws.AsyncWebSocketSession.send_bytes] - and [send_json()][httpx_ws.AsyncWebSocketSession.send_json] are preferred. - - Args: - event: The event to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send an event. - - event = await wsproto.events.Message(b"Hello!") - ws.send(event) - """ - try: - data = self.connection.send(event) - await self.stream.write(data) - except httpcore.WriteError as e: - await self.close(CloseReason.INTERNAL_ERROR, "Stream write error") - raise WebSocketNetworkError() from e - - async def send_text(self, data: str) -> None: - """ - Send a text message. - - Args: - data: The text to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send a text message. - - await ws.send_text("Hello!") - """ - event = wsproto.events.TextMessage(data=data) - await self.send(event) - - async def send_bytes(self, data: bytes) -> None: - """ - Send a bytes message. - - Args: - data: The data to send. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send a bytes message. - - await ws.send_bytes(b"Hello!") - """ - event = wsproto.events.BytesMessage(data=data) - await self.send(event) - - async def send_json(self, data: typing.Any, mode: JSONMode = "text") -> None: - """ - Send JSON data. - - Args: - data: - The data to send. Must be serializable by [json.dumps][json.dumps]. - mode: - The sending mode. Should either be `'text'` or `'bytes'`. - - Raises: - WebSocketNetworkError: A network error occured. - - Examples: - Send JSON data. - - data = {"message": "Hello!"} - await ws.send_json(data) - """ - assert mode in ["text", "binary"] - serialized_data = json.dumps(data) - if mode == "text": - await self.send_text(serialized_data) - else: - await self.send_bytes(serialized_data.encode("utf-8")) - - async def receive(self, timeout: typing.Optional[float] = None) -> wsproto.events.Event: - """ - Receive an event from the server. - - Mainly useful to receive raw [wsproto.events.Event][wsproto.events.Event]. - Most of the time, [receive_text()][httpx_ws.AsyncWebSocketSession.receive_text], - [receive_bytes()][httpx_ws.AsyncWebSocketSession.receive_bytes], - and [receive_json()][httpx_ws.AsyncWebSocketSession.receive_json] are preferred. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - A raw [wsproto.events.Event][wsproto.events.Event]. - - Raises: - TimeoutError: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - - Examples: - Wait for an event until one is available. - - try: - event = await ws.receive() - except WebSocketDisconnect: - print("Connection closed") - - Wait for an event for 2 seconds. - - try: - event = await ws.receive(timeout=2.) - except TimeoutError: - print("No event received.") - except WebSocketDisconnect: - print("Connection closed") - """ - with anyio.fail_after(timeout): - event = await self._receive_event.receive() - if isinstance(event, HTTPXWSException): - raise event - if isinstance(event, wsproto.events.CloseConnection): - raise WebSocketDisconnect(event.code, event.reason) - return event - - async def receive_text(self, timeout: typing.Optional[float] = None) -> str: - """ - Receive text from the server. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - Text data. - - Raises: - TimeoutError: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event was not a text message. - - Examples: - Wait for text until available. - - try: - text = await ws.receive_text() - except WebSocketDisconnect: - print("Connection closed") - - Wait for text for 2 seconds. - - try: - event = await ws.receive_text(timeout=2.) - except TimeoutError: - print("No text received.") - except WebSocketDisconnect: - print("Connection closed") - """ - event = await self.receive(timeout) - if isinstance(event, wsproto.events.TextMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) - - async def receive_bytes(self, timeout: typing.Optional[float] = None) -> bytes: - """ - Receive bytes from the server. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - - Returns: - Bytes data. - - Raises: - TimeoutError: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event was not a bytes message. - - Examples: - Wait for bytes until available. - - try: - data = await ws.receive_bytes() - except WebSocketDisconnect: - print("Connection closed") - - Wait for bytes for 2 seconds. - - try: - data = await ws.receive_bytes(timeout=2.) - except TimeoutError: - print("No data received.") - except WebSocketDisconnect: - print("Connection closed") - """ - event = await self.receive(timeout) - if isinstance(event, wsproto.events.BytesMessage): - return event.data - raise WebSocketInvalidTypeReceived(event) - - async def receive_json( - self, timeout: typing.Optional[float] = None, mode: JSONMode = "text" - ) -> typing.Any: - """ - Receive JSON data from the server. - - The received data should be parseable by [json.loads][json.loads]. - - Args: - timeout: - Number of seconds to wait for an event. - If `None`, will block until an event is available. - mode: - Receive mode. Should either be `'text'` or `'bytes'`. - - Returns: - Parsed JSON data. - - Raises: - TimeoutError: No event was received before the timeout delay. - WebSocketDisconnect: The server closed the websocket. - WebSocketNetworkError: A network error occured. - WebSocketInvalidTypeReceived: The received event - didn't correspond to the specified mode. - - Examples: - Wait for data until available. - - try: - data = await ws.receive_json() - except WebSocketDisconnect: - print("Connection closed") - - Wait for data for 2 seconds. - - try: - data = await ws.receive_json(timeout=2.) - except TimeoutError: - print("No data received.") - except WebSocketDisconnect: - print("Connection closed") - """ - assert mode in ["text", "binary"] - data: typing.Union[str, bytes] - if mode == "text": - data = await self.receive_text(timeout) - elif mode == "binary": - data = await self.receive_bytes(timeout) - return json.loads(data) - - async def close(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: - """ - Close the WebSocket session. - - Internally, it'll send the - [CloseConnection][wsproto.events.CloseConnection] event. - - *This method is automatically called when exiting the context manager.* - - Args: - code: - The integer close code to indicate why the connection has closed. - reason: - Additional reasoning for why the connection has closed. - - Examples: - Close the WebSocket session. - - await ws.close() - """ - self._should_close.set() - if self.connection.state not in { - wsproto.connection.ConnectionState.LOCAL_CLOSING, - wsproto.connection.ConnectionState.CLOSED, - }: - event = wsproto.events.CloseConnection(code, reason) - data = self.connection.send(event) - try: - await self.stream.write(data) - except httpcore.WriteError: - pass - await self.stream.aclose() - - async def _background_receive(self, max_bytes: int) -> None: - """ - Background task listening for data from the server. - - Internally, it'll: - - * Answer to Ping events. - * Acknowledge Pong events. - * Put other events in the [_events][_events] - queue that'll eventually be consumed by the user. - - Args: - max_bytes: The maximum chunk size to read at each iteration. - """ - partial_message_buffer: typing.Union[str, bytes, None] = None - try: - while not self._should_close.is_set(): - data = await self.stream.read(max_bytes=max_bytes) - self.connection.receive_data(data) - for event in self.connection.events(): - if isinstance(event, wsproto.events.Ping): - data = self.connection.send(event.response()) - await self.stream.write(data) - continue - if isinstance(event, wsproto.events.Pong): - self._ping_manager.ack(event.payload) - continue - if isinstance(event, wsproto.events.CloseConnection): - self._should_close.set() - if isinstance(event, wsproto.events.Message): - # Unfinished message: bufferize - if not event.message_finished: - if partial_message_buffer is None: - partial_message_buffer = event.data - else: - partial_message_buffer += event.data - # Finished message but no buffer: just emit the event - elif partial_message_buffer is None: - await self._send_event.send(event) - # Finished message with buffer: emit the full event - else: - event_type = type(event) - full_message_event = event_type(partial_message_buffer + event.data) - partial_message_buffer = None - await self._send_event.send(full_message_event) - continue - await self._send_event.send(event) - except (httpcore.ReadError, httpcore.WriteError): - await self.close(CloseReason.INTERNAL_ERROR, "Stream error") - await self._send_event.send(WebSocketNetworkError()) - - async def _background_keepalive_ping( - self, interval_seconds: float, timeout_seconds: typing.Optional[float] = None - ) -> None: - while not self._should_close.is_set(): - await anyio.sleep(interval_seconds) - pong_callback = await self.ping() - if timeout_seconds is not None: - try: - with anyio.fail_after(timeout_seconds): - await pong_callback.wait() - except TimeoutError: - await self.close(CloseReason.INTERNAL_ERROR, "Keepalive ping timeout") - await self._send_event.send(WebSocketNetworkError()) - - -def _get_headers( - subprotocols: typing.Optional[list[str]], -) -> dict[str, typing.Any]: - headers = { - "connection": "upgrade", - "upgrade": "websocket", - "sec-websocket-key": base64.b64encode(secrets.token_bytes(16)).decode("utf-8"), - "sec-websocket-version": "13", - } - if subprotocols is not None: - headers["sec-websocket-protocol"] = ", ".join(subprotocols) - return headers - - -@contextlib.contextmanager -def _connect_ws( - url: str, - client: httpx.Client, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, - **kwargs: typing.Any, -) -> typing.Generator[WebSocketSession, None, None]: - headers = kwargs.pop("headers", {}) - headers.update(_get_headers(subprotocols)) - - with client.stream("GET", url, headers=headers, **kwargs) as response: - if response.status_code != 101: - raise WebSocketUpgradeError(response) - - with WebSocketSession( - response.extensions["network_stream"], - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - response=response, - ) as session: - yield session - - -@contextlib.contextmanager -def connect_ws( - url: str, - client: typing.Optional[httpx.Client] = None, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, - **kwargs: typing.Any, -) -> typing.Generator[WebSocketSession, None, None]: - """ - Start a sync WebSocket session. - - It returns a context manager that'll automatically - call [close()][httpx_ws.WebSocketSession.close] when exiting. - - Args: - url: The WebSocket URL. - client: - HTTPX client to use. - If not provided, a default one will be initialized. - max_message_size_bytes: - Message size in bytes to receive from the server. - Defaults to 65 KiB. - queue_size: - Size of the queue where the received messages will be held - until they are consumed. - If the queue is full, the client will stop receive messages - from the server until the queue has room available. - Defaults to 512. - keepalive_ping_interval_seconds: - Interval at which the client will automatically send a Ping event - to keep the connection alive. Set it to `None` to disable this mechanism. - Defaults to 20 seconds. - keepalive_ping_timeout_seconds: - Maximum delay the client will wait for an answer to its Ping event. - If the delay is exceeded, - [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] - will be raised and the connection closed. - Defaults to 20 seconds. - subprotocols: - Optional list of suprotocols to negotiate with the server. - **kwargs: - Additional keyword arguments that will be passed to - the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. - - Returns: - A [context manager][contextlib.AbstractContextManager] - for [WebSocketSession][httpx_ws.WebSocketSession]. - - Examples: - Without explicit HTTPX client. - - with connect_ws("http://localhost:8000/ws") as ws: - message = ws.receive_text() - print(message) - ws.send_text("Hello!") - - With explicit HTTPX client. - - with httpx.Client() as client: - with connect_ws("http://localhost:8000/ws", client) as ws: - message = ws.receive_text() - print(message) - ws.send_text("Hello!") - """ - if client is None: - with httpx.Client() as client: - with _connect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket - else: - with _connect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket - - -@contextlib.asynccontextmanager -async def _aconnect_ws( - url: str, - client: httpx.AsyncClient, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, - **kwargs: typing.Any, -) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: - headers = kwargs.pop("headers", {}) - headers.update(_get_headers(subprotocols)) - - async with client.stream("GET", url, headers=headers, **kwargs) as response: - if response.status_code != 101: - raise WebSocketUpgradeError(response) - - async with AsyncWebSocketSession( - response.extensions["network_stream"], - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - response=response, - ) as session: - yield session - - -@contextlib.asynccontextmanager -async def aconnect_ws( - url: str, - client: typing.Optional[httpx.AsyncClient] = None, - *, - max_message_size_bytes: int = DEFAULT_MAX_MESSAGE_SIZE_BYTES, - queue_size: int = DEFAULT_QUEUE_SIZE, - keepalive_ping_interval_seconds: typing.Optional[ - float - ] = DEFAULT_KEEPALIVE_PING_INTERVAL_SECONDS, - keepalive_ping_timeout_seconds: typing.Optional[float] = DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS, - subprotocols: typing.Optional[list[str]] = None, - **kwargs: typing.Any, -) -> typing.AsyncGenerator[AsyncWebSocketSession, None]: - """ - Start an async WebSocket session. - - It returns an async context manager that'll automatically - call [close()][httpx_ws.AsyncWebSocketSession.close] when exiting. - - Args: - url: The WebSocket URL. - client: - HTTPX client to use. - If not provided, a default one will be initialized. - max_message_size_bytes: - Message size in bytes to receive from the server. - Defaults to 65 KiB. - queue_size: - Size of the queue where the received messages will be held - until they are consumed. - If the queue is full, the client will stop receive messages - from the server until the queue has room available. - Defaults to 512. - keepalive_ping_interval_seconds: - Interval at which the client will automatically send a Ping event - to keep the connection alive. Set it to `None` to disable this mechanism. - Defaults to 20 seconds. - keepalive_ping_timeout_seconds: - Maximum delay the client will wait for an answer to its Ping event. - If the delay is exceeded, - [WebSocketNetworkError][httpx_ws.WebSocketNetworkError] - will be raised and the connection closed. - Defaults to 20 seconds. - subprotocols: - Optional list of suprotocols to negotiate with the server. - **kwargs: - Additional keyword arguments that will be passed to - the [HTTPX stream()](https://www.python-httpx.org/api/#request) method. - - Returns: - An [async context manager][contextlib.AbstractAsyncContextManager] - for [AsyncWebSocketSession][httpx_ws.AsyncWebSocketSession]. - - Examples: - Without explicit HTTPX client. - - async with aconnect_ws("http://localhost:8000/ws") as ws: - message = await ws.receive_text() - print(message) - await ws.send_text("Hello!") - - With explicit HTTPX client. - - async with httpx.AsyncClient() as client: - async with aconnect_ws("http://localhost:8000/ws", client) as ws: - message = await ws.receive_text() - print(message) - await ws.send_text("Hello!") - """ - if client is None: - async with httpx.AsyncClient() as client: - async with _aconnect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket - else: - async with _aconnect_ws( - url, - client=client, - max_message_size_bytes=max_message_size_bytes, - queue_size=queue_size, - keepalive_ping_interval_seconds=keepalive_ping_interval_seconds, - keepalive_ping_timeout_seconds=keepalive_ping_timeout_seconds, - subprotocols=subprotocols, - **kwargs, - ) as websocket: - yield websocket diff --git a/tests/unit/vendor/httpx_ws/_exceptions.py b/tests/unit/vendor/httpx_ws/_exceptions.py deleted file mode 100644 index 0facbf82aa..0000000000 --- a/tests/unit/vendor/httpx_ws/_exceptions.py +++ /dev/null @@ -1,55 +0,0 @@ -import typing - -import httpx -import wsproto - - -class HTTPXWSException(Exception): - """ - Base exception class for HTTPX WS. - """ - - pass - - -class WebSocketUpgradeError(HTTPXWSException): - """ - Raised when the initial connection didn't correctly upgrade to a WebSocket session. - """ - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - -class WebSocketDisconnect(HTTPXWSException): - """ - Raised when the server closed the WebSocket session. - - Args: - code: - The integer close code to indicate why the connection has closed. - reason: - Additional reasoning for why the connection has closed. - """ - - def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: - self.code = code - self.reason = reason or "" - - -class WebSocketInvalidTypeReceived(HTTPXWSException): - """ - Raised when a event is not of the expected type. - """ - - def __init__(self, event: wsproto.events.Event) -> None: - self.event = event - - -class WebSocketNetworkError(HTTPXWSException): - """ - Raised when a network error occured, - typically if the underlying stream has closed or timeout. - """ - - pass diff --git a/tests/unit/vendor/httpx_ws/_ping.py b/tests/unit/vendor/httpx_ws/_ping.py deleted file mode 100644 index 434e78e03c..0000000000 --- a/tests/unit/vendor/httpx_ws/_ping.py +++ /dev/null @@ -1,40 +0,0 @@ -import secrets -import threading -import typing - -import anyio - - -class PingManagerBase: - def _generate_id(self) -> bytes: - return secrets.token_bytes() - - -class PingManager(PingManagerBase): - def __init__(self) -> None: - self._pings: dict[bytes, threading.Event] = {} - - def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, threading.Event]: - ping_id = self._generate_id() if not ping_id else ping_id - event = threading.Event() - self._pings[ping_id] = event - return ping_id, event - - def ack(self, ping_id: typing.Union[bytes, bytearray]) -> None: - event = self._pings.pop(bytes(ping_id)) - event.set() - - -class AsyncPingManager(PingManagerBase): - def __init__(self) -> None: - self._pings: dict[bytes, anyio.Event] = {} - - def create(self, ping_id: typing.Optional[bytes] = None) -> tuple[bytes, anyio.Event]: - ping_id = self._generate_id() if not ping_id else ping_id - event = anyio.Event() - self._pings[ping_id] = event - return ping_id, event - - def ack(self, ping_id: typing.Union[bytes, bytearray]) -> None: - event = self._pings.pop(bytes(ping_id)) - event.set()