Skip to content

Commit

Permalink
Add Retries to the Python SDK with Exponential Backoff (#70)
Browse files Browse the repository at this point in the history
* add decorator for retries

* add unit test

* add request_mock in pyproject.toml

* remove unnecessary property from internal api client

* add more unit tests

* fix poetry lock file

* fix linting

* inherit from ApiException in InternalApiException

* remove lock file

* disable linting for useless-super-delegation

* disable useless-super-delegation linting

* allow caching image byte stream for subsequent access when file is closed

* add a wrapper class to bytes

* fix linting

* forgotten return statement

* add unit test for ByteStreamWrapper

* fix linting

* fix linting

---------

Co-authored-by: Blaise Munyampirwa <[email protected]>
  • Loading branch information
blaise-muhirwa and Blaise Munyampirwa authored Jun 30, 2023
1 parent b3b4cdc commit bd0f361
Show file tree
Hide file tree
Showing 6 changed files with 302 additions and 28 deletions.
12 changes: 9 additions & 3 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@

from groundlight.binary_labels import Label, convert_display_label_to_internal, convert_internal_label_to_display
from groundlight.config import API_TOKEN_VARIABLE_NAME, API_TOKEN_WEB_URL
from groundlight.images import parse_supported_image_types
from groundlight.internalapi import GroundlightApiClient, NotFoundError, iq_is_confident, sanitize_endpoint_url
from groundlight.images import ByteStreamWrapper, parse_supported_image_types
from groundlight.internalapi import (
GroundlightApiClient,
NotFoundError,
iq_is_confident,
sanitize_endpoint_url,
)
from groundlight.optional_imports import Image, np

logger = logging.getLogger("groundlight.sdk")
Expand Down Expand Up @@ -181,7 +186,8 @@ def submit_image_query(
if wait is None:
wait = self.DEFAULT_WAIT
detector_id = detector.id if isinstance(detector, Detector) else detector
image_bytesio: Union[BytesIO, BufferedReader] = parse_supported_image_types(image)

image_bytesio: ByteStreamWrapper = parse_supported_image_types(image)

raw_image_query = self.image_queries_api.submit_image_query(detector_id=detector_id, body=image_bytesio)
image_query = ImageQuery.parse_obj(raw_image_query.to_dict())
Expand Down
40 changes: 32 additions & 8 deletions src/groundlight/images.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,33 @@
import imghdr
from io import BufferedReader, BytesIO
from io import BufferedReader, BytesIO, IOBase
from typing import Union

from groundlight.optional_imports import Image, np


class ByteStreamWrapper(IOBase):
"""This class acts as a thin wrapper around bytes in order to
maintain files in an open state. This is useful, in particular,
when we want to retry accessing the file without having to re-open it.
"""

def __init__(self, data: Union[BufferedReader, BytesIO, bytes]) -> None:
super().__init__()
if isinstance(data, (BufferedReader, BytesIO)):
self._data = data.read()
else:
self._data = data

def read(self) -> bytes:
return self._data

def getvalue(self) -> bytes:
return self._data

def close(self) -> None:
pass


def buffer_from_jpeg_file(image_filename: str) -> BufferedReader:
"""Get a buffer from an jpeg image file.
Expand All @@ -29,31 +52,32 @@ def jpeg_from_numpy(img: np.ndarray, jpeg_quality: int = 95) -> bytes:
def parse_supported_image_types(
image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray],
jpeg_quality: int = 95,
) -> Union[BytesIO, BufferedReader]:
) -> ByteStreamWrapper:
"""Parse the many supported image types into a bytes-stream objects.
In some cases we have to JPEG compress.
"""
if isinstance(image, str):
# Assume it is a filename
return buffer_from_jpeg_file(image)
buffer = buffer_from_jpeg_file(image)
return ByteStreamWrapper(data=buffer)
if isinstance(image, bytes):
# Create a BytesIO object
return BytesIO(image)
return ByteStreamWrapper(data=image)
if isinstance(image, Image.Image):
# Save PIL image as jpeg in BytesIO
bytesio = BytesIO()
image.save(bytesio, "jpeg", quality=jpeg_quality)
bytesio.seek(0)
return bytesio
return ByteStreamWrapper(data=bytesio)
if isinstance(image, (BytesIO, BufferedReader)):
# Already in the right format
return image
return ByteStreamWrapper(data=image)
if isinstance(image, np.ndarray):
# Assume it is in BGR format from opencv
return BytesIO(jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality))
return ByteStreamWrapper(data=jpeg_from_numpy(image[:, :, ::-1], jpeg_quality=jpeg_quality))
raise TypeError(
(
"Unsupported type for image. Must be PIL, numpy (H,W,3) RGB, or a JPEG as a filename (str), bytes,"
"Unsupported type for image. Must be PIL, numpy (H,W,3) BGR, or a JPEG as a filename (str), bytes,"
" BytesIO, or BufferedReader."
),
)
91 changes: 81 additions & 10 deletions src/groundlight/internalapi.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
import os
import random
import time
import uuid
from typing import Optional
from functools import wraps
from typing import Callable, Optional
from urllib.parse import urlsplit, urlunsplit

import requests
from model import Detector, ImageQuery
from openapi_client.api_client import ApiClient
from openapi_client.api_client import ApiClient, ApiException

from groundlight.status_codes import is_ok

Expand Down Expand Up @@ -67,9 +69,76 @@ def iq_is_confident(iq: ImageQuery, confidence_threshold: float) -> bool:
return iq.result.confidence >= confidence_threshold


class InternalApiError(RuntimeError):
# TODO: We need a better exception hierarchy
pass
class InternalApiError(ApiException, RuntimeError):
# TODO: We should really avoid this double inheritance since
# both `ApiException` and `RuntimeError` are subclasses of
# `Exception`. Error handling might become more complex since
# the two super classes cross paths.
# pylint: disable=useless-super-delegation
def __init__(self, status=None, reason=None, http_resp=None):
super().__init__(status, reason, http_resp)


class RequestsRetryDecorator:
"""
Decorate a function to retry sending HTTP requests.
Tries to re-execute the decorated function in case the execution
fails due to a server error (HTTP Error code 500 - 599).
Retry attempts are executed while exponentially backing off by a factor
of 2 with full jitter (picking a random delay time between 0 and the
maximum delay time).
"""

def __init__(
self,
initial_delay: float = 0.2,
exponential_backoff: int = 2,
status_code_range: tuple = (500, 600),
max_retries: int = 3,
):
self.initial_delay = initial_delay
self.exponential_backoff = exponential_backoff
self.status_code_range = range(*status_code_range)
self.max_retries = max_retries

def __call__(self, function: Callable) -> Callable:
""":param callable: The function to invoke."""

@wraps(function)
def decorated(*args, **kwargs):
delay = self.initial_delay
retry_count = 0

while retry_count <= self.max_retries:
try:
return function(*args, **kwargs)
except ApiException as e:
is_retryable = (e.status is not None) and (e.status in self.status_code_range)
if not is_retryable:
raise e
if retry_count == self.max_retries:
raise InternalApiError(reason="Maximum retries reached") from e

if is_retryable:
status_code = e.status
if status_code in self.status_code_range:
logger.warning(
(
f"Current HTTP response status: {status_code}. "
f"Remaining retries: {self.max_retries - retry_count}"
),
exc_info=True,
)
# This is implementing a full jitter strategy
random_delay = random.uniform(0, delay)
time.sleep(random_delay)

retry_count += 1
delay *= self.exponential_backoff

return decorated


class GroundlightApiClient(ApiClient):
Expand All @@ -80,6 +149,7 @@ class GroundlightApiClient(ApiClient):

REQUEST_ID_HEADER = "X-Request-Id"

@RequestsRetryDecorator()
def call_api(self, *args, **kwargs):
"""Adds a request-id header to each API call."""
# Note we don't look for header_param in kwargs here, because this method is only called in one place
Expand All @@ -97,7 +167,6 @@ def call_api(self, *args, **kwargs):
# The methods below will eventually go away when we move to properly model
# these methods with OpenAPI
#

def _headers(self) -> dict:
request_id = _generate_request_id()
return {
Expand All @@ -106,6 +175,7 @@ def _headers(self) -> dict:
"X-Request-Id": request_id,
}

@RequestsRetryDecorator()
def _add_label(self, image_query_id: str, label: str) -> dict:
"""Temporary internal call to add a label to an image query. Not supported."""
# TODO: Properly model this with OpenApi spec.
Expand All @@ -126,11 +196,14 @@ def _add_label(self, image_query_id: str, label: str) -> dict:

if not is_ok(response.status_code):
raise InternalApiError(
f"Error adding label to image query {image_query_id} status={response.status_code} {response.text}",
status=response.status_code,
reason=f"Error adding label to image query {image_query_id}",
http_resp=response,
)

return response.json()

@RequestsRetryDecorator()
def _get_detector_by_name(self, name: str) -> Detector:
"""Get a detector by name. For now, we use the list detectors API directly.
Expand All @@ -141,9 +214,7 @@ def _get_detector_by_name(self, name: str) -> Detector:
response = requests.request("GET", url, headers=headers)

if not is_ok(response.status_code):
raise InternalApiError(
f"Error getting detector by name '{name}' (status={response.status_code}): {response.text}",
)
raise InternalApiError(status=response.status_code, http_resp=response)

parsed = response.json()

Expand Down
13 changes: 7 additions & 6 deletions src/groundlight/status_codes.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Helper functions for checking HTTP status codes.

OK_MIN = 200
OK_MAX = 299
USER_ERROR_MIN = 400
USER_ERROR_MAX = 499

# We can use range because of Python's lazy evaluation. Thus, the values
# in the range are actually not generated, so we still get O(1) time complexity
OK_RANGE = range(200, 300)
USER_ERROR_RANGE = range(400, 500)


def is_ok(status_code: int) -> bool:
return OK_MIN <= status_code <= OK_MAX
return status_code in OK_RANGE


def is_user_error(status_code: int) -> bool:
return USER_ERROR_MIN <= status_code <= USER_ERROR_MAX
return status_code in USER_ERROR_RANGE
Loading

0 comments on commit bd0f361

Please sign in to comment.