Skip to content

Commit

Permalink
Adding count mode tests
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-groundlight committed Oct 5, 2024
1 parent 03e283d commit 1e149f0
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 14 deletions.
48 changes: 37 additions & 11 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,38 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto
)
return PaginatedDetectorList.parse_obj(obj.to_dict())

def _prep_create_detector(
self,
name: str,
query: str,
*,
group_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
) -> Detector:
"""
A helper function to prepare the input for creating a detector. Individual create_detector
methods may add to the input before calling the API.
"""
detector_creation_input = DetectorCreationInputRequest(
name=name,
query=query,
pipeline_config=pipeline_config,
)
if group_name is not None:
detector_creation_input.group_name = group_name
if metadata is not None:
detector_creation_input.metadata = str(url_encode_dict(metadata, name="metadata", size_limit_bytes=1024))
if confidence_threshold:
detector_creation_input.confidence_threshold = confidence_threshold
if isinstance(patience_time, int):
patience_time = float(patience_time)
if patience_time:
detector_creation_input.patience_time = patience_time
return detector_creation_input

def create_detector( # noqa: PLR0913
self,
name: str,
Expand Down Expand Up @@ -279,21 +311,15 @@ def create_detector( # noqa: PLR0913
:return: Detector
"""

detector_creation_input = DetectorCreationInputRequest(
detector_creation_input = self._prep_create_detector(
name=name,
query=query,
group_name=group_name,
confidence_threshold=confidence_threshold,
patience_time=patience_time,
pipeline_config=pipeline_config,
metadata=metadata,
)
if group_name is not None:
detector_creation_input.group_name = group_name
if metadata is not None:
detector_creation_input.metadata = str(url_encode_dict(metadata, name="metadata", size_limit_bytes=1024))
if confidence_threshold:
detector_creation_input.confidence_threshold = confidence_threshold
if isinstance(patience_time, int):
patience_time = float(patience_time)
if patience_time:
detector_creation_input.patience_time = patience_time
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())

Expand Down
38 changes: 35 additions & 3 deletions src/groundlight/experimental_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import json
from io import BufferedReader, BytesIO
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import requests
from groundlight_openapi_client.api.actions_api import ActionsApi
Expand All @@ -20,18 +20,19 @@
from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest
from groundlight_openapi_client.model.channel_enum import ChannelEnum
from groundlight_openapi_client.model.condition_request import ConditionRequest
from groundlight_openapi_client.model.count_mode_configuration_serializer import CountModeConfigurationSerializer
from groundlight_openapi_client.model.detector_group_request import DetectorGroupRequest
from groundlight_openapi_client.model.label_value_request import LabelValueRequest
from groundlight_openapi_client.model.roi_request import ROIRequest
from groundlight_openapi_client.model.rule_request import RuleRequest
from groundlight_openapi_client.model.verb_enum import VerbEnum
from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, PaginatedRuleList, Rule
from model import ROI, BBoxGeometry, Detector, DetectorGroup, ImageQuery, ModeEnum, PaginatedRuleList, Rule

from groundlight.binary_labels import Label, convert_display_label_to_internal
from groundlight.images import parse_supported_image_types
from groundlight.optional_imports import Image, np

from .client import Groundlight
from .client import DEFAULT_REQUEST_TIMEOUT, Groundlight


class ExperimentalApi(Groundlight):
Expand Down Expand Up @@ -305,3 +306,34 @@ def reset_detector(self, detector: Union[str, Detector]) -> None:
if isinstance(detector, Detector):
detector = detector.id
self.detector_reset_api.reset_detector(detector)

def create_counting_detector(
self,
name: str,
query: str,
*,
max_count: Optional[int] = None,
group_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
) -> Detector:
"""
Creates a counting detector with the given name and query
"""

detector_creation_input = self._prep_create_detector(
name=name,
query=query,
group_name=group_name,
confidence_threshold=confidence_threshold,
patience_time=patience_time,
pipeline_config=pipeline_config,
metadata=metadata,
)
detector_creation_input.mode = ModeEnum.COUNT
mode_config = CountModeConfigurationSerializer(max_count=max_count)
detector_creation_input.mode_configuration = mode_config
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())
10 changes: 10 additions & 0 deletions test/unit/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,13 @@ def test_submit_multiple_rois(gl_experimental: ExperimentalApi, image_query_no:
label_name = "dog"
roi = gl_experimental.create_roi(label_name, (0, 0), (0.5, 0.5))
gl_experimental.add_label(image_query_no, "YES", [roi] * 3)

def test_counting_detector(gl_experimental: ExperimentalApi):
"""
verify that we can create and submit to a counting detector
"""
name = f"Test {datetime.utcnow()}"
created_detector = gl_experimental.create_counting_detector(name, "How many dogs")
assert created_detector is not None
count_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
assert count_iq.count is not None

0 comments on commit 1e149f0

Please sign in to comment.