From 1e149f0e9279ff61e72a7895e2915a6c28ff2ab4 Mon Sep 17 00:00:00 2001 From: brandon Date: Fri, 4 Oct 2024 18:53:00 -0700 Subject: [PATCH] Adding count mode tests --- src/groundlight/client.py | 48 ++++++++++++++++++++++------- src/groundlight/experimental_api.py | 38 +++++++++++++++++++++-- test/unit/test_experimental.py | 10 ++++++ 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/src/groundlight/client.py b/src/groundlight/client.py index 35eca62b..21dbd375 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -244,6 +244,38 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto ) return PaginatedDetectorList.parse_obj(obj.to_dict()) + def _prep_create_detector( + self, + name: str, + query: str, + *, + group_name: Optional[str] = None, + confidence_threshold: Optional[float] = None, + patience_time: Optional[float] = None, + pipeline_config: Optional[str] = None, + metadata: Union[dict, str, None] = None, + ) -> Detector: + """ + A helper function to prepare the input for creating a detector. Individual create_detector + methods may add to the input before calling the API. + """ + detector_creation_input = DetectorCreationInputRequest( + name=name, + query=query, + pipeline_config=pipeline_config, + ) + if group_name is not None: + detector_creation_input.group_name = group_name + if metadata is not None: + detector_creation_input.metadata = str(url_encode_dict(metadata, name="metadata", size_limit_bytes=1024)) + if confidence_threshold: + detector_creation_input.confidence_threshold = confidence_threshold + if isinstance(patience_time, int): + patience_time = float(patience_time) + if patience_time: + detector_creation_input.patience_time = patience_time + return detector_creation_input + def create_detector( # noqa: PLR0913 self, name: str, @@ -279,21 +311,15 @@ def create_detector( # noqa: PLR0913 :return: Detector """ - detector_creation_input = DetectorCreationInputRequest( + detector_creation_input = self._prep_create_detector( name=name, query=query, + group_name=group_name, + confidence_threshold=confidence_threshold, + patience_time=patience_time, pipeline_config=pipeline_config, + metadata=metadata, ) - if group_name is not None: - detector_creation_input.group_name = group_name - if metadata is not None: - detector_creation_input.metadata = str(url_encode_dict(metadata, name="metadata", size_limit_bytes=1024)) - if confidence_threshold: - detector_creation_input.confidence_threshold = confidence_threshold - if isinstance(patience_time, int): - patience_time = float(patience_time) - if patience_time: - detector_creation_input.patience_time = patience_time obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) return Detector.parse_obj(obj.to_dict()) diff --git a/src/groundlight/experimental_api.py b/src/groundlight/experimental_api.py index 97ea089d..d7498b96 100644 --- a/src/groundlight/experimental_api.py +++ b/src/groundlight/experimental_api.py @@ -8,7 +8,7 @@ import json from io import BufferedReader, BytesIO -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import requests from groundlight_openapi_client.api.actions_api import ActionsApi @@ -20,18 +20,19 @@ from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest from groundlight_openapi_client.model.channel_enum import ChannelEnum from groundlight_openapi_client.model.condition_request import ConditionRequest +from groundlight_openapi_client.model.count_mode_configuration_serializer import CountModeConfigurationSerializer from groundlight_openapi_client.model.detector_group_request import DetectorGroupRequest from groundlight_openapi_client.model.label_value_request import LabelValueRequest from groundlight_openapi_client.model.roi_request import ROIRequest from groundlight_openapi_client.model.rule_request import RuleRequest from groundlight_openapi_client.model.verb_enum import VerbEnum -from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, PaginatedRuleList, Rule +from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule from groundlight.binary_labels import Label, convert_display_label_to_internal from groundlight.images import parse_supported_image_types from groundlight.optional_imports import Image, np -from .client import Groundlight +from .client import DEFAULT_REQUEST_TIMEOUT, Groundlight class ExperimentalApi(Groundlight): @@ -305,3 +306,34 @@ def reset_detector(self, detector: Union[str, Detector]) -> None: if isinstance(detector, Detector): detector = detector.id self.detector_reset_api.reset_detector(detector) + + def create_counting_detector( + self, + name: str, + query: str, + *, + max_count: Optional[int] = None, + group_name: Optional[str] = None, + confidence_threshold: Optional[float] = None, + patience_time: Optional[float] = None, + pipeline_config: Optional[str] = None, + metadata: Union[dict, str, None] = None, + ) -> Detector: + """ + Creates a counting detector with the given name and query + """ + + detector_creation_input = self._prep_create_detector( + name=name, + query=query, + group_name=group_name, + confidence_threshold=confidence_threshold, + patience_time=patience_time, + pipeline_config=pipeline_config, + metadata=metadata, + ) + detector_creation_input.mode = ModeEnum.COUNT + mode_config = CountModeConfigurationSerializer(max_count=max_count) + detector_creation_input.mode_configuration = mode_config + obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) + return Detector.parse_obj(obj.to_dict()) \ No newline at end of file diff --git a/test/unit/test_experimental.py b/test/unit/test_experimental.py index 5a94a8a9..163f1de4 100644 --- a/test/unit/test_experimental.py +++ b/test/unit/test_experimental.py @@ -43,3 +43,13 @@ def test_submit_multiple_rois(gl_experimental: ExperimentalApi, image_query_no: label_name = "dog" roi = gl_experimental.create_roi(label_name, (0, 0), (0.5, 0.5)) gl_experimental.add_label(image_query_no, "YES", [roi] * 3) + +def test_counting_detector(gl_experimental: ExperimentalApi): + """ + verify that we can create and submit to a counting detector + """ + name = f"Test {datetime.utcnow()}" + created_detector = gl_experimental.create_counting_detector(name, "How many dogs") + assert created_detector is not None + count_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg") + assert count_iq.count is not None \ No newline at end of file