diff --git a/pyproject.toml b/pyproject.toml index dbff7c63..fea4d249 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ packages = [ {include = "**/*.py", from = "src"}, ] readme = "README.md" -version = "0.10.1" +version = "0.11.0" [tool.poetry.dependencies] certifi = "^2021.10.8" diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 72a662fb..65e0b97b 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -165,12 +165,13 @@ def list_image_queries(self, page: int = 1, page_size: int = 10) -> PaginatedIma image_queries.results = [self._fixup_image_query(iq) for iq in image_queries.results] return image_queries - def submit_image_query( + def submit_image_query( # noqa: PLR0913 # pylint: disable=too-many-arguments self, detector: Union[Detector, str], image: Union[str, bytes, Image.Image, BytesIO, BufferedReader, np.ndarray], wait: Optional[float] = None, human_review: Optional[str] = None, + inspection_id: Optional[str] = None, ) -> ImageQuery: """Evaluates an image with Groundlight. :param detector: the Detector object, or string id of a detector like `det_12345` @@ -187,9 +188,12 @@ def submit_image_query( only if the ML prediction is not confident. If set to `ALWAYS`, always send the image query for human review. If set to `NEVER`, never send the image query for human review. + :param inspection_id: Most users will omit this. For accounts with Inspection Reports enabled, + this is the ID of the inspection to associate with the image query. """ if wait is None: wait = self.DEFAULT_WAIT + detector_id = detector.id if isinstance(detector, Detector) else detector image_bytesio: ByteStreamWrapper = parse_supported_image_types(image) @@ -203,8 +207,17 @@ def submit_image_query( if human_review is not None: params["human_review"] = human_review - raw_image_query = self.image_queries_api.submit_image_query(**params) - image_query = ImageQuery.parse_obj(raw_image_query.to_dict()) + # If no inspection_id is provided, we submit the image query using image_queries_api (autogenerated via OpenAPI) + # However, our autogenerated code does not currently support inspection_id, so if an inspection_id was + # provided, we use the private API client instead. + if inspection_id is None: + raw_image_query = self.image_queries_api.submit_image_query(**params) + image_query = ImageQuery.parse_obj(raw_image_query.to_dict()) + else: + params["inspection_id"] = inspection_id + iq_id = self.api_client.submit_image_query_with_inspection(**params) + image_query = self.get_image_query(iq_id) + if wait: threshold = self.get_detector(detector).confidence_threshold image_query = self.wait_for_confident_result(image_query, confidence_threshold=threshold, timeout_sec=wait) @@ -212,7 +225,7 @@ def submit_image_query( def wait_for_confident_result( self, - image_query: ImageQuery, + image_query: Union[ImageQuery, str], confidence_threshold: float, timeout_sec: float = 30.0, ) -> ImageQuery: @@ -222,7 +235,10 @@ def wait_for_confident_result( :param confidence_threshold: The minimum confidence level required to return before the timeout. :param timeout_sec: The maximum number of seconds to wait. """ - # TODO: Add support for ImageQuery id instead of object. + # Convert from image_query_id to ImageQuery if needed. + if isinstance(image_query, str): + image_query = self.get_image_query(image_query) + start_time = time.time() next_delay = self.POLLING_INITIAL_DELAY target_delay = 0.0 @@ -263,3 +279,27 @@ def add_label(self, image_query: Union[ImageQuery, str], label: Union[Label, str api_label = convert_display_label_to_internal(image_query_id, label) return self.api_client._add_label(image_query_id, api_label) # pylint: disable=protected-access + + def start_inspection(self) -> str: + """For users with Inspection Reports enabled only. + Starts an inspection report and returns the id of the inspection. + """ + return self.api_client.start_inspection() + + def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None: + """For users with Inspection Reports enabled only. + Add/update inspection metadata with the user_provided_key and user_provided_value. + """ + self.api_client.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + def stop_inspection(self, inspection_id: str) -> str: + """For users with Inspection Reports enabled only. + Stops an inspection and raises an exception if the response from the server + indicates that the inspection was not successfully stopped. + Returns a str with result of the inspection (either PASS or FAIL). + """ + return self.api_client.stop_inspection(inspection_id) + + def update_detector_confidence_threshold(self, detector_id: str, confidence_threshold: float) -> None: + """Updates the confidence threshold of a detector given a detector_id.""" + self.api_client.update_detector_confidence_threshold(detector_id, confidence_threshold) diff --git a/src/groundlight/internalapi.py b/src/groundlight/internalapi.py index 968baebf..95243cd9 100644 --- a/src/groundlight/internalapi.py +++ b/src/groundlight/internalapi.py @@ -1,16 +1,18 @@ +import json import logging import os import random import time import uuid from functools import wraps -from typing import Callable, Optional +from typing import Callable, Dict, Optional, Union from urllib.parse import urlsplit, urlunsplit import requests from model import Detector, ImageQuery from openapi_client.api_client import ApiClient, ApiException +from groundlight.images import ByteStreamWrapper from groundlight.status_codes import is_ok logger = logging.getLogger("groundlight.sdk") @@ -225,3 +227,182 @@ def _get_detector_by_name(self, name: str) -> Detector: f"We found multiple ({parsed['count']}) detectors with the same name. This shouldn't happen.", ) return Detector.parse_obj(parsed["results"][0]) + + @RequestsRetryDecorator() + def submit_image_query_with_inspection( # noqa: PLR0913 # pylint: disable=too-many-arguments + self, + detector_id: str, + patience_time: float, + body: ByteStreamWrapper, + inspection_id: str, + human_review: str = "DEFAULT", + ) -> str: + """Submits an image query to the API and returns the ID of the image query. + The image query will be associated to the inspection_id provided. + """ + + url = f"{self.configuration.host}/posichecks" + + params: Dict[str, Union[str, float, bool]] = { + "inspection_id": inspection_id, + "predictor_id": detector_id, + "patience_time": patience_time, + } + + # In the API, 'send_notification' is used to control human_review escalation. This will eventually + # be deprecated, but for now we need to support it in the following manner: + if human_review == "ALWAYS": + params["send_notification"] = True + elif human_review == "NEVER": + params["send_notification"] = False + else: + pass # don't send the send_notifications param, allow "DEFAULT" behavior + + headers = self._headers() + headers["Content-Type"] = "image/jpeg" + + response = requests.request("POST", url, headers=headers, params=params, data=body.read()) + + if not is_ok(response.status_code): + logger.info(response) + raise InternalApiError( + status=response.status_code, + reason=f"Error submitting image query with inspection ID {inspection_id} on detector {detector_id}", + http_resp=response, + ) + + return response.json()["id"] + + @RequestsRetryDecorator() + def start_inspection(self) -> str: + """Starts an inspection, returns the ID.""" + url = f"{self.configuration.host}/inspections" + + headers = self._headers() + + response = requests.request("POST", url, headers=headers, json={}) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason="Error starting inspection.", + http_resp=response, + ) + + return response.json()["id"] + + @RequestsRetryDecorator() + def update_inspection_metadata(self, inspection_id: str, user_provided_key: str, user_provided_value: str) -> None: + """Add/update inspection metadata with the user_provided_key and user_provided_value. + + The API stores inspections metadata in two ways: + 1) At the top level of the inspection with user_provided_id_key and user_provided_id_value. This is a + kind of "primary" piece of metadata for the inspection. Only one key/value pair is allowed at this level. + 2) In the user_metadata field as a dictionary. Multiple key/value pairs are allowed at this level. + + The first piece of metadata presented to an inspection will be assumed to be the user_provided_id_key and + user_provided_id_value. All subsequent pieces metadata will be stored in the user_metadata field. + + """ + url = f"{self.configuration.host}/inspections/{inspection_id}" + + headers = self._headers() + + # Get inspection in order to find out: + # 1) if user_provided_id_key has been set + # 2) if the inspection is closed + response = requests.request("GET", url, headers=headers) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason=f"Error getting inspection details for inspection {inspection_id}.", + http_resp=response, + ) + if response.json()["status"] == "COMPLETE": + raise ValueError(f"Inspection {inspection_id} is closed. Metadata cannot be added.") + + payload = {} + + # Set the user_provided_id_key and user_provided_id_value if they were not previously set. + response_json = response.json() + if not response_json.get("user_provided_id_key"): + payload["user_provided_id_key"] = user_provided_key + payload["user_provided_id_value"] = user_provided_value + + # Get the existing keys and values in user_metadata (if any) so that we don't overwrite them. + metadata = response_json["user_metadata"] + if not metadata: + metadata = {} + + # Submit the new metadata + metadata[user_provided_key] = user_provided_value + payload["user_metadata_json"] = json.dumps(metadata) + response = requests.request("PATCH", url, headers=headers, json=payload) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason=f"Error updating inspection metadata on inspection {inspection_id}.", + http_resp=response, + ) + + @RequestsRetryDecorator() + def stop_inspection(self, inspection_id: str) -> str: + """Stops an inspection and raises an exception if the response from the server does not indicate success. + Returns a string that indicates the result (either PASS or FAIL). The URCap requires this. + """ + url = f"{self.configuration.host}/inspections/{inspection_id}" + + headers = self._headers() + + # Closing an inspection generates a new inspection PDF. Therefore, if the inspection + # is already closed, just return "COMPLETE" to avoid unnecessarily generating a new PDF. + response = requests.request("GET", url, headers=headers) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason=f"Error checking the status of {inspection_id}.", + http_resp=response, + ) + + if response.json().get("status") == "COMPLETE": + return "COMPLETE" + + payload = {"status": "COMPLETE"} + + response = requests.request("PATCH", url, headers=headers, json=payload) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason=f"Error stopping inspection {inspection_id}.", + http_resp=response, + ) + + return response.json()["result"] + + @RequestsRetryDecorator() + def update_detector_confidence_threshold(self, detector_id: str, confidence_threshold: float) -> None: + """Updates the confidence threshold of a detector.""" + + # The API does not validate the confidence threshold, + # so we will validate it here and raise an exception if necessary. + if confidence_threshold < 0 or confidence_threshold > 1: + raise ValueError(f"Confidence threshold must be between 0 and 1. Got {confidence_threshold}.") + + url = f"{self.configuration.host}/predictors/{detector_id}" + + headers = self._headers() + + payload = {"confidence_threshold": confidence_threshold} + + response = requests.request("PATCH", url, headers=headers, json=payload) + + if not is_ok(response.status_code): + raise InternalApiError( + status=response.status_code, + reason=f"Error updating detector: {detector_id}.", + http_resp=response, + ) diff --git a/test/integration/test_groundlight.py b/test/integration/test_groundlight.py index 85a49653..5c836714 100644 --- a/test/integration/test_groundlight.py +++ b/test/integration/test_groundlight.py @@ -8,7 +8,7 @@ import pytest from groundlight import Groundlight from groundlight.binary_labels import VALID_DISPLAY_LABELS, DeprecatedLabel, Label, convert_internal_label_to_display -from groundlight.internalapi import NotFoundError +from groundlight.internalapi import InternalApiError, NotFoundError from groundlight.optional_imports import * from groundlight.status_codes import is_user_error from model import ClassificationResult, Detector, ImageQuery, PaginatedDetectorList, PaginatedImageQueryList @@ -210,7 +210,6 @@ def test_submit_image_query_jpeg_truncated(gl: Groundlight, detector: Detector): with pytest.raises(openapi_client.exceptions.ApiException) as exc_info: _image_query = gl.submit_image_query(detector=detector.id, image=jpeg_truncated) exc_value = exc_info.value - print(f"exc_info = {exc_info}") assert is_user_error(exc_value.status) @@ -420,3 +419,98 @@ def submit_noisy_image(image, label=None): return assert False, "The detector performance has not improved after two minutes" + + +def test_start_inspection(gl: Groundlight): + inspection_id = gl.start_inspection() + + assert isinstance(inspection_id, str) + assert "inspect_" in inspection_id + + +def test_update_inspection_metadata_success(gl: Groundlight): + """Starts an inspection and adds a couple pieces of metadata to it. + This should succeed. If there are any errors, an exception will be raised. + """ + inspection_id = gl.start_inspection() + + user_provided_key = "Inspector" + user_provided_value = "Bob" + gl.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + user_provided_key = "Engine ID" + user_provided_value = "1234" + gl.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + +def test_update_inspection_metadata_failure(gl: Groundlight): + """Attempts to add metadata to an inspection after it is closed. + Should raise an exception. + """ + inspection_id = gl.start_inspection() + + _ = gl.stop_inspection(inspection_id) + + with pytest.raises(ValueError): + user_provided_key = "Inspector" + user_provided_value = "Bob" + gl.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + +def test_update_inspection_metadata_invalid_inspection_id(gl: Groundlight): + """Attempt to update metadata for an inspection that doesn't exist. + Should raise an InternalApiError. + """ + + inspection_id = "some_invalid_inspection_id" + user_provided_key = "Operator" + user_provided_value = "Bob" + + with pytest.raises(InternalApiError): + gl.update_inspection_metadata(inspection_id, user_provided_key, user_provided_value) + + +def test_stop_inspection_pass(gl: Groundlight, detector: Detector): + """Starts an inspection, submits a query with the inspection ID that should pass, stops + the inspection, checks the result. + """ + inspection_id = gl.start_inspection() + + _ = gl.submit_image_query(detector=detector, image="test/assets/dog.jpeg", inspection_id=inspection_id) + + assert gl.stop_inspection(inspection_id) == "PASS" + + +def test_stop_inspection_fail(gl: Groundlight, detector: Detector): + """Starts an inspection, submits a query that should fail, stops + the inspection, checks the result. + """ + inspection_id = gl.start_inspection() + + iq = gl.submit_image_query(detector=detector, image="test/assets/cat.jpeg", inspection_id=inspection_id) + gl.add_label(iq, Label.NO) # labeling it NO just to be sure the inspection fails + + assert gl.stop_inspection(inspection_id) == "FAIL" + + +def test_stop_inspection_with_invalid_id(gl: Groundlight): + inspection_id = "some_invalid_inspection_id" + + with pytest.raises(InternalApiError): + gl.stop_inspection(inspection_id) + + +def test_update_detector_confidence_threshold_success(gl: Groundlight, detector: Detector): + """Updates the confidence threshold for a detector. This should succeed.""" + gl.update_detector_confidence_threshold(detector.id, 0.77) + + +def test_update_detector_confidence_threshold_failure(gl: Groundlight, detector: Detector): + """Attempts to update the confidence threshold for a detector to invalid values. + Should raise ValueError exceptions. + """ + with pytest.raises(ValueError): + gl.update_detector_confidence_threshold(detector.id, 77) # too high + + with pytest.raises(ValueError): + gl.update_detector_confidence_threshold(detector.id, -1) # too low