diff --git a/assets/test.png b/assets/test.png new file mode 100644 index 0000000000..5fcad46f3f Binary files /dev/null and b/assets/test.png differ diff --git a/dimos/models/vl/README.md b/dimos/models/vl/README.md index 3a8353c69a..c252d47957 100644 --- a/dimos/models/vl/README.md +++ b/dimos/models/vl/README.md @@ -20,3 +20,48 @@ image = Image.from_file("path/to/your/image.jpg") response = model.query(image.data, "What do you see in this image?") print(response) ``` + +## Moondream Hosted Model + +The `MoondreamHostedVlModel` class provides access to the hosted Moondream API for fast vision-language tasks. + +**Prerequisites:** + +You must export your API key before using the model: +```bash +export MOONDREAM_API_KEY="your_api_key_here" +``` + +### Capabilities + +The model supports four modes of operation: + +1. **Caption**: Generate a description of the image. +2. **Query**: Ask natural language questions about the image. +3. **Detect**: Find bounding boxes for specific objects. +4. **Point**: Locate the center points of specific objects. + +### Example Usage + +```python +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.msgs.sensor_msgs import Image + +model = MoondreamHostedVlModel() +image = Image.from_file("path/to/image.jpg") + +# 1. Caption +print(f"Caption: {model.caption(image)}") + +# 2. Query +print(f"Answer: {model.query(image, 'Is there a person in the image?')}") + +# 3. Detect (returns ImageDetections2D) +detections = model.query_detections(image, "person") +for det in detections.detections: + print(f"Found person at {det.bbox}") + +# 4. Point (returns list of (x, y) coordinates) +points = model.point(image, "person") +print(f"Person centers: {points}") +``` diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py index 8cb0a7944b..3ea4a28453 100644 --- a/dimos/models/vl/__init__.py +++ b/dimos/models/vl/__init__.py @@ -1,2 +1,4 @@ from dimos.models.vl.base import VlModel +from dimos.models.vl.moondream import MoondreamVlModel +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel from dimos.models.vl.qwen import QwenVlModel diff --git a/dimos/models/vl/moondream_hosted.py b/dimos/models/vl/moondream_hosted.py new file mode 100644 index 0000000000..1bf0f43e67 --- /dev/null +++ b/dimos/models/vl/moondream_hosted.py @@ -0,0 +1,133 @@ +import os +import warnings +from functools import cached_property + +import moondream as md +import numpy as np +from PIL import Image as PILImage + +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D + + +class MoondreamHostedVlModel(VlModel): + _api_key: str | None + + def __init__(self, api_key: str | None = None) -> None: + self._api_key = api_key + + @cached_property + def _client(self) -> md.vl: + api_key = self._api_key or os.getenv("MOONDREAM_API_KEY") + if not api_key: + raise ValueError( + "Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable" + ) + return md.vl(api_key=api_key) + + def _to_pil_image(self, image: Image | np.ndarray) -> PILImage.Image: + if isinstance(image, np.ndarray): + warnings.warn( + "MoondreamHostedVlModel should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=3, + ) + image = Image.from_numpy(image) + + rgb_image = image.to_rgb() + return PILImage.fromarray(rgb_image.data) + + def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: + pil_image = self._to_pil_image(image) + + result = self._client.query(pil_image, query) + return result.get("answer", str(result)) + + def caption(self, image: Image | np.ndarray, length: str = "normal") -> str: + """Generate a caption for the image. + + Args: + image: Input image + length: Caption length ("normal", "short", "long") + """ + pil_image = self._to_pil_image(image) + result = self._client.caption(pil_image, length=length) + return result.get("caption", str(result)) + + def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D: + """Detect objects using Moondream's hosted detect method. + + Args: + image: Input image + query: Object query (e.g., "person", "car") + max_objects: Maximum number of objects to detect (not directly supported by hosted API args in docs, + but we handle the output) + + Returns: + ImageDetections2D containing detected bounding boxes + """ + pil_image = self._to_pil_image(image) + + # API docs: detect(image, object) -> {"objects": [...]} + result = self._client.detect(pil_image, query) + objects = result.get("objects", []) + + # Convert to ImageDetections2D + image_detections = ImageDetections2D(image) + height, width = image.height, image.width + + for track_id, obj in enumerate(objects): + # Expected format from docs: Region with x_min, y_min, x_max, y_max + # Assuming normalized coordinates as per local model and standard VLM behavior + x_min_norm = obj.get("x_min", 0.0) + y_min_norm = obj.get("y_min", 0.0) + x_max_norm = obj.get("x_max", 1.0) + y_max_norm = obj.get("y_max", 1.0) + + x1 = x_min_norm * width + y1 = y_min_norm * height + x2 = x_max_norm * width + y2 = y_max_norm * height + + bbox = (x1, y1, x2, y2) + + detection = Detection2DBBox( + bbox=bbox, + track_id=track_id, + class_id=-1, + confidence=1.0, + name=query, + ts=image.ts, + image=image, + ) + + if detection.is_valid(): + image_detections.detections.append(detection) + + return image_detections + + def point(self, image: Image, query: str) -> list[tuple[float, float]]: + """Get coordinates of specific objects in an image. + + Args: + image: Input image + query: Object query + + Returns: + List of (x, y) pixel coordinates + """ + pil_image = self._to_pil_image(image) + result = self._client.point(pil_image, query) + points = result.get("points", []) + + pixel_points = [] + height, width = image.height, image.width + + for p in points: + x_norm = p.get("x", 0.0) + y_norm = p.get("y", 0.0) + pixel_points.append((x_norm * width, y_norm * height)) + + return pixel_points + diff --git a/dimos/models/vl/test_moondream_hosted.py b/dimos/models/vl/test_moondream_hosted.py new file mode 100644 index 0000000000..dd18b993a6 --- /dev/null +++ b/dimos/models/vl/test_moondream_hosted.py @@ -0,0 +1,96 @@ +import os +import time +import pytest +from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.detection.type import ImageDetections2D + +# Skip all tests in this module if API key is missing +pytestmark = pytest.mark.skipif( + not os.getenv("MOONDREAM_API_KEY"), + reason="MOONDREAM_API_KEY not set" +) + +@pytest.fixture +def model(): + return MoondreamHostedVlModel() + +@pytest.fixture +def test_image(): + image_path = os.path.join(os.getcwd(), "assets/test.png") + if not os.path.exists(image_path): + pytest.skip(f"Test image not found at {image_path}") + return Image.from_file(image_path) + +def test_caption(model, test_image): + """Test generating a caption.""" + print("\n--- Testing Caption ---") + caption = model.caption(test_image) + print(f"Caption: {caption}") + assert isinstance(caption, str) + assert len(caption) > 0 + +def test_query(model, test_image): + """Test querying the image.""" + print("\n--- Testing Query ---") + question = "Is there an xbox controller in the image?" + answer = model.query(test_image, question) + print(f"Question: {question}") + print(f"Answer: {answer}") + assert isinstance(answer, str) + assert len(answer) > 0 + # The answer should likely be positive given the user's prompt + assert "yes" in answer.lower() or "controller" in answer.lower() + +def test_query_latency(model, test_image): + """Test that a simple query returns in under 1 second.""" + print("\n--- Testing Query Latency ---") + question = "What is this?" + + # Warmup (optional, but good practice if first call establishes connection) + # model.query(test_image, "warmup") + + start_time = time.perf_counter() + model.query(test_image, question) + end_time = time.perf_counter() + + duration = end_time - start_time + print(f"Query took {duration:.4f} seconds") + + assert duration < 1.0, f"Query took too long: {duration:.4f}s > 1.0s" + +@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) +def test_detect(model, test_image, subject): + """Test detecting objects.""" + print(f"\n--- Testing Detect: {subject} ---") + detections = model.query_detections(test_image, subject) + + assert isinstance(detections, ImageDetections2D) + print(f"Found {len(detections.detections)} detections for {subject}") + + # We expect to find at least one of each in the provided test image + assert len(detections.detections) > 0 + + for det in detections.detections: + assert det.is_valid() + assert det.name == subject + # Check if bbox coordinates are within image dimensions + x1, y1, x2, y2 = det.bbox + assert 0 <= x1 < x2 <= test_image.width + assert 0 <= y1 < y2 <= test_image.height + +@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"]) +def test_point(model, test_image, subject): + """Test pointing at objects.""" + print(f"\n--- Testing Point: {subject} ---") + points = model.point(test_image, subject) + + print(f"Found {len(points)} points for {subject}: {points}") + assert isinstance(points, list) + assert len(points) > 0 + + for x, y in points: + assert isinstance(x, (int, float)) + assert isinstance(y, (int, float)) + assert 0 <= x <= test_image.width + assert 0 <= y <= test_image.height diff --git a/pyproject.toml b/pyproject.toml index 1631baed36..4f2d866ffa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "openai", "anthropic>=0.19.0", "cerebras-cloud-sdk", + "moondream", "numpy>=1.26.4,<2.0.0", "colorlog==6.9.0", "yapf==0.40.2",