Skip to content

Commit

Permalink
Fix handling of explict close (#467)
Browse files Browse the repository at this point in the history
The check for next_event was not correct for explict closes as
https://github.com/python-hyper/h11/blob/a2c68948accadc3876dffcf979d98002e4a4ed27/h11/_connection.py#L445

will only return h11.ConnectionClosed as an object and not a type
  • Loading branch information
bdraco authored Oct 25, 2023
1 parent 3cb893c commit 4c3df6c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 23 deletions.
9 changes: 7 additions & 2 deletions pyhap/accessory.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def to_HAP(self, include_value: bool = True) -> Dict[str, Any]:
"""
return {
HAP_REPR_AID: self.aid,
HAP_REPR_SERVICES: [s.to_HAP(include_value=include_value) for s in self.services],
HAP_REPR_SERVICES: [
s.to_HAP(include_value=include_value) for s in self.services
],
}

def setup_message(self):
Expand Down Expand Up @@ -391,7 +393,10 @@ def to_HAP(self, include_value: bool = True) -> List[Dict[str, Any]]:
.. seealso:: Accessory.to_HAP
"""
return [acc.to_HAP(include_value=include_value) for acc in (super(), *self.accessories.values())]
return [
acc.to_HAP(include_value=include_value)
for acc in (super(), *self.accessories.values())
]

def get_characteristic(self, aid: int, iid: int) -> Optional["Characteristic"]:
""".. seealso:: Accessory.to_HAP"""
Expand Down
26 changes: 16 additions & 10 deletions pyhap/hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
connections: Dict[str, "HAPServerProtocol"],
accessory_driver: "AccessoryDriver",
) -> None:
self.loop: asyncio.AbstractEventLoop = loop
self.loop = loop
self.conn = h11.Connection(h11.SERVER)
self.connections = connections
self.accessory_driver = accessory_driver
Expand All @@ -55,7 +55,7 @@ def __init__(
self.transport: Optional[asyncio.Transport] = None

self.request: Optional[h11.Request] = None
self.request_body: Optional[bytes] = None
self.request_body: List[bytes] = []
self.response: Optional[HAPResponse] = None

self.last_activity: Optional[float] = None
Expand Down Expand Up @@ -246,27 +246,33 @@ def _process_one_event(self) -> bool:
logger.debug(
"%s (%s): h11 Event: %s", self.peername, self.handler.client_uuid, event
)
if event in (h11.NEED_DATA, h11.ConnectionClosed):
if event is h11.NEED_DATA:
return False

if event is h11.PAUSED:
self.conn.start_next_cycle()
return True

if isinstance(event, h11.Request):
event_type = type(event)
if event_type is h11.ConnectionClosed:
return False

if event_type is h11.Request:
self.request = event
self.request_body = b""
self.request_body = []
return True

if isinstance(event, h11.Data):
self.request_body += event.data
if event_type is h11.Data:
if TYPE_CHECKING:
assert isinstance(event, h11.Data) # nosec
self.request_body.append(event.data)
return True

if isinstance(event, h11.EndOfMessage):
response = self.handler.dispatch(self.request, bytes(self.request_body))
if event_type is h11.EndOfMessage:
response = self.handler.dispatch(self.request, b"".join(self.request_body))
self._process_response(response)
self.request = None
self.request_body = None
self.request_body = []
return True

return self._handle_invalid_conn_state(f"Unexpected event: {event}")
Expand Down
30 changes: 19 additions & 11 deletions pyhap/hap_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
The HAPServer is the point of contact to and from the world.
"""

import asyncio
import logging
import time
from typing import TYPE_CHECKING, Dict, Optional, Tuple

from .hap_protocol import HAPServerProtocol
from .util import callback

if TYPE_CHECKING:
from .accessory_driver import AccessoryDriver

logger = logging.getLogger(__name__)

IDLE_CONNECTION_CHECK_INTERVAL_SECONDS = 120
Expand All @@ -28,17 +33,18 @@ class HAPServer:
implements exclusive access to the send methods.
"""

def __init__(self, addr_port, accessory_handler):
def __init__(
self, addr_port: Tuple[str, int], accessory_handler: "AccessoryDriver"
) -> None:
"""Create a HAP Server."""
self._addr_port = addr_port
self.connections = {} # (address, port): socket
self.connections: Dict[Tuple[str, int], HAPServerProtocol] = {}
self.accessory_handler = accessory_handler
self.server = None
self._serve_task = None
self._connection_cleanup = None
self.loop = None
self.server: Optional[asyncio.Server] = None
self._connection_cleanup: Optional[asyncio.TimerHandle] = None
self.loop: Optional[asyncio.AbstractEventLoop] = None

async def async_start(self, loop):
async def async_start(self, loop: asyncio.AbstractEventLoop) -> None:
"""Start the http-hap server."""
self.loop = loop
self.server = await loop.create_server(
Expand All @@ -49,7 +55,7 @@ async def async_start(self, loop):
self.async_cleanup_connections()

@callback
def async_cleanup_connections(self):
def async_cleanup_connections(self) -> None:
"""Cleanup stale connections."""
now = time.time()
for hap_proto in list(self.connections.values()):
Expand All @@ -59,7 +65,7 @@ def async_cleanup_connections(self):
)

@callback
def async_stop(self):
def async_stop(self) -> None:
"""Stop the server.
This method must be run in the event loop.
Expand All @@ -70,10 +76,12 @@ def async_stop(self):
self.server.close()
self.connections.clear()

def push_event(self, data, client_addr, immediate=False):
def push_event(
self, data: bytes, client_addr: Tuple[str, int], immediate: bool = False
) -> bool:
"""Queue an event to the current connection with the provided data.
:param data: The charateristic changes
:param data: The characteristic changes
:type data: dict
:param client_addr: A client (address, port) tuple to which to send the data.
Expand Down
60 changes: 60 additions & 0 deletions tests/test_hap_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,30 @@

from pyhap import hap_handler, hap_protocol
from pyhap.accessory import Accessory, Bridge
from pyhap.accessory_driver import AccessoryDriver
from pyhap.hap_handler import HAPResponse


class MockTransport(asyncio.Transport): # pylint: disable=abstract-method
"""A mock transport."""

_is_closing: bool = False

def set_write_buffer_limits(self, high=None, low=None):
"""Set the write buffer limits."""

def write_eof(self) -> None:
"""Write EOF to the stream."""

def close(self) -> None:
"""Close the stream."""
self._is_closing = True

def is_closing(self) -> bool:
"""Return True if the transport is closing or closed."""
return self._is_closing


class MockHAPCrypto:
"""Mock HAPCrypto that only returns plaintext."""

Expand Down Expand Up @@ -734,3 +755,42 @@ async def test_does_not_timeout(driver):
assert writer.call_args_list[0][0][0].startswith(b"HTTP/1.1 200 OK\r\n") is True
hap_proto.check_idle(time.time())
assert hap_proto_close.called is False


def test_explicit_close(driver: AccessoryDriver):
"""Test an explicit connection close."""
loop = MagicMock()

transport = MockTransport()
connections = {}

acc = Accessory(driver, "TestAcc", aid=1)
assert acc.aid == 1
service = acc.driver.loader.get_service("TemperatureSensor")
acc.add_service(service)
driver.add_accessory(acc)

hap_proto = hap_protocol.HAPServerProtocol(loop, connections, driver)
hap_proto.connection_made(transport)

hap_proto.hap_crypto = MockHAPCrypto()
hap_proto.handler.is_encrypted = True
assert hap_proto.transport.is_closing() is False

with patch.object(hap_proto.transport, "write") as writer:
hap_proto.data_received(
b"GET /characteristics?id=3762173001.7 HTTP/1.1\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)
hap_proto.data_received(
b"GET /characteristics?id=1.5 HTTP/1.1\r\nConnection: close\r\nHost: HASS\\032Bridge\\032YPHW\\032B223AD._hap._tcp.local\r\n\r\n" # pylint: disable=line-too-long
)

assert b"Content-Length:" in writer.call_args_list[0][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[0][0][0]
assert b"-70402" in writer.call_args_list[0][0][0]

assert b"Content-Length:" in writer.call_args_list[1][0][0]
assert b"Transfer-Encoding: chunked\r\n\r\n" not in writer.call_args_list[1][0][0]
assert b"TestAcc" in writer.call_args_list[1][0][0]

assert hap_proto.transport.is_closing() is True

0 comments on commit 4c3df6c

Please sign in to comment.