Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added assets/test.png
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be in LFS. You can make it similar to data/.lfs/cafe.jpg.tar.gz to get it from LFS with get_data.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions dimos/models/vl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,48 @@ image = Image.from_file("path/to/your/image.jpg")
response = model.query(image.data, "What do you see in this image?")
print(response)
```

## Moondream Hosted Model

The `MoondreamHostedVlModel` class provides access to the hosted Moondream API for fast vision-language tasks.

**Prerequisites:**

You must export your API key before using the model:
```bash
export MOONDREAM_API_KEY="your_api_key_here"
```

### Capabilities

The model supports four modes of operation:

1. **Caption**: Generate a description of the image.
2. **Query**: Ask natural language questions about the image.
3. **Detect**: Find bounding boxes for specific objects.
4. **Point**: Locate the center points of specific objects.

### Example Usage

```python
from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel
from dimos.msgs.sensor_msgs import Image

model = MoondreamHostedVlModel()
image = Image.from_file("path/to/image.jpg")

# 1. Caption
print(f"Caption: {model.caption(image)}")

# 2. Query
print(f"Answer: {model.query(image, 'Is there a person in the image?')}")

# 3. Detect (returns ImageDetections2D)
detections = model.query_detections(image, "person")
for det in detections.detections:
print(f"Found person at {det.bbox}")

