diff --git a/src/groundlight/experimental_api.py b/src/groundlight/experimental_api.py index 98c3754e..9e6a8319 100644 --- a/src/groundlight/experimental_api.py +++ b/src/groundlight/experimental_api.py @@ -20,6 +20,7 @@ from groundlight_openapi_client.api.image_queries_api import ImageQueriesApi from groundlight_openapi_client.api.notes_api import NotesApi from groundlight_openapi_client.model.action_request import ActionRequest +from groundlight_openapi_client.model.bounding_box_mode_configuration import BoundingBoxModeConfiguration from groundlight_openapi_client.model.channel_enum import ChannelEnum from groundlight_openapi_client.model.condition_request import ConditionRequest from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration @@ -902,10 +903,12 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume metadata=metadata, ) detector_creation_input.mode = ModeEnum.COUNT - # TODO: pull the BE defined default + if max_count is None: - max_count = 10 - mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name) + mode_config = CountModeConfiguration(class_name=class_name) + else: + mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name) + detector_creation_input.mode_configuration = mode_config obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) return Detector.parse_obj(obj.to_dict()) @@ -974,6 +977,81 @@ def create_multiclass_detector( # noqa: PLR0913 # pylint: disable=too-many-argu obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) return Detector.parse_obj(obj.to_dict()) + def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals + self, + name: str, + query: str, + class_name: str, + *, + max_num_bboxes: Optional[int] = None, + group_name: Optional[str] = None, + confidence_threshold: Optional[float] = None, + patience_time: Optional[float] = None, + pipeline_config: Optional[str] = None, + metadata: Union[dict, str, None] = None, + ) -> Detector: + """ + Creates a bounding box detector that can detect objects in images up to a specified maximum number of bounding + boxes. + + **Example usage**:: + + gl = ExperimentalApi() + + # Create a detector that counts people up to 5 + detector = gl.create_bounding_box_detector( + name="people_counter", + query="Draw a bounding box around each person in the image", + class_name="person", + max_num_bboxes=5, + confidence_threshold=0.9, + patience_time=30.0 + ) + + # Use the detector to find people in an image + image_query = gl.ask_ml(detector, "path/to/image.jpg") + print(f"Confidence: {image_query.result.confidence}") + print(f"Bounding boxes: {image_query.result.rois}") + + :param name: A short, descriptive name for the detector. + :param query: A question about the object to detect in the image. + :param class_name: The class name of the object to detect. + :param max_num_bboxes: Maximum number of bounding boxes to detect (default: 10) + :param group_name: Optional name of a group to organize related detectors together. + :param confidence_threshold: A value that sets the minimum confidence level required for the ML model's + predictions. If confidence is below this threshold, the query may be sent for human review. + :param patience_time: The maximum time in seconds that Groundlight will attempt to generate a + confident prediction before falling back to human review. Defaults to 30 seconds. + :param pipeline_config: Advanced usage only. Configuration string needed to instantiate a specific + prediction pipeline for this detector. + :param metadata: A dictionary or JSON string containing custom key/value pairs to associate with + the detector (limited to 1KB). This metadata can be used to store additional + information like location, purpose, or related system IDs. You can retrieve this + metadata later by calling `get_detector()`. + + :return: The created Detector object + """ + + detector_creation_input = self._prep_create_detector( + name=name, + query=query, + group_name=group_name, + confidence_threshold=confidence_threshold, + patience_time=patience_time, + pipeline_config=pipeline_config, + metadata=metadata, + ) + detector_creation_input.mode = ModeEnum.BOUNDING_BOX + + if max_num_bboxes is None: + mode_config = BoundingBoxModeConfiguration(class_name=class_name) + else: + mode_config = BoundingBoxModeConfiguration(max_num_bboxes=max_num_bboxes, class_name=class_name) + + detector_creation_input.mode_configuration = mode_config + obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT) + return Detector.parse_obj(obj.to_dict()) + def _download_mlbinary_url(self, detector: Union[str, Detector]) -> EdgeModelInfo: """ Gets a temporary presigned URL to download the model binaries for the given detector, along diff --git a/test/integration/test_groundlight.py b/test/integration/test_groundlight.py index 545803ac..c51187e3 100644 --- a/test/integration/test_groundlight.py +++ b/test/integration/test_groundlight.py @@ -17,6 +17,7 @@ from ksuid import KsuidMs from model import ( BinaryClassificationResult, + BoundingBoxResult, CountingResult, Detector, ImageQuery, @@ -35,6 +36,7 @@ def is_valid_display_result(result: Any) -> bool: not isinstance(result, BinaryClassificationResult) and not isinstance(result, CountingResult) and not isinstance(result, MultiClassificationResult) + and not isinstance(result, BoundingBoxResult) ): return False diff --git a/test/unit/test_experimental.py b/test/unit/test_experimental.py index 0b597649..ee5a7940 100644 --- a/test/unit/test_experimental.py +++ b/test/unit/test_experimental.py @@ -1,5 +1,5 @@ import time -from datetime import datetime +from datetime import datetime, timezone import pytest from groundlight import ExperimentalApi @@ -145,3 +145,52 @@ def test_multiclass_detector(gl_experimental: ExperimentalApi): mc_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg") assert mc_iq.result.label is not None assert mc_iq.result.label in class_names + + +@pytest.mark.skip( + reason=( + "General users currently currently can't use bounding box detectors. If you have questions, reach out" + " to Groundlight support, or upgrade your plan." + ) +) +def test_bounding_box_detector(gl_experimental: ExperimentalApi): + """ + Verify that we can create and submit to a bounding box detector + """ + name = f"Test {datetime.now(timezone.utc)}" + created_detector = gl_experimental.create_bounding_box_detector( + name, "Draw a bounding box around each dog in the image", "dog" + ) + assert created_detector is not None + bbox_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg") + assert bbox_iq.result.label is not None + assert bbox_iq.rois is not None + + +@pytest.mark.skip( + reason=( + "General users currently currently can't use bounding box detectors. If you have questions, reach out" + " to Groundlight support, or upgrade your plan." + ) +) +def test_bounding_box_detector_async(gl_experimental: ExperimentalApi): + """ + Verify that we can create and submit to a bounding box detector with ask_async + """ + name = f"Test {datetime.now(timezone.utc)}" + created_detector = gl_experimental.create_bounding_box_detector( + name, "Draw a bounding box around each dog in the image", "dog" + ) + assert created_detector is not None + async_iq = gl_experimental.ask_async(created_detector, "test/assets/dog.jpeg") + + # attempting to access fields within the result should raise an exception + with pytest.raises(AttributeError): + _ = async_iq.result.label # type: ignore + with pytest.raises(AttributeError): + _ = async_iq.result.confidence # type: ignore + + time.sleep(5) + # you should be able to get a "real" result by retrieving an updated image query object from the server + _image_query = gl_experimental.get_image_query(id=async_iq.id) + assert _image_query.result is not None