diff --git a/src/groundlight/client.py b/src/groundlight/client.py index ff48b939..e0250bcf 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -12,8 +12,13 @@ from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display from groundlight.config import API_TOKEN_VARIABLE_NAME, API_TOKEN_WEB_URL -from groundlight.images import parse_supported_image_types -from groundlight.internalapi import GroundlightApiClient, NotFoundError, iq_is_confident, sanitize_endpoint_url +from groundlight.images import ByteStreamWrapper, parse_supported_image_types +from groundlight.internalapi import ( + GroundlightApiClient, + NotFoundError, + iq_is_confident, + sanitize_endpoint_url, +) from groundlight.optional_imports import Image, np logger = logging.getLogger("groundlight.sdk") @@ -181,7 +186,8 @@ def submit_image_query( if wait is None: wait = self.DEFAULT_WAIT detector_id = detector.id if isinstance(detector, Detector) else detector - image_bytesio: Union[BytesIO, BufferedReader] = parse_supported_image_types(image) + + image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) raw_image_query = self.image_queries_api.submit_image_query(detector_id=detector_id, body=image_bytesio) image_query = ImageQuery.parse_obj(raw_image_query.to_dict()) diff --git a/src/groundlight/images.py b/src/groundlight/images.py index 9888fbed..11a78cfd 100644 --- a/src/groundlight/images.py +++ b/src/groundlight/images.py @@ -1,10 +1,33 @@ import imghdr -from io import BufferedReader, BytesIO +from io import BufferedReader, BytesIO, IOBase from typing import Union from groundlight.optional_imports import Image, np +class ByteStreamWrapper(IOBase): + """This class acts as a thin wrapper around bytes in order to + maintain files in an open state. This is useful, in particular, + when we want to retry accessing the file without having to re-open it. + """ + + def __init__(self, data: Union[BufferedReader, BytesIO, bytes]) -> None: + super().__init__() + if isinstance(data, (BufferedReader, BytesIO)): + self._data = data.read() + else: + self._data = data + + def read(self) -> bytes: + return self._data + + def getvalue(self) -> bytes: + return self._data + + def close(self) -> None: + pass + + def buffer_from_jpeg_file(image_filename: str) -> BufferedReader: """Get a buffer from an jpeg image file. @@ -29,31 +52,32 @@ def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes: def parse_supported_image_types( image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], jpeg_quality: int = 95, -) -> Union[BytesIO, BufferedReader]: +) -> ByteStreamWrapper: """Parse the many supported image types into a bytes-stream objects. In some cases we have to JPEG compress. """ if isinstance(image, str): # Assume it is a filename - return buffer_from_jpeg_file(image) + buffer = buffer_from_jpeg_file(image) + return ByteStreamWrapper(data=buffer) if isinstance(image, bytes): # Create a BytesIO object - return BytesIO(image) + return ByteStreamWrapper(data=image) if isinstance(image, Image.Image): # Save PIL image as jpeg in BytesIO bytesio = BytesIO() image.save(bytesio, "jpeg", quality=jpeg_quality) bytesio.seek(0) - return bytesio + return ByteStreamWrapper(data=bytesio) if isinstance(image, (BytesIO, BufferedReader)): # Already in the right format - return image + return ByteStreamWrapper(data=image) if isinstance(image, np.ndarray): # Assume it is in BGR format from opencv - return BytesIO(jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality)) + return ByteStreamWrapper(data=jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality)) raise TypeError( ( - "Unsupported type for image. Must be PIL, numpy (H,W,3) RGB, or a JPEG as a filename (str), bytes," + "Unsupported type for image. Must be PIL, numpy (H,W,3) BGR, or a JPEG as a filename (str), bytes," " BytesIO, or BufferedReader." ), ) diff --git a/src/groundlight/internalapi.py b/src/groundlight/internalapi.py index 02dba1bd..968baebf 100644 --- a/src/groundlight/internalapi.py +++ b/src/groundlight/internalapi.py @@ -1,13 +1,15 @@ import logging import os +import random import time import uuid -from typing import Optional +from functools import wraps +from typing import Callable, Optional from urllib.parse import urlsplit, urlunsplit import requests from model import Detector, ImageQuery -from openapi_client.api_client import ApiClient +from openapi_client.api_client import ApiClient, ApiException from groundlight.status_codes import is_ok @@ -67,9 +69,76 @@ def iq_is_confident(iq: ImageQuery, confidence_threshold: float) -> bool: return iq.result.confidence >= confidence_threshold -class InternalApiError(RuntimeError): - # TODO: We need a better exception hierarchy - pass +class InternalApiError(ApiException, RuntimeError): + # TODO: We should really avoid this double inheritance since + # both `ApiException` and `RuntimeError` are subclasses of + # `Exception`. Error handling might become more complex since + # the two super classes cross paths. + # pylint: disable=useless-super-delegation + def __init__(self, status=None, reason=None, http_resp=None): + super().__init__(status, reason, http_resp) + + +class RequestsRetryDecorator: + """ + Decorate a function to retry sending HTTP requests. + + Tries to re-execute the decorated function in case the execution + fails due to a server error (HTTP Error code 500 - 599). + Retry attempts are executed while exponentially backing off by a factor + of 2 with full jitter (picking a random delay time between 0 and the + maximum delay time). + + """ + + def __init__( + self, + initial_delay: float = 0.2, + exponential_backoff: int = 2, + status_code_range: tuple = (500, 600), + max_retries: int = 3, + ): + self.initial_delay = initial_delay + self.exponential_backoff = exponential_backoff + self.status_code_range = range(*status_code_range) + self.max_retries = max_retries + + def __call__(self, function: Callable) -> Callable: + """:param callable: The function to invoke.""" + + @wraps(function) + def decorated(*args, **kwargs): + delay = self.initial_delay + retry_count = 0 + + while retry_count <= self.max_retries: + try: + return function(*args, **kwargs) + except ApiException as e: + is_retryable = (e.status is not None) and (e.status in self.status_code_range) + if not is_retryable: + raise e + if retry_count == self.max_retries: + raise InternalApiError(reason="Maximum retries reached") from e + + if is_retryable: + status_code = e.status + if status_code in self.status_code_range: + logger.warning( + ( + f"Current HTTP response status: {status_code}. " + f"Remaining retries: {self.max_retries - retry_count}" + ), + exc_info=True, + ) + # This is implementing a full jitter strategy + random_delay = random.uniform(0, delay) + time.sleep(random_delay) + + retry_count += 1 + delay *= self.exponential_backoff + + return decorated class GroundlightApiClient(ApiClient): @@ -80,6 +149,7 @@ class GroundlightApiClient(ApiClient): REQUEST_ID_HEADER = "X-Request-Id" + @RequestsRetryDecorator() def call_api(self, *args, **kwargs): """Adds a request-id header to each API call.""" # Note we don't look for header_param in kwargs here, because this method is only called in one place @@ -97,7 +167,6 @@ def call_api(self, *args, **kwargs): # The methods below will eventually go away when we move to properly model # these methods with OpenAPI # - def _headers(self) -> dict: request_id = _generate_request_id() return { @@ -106,6 +175,7 @@ def _headers(self) -> dict: "X-Request-Id": request_id, } + @RequestsRetryDecorator() def _add_label(self, image_query_id: str, label: str) -> dict: """Temporary internal call to add a label to an image query. Not supported.""" # TODO: Properly model this with OpenApi spec. @@ -126,11 +196,14 @@ def _add_label(self, image_query_id: str, label: str) -> dict: if not is_ok(response.status_code): raise InternalApiError( - f"Error adding label to image query {image_query_id} status={response.status_code} {response.text}", + status=response.status_code, + reason=f"Error adding label to image query {image_query_id}", + http_resp=response, ) return response.json() + @RequestsRetryDecorator() def _get_detector_by_name(self, name: str) -> Detector: """Get a detector by name. For now, we use the list detectors API directly. @@ -141,9 +214,7 @@ def _get_detector_by_name(self, name: str) -> Detector: response = requests.request("GET", url, headers=headers) if not is_ok(response.status_code): - raise InternalApiError( - f"Error getting detector by name '{name}' (status={response.status_code}): {response.text}", - ) + raise InternalApiError(status=response.status_code, http_resp=response) parsed = response.json() diff --git a/src/groundlight/status_codes.py b/src/groundlight/status_codes.py index 6a40e935..c002a156 100644 --- a/src/groundlight/status_codes.py +++ b/src/groundlight/status_codes.py @@ -1,14 +1,15 @@ # Helper functions for checking HTTP status codes. -OK_MIN = 200 -OK_MAX = 299 -USER_ERROR_MIN = 400 -USER_ERROR_MAX = 499 + +# We can use range because of Python's lazy evaluation. Thus, the values +# in the range are actually not generated, so we still get O(1) time complexity +OK_RANGE = range(200, 300) +USER_ERROR_RANGE = range(400, 500) def is_ok(status_code: int) -> bool: - return OK_MIN <= status_code <= OK_MAX + return status_code in OK_RANGE def is_user_error(status_code: int) -> bool: - return USER_ERROR_MIN <= status_code <= USER_ERROR_MAX + return status_code in USER_ERROR_RANGE diff --git a/test/unit/test_http_retries.py b/test/unit/test_http_retries.py new file mode 100644 index 00000000..aa7876c7 --- /dev/null +++ b/test/unit/test_http_retries.py @@ -0,0 +1,142 @@ +from datetime import datetime +from typing import Any, Callable +from unittest import mock + +import pytest +from groundlight import Groundlight +from groundlight.binary_labels import Label +from groundlight.internalapi import InternalApiError +from model import Detector + +DEFAULT_CONFIDENCE_THRESHOLD = 0.9 +DETECTOR_NAME = f"test detector_{datetime.utcnow().strftime('%Y=%m-%d %H:%M:%S')}" +TOTAL_RETRIES = 3 +STATUS_CODES = range(500, 505) +IMAGE_FILE = "test/assets/dog.jpeg" + + +@pytest.fixture(name="gl") +def groundlight_fixture() -> Groundlight: + "Creates a Groundlight client" + gl = Groundlight() + return gl + + +@pytest.fixture(name="detector") +def detector_fixture(gl: Groundlight) -> Detector: + return gl.get_or_create_detector( + name=DETECTOR_NAME, query="Is there a dog?", confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD + ) + + +def test_create_detector_attempts_retries(gl: Groundlight): + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.create_detector, + expected_call_counts=TOTAL_RETRIES + 1, + name=DETECTOR_NAME, + query="Is there a dog?", + confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD, + ) + + +def test_get_or_create_detector_attempts_retries(gl: Groundlight): + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.get_or_create_detector, + expected_call_counts=TOTAL_RETRIES + 1, + name=DETECTOR_NAME, + query="Is there a dog?", + confidence_threshold=DEFAULT_CONFIDENCE_THRESHOLD, + ) + + +def test_get_detector_attempts_retries(gl: Groundlight, detector: Detector): + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.get_detector, + expected_call_counts=TOTAL_RETRIES + 1, + id=detector.id, + ) + + +def test_get_detector_by_name_attempts_retries(gl: Groundlight): + run_test( + mocked_call="requests.request", + api_method=gl.get_detector_by_name, + expected_call_counts=TOTAL_RETRIES + 1, + name=DETECTOR_NAME, + ) + + +def test_list_detectors_attempts_retries(gl: Groundlight): + run_test( + mocked_call="urllib3.PoolManager.request", api_method=gl.list_detectors, expected_call_counts=TOTAL_RETRIES + 1 + ) + + +def test_submit_image_query_attempts_retries(gl: Groundlight): + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.submit_image_query, + expected_call_counts=TOTAL_RETRIES + 1, + detector=DETECTOR_NAME, + image=IMAGE_FILE, + wait=1, + ) + + +def test_get_image_query_attempts_retries(gl: Groundlight, detector: Detector): + image_query = gl.submit_image_query(detector=detector.id, image=IMAGE_FILE) + + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.get_image_query, + expected_call_counts=TOTAL_RETRIES + 1, + id=image_query.id, + ) + + +def test_list_image_queries_attempts_retries(gl: Groundlight): + run_test( + mocked_call="urllib3.PoolManager.request", + api_method=gl.list_image_queries, + expected_call_counts=TOTAL_RETRIES + 1, + ) + + +def test_add_label_attempts_retries(gl: Groundlight, detector: Detector): + image_query = gl.submit_image_query(detector=detector.id, image=IMAGE_FILE) + run_test( + mocked_call="requests.request", + api_method=gl.add_label, + expected_call_counts=TOTAL_RETRIES + 1, + image_query=image_query, + label=Label.YES, + ) + + run_test( + mocked_call="requests.request", + api_method=gl.add_label, + expected_call_counts=TOTAL_RETRIES + 1, + image_query=image_query, + label="NO", + ) + + +def run_test(mocked_call: str, api_method: Callable[..., Any], expected_call_counts: int, **kwargs): + with mock.patch(mocked_call) as mock_request: + for status_code in STATUS_CODES: + mock_request.return_value.status = status_code + + with pytest.raises(InternalApiError): + api_method(**kwargs) + + assert mock_request.call_count == expected_call_counts + mock_request.reset_mock() + + +def test_submit_image_query_succeeds_after_retry(gl: Groundlight, detector: Detector): + # TODO: figure out a good way to test `submit_image_query` such that it fails + # the first few times, but eventually succeeds. + pass diff --git a/test/unit/test_imagefuncs.py b/test/unit/test_imagefuncs.py index e0c19cb1..a159639d 100644 --- a/test/unit/test_imagefuncs.py +++ b/test/unit/test_imagefuncs.py @@ -2,6 +2,7 @@ # ruff: noqa: F403,F405 # pylint: disable=wildcard-import,unused-wildcard-import,redefined-outer-name,import-outside-toplevel import tempfile +from io import BytesIO import pytest from groundlight.images import * @@ -42,7 +43,7 @@ def test_pil_support(): img = Image.new("RGB", (640, 480)) jpeg = parse_supported_image_types(img) - assert isinstance(jpeg, BytesIO) + assert isinstance(jpeg, ByteStreamWrapper) # Now try to parse the BytesIO object as an image jpeg_bytes = jpeg.getvalue() @@ -69,3 +70,32 @@ def test_pil_support_ref(): f.seek(0) img2 = Image.open(f) assert img2.size == (509, 339) + + +def test_byte_stream_wrapper(): + """ + Test that we can call `open` and `close` repeatedly many times on a + ByteStreamWrapper and get the same output. + """ + + def run_test(byte_stream: ByteStreamWrapper): + previous_bytes = byte_stream.read() + + current_attempt, total_attempts = 0, 5 + + while current_attempt < total_attempts: + new_bytes = byte_stream.read() + assert new_bytes == previous_bytes + byte_stream.close() + + current_attempt += 1 + + image = "test/assets/dog.jpeg" + buffer = buffer_from_jpeg_file(image_filename=image) + + buffered_reader = ByteStreamWrapper(data=buffer) + with open(image, "rb") as image_file: + bytes_io = ByteStreamWrapper(data=BytesIO(image_file.read())) + + run_test(buffered_reader) + run_test(bytes_io)