Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 81 additions & 3 deletions src/groundlight/experimental_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from groundlight_openapi_client.api.image_queries_api import ImageQueriesApi
from groundlight_openapi_client.api.notes_api import NotesApi
from groundlight_openapi_client.model.action_request import ActionRequest
from groundlight_openapi_client.model.bounding_box_mode_configuration import BoundingBoxModeConfiguration
from groundlight_openapi_client.model.channel_enum import ChannelEnum
from groundlight_openapi_client.model.condition_request import ConditionRequest
from groundlight_openapi_client.model.count_mode_configuration import CountModeConfiguration
Expand Down Expand Up @@ -902,10 +903,12 @@ def create_counting_detector( # noqa: PLR0913 # pylint: disable=too-many-argume
metadata=metadata,
)
detector_creation_input.mode = ModeEnum.COUNT
# TODO: pull the BE defined default

if max_count is None:
max_count = 10
mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name)
mode_config = CountModeConfiguration(class_name=class_name)
else:
mode_config = CountModeConfiguration(max_count=max_count, class_name=class_name)

detector_creation_input.mode_configuration = mode_config
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())
Expand Down Expand Up @@ -974,6 +977,81 @@ def create_multiclass_detector( # noqa: PLR0913 # pylint: disable=too-many-argu
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())

def create_bounding_box_detector( # noqa: PLR0913 # pylint: disable=too-many-arguments, too-many-locals
self,
name: str,
query: str,
class_name: str,
*,
max_num_bboxes: Optional[int] = None,
group_name: Optional[str] = None,
confidence_threshold: Optional[float] = None,
patience_time: Optional[float] = None,
pipeline_config: Optional[str] = None,
metadata: Union[dict, str, None] = None,
) -> Detector:
"""
Creates a bounding box detector that can detect objects in images up to a specified maximum number of bounding
boxes.

**Example usage**::

gl = ExperimentalApi()

# Create a detector that counts people up to 5
detector = gl.create_bounding_box_detector(
name="people_counter",
query="Draw a bounding box around each person in the image",
class_name="person",
max_num_bboxes=5,
confidence_threshold=0.9,
patience_time=30.0
)

# Use the detector to find people in an image
image_query = gl.ask_ml(detector, "path/to/image.jpg")
print(f"Confidence: {image_query.result.confidence}")
print(f"Bounding boxes: {image_query.result.rois}")

:param name: A short, descriptive name for the detector.
:param query: A question about the object to detect in the image.
:param class_name: The class name of the object to detect.
:param max_num_bboxes: Maximum number of bounding boxes to detect (default: 10)
:param group_name: Optional name of a group to organize related detectors together.
:param confidence_threshold: A value that sets the minimum confidence level required for the ML model's
predictions. If confidence is below this threshold, the query may be sent for human review.
:param patience_time: The maximum time in seconds that Groundlight will attempt to generate a
confident prediction before falling back to human review. Defaults to 30 seconds.
:param pipeline_config: Advanced usage only. Configuration string needed to instantiate a specific
prediction pipeline for this detector.
:param metadata: A dictionary or JSON string containing custom key/value pairs to associate with
the detector (limited to 1KB). This metadata can be used to store additional
information like location, purpose, or related system IDs. You can retrieve this
metadata later by calling `get_detector()`.

:return: The created Detector object
"""

detector_creation_input = self._prep_create_detector(
name=name,
query=query,
group_name=group_name,
confidence_threshold=confidence_threshold,
patience_time=patience_time,
pipeline_config=pipeline_config,
metadata=metadata,
)
detector_creation_input.mode = ModeEnum.BOUNDING_BOX

if max_num_bboxes is None:
mode_config = BoundingBoxModeConfiguration(class_name=class_name)
else:
mode_config = BoundingBoxModeConfiguration(max_num_bboxes=max_num_bboxes, class_name=class_name)

detector_creation_input.mode_configuration = mode_config
obj = self.detectors_api.create_detector(detector_creation_input, _request_timeout=DEFAULT_REQUEST_TIMEOUT)
return Detector.parse_obj(obj.to_dict())

def _download_mlbinary_url(self, detector: Union[str, Detector]) -> EdgeModelInfo:
"""
Gets a temporary presigned URL to download the model binaries for the given detector, along
Expand Down
2 changes: 2 additions & 0 deletions test/integration/test_groundlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ksuid import KsuidMs
from model import (
BinaryClassificationResult,
BoundingBoxResult,
CountingResult,
Detector,
ImageQuery,
Expand All @@ -35,6 +36,7 @@ def is_valid_display_result(result: Any) -> bool:
not isinstance(result, BinaryClassificationResult)
and not isinstance(result, CountingResult)
and not isinstance(result, MultiClassificationResult)
and not isinstance(result, BoundingBoxResult)
):
return False

Expand Down
51 changes: 50 additions & 1 deletion test/unit/test_experimental.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from datetime import datetime
from datetime import datetime, timezone

import pytest
from groundlight import ExperimentalApi
Expand Down Expand Up @@ -145,3 +145,52 @@ def test_multiclass_detector(gl_experimental: ExperimentalApi):
mc_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
assert mc_iq.result.label is not None
assert mc_iq.result.label in class_names


@pytest.mark.skip(
reason=(
"General users currently currently can't use bounding box detectors. If you have questions, reach out"
" to Groundlight support, or upgrade your plan."
)
)
def test_bounding_box_detector(gl_experimental: ExperimentalApi):
"""
Verify that we can create and submit to a bounding box detector
"""
name = f"Test {datetime.now(timezone.utc)}"
created_detector = gl_experimental.create_bounding_box_detector(
name, "Draw a bounding box around each dog in the image", "dog"
)
assert created_detector is not None
bbox_iq = gl_experimental.submit_image_query(created_detector, "test/assets/dog.jpeg")
assert bbox_iq.result.label is not None
assert bbox_iq.rois is not None


@pytest.mark.skip(
reason=(
"General users currently currently can't use bounding box detectors. If you have questions, reach out"
" to Groundlight support, or upgrade your plan."
)
)
def test_bounding_box_detector_async(gl_experimental: ExperimentalApi):
"""
Verify that we can create and submit to a bounding box detector with ask_async
"""
name = f"Test {datetime.now(timezone.utc)}"
created_detector = gl_experimental.create_bounding_box_detector(
name, "Draw a bounding box around each dog in the image", "dog"
)
assert created_detector is not None
async_iq = gl_experimental.ask_async(created_detector, "test/assets/dog.jpeg")

# attempting to access fields within the result should raise an exception
with pytest.raises(AttributeError):
_ = async_iq.result.label # type: ignore
with pytest.raises(AttributeError):
_ = async_iq.result.confidence # type: ignore

time.sleep(5)
# you should be able to get a "real" result by retrieving an updated image query object from the server
_image_query = gl_experimental.get_image_query(id=async_iq.id)
assert _image_query.result is not None
Loading