Skip to content

Commit

Permalink
pass postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Oct 28, 2024
1 parent 168e6bc commit 709d51b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 35 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def pytest_addoption(parser: Parser) -> None:
parser.addoption(
"--run-postgres",
action="store_true",
action="store_false",
help="Run tests that require Postgres",
)
parser.addoption(
Expand Down
74 changes: 44 additions & 30 deletions tests/unit/httpx_ws/transport.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
import contextlib
import queue
import typing
from concurrent.futures import Future

import anyio
import wsproto
from httpcore import AsyncNetworkStream
from httpx import ASGITransport, AsyncByteStream, Request, Response
Expand Down Expand Up @@ -36,19 +34,22 @@ class ASGIWebSocketAsyncNetworkStream(AsyncNetworkStream):
def __init__(self, app: ASGIApp, scope: Scope) -> None:
self.app = app
self.scope = scope
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message] = queue.Queue()
self._receive_queue: asyncio.Queue[Message] = asyncio.Queue()
self._send_queue: asyncio.Queue[Message] = asyncio.Queue()
self.connection = wsproto.WSConnection(wsproto.ConnectionType.SERVER)
self.connection.initiate_upgrade_connection(scope["headers"], scope["path"])
self.tasks: list[asyncio.Task] = []

async def __aenter__(
self,
) -> tuple["ASGIWebSocketAsyncNetworkStream", bytes]:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.from_thread.start_blocking_portal("asyncio")
)
_: Future[None] = self.portal.start_task_soon(self._run)
self.exit_stack = contextlib.AsyncExitStack()
await self.exit_stack.__aenter__()

# Start the _run coroutine as a task
self._run_task = asyncio.create_task(self._run())
self.tasks.append(self._run_task)
self.exit_stack.push_async_callback(self._cancel_tasks)

await self.send({"type": "websocket.connect"})
message = await self.receive()
Expand All @@ -62,32 +63,44 @@ async def __aenter__(

async def __aexit__(self, *args: typing.Any) -> None:
await self.aclose()
await self.exit_stack.aclose()

async def _cancel_tasks(self):
# Cancel all running tasks
for task in self.tasks:
task.cancel()
# Wait for tasks to be cancelled
await asyncio.gather(*self.tasks, return_exceptions=True)

async def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
message: Message = await self.receive(timeout=timeout)
type = message["type"]
message: Message = await self.receive()
message_type = message["type"]

if type not in {"websocket.send", "websocket.close"}:
if message_type not in {"websocket.send", "websocket.close"}:
raise UnhandledASGIMessageType(message)

event: wsproto.events.Event
if type == "websocket.send":
if message_type == "websocket.send":
data_str: typing.Optional[str] = message.get("text")
if data_str is not None:
event = wsproto.events.TextMessage(data_str)
data_bytes: typing.Optional[bytes] = message.get("bytes")
if data_bytes is not None:
event = wsproto.events.BytesMessage(data_bytes)
elif type == "websocket.close":
event = wsproto.events.CloseConnection(message["code"], message["reason"])
else:
data_bytes: typing.Optional[bytes] = message.get("bytes")
if data_bytes is not None:
event = wsproto.events.BytesMessage(data_bytes)
else:
# If neither text nor bytes are provided, raise an error
raise ValueError("websocket.send message missing 'text' or 'bytes'")
elif message_type == "websocket.close":
event = wsproto.events.CloseConnection(message["code"], message.get("reason"))

return self.connection.send(event)

async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
self.connection.receive_data(buffer)
for event in self.connection.events():
if isinstance(event, wsproto.events.Request):
pass
pass # Already handled in __init__
elif isinstance(event, wsproto.events.CloseConnection):
await self.send(
{
Expand All @@ -105,19 +118,22 @@ async def write(self, buffer: bytes, timeout: typing.Optional[float] = None) ->

async def aclose(self) -> None:
await self.send({"type": "websocket.close"})
self.exit_stack.close()
# Ensure tasks are cancelled and cleaned up
await self._cancel_tasks()

async def send(self, message: Message) -> None:
self._receive_queue.put(message)
await self._receive_queue.put(message)

async def receive(self, timeout: typing.Optional[float] = None) -> Message:
while self._send_queue.empty():
await anyio.sleep(0)
return self._send_queue.get(timeout=timeout)
try:
message = await asyncio.wait_for(self._send_queue.get(), timeout)
return message
except asyncio.TimeoutError:
raise TimeoutError("Timed out waiting for message")

async def _run(self) -> None:
"""
The sub-thread in which the websocket session runs.
The coroutine in which the websocket session runs.
"""
scope = self.scope
receive = self._asgi_receive
Expand All @@ -133,12 +149,10 @@ async def _run(self) -> None:
await self._asgi_send(message)

async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
await anyio.sleep(0)
return self._receive_queue.get()
return await self._receive_queue.get()

async def _asgi_send(self, message: Message) -> None:
self._send_queue.put(message)
await self._send_queue.put(message)

def _build_accept_response(self, message: Message) -> bytes:
subprotocol = message.get("subprotocol", None)
Expand Down
8 changes: 4 additions & 4 deletions tests/unit/server/api/test_subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,9 @@ class TestChatCompletionSubscription:

async def test_openai_text_response_emits_expected_payloads_and_records_expected_span(
self,
dialect: str,
gql_client: Any,
openai_api_key: str,
) -> None:
if dialect == "postgresql":
pytest.skip("fails on postgres for unknown reason")
variables = {
"input": {
"messages": [
Expand Down Expand Up @@ -182,6 +179,10 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected
query=self.QUERY, variables={"spanId": span_id}, operation_name="SpanQuery"
)
span = data["span"]
assert json.loads(attributes := span.pop("attributes")) == json.loads(
subscription_span.pop("attributes")
)
attributes = dict(flatten(json.loads(attributes)))
assert span == subscription_span

# check attributes
Expand All @@ -195,7 +196,6 @@ async def test_openai_text_response_emits_expected_payloads_and_records_expected
assert span.pop("parentId") is None
assert span.pop("spanKind") == "llm"
assert (context := span.pop("context")).pop("spanId")
assert (attributes := dict(flatten(json.loads(span.pop("attributes")))))
assert context.pop("traceId")
assert not context
assert span.pop("metadata") is None
Expand Down

0 comments on commit 709d51b

Please sign in to comment.