Skip to content

Commit

Permalink
create and pull count detector/iq
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-groundlight committed Oct 5, 2024
1 parent 1e149f0 commit f3e0535
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@


def lazy_import():
from groundlight_openapi_client.model.count_mode_configuration_serializer import CountModeConfigurationSerializer
from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration
from groundlight_openapi_client.model.mode_enum import ModeEnum
from groundlight_openapi_client.model.multi_class_mode_configuration_serializer import (
MultiClassModeConfigurationSerializer,
)
from groundlight_openapi_client.model.multi_class_mode_configuration import MultiClassModeConfiguration

globals()["CountModeConfigurationSerializer"] = CountModeConfigurationSerializer
globals()["CountModeConfiguration"] = CountModeConfiguration
globals()["ModeEnum"] = ModeEnum
globals()["MultiClassModeConfigurationSerializer"] = MultiClassModeConfigurationSerializer
globals()["MultiClassModeConfiguration"] = MultiClassModeConfiguration


class DetectorCreationInputRequest(ModelNormal):
Expand Down
12 changes: 2 additions & 10 deletions generated/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: public-api.yaml
# timestamp: 2024-10-05T01:13:27+00:00
# timestamp: 2024-10-05T01:53:23+00:00

from __future__ import annotations

Expand Down Expand Up @@ -234,14 +234,6 @@ class ActionList(RootModel[List[Action]]):
root: List[Action]


class CountModeConfigurationSerializer(RootModel[Any]):
root: Any


class MultiClassModeConfigurationSerializer(RootModel[Any]):
root: Any


class AllNotes(BaseModel):
"""
Serializes all notes for a given detector, grouped by type as listed in UserProfile.NoteCategoryChoices
Expand Down Expand Up @@ -331,7 +323,7 @@ class DetectorCreationInputRequest(BaseModel):
" MULTI_CLASS"
),
)
mode_configuration: Optional[Union[CountModeConfigurationSerializer, MultiClassModeConfigurationSerializer]] = None
mode_configuration: Optional[Union[CountModeConfiguration, MultiClassModeConfiguration]] = None


class ImageQuery(BaseModel):
Expand Down
4 changes: 2 additions & 2 deletions spec/public-api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -819,8 +819,8 @@ components:
* `MULTI_CLASS` - MULTI_CLASS
mode_configuration:
oneOf:
- $ref: '#/components/schemas/CountModeConfigurationSerializer'
- $ref: '#/components/schemas/MultiClassModeConfigurationSerializer'
- $ref: '#/components/schemas/CountModeConfiguration'
- $ref: '#/components/schemas/MultiClassModeConfiguration'
nullable: true
required:
- name
Expand Down
21 changes: 11 additions & 10 deletions src/groundlight/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from groundlight_openapi_client.model.label_value_request import LabelValueRequest
from model import (
ROI,
BinaryClassificationResult,
Detector,
ImageQuery,
PaginatedDetectorList,
Expand Down Expand Up @@ -188,7 +189,7 @@ def _fixup_image_query(iq: ImageQuery) -> ImageQuery:
# Note: This might go away once we clean up the mapping logic server-side.

# we have to check that result is not None because the server will return a result of None if want_async=True
if iq.result is not None:
if isinstance(iq.result, BinaryClassificationResult):
iq.result.label = convert_internal_label_to_display(iq, iq.result.label)
return iq

Expand Down Expand Up @@ -245,15 +246,15 @@ def list_detectors(self, page: int = 1, page_size: int = 10) -> PaginatedDetecto
return PaginatedDetectorList.parse_obj(obj.to_dict())

def _prep_create_detector(
self,
name: str,
query: str,
*,
group_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
self,
name: str,
query: str,
*,
group_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
) -> Detector:
"""
A helper function to prepare the input for creating a detector. Individual create_detector
Expand Down
9 changes: 6 additions & 3 deletions src/groundlight/experimental_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from groundlight_openapi_client.model.b_box_geometry_request import BBoxGeometryRequest
from groundlight_openapi_client.model.channel_enum import ChannelEnum
from groundlight_openapi_client.model.condition_request import ConditionRequest
from groundlight_openapi_client.model.count_mode_configuration_serializer import CountModeConfigurationSerializer
from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration
from groundlight_openapi_client.model.detector_group_request import DetectorGroupRequest
from groundlight_openapi_client.model.label_value_request import LabelValueRequest
from groundlight_openapi_client.model.roi_request import ROIRequest
Expand Down Expand Up @@ -333,7 +333,10 @@ def create_counting_detector(
metadata=metadata,
)
detector_creation_input.mode = ModeEnum.COUNT
mode_config = CountModeConfigurationSerializer(max_count=max_count)
# TODO: pull the BE defined default
if max_count is None:
max_count = 10
mode_config = CountModeConfiguration(max_count=max_count)
detector_creation_input.mode_configuration = mode_config
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())
return Detector.parse_obj(obj.to_dict())
3 changes: 2 additions & 1 deletion test/unit/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def test_submit_multiple_rois(gl_experimental: ExperimentalApi, image_query_no:
roi = gl_experimental.create_roi(label_name, (0, 0), (0.5, 0.5))
gl_experimental.add_label(image_query_no, "YES", [roi] * 3)


def test_counting_detector(gl_experimental: ExperimentalApi):
"""
verify that we can create and submit to a counting detector
Expand All @@ -52,4 +53,4 @@ def test_counting_detector(gl_experimental: ExperimentalApi):
created_detector = gl_experimental.create_counting_detector(name, "How many dogs")
assert created_detector is not None
count_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
assert count_iq.count is not None
assert count_iq.result.count is not None

0 comments on commit f3e0535

Please sign in to comment.