# 4. Point (returns list of (x, y) coordinates)
points = model.point(image, "person")
print(f"Person centers: {points}")
```
2 changes: 2 additions & 0 deletions dimos/models/vl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from dimos.models.vl.base import VlModel
from dimos.models.vl.moondream import MoondreamVlModel
from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel
from dimos.models.vl.qwen import QwenVlModel
133 changes: 133 additions & 0 deletions dimos/models/vl/moondream_hosted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import os
import warnings
from functools import cached_property

import moondream as md
import numpy as np
from PIL import Image as PILImage

from dimos.models.vl.base import VlModel
from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D


class MoondreamHostedVlModel(VlModel):
_api_key: str | None

def __init__(self, api_key: str | None = None) -> None:
self._api_key = api_key

@cached_property
def _client(self) -> md.vl:
api_key = self._api_key or os.getenv("MOONDREAM_API_KEY")
if not api_key:
raise ValueError(
"Moondream API key must be provided or set in MOONDREAM_API_KEY environment variable"
)
return md.vl(api_key=api_key)

def _to_pil_image(self, image: Image | np.ndarray) -> PILImage.Image:
if isinstance(image, np.ndarray):
warnings.warn(
"MoondreamHostedVlModel should receive standard dimos Image type, not a numpy array",
DeprecationWarning,
stacklevel=3,
)
image = Image.from_numpy(image)

rgb_image = image.to_rgb()
return PILImage.fromarray(rgb_image.data)

def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str:
pil_image = self._to_pil_image(image)

result = self._client.query(pil_image, query)
return result.get("answer", str(result))

def caption(self, image: Image | np.ndarray, length: str = "normal") -> str:
"""Generate a caption for the image.

Args:
image: Input image
length: Caption length ("normal", "short", "long")
"""
pil_image = self._to_pil_image(image)
result = self._client.caption(pil_image, length=length)
return result.get("caption", str(result))

def query_detections(self, image: Image, query: str, **kwargs) -> ImageDetections2D:
"""Detect objects using Moondream's hosted detect method.

Args:
image: Input image
query: Object query (e.g., "person", "car")
max_objects: Maximum number of objects to detect (not directly supported by hosted API args in docs,
but we handle the output)

Returns:
ImageDetections2D containing detected bounding boxes
"""
pil_image = self._to_pil_image(image)

# API docs: detect(image, object) -> {"objects": [...]}
result = self._client.detect(pil_image, query)
objects = result.get("objects", [])

# Convert to ImageDetections2D
image_detections = ImageDetections2D(image)
height, width = image.height, image.width

for track_id, obj in enumerate(objects):
# Expected format from docs: Region with x_min, y_min, x_max, y_max
# Assuming normalized coordinates as per local model and standard VLM behavior
x_min_norm = obj.get("x_min", 0.0)
y_min_norm = obj.get("y_min", 0.0)
x_max_norm = obj.get("x_max", 1.0)
y_max_norm = obj.get("y_max", 1.0)

x1 = x_min_norm * width
y1 = y_min_norm * height
x2 = x_max_norm * width
y2 = y_max_norm * height

bbox = (x1, y1, x2, y2)

detection = Detection2DBBox(
bbox=bbox,
track_id=track_id,
class_id=-1,
confidence=1.0,
name=query,
ts=image.ts,
image=image,
)

if detection.is_valid():
image_detections.detections.append(detection)

return image_detections

def point(self, image: Image, query: str) -> list[tuple[float, float]]:
"""Get coordinates of specific objects in an image.

Args:
image: Input image
query: Object query

Returns:
List of (x, y) pixel coordinates
"""
pil_image = self._to_pil_image(image)
result = self._client.point(pil_image, query)
points = result.get("points", [])

pixel_points = []
height, width = image.height, image.width

for p in points:
x_norm = p.get("x", 0.0)
y_norm = p.get("y", 0.0)
pixel_points.append((x_norm * width, y_norm * height))

return pixel_points

96 changes: 96 additions & 0 deletions dimos/models/vl/test_moondream_hosted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import time
import pytest
from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel
from dimos.msgs.sensor_msgs import Image
from dimos.perception.detection.type import ImageDetections2D

# Skip all tests in this module if API key is missing
pytestmark = pytest.mark.skipif(
not os.getenv("MOONDREAM_API_KEY"),
reason="MOONDREAM_API_KEY not set"
)

@pytest.fixture
def model():
return MoondreamHostedVlModel()

@pytest.fixture
def test_image():
image_path = os.path.join(os.getcwd(), "assets/test.png")
if not os.path.exists(image_path):
pytest.skip(f"Test image not found at {image_path}")
return Image.from_file(image_path)

def test_caption(model, test_image):
"""Test generating a caption."""
print("\n--- Testing Caption ---")
caption = model.caption(test_image)
print(f"Caption: {caption}")
assert isinstance(caption, str)
assert len(caption) > 0

def test_query(model, test_image):
"""Test querying the image."""
print("\n--- Testing Query ---")
question = "Is there an xbox controller in the image?"
answer = model.query(test_image, question)
print(f"Question: {question}")
print(f"Answer: {answer}")
assert isinstance(answer, str)
assert len(answer) > 0
# The answer should likely be positive given the user's prompt
assert "yes" in answer.lower() or "controller" in answer.lower()

def test_query_latency(model, test_image):
"""Test that a simple query returns in under 1 second."""
print("\n--- Testing Query Latency ---")
question = "What is this?"

# Warmup (optional, but good practice if first call establishes connection)
# model.query(test_image, "warmup")

start_time = time.perf_counter()
model.query(test_image, question)
end_time = time.perf_counter()

duration = end_time - start_time
print(f"Query took {duration:.4f} seconds")

assert duration < 1.0, f"Query took too long: {duration:.4f}s > 1.0s"

@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"])
def test_detect(model, test_image, subject):
"""Test detecting objects."""
print(f"\n--- Testing Detect: {subject} ---")
detections = model.query_detections(test_image, subject)

assert isinstance(detections, ImageDetections2D)
print(f"Found {len(detections.detections)} detections for {subject}")

# We expect to find at least one of each in the provided test image
assert len(detections.detections) > 0

for det in detections.detections:
assert det.is_valid()
assert det.name == subject
# Check if bbox coordinates are within image dimensions
x1, y1, x2, y2 = det.bbox
assert 0 <= x1 < x2 <= test_image.width
assert 0 <= y1 < y2 <= test_image.height

@pytest.mark.parametrize("subject", ["xbox controller", "lip balm"])
def test_point(model, test_image, subject):
"""Test pointing at objects."""
print(f"\n--- Testing Point: {subject} ---")
points = model.point(test_image, subject)

print(f"Found {len(points)} points for {subject}: {points}")
assert isinstance(points, list)
assert len(points) > 0

for x, y in points:
assert isinstance(x, (int, float))
assert isinstance(y, (int, float))
assert 0 <= x <= test_image.width
assert 0 <= y <= test_image.height
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
"openai",
"anthropic>=0.19.0",
"cerebras-cloud-sdk",
"moondream",
"numpy>=1.26.4,<2.0.0",
"colorlog==6.9.0",
"yapf==0.40.2",
Expand Down
Loading