From 7e99d3fb2a8566162b0d8f52459c5692cb780614 Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Fri, 9 Jan 2026 11:29:15 -0800 Subject: [PATCH 01/21] temporal memory + vlm agent + blueprints --- dimos/agents/temp/webcam_agent.py | 2 +- dimos/agents/vlm_agent.py | 21 +- dimos/core/test_blueprints.py | 12 +- dimos/models/vl/__init__.py | 2 + dimos/models/vl/base.py | 71 ++- dimos/models/vl/openai.py | 159 +++++ dimos/models/vl/qwen.py | 56 ++ dimos/perception/clip_filter.py | 178 ++++++ dimos/perception/temporal_memory.py | 551 ++++++++++++++++++ dimos/perception/temporal_memory_example.py | 123 ++++ .../perception/test_temporal_memory_module.py | 231 ++++++++ dimos/perception/videorag_utils.py | 457 +++++++++++++++ dimos/robot/all_blueprints.py | 1 + dimos/spec/perception.py | 5 - dimos/stream/video_operators.py | 2 +- 15 files changed, 1844 insertions(+), 27 deletions(-) create mode 100644 dimos/models/vl/openai.py create mode 100644 dimos/perception/clip_filter.py create mode 100644 dimos/perception/temporal_memory.py create mode 100644 dimos/perception/temporal_memory_example.py create mode 100644 dimos/perception/test_temporal_memory_module.py create mode 100644 dimos/perception/videorag_utils.py diff --git a/dimos/agents/temp/webcam_agent.py b/dimos/agents/temp/webcam_agent.py index b09ec2e1d8..98ae0a903b 100644 --- a/dimos/agents/temp/webcam_agent.py +++ b/dimos/agents/temp/webcam_agent.py @@ -115,7 +115,7 @@ def main() -> None: ), hardware=lambda: Webcam( camera_index=0, - fps=15, + frequency=15, stereo_slice="left", camera_info=zed.CameraInfo.SingleWebcam, ), diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 2600a7ab50..542fb4a180 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_core.messages import AIMessage, HumanMessage +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from dimos.agents.llm_init import build_llm, build_system_message from dimos.agents.spec import AgentSpec, AnyMessage @@ -71,15 +73,20 @@ def _extract_text(self, msg: HumanMessage) -> str: return str(part.get("text", "")) return str(content) - def _invoke(self, msg: HumanMessage) -> AIMessage: + def _invoke(self, msg: HumanMessage, **kwargs: Any) -> AIMessage: messages = [self._system_message, msg] - response = self._llm.invoke(messages) + response = self._llm.invoke(messages, **kwargs) self.append_history([msg, response]) # type: ignore[arg-type] return response # type: ignore[return-value] - def _invoke_image(self, image: Image, query: str) -> AIMessage: + def _invoke_image( + self, image: Image, query: str, response_format: dict | None = None + ) -> AIMessage: content = [{"type": "text", "text": query}, *image.agent_encode()] - return self._invoke(HumanMessage(content=content)) + kwargs: dict[str, Any] = {} + if response_format: + kwargs["response_format"] = response_format + return self._invoke(HumanMessage(content=content), **kwargs) @rpc def clear_history(self): # type: ignore[no-untyped-def] @@ -110,8 +117,8 @@ def query(self, query: str): # type: ignore[no-untyped-def] return response.content @rpc - def query_image(self, image: Image, query: str): # type: ignore[no-untyped-def] - response = self._invoke_image(image, query) + def query_image(self, image: Image, query: str, response_format: dict | None = None): # type: ignore[no-untyped-def] + response = self._invoke_image(image, query, response_format=response_format) return response.content diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 7a99a23abe..54313f1a84 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -27,7 +27,6 @@ autoconnect, ) from dimos.core.core import rpc -from dimos.core.global_config import GlobalConfig from dimos.core.module import Module from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.rpc_client import RpcCall @@ -35,11 +34,6 @@ from dimos.core.transport import LCMTransport from dimos.protocol import pubsub -# Disable Rerun for tests (prevents viewer spawn and gRPC flush errors) -_BUILD_WITHOUT_RERUN = { - "global_config": GlobalConfig(rerun_enabled=False, viewer_backend="foxglove"), -} - class Scratch: pass @@ -167,7 +161,7 @@ def test_build_happy_path() -> None: blueprint_set = autoconnect(module_a(), module_b(), module_c()) - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build() try: assert isinstance(coordinator, ModuleCoordinator) @@ -303,7 +297,7 @@ class TargetModule(Module): assert ("color_image", Data1) not in blueprint_set._all_name_types # Build and verify connections work - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build() try: source_instance = coordinator.get_instance(SourceModule) @@ -356,7 +350,7 @@ def test_future_annotations_autoconnect() -> None: blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint()) - coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) + coordinator = blueprint_set.build() try: out_instance = coordinator.get_instance(FutureModuleOut) diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py index 6f120f9141..e4bb68e03c 100644 --- a/dimos/models/vl/__init__.py +++ b/dimos/models/vl/__init__.py @@ -2,6 +2,7 @@ from dimos.models.vl.florence import Florence2Model from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.models.vl.openai import OpenAIVlModel from dimos.models.vl.qwen import QwenVlModel __all__ = [ @@ -9,6 +10,7 @@ "Florence2Model", "MoondreamHostedVlModel", "MoondreamVlModel", + "OpenAIVlModel", "QwenVlModel", "VlModel", ] diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 93caba4de7..a9cc80978f 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -2,11 +2,16 @@ from dataclasses import dataclass import json import logging +from typing import Any import warnings from dimos.core.resource import Resource from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D +from dimos.perception.detection.type import ( + Detection2DBBox, + Detection2DPoint, + ImageDetections2D, +) from dimos.protocol.service import Configurable # type: ignore[attr-defined] from dimos.utils.data import get_data from dimos.utils.decorators import retry @@ -82,7 +87,9 @@ def vlm_detection_to_detection2d( try: coords = [float(vlm_detection[i]) for i in range(1, 5)] except (ValueError, TypeError) as e: - logger.debug(f"Invalid VLM detection coordinates: {vlm_detection[1:]}. Error: {e}") + logger.debug( + f"Invalid VLM detection coordinates: {vlm_detection[1:]}. Error: {e}" + ) return None bbox = (coords[0], coords[1], coords[2], coords[3]) @@ -125,7 +132,9 @@ def vlm_point_to_detection2d_point( return None if len(vlm_point) != 3: - logger.debug(f"Invalid VLM point length: {len(vlm_point)}, expected 3. Got: {vlm_point}") + logger.debug( + f"Invalid VLM point length: {len(vlm_point)}, expected 3. Got: {vlm_point}" + ) return None # Extract label @@ -169,6 +178,15 @@ class VlModel(Captioner, Resource, Configurable[VlModelConfig]): default_config = VlModelConfig config: VlModelConfig + @abstractmethod + def is_set_up(self) -> None: + """Verify that the VLM is properly configured (e.g., API key is set). + + Raises: + ValueError: If the VLM is not properly configured + """ + ... + def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. @@ -180,9 +198,50 @@ def _prepare_image(self, image: Image) -> tuple[Image, float]: return image.resize_to_fit(max_w, max_h) return image, 1.0 + def __getstate__(self) -> dict[str, Any]: + """Exclude unpicklable attributes when serializing. + + Subclasses should override to handle their own unpicklable attributes + (e.g., API clients, cached properties). + """ + state = self.__dict__.copy() + # Remove common unpicklable attributes (may not exist in all subclasses) + state.pop("_client", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restore object from pickled state. + + Subclasses should override to reinitialize their own unpicklable attributes + and reload any necessary configuration (e.g., API keys from environment). + """ + self.__dict__.update(state) + # Clear cached properties that may have been removed + if "_client" in self.__dict__: + del self.__dict__["_client"] + @abstractmethod def query(self, image: Image, query: str, **kwargs) -> str: ... # type: ignore[no-untyped-def] + def query_multi_images(self, images: list[Image], query: str, **kwargs) -> str: # type: ignore[no-untyped-def] + """Query VLM with multiple images in a single request. + + This is useful for temporal reasoning across multiple frames or + multi-view analysis. The VLM can see all images together and reason + about relationships between them. + + Subclasses must override this method for models that support multi-image input + (e.g., GPT-4V, Qwen). + + Args: + images: List of input images (e.g., frames from a video window) + query: Question to ask about all images together + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not support multi-image queries. " + "Subclasses must override query_multi_images() to provide this functionality." + ) + def query_batch(self, images: list[Image], query: str, **kwargs) -> list[str]: # type: ignore[no-untyped-def] """Query multiple images with the same question. @@ -329,7 +388,11 @@ def query_points( for track_id, point_tuple in enumerate(point_tuples): # Scale coordinates back to original image size if resized - if scale != 1.0 and isinstance(point_tuple, (list, tuple)) and len(point_tuple) == 3: + if ( + scale != 1.0 + and isinstance(point_tuple, (list, tuple)) + and len(point_tuple) == 3 + ): point_tuple = [ point_tuple[0], # label point_tuple[1] / scale, # x diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py new file mode 100644 index 0000000000..5907774561 --- /dev/null +++ b/dimos/models/vl/openai.py @@ -0,0 +1,159 @@ +from dataclasses import dataclass +from functools import cached_property +import os +from typing import Any + +import numpy as np +from openai import OpenAI + +from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class OpenAIVlModelConfig(VlModelConfig): + model_name: str = "gpt-4o-mini" + api_key: str | None = None + + +class OpenAIVlModel(VlModel): + default_config = OpenAIVlModelConfig + config: OpenAIVlModelConfig + + def is_set_up(self) -> None: + """ + Verify that OpenAI API key is configured. + """ + api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable" + ) + + def __getstate__(self) -> dict[str, Any]: + """Exclude unpicklable attributes when serializing.""" + state = super().__getstate__() + # _client is already removed by base class, but ensure it's gone + state.pop("_client", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + """Restore object from pickled state and reload API key if needed.""" + super().__setstate__(state) + + # Reload API key from environment if config doesn't have it + # This is important when unpickling on Dask workers where env vars may differ + if not self.config.api_key: + api_key = os.getenv("OPENAI_API_KEY") + if api_key: + self.config.api_key = api_key + + # Verify setup (will raise ValueError if API key is still missing) + self.is_set_up() + + @cached_property + def _client(self) -> OpenAI: + api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable" + ) + + return OpenAI(api_key=api_key) + + def query(self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs) -> str: # type: ignore[override, type-arg, no-untyped-def] + if isinstance(image, np.ndarray): + import warnings + + warnings.warn( + "OpenAIVlModel.query should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=2, + ) + + image = Image.from_numpy(image) + + # Apply auto_resize if configured + image, _ = self._prepare_image(image) + + img_base64 = image.to_base64() + + api_kwargs: dict[str, Any] = { + "model": self.config.model_name, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + }, + {"type": "text", "text": query}, + ], + } + ], + } + + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + + return response.choices[0].message.content # type: ignore[return-value] + + def query_multi_images( + self, images: list[Image], query: str, response_format: dict | None = None, **kwargs + ) -> str: # type: ignore[no-untyped-def, override] + """Query VLM with multiple images (for temporal/multi-view reasoning). + + Args: + images: List of images to analyze together + query: Question about all images + response_format: Optional response format for structured output + - {"type": "json_object"} for JSON mode + - {"type": "json_schema", "json_schema": {...}} for schema enforcement + + Returns: + Response from the model + """ + if not images: + raise ValueError("Must provide at least one image") + + # Build content with multiple images + content: list[dict] = [] # type: ignore[type-arg] + + # Add all images first + for img in images: + # Apply auto_resize if configured + prepared_img, _ = self._prepare_image(img) + img_base64 = prepared_img.to_base64() + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + } + ) + + # Add query text last + content.append({"type": "text", "text": query}) + + # Build messages + messages = [{"role": "user", "content": content}] + + # Call API with optional response_format + api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages} + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + + return response.choices[0].message.content # type: ignore[return-value] + + def stop(self) -> None: + """Release the OpenAI client.""" + if "_client" in self.__dict__: + del self.__dict__["_client"] + diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index b1d3d6f036..f7e3c9c733 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -21,6 +21,16 @@ class QwenVlModel(VlModel): default_config = QwenVlModelConfig config: QwenVlModelConfig + def is_set_up(self) -> None: + """ + Verify that Alibaba API key is configured. + """ + api_key = self.config.api_key or os.getenv("ALIBABA_API_KEY") + if not api_key: + raise ValueError( + "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" + ) + @cached_property def _client(self) -> OpenAI: api_key = self.config.api_key or os.getenv("ALIBABA_API_KEY") @@ -69,6 +79,52 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o return response.choices[0].message.content # type: ignore[return-value] + def query_multi_images(self, images: list[Image], query: str, response_format: dict | None = None) -> str: # type: ignore[no-untyped-def, override] + """Query VLM with multiple images (for temporal/multi-view reasoning). + + Args: + images: List of images to analyze together + query: Question about all images + response_format: Optional response format for structured output + - {"type": "json_object"} for JSON mode + - {"type": "json_schema", "json_schema": {...}} for schema enforcement + + Returns: + Response from the model + """ + if not images: + raise ValueError("Must provide at least one image") + + # Build content with multiple images + content: list[dict] = [] # type: ignore[type-arg] + + # Add all images first + for img in images: + # Apply auto_resize if configured + prepared_img, _ = self._prepare_image(img) + img_base64 = prepared_img.to_base64() + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + } + ) + + # Add query text last + content.append({"type": "text", "text": query}) + + # Build messages + messages = [{"role": "user", "content": content}] + + # Call API with optional response_format + api_kwargs = {"model": self.config.model_name, "messages": messages} + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + + return response.choices[0].message.content # type: ignore[return-value] + def stop(self) -> None: """Release the OpenAI client.""" if "_client" in self.__dict__: diff --git a/dimos/perception/clip_filter.py b/dimos/perception/clip_filter.py new file mode 100644 index 0000000000..24e43df1e2 --- /dev/null +++ b/dimos/perception/clip_filter.py @@ -0,0 +1,178 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +CLIP-based frame filtering for selecting diverse frames from video windows. + +Adapted from videorag/clip_filter.py - uses CLIP embeddings to select the most +visually diverse frames from a window, reducing VLM costs while maintaining coverage. +""" + +import logging +from typing import Any + +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +# Optional CLIP imports +try: + import clip + from PIL import Image as PILImage + import torch + + CLIP_AVAILABLE = True +except ImportError: + CLIP_AVAILABLE = False + logger.warning( + "CLIP not available. Install with: pip install torch torchvision openai-clip. " + "Frame filtering will fall back to simple sampling." + ) + + +class CLIPFrameFilter: + """Filter video frames using CLIP embeddings to select diverse frames.""" + + def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): + """ + Initialize CLIP frame filter. + + Args: + model_name: CLIP model name (e.g., "ViT-B/32", "ViT-L/14") + device: Device to use ("cuda", "cpu", or None for auto-detect) + """ + if not CLIP_AVAILABLE: + raise ImportError( + "CLIP is not available. Install with: pip install torch torchvision openai-clip" + ) + + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Loading CLIP model {model_name} on {self.device}") + self.model, self.preprocess = clip.load(model_name, device=self.device) + logger.info("CLIP model loaded successfully") + + def _image_to_pil(self, image: Image) -> PILImage.Image: + """Convert dimos Image to PIL Image.""" + # Get numpy array from dimos Image + img_array = image.data # Assumes Image has .data attribute with numpy array + + # Convert to PIL + return PILImage.fromarray(img_array) + + def _encode_images(self, images: list[Image]) -> torch.Tensor: + """Encode images using CLIP. + + Args: + images: List of dimos Images + + Returns: + Tensor of normalized CLIP embeddings, shape (N, embedding_dim) + """ + # Convert to PIL and preprocess + pil_images = [self._image_to_pil(img) for img in images] + preprocessed = [self.preprocess(img) for img in pil_images] + + # Stack and encode + image_tensor = torch.stack(preprocessed).to(self.device) + with torch.no_grad(): + embeddings = self.model.encode_image(image_tensor) + # Normalize embeddings + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) + + return embeddings + + def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: + """ + Select diverse frames using greedy farthest-point sampling. + + This selects frames that are maximally different from each other in CLIP + embedding space, ensuring good visual coverage of the window. + + Algorithm: + 1. Always include first frame (temporal anchor) + 2. Iteratively select frame most different from already-selected frames + 3. Continue until we have max_frames frames + + Args: + frames: List of Frame objects with .image attribute + max_frames: Maximum number of frames to select + + Returns: + List of selected Frame objects (subset of input frames) + """ + if len(frames) <= max_frames: + return frames + + # Extract images from frames + images = [f.image for f in frames] + + # Encode all images + embeddings = self._encode_images(images) + + # Greedy farthest-point sampling + selected_indices = [0] # Always include first frame + remaining_indices = list(range(1, len(frames))) + + while len(selected_indices) < max_frames and remaining_indices: + selected_embs = embeddings[selected_indices] + remaining_embs = embeddings[remaining_indices] + + # Compute similarities between remaining and selected + # Shape: (num_remaining, num_selected) + similarities = remaining_embs @ selected_embs.T + + # For each remaining frame, find its max similarity to any selected frame + # Shape: (num_remaining,) + max_similarities = similarities.max(dim=1)[0] + + # Select frame with minimum max similarity (most different from all selected) + best_idx = max_similarities.argmin().item() + + selected_indices.append(remaining_indices[best_idx]) + remaining_indices.pop(best_idx) + + # Return frames in temporal order (sorted by index) + return [frames[i] for i in sorted(selected_indices)] + + def close(self) -> None: + """Clean up CLIP model.""" + if hasattr(self, "model"): + del self.model + if hasattr(self, "preprocess"): + del self.preprocess + + +def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: + """ + Fallback frame selection when CLIP is not available. + + Uses simple uniform sampling across the window. + + Args: + frames: List of Frame objects + max_frames: Maximum number of frames to select + + Returns: + List of selected Frame objects + """ + if len(frames) <= max_frames: + return frames + + # Sample uniformly across window + indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] + return [frames[i] for i in indices] + + +__all__ = ["CLIP_AVAILABLE", "CLIPFrameFilter", "select_diverse_frames_simple"] diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py new file mode 100644 index 0000000000..c5c76b91d2 --- /dev/null +++ b/dimos/perception/temporal_memory.py @@ -0,0 +1,551 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Temporal Memory module for creating entity-based temporal understanding of video streams. + +This module implements a sophisticated temporal memory system inspired by VideoRAG, +using VLM (Vision-Language Model) API calls to maintain entity rosters, rolling summaries, +and temporal relationships across video frames. +""" + +from collections import deque +from dataclasses import dataclass +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + +from reactivex import interval +from reactivex.disposable import Disposable + +from dimos import spec +from dimos.agents import skill +from dimos.core import DimosCluster, In, rpc +from dimos.core.skill_module import SkillModule +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.clip_filter import ( + CLIP_AVAILABLE, + CLIPFrameFilter, + select_diverse_frames_simple, +) +from dimos.perception.videorag_utils import ( + apply_summary_update, + build_query_prompt, + build_summary_prompt, + build_window_prompt, + get_structured_output_format, + parse_window_response, + update_state_from_window, +) +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class Frame: + frame_index: int + timestamp_s: float + image: Image + + +@dataclass +class TemporalMemoryConfig: + fps: float = 1.0 + window_s: float = 2.0 + stride_s: float = 2.0 + summary_interval_s: float = 10.0 + max_frames_per_window: int = 3 + frame_buffer_size: int = 300 + output_dir: str | Path | None = None + max_tokens: int = 900 + temperature: float = 0.2 + use_clip_filtering: bool = True + clip_model: str = "ViT-B/32" + + +def default_state() -> dict[str, Any]: + return { + "entity_roster": [], + "rolling_summary": "", + "chunk_buffer": [], + "next_summary_at_s": 0.0, + "last_present": [], + } + + +class TemporalMemory(SkillModule): + """ + builds temporal understanding of video streams using vlms. + + processes frames reactively, maintains entity rosters, tracks temporal + relationships, builds rolling summaries. responds to queries about current + state and recent events. + """ + + color_image: In[Image] + + def __init__( + self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None + ) -> None: + super().__init__() + + self._vlm = vlm # Can be None for blueprint usage + self.config = config or TemporalMemoryConfig() + + # single lock protects all state + self._state_lock = threading.Lock() + + # protected state + self._state = default_state() + self._state["next_summary_at_s"] = float(self.config.summary_interval_s) + self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size) + self._recent_windows: deque[dict[str, Any]] = deque(maxlen=50) + self._frame_count = 0 + self._last_analysis_time = 0.0 + self._video_start_wall_time: float | None = None + + # clip filter + self._clip_filter: CLIPFrameFilter | None = None + if self.config.use_clip_filtering and CLIP_AVAILABLE: + try: + self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model) + logger.info("clip filtering enabled") + except Exception as e: + logger.warning(f"clip init failed: {e}") + self.config.use_clip_filtering = False + elif self.config.use_clip_filtering: + logger.warning("clip not available") + self.config.use_clip_filtering = False + + # output directory + if self.config.output_dir: + self._output_path = Path(self.config.output_dir) + self._output_path.mkdir(parents=True, exist_ok=True) + self._evidence_file = self._output_path / "evidence.jsonl" + self._state_file = self._output_path / "state.json" + self._entities_file = self._output_path / "entities.json" + self._frames_index_file = self._output_path / "frames_index.jsonl" + logger.info(f"artifacts save to: {self._output_path}") + + logger.info( + f"temporalmemory init: fps={self.config.fps}, " + f"window={self.config.window_s}s, stride={self.config.stride_s}s" + ) + + @property + def vlm(self) -> VlModel: + """Get or create VLM instance lazily.""" + if self._vlm is None: + from dimos.models.vl.openai import OpenAIVlModel + + # Load API key from environment + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY environment variable not set. " + "Either set it or pass a vlm instance to TemporalMemory constructor." + ) + self._vlm = OpenAIVlModel(api_key=api_key) + logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable") + return self._vlm + + def __getstate__(self) -> dict[str, Any]: + state = super().__getstate__() + if state is None: + # Parent doesn't implement __getstate__, so we need to manually exclude unpicklable attrs + state = self.__dict__.copy() + # Remove unpicklable attributes + state.pop("_disposables", None) + state.pop("_loop", None) + state.pop("_state_lock", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + # Let parent restore what it needs + if hasattr(super(), "__setstate__"): + super().__setstate__(state) + else: + self.__dict__.update(state) + # Recreate unpicklable attributes + self._state_lock = threading.Lock() + + @rpc + def start(self) -> None: + super().start() + + with self._state_lock: + if self._video_start_wall_time is None: + self._video_start_wall_time = time.time() + + def on_frame(image: Image) -> None: + with self._state_lock: + video_start = self._video_start_wall_time + if image.ts is not None: + timestamp_s = image.ts - video_start + else: + timestamp_s = time.time() - video_start + + frame = Frame( + frame_index=self._frame_count, + timestamp_s=timestamp_s, + image=image, + ) + self._frame_buffer.append(frame) + self._frame_count += 1 + + unsub_image = self.color_image.subscribe(on_frame) + self._disposables.add(Disposable(unsub_image)) + + self._disposables.add( + interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) + ) + + logger.info("temporalmemory started") + + @rpc + def stop(self) -> None: + self.save_state() + self.save_entities() + self.save_frames_index() + + if self._clip_filter: + self._clip_filter.close() + self._clip_filter = None + + # Stop all stream transports to clean up LCM/shared memory threads + for stream in list(self.inputs.values()) + list(self.outputs.values()): + if stream.transport is not None and hasattr(stream.transport, "stop"): + stream.transport.stop() + stream._transport = None + + super().stop() + logger.info("temporalmemory stopped") + + def _format_timestamp(self, seconds: float) -> str: + m = int(seconds // 60) + s = seconds - 60 * m + return f"{m:02d}:{s:06.3f}" + + def _analyze_window(self) -> None: + try: + # get snapshot + with self._state_lock: + if not self._frame_buffer: + return + current_time = self._frame_buffer[-1].timestamp_s + if current_time - self._last_analysis_time < self.config.stride_s: + return + + frames_needed = max(1, int(self.config.fps * self.config.window_s)) + if len(self._frame_buffer) < frames_needed: + return + + window_frames = list(self._frame_buffer)[-frames_needed:] + state_snapshot = self._state.copy() + + w_start = window_frames[0].timestamp_s + w_end = window_frames[-1].timestamp_s + + # filter frames + if len(window_frames) > self.config.max_frames_per_window: + if self._clip_filter: + window_frames = self._clip_filter.select_diverse_frames( + window_frames, max_frames=self.config.max_frames_per_window + ) + else: + window_frames = select_diverse_frames_simple( + window_frames, max_frames=self.config.max_frames_per_window + ) + + logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") + + # build prompt + query = build_window_prompt( + w_start=w_start, + w_end=w_end, + frame_count=len(window_frames), + state=state_snapshot, + ) + + # query vlm (slow, outside lock) + # use middle frame for window analysis + try: + middle_frame = window_frames[len(window_frames) // 2] + response_format = get_structured_output_format() + response_text = self._vlm.query( + middle_frame.image, query, response_format=response_format + ) + except Exception as e: + logger.error(f"vlm agent query failed [{w_start:.1f}-{w_end:.1f}s]: {e}") + with self._state_lock: + self._last_analysis_time = w_end + return + + # parse response + parsed = parse_window_response(response_text, w_start, w_end, len(window_frames)) + + if "_error" in parsed: + logger.error(f"parse error: {parsed['_error']}") + else: + logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") + + # update state + with self._state_lock: + needs_summary = update_state_from_window( + self._state, parsed, w_end, self.config.summary_interval_s + ) + self._recent_windows.append(parsed) + self._last_analysis_time = w_end + + # save evidence + if self.config.output_dir: + self._append_evidence(parsed) + + # update summary if needed + if needs_summary: + logger.info(f"updating summary at t≈{w_end:.1f}s") + self._update_rolling_summary(w_end) + + # periodic save + with self._state_lock: + window_count = len(self._recent_windows) + + if window_count % 10 == 0: + self.save_state() + self.save_entities() + + except Exception as e: + logger.error(f"error analyzing window: {e}", exc_info=True) + + def _update_rolling_summary(self, w_end: float) -> None: + try: + # get state + with self._state_lock: + rolling_summary = str(self._state.get("rolling_summary", "")) + chunk_buffer = list(self._state.get("chunk_buffer", [])) + if self._frame_buffer: + latest_frame = self._frame_buffer[-1].image + else: + latest_frame = None + + if not chunk_buffer or not latest_frame: + return + + # build prompt + prompt = build_summary_prompt( + rolling_summary=rolling_summary, + chunk_windows=chunk_buffer, + ) + + # query vlm (slow, outside lock) + try: + summary_text = self._vlm.query(latest_frame, prompt) + if summary_text and summary_text.strip(): + with self._state_lock: + apply_summary_update( + self._state, summary_text, w_end, self.config.summary_interval_s + ) + logger.info(f"updated summary: {summary_text[:100]}...") + except Exception as e: + logger.error(f"summary update failed: {e}", exc_info=True) + + except Exception as e: + logger.error(f"error updating summary: {e}", exc_info=True) + + @skill() + def query(self, question: str) -> str: + """ + Answer a question about the video stream using temporal memory. + + Args: + question: Question to ask about the video stream + + Returns: + Answer based on temporal memory state and video context + """ + # read state + with self._state_lock: + entity_roster = list(self._state.get("entity_roster", [])) + rolling_summary = str(self._state.get("rolling_summary", "")) + last_present = list(self._state.get("last_present", [])) + recent_windows = list(self._recent_windows) + latest_frame = self._frame_buffer[-1].image if self._frame_buffer else None + + if not latest_frame: + return "no frames available" + + # build context + currently_present = {e["id"] for e in last_present if isinstance(e, dict) and "id" in e} + for window in recent_windows[-3:]: + for entity in window.get("entities_present", []): + if isinstance(entity, dict) and isinstance(entity.get("id"), str): + currently_present.add(entity["id"]) + + context = { + "entity_roster": entity_roster, + "rolling_summary": rolling_summary, + "currently_present_entities": sorted(currently_present), + "recent_windows_count": len(recent_windows), + "timestamp": time.time(), + } + + # build query prompt using videorag utils + prompt = build_query_prompt(question=question, context=context) + + # query vlm (slow, outside lock) + try: + answer_text = self.vlm.query(latest_frame, prompt) + return answer_text.strip() + except Exception as e: + logger.error(f"query failed: {e}", exc_info=True) + return f"error: {e}" + + def clear_history(self) -> None: + """Clear temporal memory state.""" + with self._state_lock: + self._state = default_state() + self._state["next_summary_at_s"] = float(self.config.summary_interval_s) + self._recent_windows.clear() + logger.info("cleared history") + + def get_state(self) -> dict[str, Any]: + with self._state_lock: + return { + "entity_count": len(self._state.get("entity_roster", [])), + "entities": list(self._state.get("entity_roster", [])), + "rolling_summary": str(self._state.get("rolling_summary", "")), + "frame_count": self._frame_count, + "buffer_size": len(self._frame_buffer), + "recent_windows": len(self._recent_windows), + "currently_present": list(self._state.get("last_present", [])), + } + + def get_entity_roster(self) -> list[dict[str, Any]]: + with self._state_lock: + return list(self._state.get("entity_roster", [])) + + def get_rolling_summary(self) -> str: + with self._state_lock: + return str(self._state.get("rolling_summary", "")) + + def save_state(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + state_copy = self._state.copy() + with open(self._state_file, "w") as f: + json.dump(state_copy, f, indent=2, ensure_ascii=False) + logger.info(f"saved state to {self._state_file}") + return True + except Exception as e: + logger.error(f"save state failed: {e}", exc_info=True) + return False + + def _append_evidence(self, evidence: dict[str, Any]) -> None: + try: + with open(self._evidence_file, "a") as f: + f.write(json.dumps(evidence, ensure_ascii=False) + "\n") + except Exception as e: + logger.error(f"append evidence failed: {e}") + + def save_entities(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + entity_roster = list(self._state.get("entity_roster", [])) + with open(self._entities_file, "w") as f: + json.dump(entity_roster, f, indent=2, ensure_ascii=False) + logger.info(f"saved {len(entity_roster)} entities") + return True + except Exception as e: + logger.error(f"save entities failed: {e}", exc_info=True) + return False + + def save_frames_index(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + frames = list(self._frame_buffer) + + frames_index = [ + { + "frame_index": f.frame_index, + "timestamp_s": f.timestamp_s, + "timestamp": self._format_timestamp(f.timestamp_s), + } + for f in frames + ] + + if frames_index: + with open(self._frames_index_file, "w", encoding="utf-8") as f: + for rec in frames_index: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + logger.info(f"saved {len(frames_index)} frames") + return True + except Exception as e: + logger.error(f"save frames failed: {e}", exc_info=True) + return False + + +def deploy( + dimos: DimosCluster, + camera: spec.Camera, + vlm: VlModel | None = None, + config: TemporalMemoryConfig | None = None, +) -> TemporalMemory: + """ + Deploy TemporalMemory with a camera. + + Args: + dimos: DimosCluster instance + camera: Camera module + vlm: Optional VlModel instance (will create OpenAIVlModel if None) + config: Optional TemporalMemoryConfig + + Returns: + Deployed TemporalMemory module + """ + if vlm is None: + from dimos.models.vl.openai import OpenAIVlModel + + # Load API key from environment + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + vlm = OpenAIVlModel(api_key=api_key) + + temporal_memory = dimos.deploy(TemporalMemory, vlm=vlm, config=config) # type: ignore[attr-defined] + + if camera.color_image.transport is None: + from dimos.core.transport import JpegShmTransport + + transport = JpegShmTransport("/temporal_memory/color_image") + camera.color_image.transport = transport + + temporal_memory.color_image.connect(camera.color_image) + temporal_memory.start() + return temporal_memory + + +temporal_memory = TemporalMemory.blueprint + +__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "deploy", "temporal_memory"] diff --git a/dimos/perception/temporal_memory_example.py b/dimos/perception/temporal_memory_example.py new file mode 100644 index 0000000000..59f669e758 --- /dev/null +++ b/dimos/perception/temporal_memory_example.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of TemporalMemory module with a VLM. + +This example demonstrates how to: +1. Deploy a camera module +2. Deploy TemporalMemory with the camera +3. Query the temporal memory about entities and events +""" + +import os +from pathlib import Path + +from dotenv import load_dotenv + +from dimos import core +from dimos.hardware.sensors.camera.module import CameraModule +from dimos.hardware.sensors.camera.webcam import Webcam +from dimos.perception.temporal_memory import ( + TemporalMemoryConfig, + deploy, +) + +# Load environment variables +load_dotenv() + + +def example_usage(): + """Example of how to use TemporalMemory.""" + # Initialize variables to None for cleanup + temporal_memory = None + camera = None + dimos = None + + try: + # Create Dimos cluster + dimos = core.start(1) + # Deploy camera module + camera = dimos.deploy(CameraModule, hardware=lambda: Webcam(camera_index=0)) + camera.start() + + # Deploy temporal memory using the deploy function + output_dir = Path("./temporal_memory_output") + temporal_memory = deploy( + dimos, + camera, + vlm=None, # Will auto-create OpenAIVlModel if None + config=TemporalMemoryConfig( + fps=1.0, # Process 1 frame per second + window_s=2.0, # Analyze 2-second windows + stride_s=2.0, # New window every 2 seconds + summary_interval_s=10.0, # Update rolling summary every 10 seconds + max_frames_per_window=3, # Max 3 frames per window + output_dir=output_dir, + ), + ) + + print("TemporalMemory deployed and started!") + print(f"Artifacts will be saved to: {output_dir}") + + # Let it run for a bit to build context + print("Building temporal context... (wait ~15 seconds)") + import time + + time.sleep(15) + + # Query the temporal memory + questions = [ + "What entities are currently visible?", + "What has happened in the last few seconds?", + "Are there any people in the scene?", + "Describe the main activity happening now", + ] + + for question in questions: + print(f"\nQuestion: {question}") + answer = temporal_memory.query(question) + print(f"Answer: {answer}") + + # Get current state + state = temporal_memory.get_state() + print("\n=== Current State ===") + print(f"Entity count: {state['entity_count']}") + print(f"Frame count: {state['frame_count']}") + print(f"Rolling summary: {state['rolling_summary']}") + print(f"Entities: {state['entities']}") + + # Get entity roster + entities = temporal_memory.get_entity_roster() + print("\n=== Entity Roster ===") + for entity in entities: + print(f" {entity['id']}: {entity['descriptor']}") + + # Stop when done + temporal_memory.stop() + camera.stop() + print("\nTemporalMemory stopped") + + finally: + if temporal_memory is not None: + temporal_memory.stop() + if camera is not None: + camera.stop() + if dimos is not None: + dimos.close_all() + + +if __name__ == "__main__": + example_usage() diff --git a/dimos/perception/test_temporal_memory_module.py b/dimos/perception/test_temporal_memory_module.py new file mode 100644 index 0000000000..45750fd139 --- /dev/null +++ b/dimos/perception/test_temporal_memory_module.py @@ -0,0 +1,231 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import pathlib +import tempfile +import time + +from dotenv import load_dotenv +import pytest +from reactivex import operators as ops + +from dimos import core +from dimos.core import Module, Out, rpc +from dimos.models.vl.openai import OpenAIVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.temporal_memory import TemporalMemory, TemporalMemoryConfig +from dimos.protocol import pubsub +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +# Load environment variables +load_dotenv() + +logger = setup_logger() + +pubsub.lcm.autoconf() + + +class VideoReplayModule(Module): + """Module that replays video data from TimedSensorReplay.""" + + video_out: Out[Image] + + def __init__(self, video_path: str) -> None: + super().__init__() + self.video_path = video_path + + @rpc + def start(self) -> None: + """Start replaying video data.""" + # Use TimedSensorReplay to replay video frames + video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) + + # Subscribe to the replay stream and publish to LCM + self._disposables.add( + video_replay.stream() + .pipe( + ops.sample(1), # Sample every 1 second + ops.take(10), # Only take 10 frames total + ) + .subscribe(self.video_out.publish) + ) + + logger.info("VideoReplayModule started") + + @rpc + def stop(self) -> None: + """Stop replaying video data.""" + # Stop all stream transports to clean up LCM loop threads + for stream in list(self.outputs.values()): + if stream.transport is not None and hasattr(stream.transport, "stop"): + stream.transport.stop() + stream._transport = None + super().stop() + logger.info("VideoReplayModule stopped") + + +@pytest.mark.lcm +@pytest.mark.gpu +@pytest.mark.module +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") +class TestTemporalMemoryModule: + @pytest.fixture(scope="function") + def temp_dir(self): + """Create a temporary directory for test data.""" + temp_dir = tempfile.mkdtemp(prefix="temporal_memory_test_") + yield temp_dir + + @pytest.fixture(scope="function") + def dimos_cluster(self): + """Create and cleanup Dimos cluster.""" + dimos = core.start(1) + yield dimos + dimos.close_all() + + @pytest.fixture(scope="function") + def video_module(self, dimos_cluster): + """Create and cleanup video replay module.""" + data_path = get_data("unitree_office_walk") + video_path = os.path.join(data_path, "video") + video_module = dimos_cluster.deploy(VideoReplayModule, video_path) + video_module.video_out.transport = core.LCMTransport("/test_video", Image) + yield video_module + try: + video_module.stop() + except Exception as e: + logger.warning(f"Failed to stop video_module: {e}") + + @pytest.fixture(scope="function") + def temporal_memory(self, dimos_cluster, temp_dir): + """Create and cleanup temporal memory module.""" + output_dir = os.path.join(temp_dir, "temporal_memory_output") + # Create OpenAIVlModel instance + api_key = os.getenv("OPENAI_API_KEY") + vlm = OpenAIVlModel(api_key=api_key) + + temporal_memory = dimos_cluster.deploy( + TemporalMemory, + vlm=vlm, + config=TemporalMemoryConfig( + fps=1.0, # Process 1 frame per second + window_s=2.0, # Analyze 2-second windows + stride_s=2.0, # New window every 2 seconds + summary_interval_s=10.0, # Update rolling summary every 10 seconds + max_frames_per_window=3, # Max 3 frames per window + output_dir=output_dir, + ), + ) + yield temporal_memory + try: + temporal_memory.stop() + except Exception as e: + logger.warning(f"Failed to stop temporal_memory: {e}") + + @pytest.mark.asyncio + async def test_temporal_memory_module_with_replay( + self, dimos_cluster, video_module, temporal_memory, temp_dir + ): + """Test TemporalMemory module with TimedSensorReplay inputs.""" + # Connect streams + temporal_memory.color_image.connect(video_module.video_out) + + # Start all modules + video_module.start() + temporal_memory.start() + logger.info("All modules started, processing in background...") + + # Wait for frames to be processed with timeout + timeout = 15.0 # 15 second timeout + start_time = time.time() + + # Keep checking state while modules are running + while (time.time() - start_time) < timeout: + state = temporal_memory.get_state() + if state["frame_count"] > 0: + logger.info( + f"Frames processing - Frame count: {state['frame_count']}, " + f"Buffer size: {state['buffer_size']}, " + f"Entity count: {state['entity_count']}" + ) + if state["frame_count"] >= 3: # Wait for at least 3 frames + break + await asyncio.sleep(0.5) + else: + # Timeout reached + state = temporal_memory.get_state() + logger.error( + f"Timeout after {timeout}s - Frame count: {state['frame_count']}, " + f"Buffer size: {state['buffer_size']}" + ) + raise AssertionError(f"No frames processed within {timeout} seconds") + + await asyncio.sleep(3) # Wait for more processing + + # Test get_state() RPC method + mid_state = temporal_memory.get_state() + logger.info( + f"Mid-test state - Frame count: {mid_state['frame_count']}, " + f"Entity count: {mid_state['entity_count']}, " + f"Recent windows: {mid_state['recent_windows']}" + ) + assert mid_state["frame_count"] >= state["frame_count"], ( + "Frame count should increase or stay same" + ) + + # Test query() RPC method + answer = temporal_memory.query("What entities are currently visible?") + logger.info(f"Query result: {answer[:200]}...") + assert len(answer) > 0, "Query should return a non-empty answer" + + # Test get_entity_roster() RPC method + entities = temporal_memory.get_entity_roster() + logger.info(f"Entity roster has {len(entities)} entities") + assert isinstance(entities, list), "Entity roster should be a list" + + # Test get_rolling_summary() RPC method + summary = temporal_memory.get_rolling_summary() + logger.info(f"Rolling summary: {summary[:200] if summary else 'empty'}...") + assert isinstance(summary, str), "Rolling summary should be a string" + + final_state = temporal_memory.get_state() + logger.info( + f"Final state - Frame count: {final_state['frame_count']}, " + f"Entity count: {final_state['entity_count']}, " + f"Recent windows: {final_state['recent_windows']}" + ) + + video_module.stop() + temporal_memory.stop() + logger.info("Stopped modules") + + # Wait a bit for file operations to complete + await asyncio.sleep(0.5) + + # Verify files were created - stop() already saved them + output_dir = os.path.join(temp_dir, "temporal_memory_output") + output_path = pathlib.Path(output_dir) + assert output_path.exists(), f"Output directory should exist: {output_dir}" + assert (output_path / "state.json").exists(), "state.json should exist" + assert (output_path / "entities.json").exists(), "entities.json should exist" + assert (output_path / "frames_index.jsonl").exists(), "frames_index.jsonl should exist" + + logger.info("All temporal memory module tests passed!") + + +if __name__ == "__main__": + pytest.main(["-v", "-s", __file__]) diff --git a/dimos/perception/videorag_utils.py b/dimos/perception/videorag_utils.py new file mode 100644 index 0000000000..393e79cb3e --- /dev/null +++ b/dimos/perception/videorag_utils.py @@ -0,0 +1,457 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +VideoRAG utilities for temporal memory - adapted from videorag/evidence.py + +This module ports the sophisticated prompts and logic from VideoRAG for use +with dimos's VlModel abstraction instead of OpenAI API directly. +""" + +import json +from typing import Any + +from dimos.utils.llm_utils import extract_json + + +def next_entity_id_hint(roster: Any) -> str: + """Generate next entity ID based on existing roster (e.g., E1, E2, E3...).""" + if not isinstance(roster, list): + return "E1" + max_n = 0 + for e in roster: + if not isinstance(e, dict): + continue + eid = e.get("id") + if isinstance(eid, str) and eid.startswith("E"): + tail = eid[1:] + if tail.isdigit(): + max_n = max(max_n, int(tail)) + return f"E{max_n + 1}" + + +def clamp_text(text: str, max_chars: int) -> str: + """Clamp text to maximum characters.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + "..." + + +def build_window_prompt( + *, + w_start: float, + w_end: float, + frame_count: int, + state: dict[str, Any], +) -> str: + """ + Build comprehensive VLM prompt for analyzing a video window. + + This is adapted from videorag's build_window_messages() but formatted + as a single text prompt for VlModel.query() instead of OpenAI's messages format. + + Args: + w_start: Window start time in seconds + w_end: Window end time in seconds + frame_count: Number of frames in this window + state: Current temporal memory state (entity_roster, rolling_summary, etc.) + + Returns: + Formatted prompt string + """ + roster = state.get("entity_roster", []) + rolling_summary = state.get("rolling_summary", "") + next_id = next_entity_id_hint(roster) + + # System instructions (from VideoRAG) + system_context = """You analyze short sequences of video frames. +You must stay grounded in what is visible. +Do not identify real people or guess names/identities; describe people anonymously. +Extract general entities (people, objects, screens, text, locations) and relations between them. +Use stable entity IDs like E1, E2 based on the provided roster.""" + + # Main prompt (from VideoRAG's build_window_messages) + prompt = f"""{system_context} + +Time window: [{w_start:.3f}, {w_end:.3f}) seconds +Number of frames: {frame_count} + +Existing entity roster (may be empty): +{json.dumps(roster, ensure_ascii=False)} + +Rolling summary so far (may be empty): +{clamp_text(str(rolling_summary), 1500)} + +Task: +1) Write a dense, grounded caption describing what is visible across the frames in this time window. +2) Identify which existing roster entities appear in these frames. +3) Add any new salient entities (people/objects/screens/text/locations) with a short grounded descriptor. +4) Extract grounded relations/events between entities (e.g., looks_at, holds, uses, walks_past, speaks_to (inferred)). + +New entity IDs must start at: {next_id} + +Rules (important): +- You MUST stay grounded in what is visible in the provided frames. +- You MUST NOT mention any entity ID unless it appears in the provided roster OR you include it in new_entities in this same output. +- If the roster is empty, introduce any salient entities you reference (start with E1, E2, ...). +- Do not invent on-screen text: only include text you can read. +- If a relation is inferred (e.g., speaks_to without audio), include it but lower confidence and explain the visual cues. + +Output JSON ONLY with this schema: +{{ + "window": {{"start_s": {w_start:.3f}, "end_s": {w_end:.3f}}}, + "caption": "dense grounded description", + "entities_present": [{{"id": "E1", "confidence": 0.0-1.0}}], + "new_entities": [{{"id": "E3", "type": "person|object|screen|text|location|other", "descriptor": "..."}}], + "relations": [ + {{ + "type": "speaks_to|looks_at|holds|uses|moves|gesture|scene_change|other", + "subject": "E1|unknown", + "object": "E2|unknown", + "confidence": 0.0-1.0, + "evidence": ["describe which frames show this"], + "notes": "short, grounded" + }} + ], + "on_screen_text": ["verbatim snippets"], + "uncertainties": ["things that are unclear"], + "confidence": 0.0-1.0 +}} +""" + return prompt + + +# JSON schema for window responses (from VideoRAG) +WINDOW_RESPONSE_SCHEMA = { + "type": "object", + "properties": { + "window": { + "type": "object", + "properties": {"start_s": {"type": "number"}, "end_s": {"type": "number"}}, + "required": ["start_s", "end_s"], + }, + "caption": {"type": "string"}, + "entities_present": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["id"], + }, + }, + "new_entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "type": { + "type": "string", + "enum": ["person", "object", "screen", "text", "location", "other"], + }, + "descriptor": {"type": "string"}, + }, + "required": ["id", "type"], + }, + }, + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"type": "string"}, + "subject": {"type": "string"}, + "object": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + "evidence": {"type": "array", "items": {"type": "string"}}, + "notes": {"type": "string"}, + }, + "required": ["type", "subject", "object"], + }, + }, + "on_screen_text": {"type": "array", "items": {"type": "string"}}, + "uncertainties": {"type": "array", "items": {"type": "string"}}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["window", "caption"], +} + + +def build_summary_prompt( + *, + rolling_summary: str, + chunk_windows: list[dict[str, Any]], +) -> str: + """ + Build prompt for updating rolling summary. + + This is adapted from videorag's build_summary_messages() but formatted + as a single text prompt for VlModel.query(). + + Args: + rolling_summary: Current rolling summary text + chunk_windows: List of recent window results to incorporate + + Returns: + Formatted prompt string + """ + # System context (from VideoRAG) + system_context = """You summarize timestamped video-window logs into a concise rolling summary. +Stay grounded in the provided window captions/relations. +Do not invent entities or rename entity IDs; preserve IDs like E1, E2 exactly. +You MAY incorporate new entity IDs if they appear in the provided chunk windows (e.g., in new_entities). +Be concise, but keep relevant entity continuity and key relations.""" + + prompt = f"""{system_context} + +Update the rolling summary using the newest chunk. + +Previous rolling summary (may be empty): +{clamp_text(rolling_summary, 2500)} + +New chunk windows (JSON): +{json.dumps(chunk_windows, ensure_ascii=False)} + +Output a concise summary as PLAIN TEXT (no JSON, no code fences). +Length constraints (important): +- Target <= 120 words total. +- Hard cap <= 900 characters. +""" + return prompt + + +def build_query_prompt( + *, + question: str, + context: dict[str, Any], +) -> str: + """ + Build prompt for querying temporal memory. + + Args: + question: User's question about the video stream + context: Context dict containing entity_roster, rolling_summary, etc. + + Returns: + Formatted prompt string + """ + prompt = f"""Answer the following question about the video stream using the provided context. + +**Question:** {question} + +**Context:** +{json.dumps(context, indent=2, ensure_ascii=False)} + +**Instructions:** +- Entities have stable IDs like E1, E2, etc. +- The 'currently_present_entities' list shows which entities are visible now +- If an entity is NOT in 'currently_present_entities', it is no longer visible +- Answer based ONLY on the provided context +- If information isn't available, say so clearly + +Provide a concise answer. +""" + return prompt + + +def parse_window_response( + response_text: str, w_start: float, w_end: float, frame_count: int +) -> dict[str, Any]: + """ + Parse VLM response for a window analysis. + + Args: + response_text: Raw text response from VLM + w_start: Window start time + w_end: Window end time + frame_count: Number of frames in window + + Returns: + Parsed dictionary with defaults filled in + """ + # Try to extract JSON (handles code fences) + parsed = extract_json(response_text) + if parsed is None: + raise ValueError(f"Failed to parse response: {response_text}") + + # Ensure we return a dict (extract_json can return a list) + if isinstance(parsed, list): + # If we got a list, wrap it in a dict with a default structure + # This shouldn't happen with proper structured output, but handle gracefully + return { + "window": {"start": w_start, "end": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Unexpected list response: {parsed}", + } + + # Ensure it's a dict + if not isinstance(parsed, dict): + raise ValueError(f"Expected dict or list, got {type(parsed)}: {parsed}") + + return parsed + + +def update_state_from_window( + state: dict[str, Any], + parsed: dict[str, Any], + w_end: float, + summary_interval_s: float, +) -> bool: + """ + Update temporal memory state from a parsed window result. + + This implements the state update logic from VideoRAG's generate_evidence(). + + Args: + state: Current state dictionary (modified in place) + parsed: Parsed window result + w_end: Window end time + summary_interval_s: How often to trigger summary updates + + Returns: + True if summary update is needed, False otherwise + """ + # Skip if there was an error + if "_error" in parsed: + return False + + new_entities = parsed.get("new_entities", []) + present = parsed.get("entities_present", []) + + # Handle new entities + if new_entities: + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + for e in new_entities: + if isinstance(e, dict) and e.get("id") not in known: + roster.append(e) + known.add(e.get("id")) + state["entity_roster"] = roster + + # Handle referenced entities (auto-add if mentioned but not in roster) + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + referenced: set[str] = set() + for p in present or []: + if isinstance(p, dict) and isinstance(p.get("id"), str): + referenced.add(p["id"]) + for rel in parsed.get("relations") or []: + if isinstance(rel, dict): + for k in ("subject", "object"): + v = rel.get(k) + if isinstance(v, str) and v != "unknown": + referenced.add(v) + for rid in sorted(referenced): + if rid not in known: + roster.append( + { + "id": rid, + "type": "other", + "descriptor": "unknown (auto-added; rerun recommended)", + } + ) + known.add(rid) + state["entity_roster"] = roster + state["last_present"] = present + + # Add to chunk buffer + chunk_buffer = state.get("chunk_buffer", []) + if not isinstance(chunk_buffer, list): + chunk_buffer = [] + chunk_buffer.append( + { + "window": parsed.get("window"), + "caption": parsed.get("caption", ""), + "entities_present": parsed.get("entities_present", []), + "new_entities": parsed.get("new_entities", []), + "relations": parsed.get("relations", []), + "on_screen_text": parsed.get("on_screen_text", []), + } + ) + state["chunk_buffer"] = chunk_buffer + + # Check if summary update is needed + if summary_interval_s > 0: + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + if w_end + 1e-6 >= next_at and chunk_buffer: + return True # Need to update summary + + return False + + +def apply_summary_update( + state: dict[str, Any], summary_text: str, w_end: float, summary_interval_s: float +) -> None: + """ + Apply a summary update to the state. + + Args: + state: State dictionary (modified in place) + summary_text: New summary text + w_end: Current window end time + summary_interval_s: Summary update interval + """ + if summary_text and summary_text.strip(): + state["rolling_summary"] = summary_text.strip() + state["chunk_buffer"] = [] + + # Advance next_summary_at_s + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + while next_at <= w_end + 1e-6: + next_at += float(summary_interval_s) + state["next_summary_at_s"] = next_at + + +def get_structured_output_format() -> dict[str, Any]: + """ + Get OpenAI-compatible structured output format for window responses. + + This uses the json_schema mode available in OpenAI API (GPT-4o mini) to enforce + the VideoRAG response schema. + + Returns: + Dictionary for response_format parameter: + {"type": "json_schema", "json_schema": {...}} + """ + + return { + "type": "json_schema", + "json_schema": { + "name": "video_window_analysis", + "description": "Analysis of a video window with entities and relations", + "schema": WINDOW_RESPONSE_SCHEMA, + "strict": False, # Allow additional fields + }, + } + + +__all__ = [ + "WINDOW_RESPONSE_SCHEMA", + "apply_summary_update", + "build_query_prompt", + "build_summary_prompt", + "build_window_prompt", + "clamp_text", + "get_structured_output_format", + "next_entity_id_hint", + "parse_window_response", + "update_state_from_window", +] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index f989098f05..877debe6f4 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -21,6 +21,7 @@ "unitree-go2-nav": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:nav", "unitree-go2-detection": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:detection", "unitree-go2-spatial": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:spatial", + "unitree-go2-temporal-memory": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:temporal_memory", "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", "unitree-go2-agentic-mcp": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_mcp", "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 4dc682523f..f2d43e1363 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -27,10 +27,5 @@ class Camera(Image): _camera_info: CameraInfo -class DepthCamera(Camera): - depth_image: Out[ImageMsg] - depth_camera_info: Out[CameraInfo] - - class Pointcloud(Protocol): pointcloud: Out[PointCloud2] diff --git a/dimos/stream/video_operators.py b/dimos/stream/video_operators.py index 548bba7598..558972e155 100644 --- a/dimos/stream/video_operators.py +++ b/dimos/stream/video_operators.py @@ -16,7 +16,7 @@ from collections.abc import Callable from datetime import datetime, timedelta from enum import Enum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import cv2 import numpy as np From 0076db99196dd41889f8c3abd08da61f447a8517 Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Fri, 9 Jan 2026 13:35:43 -0800 Subject: [PATCH 02/21] fixing module issue and style --- dimos/agents/temp/webcam_agent.py | 2 +- dimos/perception/temporal_memory.py | 96 ++++++++++++++++--- .../unitree_webrtc/unitree_go2_blueprints.py | 6 ++ dimos/spec/perception.py | 5 + 4 files changed, 94 insertions(+), 15 deletions(-) diff --git a/dimos/agents/temp/webcam_agent.py b/dimos/agents/temp/webcam_agent.py index 98ae0a903b..b09ec2e1d8 100644 --- a/dimos/agents/temp/webcam_agent.py +++ b/dimos/agents/temp/webcam_agent.py @@ -115,7 +115,7 @@ def main() -> None: ), hardware=lambda: Webcam( camera_index=0, - frequency=15, + fps=15, stereo_slice="left", camera_info=zed.CameraInfo.SingleWebcam, ), diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index c5c76b91d2..8eea9cfbcc 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -35,6 +35,7 @@ from dimos import spec from dimos.agents import skill from dimos.core import DimosCluster, In, rpc +from dimos.core.module import ModuleConfig from dimos.core.skill_module import SkillModule from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image @@ -65,7 +66,7 @@ class Frame: @dataclass -class TemporalMemoryConfig: +class TemporalMemoryConfig(ModuleConfig): fps: float = 1.0 window_s: float = 2.0 stride_s: float = 2.0 @@ -165,25 +166,87 @@ def vlm(self) -> VlModel: logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable") return self._vlm + @rpc + def set_AgentSpec_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] + """Override SkillModule to pass self directly instead of RPCClient. + + This avoids pickle issues with RPCClient's WeakSet. Since we've implemented + proper __getstate__/__setstate__, self can be safely pickled and sent across workers. + """ + from dimos.core.rpc_client import RpcCall + + if isinstance(callable, RpcCall): + callable.set_rpc(self.rpc) # type: ignore[arg-type] + callable(self) + + @rpc + def set_MCPModule_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] + """Override SkillModule to pass self directly instead of RPCClient. + + This avoids pickle issues with RPCClient's WeakSet. The instance is pickled + with minimal state for skill introspection, but actual skill execution + happens via RPC back to this original instance with full state. + """ + from dimos.core.rpc_client import RpcCall + + if isinstance(callable, RpcCall): + callable.set_rpc(self.rpc) # type: ignore[arg-type] + callable(self) + def __getstate__(self) -> dict[str, Any]: + """Pickle with minimal state needed for skill introspection. + + The agent needs to introspect @skill() methods, which may access properties. + We preserve simple attributes but set unpicklable objects to None. + """ + # Start with parent's state (which properly handles ModuleBase attributes) state = super().__getstate__() if state is None: - # Parent doesn't implement __getstate__, so we need to manually exclude unpicklable attrs - state = self.__dict__.copy() - # Remove unpicklable attributes - state.pop("_disposables", None) - state.pop("_loop", None) - state.pop("_state_lock", None) + state = {} + + # Override with our minimal state + state.update( + { + "__class__": self.__class__, + "config": self.config, + # Preserve simple state attributes (set to safe defaults) + "_vlm": None, # VLM instance - unpicklable, set to None + "_state": default_state(), # Simple dict - can pickle + "_frame_buffer": None, # Deque - set to None to avoid issues + "_recent_windows": None, # Deque - set to None + "_frame_count": 0, + "_last_analysis_time": 0.0, + "_video_start_wall_time": None, + "_clip_filter": None, # CLIPFrameFilter - unpicklable + # Output paths (simple strings/Paths - can pickle) + "_output_path": getattr(self, "_output_path", None), + "_evidence_file": getattr(self, "_evidence_file", None), + "_state_file": getattr(self, "_state_file", None), + "_entities_file": getattr(self, "_entities_file", None), + "_frames_index_file": getattr(self, "_frames_index_file", None), + } + ) return state def __setstate__(self, state: dict[str, Any]) -> None: - # Let parent restore what it needs - if hasattr(super(), "__setstate__"): - super().__setstate__(state) - else: - self.__dict__.update(state) - # Recreate unpicklable attributes - self._state_lock = threading.Lock() + """Restore minimal state after unpickling. + + This creates a minimal shell for skill introspection. The actual skill + execution happens via RPC back to the original instance with full state. + """ + # First let parent restore its critical attributes (_disposables, _loop, _rpc, etc.) + super().__setstate__(state) + + # Then restore our specific attributes + self.__dict__.update(state) + + # Recreate critical attributes that need special handling + if not hasattr(self, "_state_lock") or self._state_lock is None: + self._state_lock = threading.Lock() + if not hasattr(self, "_frame_buffer") or self._frame_buffer is None: + self._frame_buffer = deque(maxlen=self.config.frame_buffer_size if self.config else 300) + if not hasattr(self, "_recent_windows") or self._recent_windows is None: + self._recent_windows = deque(maxlen=50) @rpc def start(self) -> None: @@ -416,6 +479,7 @@ def query(self, question: str) -> str: logger.error(f"query failed: {e}", exc_info=True) return f"error: {e}" + @rpc def clear_history(self) -> None: """Clear temporal memory state.""" with self._state_lock: @@ -424,6 +488,7 @@ def clear_history(self) -> None: self._recent_windows.clear() logger.info("cleared history") + @rpc def get_state(self) -> dict[str, Any]: with self._state_lock: return { @@ -436,14 +501,17 @@ def get_state(self) -> dict[str, Any]: "currently_present": list(self._state.get("last_present", [])), } + @rpc def get_entity_roster(self) -> list[dict[str, Any]]: with self._state_lock: return list(self._state.get("entity_roster", [])) + @rpc def get_rolling_summary(self) -> str: with self._state_lock: return str(self._state.get("rolling_summary", "")) + @rpc def save_state(self) -> bool: if not self.config.output_dir: return False diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index 7629644ed6..2e459989ce 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -45,6 +45,7 @@ ) from dimos.perception.detection.moduleDB import ObjectDBModule, detectionDB_module from dimos.perception.spatial_perception import spatial_memory +from dimos.perception.temporal_memory import temporal_memory from dimos.protocol.mcp.mcp import MCPModule from dimos.robot.foxglove_bridge import foxglove_bridge from dimos.robot.unitree.connection.go2 import GO2Connection, go2_connection @@ -198,3 +199,8 @@ vlm_agent(), vlm_stream_tester(), ) + +temporal_memory = autoconnect( + agentic, + temporal_memory(), +) diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index f2d43e1363..4dc682523f 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -27,5 +27,10 @@ class Camera(Image): _camera_info: CameraInfo +class DepthCamera(Camera): + depth_image: Out[ImageMsg] + depth_camera_info: Out[CameraInfo] + + class Pointcloud(Protocol): pointcloud: Out[PointCloud2] From 6a68b5ba2fb01c4fd013c11cbc8ac8a03432556e Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Sat, 10 Jan 2026 04:26:52 +0200 Subject: [PATCH 03/21] fix skill registration --- dimos/perception/temporal_memory.py | 82 ----------------------------- 1 file changed, 82 deletions(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 8eea9cfbcc..684ab610cf 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -166,88 +166,6 @@ def vlm(self) -> VlModel: logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable") return self._vlm - @rpc - def set_AgentSpec_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] - """Override SkillModule to pass self directly instead of RPCClient. - - This avoids pickle issues with RPCClient's WeakSet. Since we've implemented - proper __getstate__/__setstate__, self can be safely pickled and sent across workers. - """ - from dimos.core.rpc_client import RpcCall - - if isinstance(callable, RpcCall): - callable.set_rpc(self.rpc) # type: ignore[arg-type] - callable(self) - - @rpc - def set_MCPModule_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] - """Override SkillModule to pass self directly instead of RPCClient. - - This avoids pickle issues with RPCClient's WeakSet. The instance is pickled - with minimal state for skill introspection, but actual skill execution - happens via RPC back to this original instance with full state. - """ - from dimos.core.rpc_client import RpcCall - - if isinstance(callable, RpcCall): - callable.set_rpc(self.rpc) # type: ignore[arg-type] - callable(self) - - def __getstate__(self) -> dict[str, Any]: - """Pickle with minimal state needed for skill introspection. - - The agent needs to introspect @skill() methods, which may access properties. - We preserve simple attributes but set unpicklable objects to None. - """ - # Start with parent's state (which properly handles ModuleBase attributes) - state = super().__getstate__() - if state is None: - state = {} - - # Override with our minimal state - state.update( - { - "__class__": self.__class__, - "config": self.config, - # Preserve simple state attributes (set to safe defaults) - "_vlm": None, # VLM instance - unpicklable, set to None - "_state": default_state(), # Simple dict - can pickle - "_frame_buffer": None, # Deque - set to None to avoid issues - "_recent_windows": None, # Deque - set to None - "_frame_count": 0, - "_last_analysis_time": 0.0, - "_video_start_wall_time": None, - "_clip_filter": None, # CLIPFrameFilter - unpicklable - # Output paths (simple strings/Paths - can pickle) - "_output_path": getattr(self, "_output_path", None), - "_evidence_file": getattr(self, "_evidence_file", None), - "_state_file": getattr(self, "_state_file", None), - "_entities_file": getattr(self, "_entities_file", None), - "_frames_index_file": getattr(self, "_frames_index_file", None), - } - ) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - """Restore minimal state after unpickling. - - This creates a minimal shell for skill introspection. The actual skill - execution happens via RPC back to the original instance with full state. - """ - # First let parent restore its critical attributes (_disposables, _loop, _rpc, etc.) - super().__setstate__(state) - - # Then restore our specific attributes - self.__dict__.update(state) - - # Recreate critical attributes that need special handling - if not hasattr(self, "_state_lock") or self._state_lock is None: - self._state_lock = threading.Lock() - if not hasattr(self, "_frame_buffer") or self._frame_buffer is None: - self._frame_buffer = deque(maxlen=self.config.frame_buffer_size if self.config else 300) - if not hasattr(self, "_recent_windows") or self._recent_windows is None: - self._recent_windows = deque(maxlen=50) - @rpc def start(self) -> None: super().start() From 56fd322484cd368fa9db2285b5396f8593150bea Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Fri, 9 Jan 2026 19:30:25 -0800 Subject: [PATCH 04/21] removing state functions unpickable --- dimos/perception/temporal_memory.py | 90 +++-------------------------- 1 file changed, 7 insertions(+), 83 deletions(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 8eea9cfbcc..205d20b753 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -166,87 +166,11 @@ def vlm(self) -> VlModel: logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable") return self._vlm - @rpc - def set_AgentSpec_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] - """Override SkillModule to pass self directly instead of RPCClient. - - This avoids pickle issues with RPCClient's WeakSet. Since we've implemented - proper __getstate__/__setstate__, self can be safely pickled and sent across workers. - """ - from dimos.core.rpc_client import RpcCall - - if isinstance(callable, RpcCall): - callable.set_rpc(self.rpc) # type: ignore[arg-type] - callable(self) - - @rpc - def set_MCPModule_register_skills(self, callable) -> None: # type: ignore[no-untyped-def] - """Override SkillModule to pass self directly instead of RPCClient. - - This avoids pickle issues with RPCClient's WeakSet. The instance is pickled - with minimal state for skill introspection, but actual skill execution - happens via RPC back to this original instance with full state. - """ - from dimos.core.rpc_client import RpcCall - - if isinstance(callable, RpcCall): - callable.set_rpc(self.rpc) # type: ignore[arg-type] - callable(self) - - def __getstate__(self) -> dict[str, Any]: - """Pickle with minimal state needed for skill introspection. - - The agent needs to introspect @skill() methods, which may access properties. - We preserve simple attributes but set unpicklable objects to None. - """ - # Start with parent's state (which properly handles ModuleBase attributes) - state = super().__getstate__() - if state is None: - state = {} - - # Override with our minimal state - state.update( - { - "__class__": self.__class__, - "config": self.config, - # Preserve simple state attributes (set to safe defaults) - "_vlm": None, # VLM instance - unpicklable, set to None - "_state": default_state(), # Simple dict - can pickle - "_frame_buffer": None, # Deque - set to None to avoid issues - "_recent_windows": None, # Deque - set to None - "_frame_count": 0, - "_last_analysis_time": 0.0, - "_video_start_wall_time": None, - "_clip_filter": None, # CLIPFrameFilter - unpicklable - # Output paths (simple strings/Paths - can pickle) - "_output_path": getattr(self, "_output_path", None), - "_evidence_file": getattr(self, "_evidence_file", None), - "_state_file": getattr(self, "_state_file", None), - "_entities_file": getattr(self, "_entities_file", None), - "_frames_index_file": getattr(self, "_frames_index_file", None), - } - ) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - """Restore minimal state after unpickling. - - This creates a minimal shell for skill introspection. The actual skill - execution happens via RPC back to the original instance with full state. - """ - # First let parent restore its critical attributes (_disposables, _loop, _rpc, etc.) - super().__setstate__(state) - - # Then restore our specific attributes - self.__dict__.update(state) - - # Recreate critical attributes that need special handling - if not hasattr(self, "_state_lock") or self._state_lock is None: - self._state_lock = threading.Lock() - if not hasattr(self, "_frame_buffer") or self._frame_buffer is None: - self._frame_buffer = deque(maxlen=self.config.frame_buffer_size if self.config else 300) - if not hasattr(self, "_recent_windows") or self._recent_windows is None: - self._recent_windows = deque(maxlen=50) + # Use default SkillModule behavior: + # - set_AgentSpec_register_skills passes RPCClient(self) to agent + # - RPCClient.__reduce__ handles pickle by not serializing LCMRPC/WeakSet + # - SkillModule.__getstate__ returns None (empty shell for routing) + # - Skills execute on original instance via RPC, not on pickled shell @rpc def start(self) -> None: @@ -351,7 +275,7 @@ def _analyze_window(self) -> None: try: middle_frame = window_frames[len(window_frames) // 2] response_format = get_structured_output_format() - response_text = self._vlm.query( + response_text = self.vlm.query( middle_frame.image, query, response_format=response_format ) except Exception as e: @@ -418,7 +342,7 @@ def _update_rolling_summary(self, w_end: float) -> None: # query vlm (slow, outside lock) try: - summary_text = self._vlm.query(latest_frame, prompt) + summary_text = self.vlm.query(latest_frame, prompt) if summary_text and summary_text.strip(): with self._state_lock: apply_summary_update( From 5ef91c27a38de3e36813546a6434313fc81b4b6f Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Sat, 10 Jan 2026 17:05:24 -0800 Subject: [PATCH 05/21] inheritancefixes and memory management --- dimos/models/vl/base.py | 66 ++++++-------------------- dimos/models/vl/openai.py | 73 ++++++----------------------- dimos/models/vl/qwen.py | 53 +++++++-------------- dimos/perception/temporal_memory.py | 50 +++++++++++--------- 4 files changed, 72 insertions(+), 170 deletions(-) diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index a9cc80978f..a7faae84a3 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -87,9 +87,7 @@ def vlm_detection_to_detection2d( try: coords = [float(vlm_detection[i]) for i in range(1, 5)] except (ValueError, TypeError) as e: - logger.debug( - f"Invalid VLM detection coordinates: {vlm_detection[1:]}. Error: {e}" - ) + logger.debug(f"Invalid VLM detection coordinates: {vlm_detection[1:]}. Error: {e}") return None bbox = (coords[0], coords[1], coords[2], coords[3]) @@ -132,9 +130,7 @@ def vlm_point_to_detection2d_point( return None if len(vlm_point) != 3: - logger.debug( - f"Invalid VLM point length: {len(vlm_point)}, expected 3. Got: {vlm_point}" - ) + logger.debug(f"Invalid VLM point length: {len(vlm_point)}, expected 3. Got: {vlm_point}") return None # Extract label @@ -198,69 +194,37 @@ def _prepare_image(self, image: Image) -> tuple[Image, float]: return image.resize_to_fit(max_w, max_h) return image, 1.0 - def __getstate__(self) -> dict[str, Any]: - """Exclude unpicklable attributes when serializing. - - Subclasses should override to handle their own unpicklable attributes - (e.g., API clients, cached properties). - """ - state = self.__dict__.copy() - # Remove common unpicklable attributes (may not exist in all subclasses) - state.pop("_client", None) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - """Restore object from pickled state. - - Subclasses should override to reinitialize their own unpicklable attributes - and reload any necessary configuration (e.g., API keys from environment). - """ - self.__dict__.update(state) - # Clear cached properties that may have been removed - if "_client" in self.__dict__: - del self.__dict__["_client"] + # Note: No custom pickle methods needed. In practice, VlModel instances + # are only stored in SkillModules, which use empty-shell pickling + # (SkillModule.__getstate__ returns None). Therefore VlModel is never + # actually pickled and doesn't need to handle unpicklable _client attributes. @abstractmethod def query(self, image: Image, query: str, **kwargs) -> str: ... # type: ignore[no-untyped-def] - def query_multi_images(self, images: list[Image], query: str, **kwargs) -> str: # type: ignore[no-untyped-def] - """Query VLM with multiple images in a single request. - - This is useful for temporal reasoning across multiple frames or - multi-view analysis. The VLM can see all images together and reason - about relationships between them. - - Subclasses must override this method for models that support multi-image input - (e.g., GPT-4V, Qwen). - - Args: - images: List of input images (e.g., frames from a video window) - query: Question to ask about all images together - """ - raise NotImplementedError( - f"{self.__class__.__name__} does not support multi-image queries. " - "Subclasses must override query_multi_images() to provide this functionality." - ) - - def query_batch(self, images: list[Image], query: str, **kwargs) -> list[str]: # type: ignore[no-untyped-def] + def query_batch( + self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + ) -> list[str]: # type: ignore[no-untyped-def] """Query multiple images with the same question. Default implementation calls query() for each image sequentially. - Subclasses may override for more efficient batched inference. + Subclasses may override for efficient batched inference. Args: images: List of input images - query: Question to ask about each image + query: Question to ask about all images + response_format: Optional response format for structured output + **kwargs: Additional arguments Returns: List of responses, one per image """ warnings.warn( - f"{self.__class__.__name__}.query_batch() is using default sequential implementation. " + f"{self.__class__.__name__}.query_batch() using sequential implementation. " "Override for efficient batched inference.", stacklevel=2, ) - return [self.query(image, query, **kwargs) for image in images] + return [self.query(image, query, response_format=response_format, **kwargs) for image in images] def query_multi(self, image: Image, queries: list[str], **kwargs) -> list[str]: # type: ignore[no-untyped-def] """Query a single image with multiple different questions. diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 5907774561..4607e29bbe 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -33,27 +33,6 @@ def is_set_up(self) -> None: "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable" ) - def __getstate__(self) -> dict[str, Any]: - """Exclude unpicklable attributes when serializing.""" - state = super().__getstate__() - # _client is already removed by base class, but ensure it's gone - state.pop("_client", None) - return state - - def __setstate__(self, state: dict[str, Any]) -> None: - """Restore object from pickled state and reload API key if needed.""" - super().__setstate__(state) - - # Reload API key from environment if config doesn't have it - # This is important when unpickling on Dask workers where env vars may differ - if not self.config.api_key: - api_key = os.getenv("OPENAI_API_KEY") - if api_key: - self.config.api_key = api_key - - # Verify setup (will raise ValueError if API key is still missing) - self.is_set_up() - @cached_property def _client(self) -> OpenAI: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") @@ -104,53 +83,29 @@ def query(self, image: Image | np.ndarray, query: str, response_format: dict | N return response.choices[0].message.content # type: ignore[return-value] - def query_multi_images( - self, images: list[Image], query: str, response_format: dict | None = None, **kwargs - ) -> str: # type: ignore[no-untyped-def, override] - """Query VLM with multiple images (for temporal/multi-view reasoning). - - Args: - images: List of images to analyze together - query: Question about all images - response_format: Optional response format for structured output - - {"type": "json_object"} for JSON mode - - {"type": "json_schema", "json_schema": {...}} for schema enforcement - - Returns: - Response from the model - """ + def query_batch( + self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + ) -> list[str]: # type: ignore[override] + """Query VLM with multiple images using a single API call.""" if not images: - raise ValueError("Must provide at least one image") - - # Build content with multiple images - content: list[dict] = [] # type: ignore[type-arg] - - # Add all images first - for img in images: - # Apply auto_resize if configured - prepared_img, _ = self._prepare_image(img) - img_base64 = prepared_img.to_base64() - content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{img_base64}"}, - } - ) - - # Add query text last + return [] + + content: list[dict[str, Any]] = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + } + for img in images + ] content.append({"type": "text", "text": query}) - # Build messages messages = [{"role": "user", "content": content}] - - # Call API with optional response_format api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages} if response_format: api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - - return response.choices[0].message.content # type: ignore[return-value] + return [response.choices[0].message.content] # type: ignore[list-item] def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index f7e3c9c733..c508b8cc72 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from functools import cached_property import os +from typing import Any import numpy as np from openai import OpenAI @@ -79,51 +80,29 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o return response.choices[0].message.content # type: ignore[return-value] - def query_multi_images(self, images: list[Image], query: str, response_format: dict | None = None) -> str: # type: ignore[no-untyped-def, override] - """Query VLM with multiple images (for temporal/multi-view reasoning). - - Args: - images: List of images to analyze together - query: Question about all images - response_format: Optional response format for structured output - - {"type": "json_object"} for JSON mode - - {"type": "json_schema", "json_schema": {...}} for schema enforcement - - Returns: - Response from the model - """ + def query_batch( + self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + ) -> list[str]: # type: ignore[override] + """Query VLM with multiple images using a single API call.""" if not images: - raise ValueError("Must provide at least one image") - - # Build content with multiple images - content: list[dict] = [] # type: ignore[type-arg] - - # Add all images first - for img in images: - # Apply auto_resize if configured - prepared_img, _ = self._prepare_image(img) - img_base64 = prepared_img.to_base64() - content.append( - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{img_base64}"}, - } - ) - - # Add query text last + return [] + + content: list[dict[str, Any]] = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + } + for img in images + ] content.append({"type": "text", "text": query}) - # Build messages messages = [{"role": "user", "content": content}] - - # Call API with optional response_format - api_kwargs = {"model": self.config.model_name, "messages": messages} + api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages} if response_format: api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - - return response.choices[0].message.content # type: ignore[return-value] + return [response.choices[0].message.content] # type: ignore[list-item] def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 3da5d24ca7..3e0fb8b1bd 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -72,7 +72,7 @@ class TemporalMemoryConfig(ModuleConfig): stride_s: float = 2.0 summary_interval_s: float = 10.0 max_frames_per_window: int = 3 - frame_buffer_size: int = 300 + frame_buffer_size: int = 200 output_dir: str | Path | None = None max_tokens: int = 900 temperature: float = 0.2 @@ -209,13 +209,20 @@ def stop(self) -> None: self._clip_filter.close() self._clip_filter = None + # Clear buffers to release image references + with self._state_lock: + self._frame_buffer.clear() + self._recent_windows.clear() + self._state = default_state() + + super().stop() + # Stop all stream transports to clean up LCM/shared memory threads for stream in list(self.inputs.values()) + list(self.outputs.values()): if stream.transport is not None and hasattr(stream.transport, "stop"): stream.transport.stop() stream._transport = None - super().stop() logger.info("temporalmemory stopped") def _format_timestamp(self, seconds: float) -> str: @@ -265,15 +272,27 @@ def _analyze_window(self) -> None: ) # query vlm (slow, outside lock) - # use middle frame for window analysis + # use query_batch for multiple frames to send all filtered frames in one API call try: - middle_frame = window_frames[len(window_frames) // 2] response_format = get_structured_output_format() - response_text = self.vlm.query( - middle_frame.image, query, response_format=response_format - ) + if len(window_frames) > 1: + # Use query_batch to send all filtered frames in one API call + # This gives the model more temporal context + frame_images = [frame.image for frame in window_frames] + responses = self.vlm.query_batch( + frame_images, query, response_format=response_format + ) + # query_batch returns list[str] with one response for all images + response_text = responses[0] if responses else "" + + # TODO: clear image data from analyzed frames & only keep metadata if the frame_buffer is still too big + else: + # Single frame - use regular query + response_text = self.vlm.query( + window_frames[0].image, query, response_format=response_format + ) except Exception as e: - logger.error(f"vlm agent query failed [{w_start:.1f}-{w_end:.1f}s]: {e}") + logger.error(f"vlm query failed [{w_start:.1f}-{w_end:.1f}s]: {e}") with self._state_lock: self._last_analysis_time = w_end return @@ -353,12 +372,6 @@ def _update_rolling_summary(self, w_end: float) -> None: def query(self, question: str) -> str: """ Answer a question about the video stream using temporal memory. - - Args: - question: Question to ask about the video stream - - Returns: - Answer based on temporal memory state and video context """ # read state with self._state_lock: @@ -500,15 +513,6 @@ def deploy( ) -> TemporalMemory: """ Deploy TemporalMemory with a camera. - - Args: - dimos: DimosCluster instance - camera: Camera module - vlm: Optional VlModel instance (will create OpenAIVlModel if None) - config: Optional TemporalMemoryConfig - - Returns: - Deployed TemporalMemory module """ if vlm is None: from dimos.models.vl.openai import OpenAIVlModel From 20bf28ed0aae197a94cec8b79b613d6805f211ec Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Sat, 10 Jan 2026 17:12:30 -0800 Subject: [PATCH 06/21] docstring for query --- dimos/perception/temporal_memory.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 3e0fb8b1bd..22a34f95a2 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -370,8 +370,23 @@ def _update_rolling_summary(self, w_end: float) -> None: @skill() def query(self, question: str) -> str: - """ - Answer a question about the video stream using temporal memory. + """Answer a question about the video stream using temporal memory. + + This skill analyzes the current video stream and temporal memory state + to answer questions about what is happening, what entities are present, + and recent events. + + Example: + query("What entities are currently visible?") + query("Do you see a wall in the video stream?") + + Args: + question (str): The question to ask about the video stream. + Examples: "What entities are visible?", "What happened recently?", + "Is there a person in the scene?" + + Returns: + str: Answer to the question based on temporal memory and current video frame. """ # read state with self._state_lock: From b76d8014e2f7d71396c67af9249628e109d6cbaa Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Sat, 10 Jan 2026 17:28:34 -0800 Subject: [PATCH 07/21] microcommit: fixing memory buffer --- dimos/perception/temporal_memory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 22a34f95a2..2bea005a9e 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -71,7 +71,7 @@ class TemporalMemoryConfig(ModuleConfig): window_s: float = 2.0 stride_s: float = 2.0 summary_interval_s: float = 10.0 - max_frames_per_window: int = 3 + max_frames_per_window: int = 50 frame_buffer_size: int = 200 output_dir: str | Path | None = None max_tokens: int = 900 @@ -286,6 +286,8 @@ def _analyze_window(self) -> None: response_text = responses[0] if responses else "" # TODO: clear image data from analyzed frames & only keep metadata if the frame_buffer is still too big + for frame in window_frames: + frame.image = None else: # Single frame - use regular query response_text = self.vlm.query( From 2184a342be65c6d4fd94763feb428ef60a93e586 Mon Sep 17 00:00:00 2001 From: Stash Pomichter Date: Mon, 12 Jan 2026 16:28:45 -0800 Subject: [PATCH 08/21] sharpness filter and simplified frame filtering --- dimos/perception/temporal_memory.py | 69 ++++++++++++++++++++++++----- 1 file changed, 57 insertions(+), 12 deletions(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 2bea005a9e..f55b080cd9 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -28,6 +28,7 @@ import threading import time from typing import Any +import numpy as np from reactivex import interval from reactivex.disposable import Disposable @@ -44,6 +45,8 @@ CLIPFrameFilter, select_diverse_frames_simple, ) +from dimos.msgs.sensor_msgs.Image import sharpness_barrier +from reactivex import Subject from dimos.perception.videorag_utils import ( apply_summary_update, build_query_prompt, @@ -71,8 +74,8 @@ class TemporalMemoryConfig(ModuleConfig): window_s: float = 2.0 stride_s: float = 2.0 summary_interval_s: float = 10.0 - max_frames_per_window: int = 50 - frame_buffer_size: int = 200 + max_frames_per_window: int = 3 + frame_buffer_size: int = 50 output_dir: str | Path | None = None max_tokens: int = 900 temperature: float = 0.2 @@ -143,6 +146,12 @@ def __init__( self._entities_file = self._output_path / "entities.json" self._frames_index_file = self._output_path / "frames_index.jsonl" logger.info(f"artifacts save to: {self._output_path}") + # else: + # self._output_path = None + + # # frames directory for saving images + # self._frames_dir = Path("temporal_memory_frames") + # self._frames_dir.mkdir(parents=True, exist_ok=True) logger.info( f"temporalmemory init: fps={self.config.fps}, " @@ -188,9 +197,26 @@ def on_frame(image: Image) -> None: image=image, ) self._frame_buffer.append(frame) + + # Save image to frames directory + # frame_filename = f"frame_{self._frame_count:06d}_{image.frame_id or 'unknown'}.jpg" + # frame_path = self._frames_dir / frame_filename + # try: + # image.save(str(frame_path)) + # except Exception as e: + # logger.warning(f"Failed to save frame {self._frame_count}: {e}") + self._frame_count += 1 - unsub_image = self.color_image.subscribe(on_frame) + # pipe through sharpness filter before buffering + frame_subject = Subject() + self._disposables.add( + frame_subject.pipe( + sharpness_barrier(self.config.fps) + ).subscribe(on_frame) + ) + + unsub_image = self.color_image.subscribe(frame_subject.on_next) self._disposables.add(Disposable(unsub_image)) self._disposables.add( @@ -230,6 +256,17 @@ def _format_timestamp(self, seconds: float) -> str: s = seconds - 60 * m return f"{m:02d}:{s:06.3f}" + def _is_scene_stale(self, frames: list[Frame]) -> bool: + """skip if scene hasn't changed meaningfully""" + if len(frames) < 2: + return False + first_img = frames[0].image + last_img = frames[-1].image + if first_img is None or last_img is None: + return False + diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) + return diff.mean() < 5.0 # tune this threshold + def _analyze_window(self) -> None: try: # get snapshot @@ -247,19 +284,27 @@ def _analyze_window(self) -> None: window_frames = list(self._frame_buffer)[-frames_needed:] state_snapshot = self._state.copy() + # add this check early, before any filtering or VLM calls + if self._is_scene_stale(window_frames): + logger.debug(f"skipping stale window [{w_start:.1f}-{w_end:.1f}s]") + with self._state_lock: + self._last_analysis_time = w_end + return + w_start = window_frames[0].timestamp_s w_end = window_frames[-1].timestamp_s # filter frames - if len(window_frames) > self.config.max_frames_per_window: - if self._clip_filter: - window_frames = self._clip_filter.select_diverse_frames( - window_frames, max_frames=self.config.max_frames_per_window - ) - else: - window_frames = select_diverse_frames_simple( - window_frames, max_frames=self.config.max_frames_per_window - ) + # NOTE: no longer using clip filter for now (alternative: sharpness barrier and stale scene check) + # if len(window_frames) > self.config.max_frames_per_window: + # if self._clip_filter: + # window_frames = self._clip_filter.select_diverse_frames( + # window_frames, max_frames=self.config.max_frames_per_window + # ) + # else: + window_frames = select_diverse_frames_simple( + window_frames, max_frames=self.config.max_frames_per_window + ) logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") From 42fd62927e5dbeba61de6a9027fae3cd368c3b4f Mon Sep 17 00:00:00 2001 From: spomichter <12108168+spomichter@users.noreply.github.com> Date: Tue, 13 Jan 2026 00:29:20 +0000 Subject: [PATCH 09/21] CI code cleanup --- dimos/perception/temporal_memory.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index f55b080cd9..1479b118b5 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -28,9 +28,9 @@ import threading import time from typing import Any -import numpy as np -from reactivex import interval +import numpy as np +from reactivex import Subject, interval from reactivex.disposable import Disposable from dimos import spec @@ -40,13 +40,12 @@ from dimos.core.skill_module import SkillModule from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.perception.clip_filter import ( CLIP_AVAILABLE, CLIPFrameFilter, select_diverse_frames_simple, ) -from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from reactivex import Subject from dimos.perception.videorag_utils import ( apply_summary_update, build_query_prompt, @@ -148,7 +147,7 @@ def __init__( logger.info(f"artifacts save to: {self._output_path}") # else: # self._output_path = None - + # # frames directory for saving images # self._frames_dir = Path("temporal_memory_frames") # self._frames_dir.mkdir(parents=True, exist_ok=True) @@ -197,7 +196,7 @@ def on_frame(image: Image) -> None: image=image, ) self._frame_buffer.append(frame) - + # Save image to frames directory # frame_filename = f"frame_{self._frame_count:06d}_{image.frame_id or 'unknown'}.jpg" # frame_path = self._frames_dir / frame_filename @@ -205,15 +204,13 @@ def on_frame(image: Image) -> None: # image.save(str(frame_path)) # except Exception as e: # logger.warning(f"Failed to save frame {self._frame_count}: {e}") - + self._frame_count += 1 # pipe through sharpness filter before buffering frame_subject = Subject() self._disposables.add( - frame_subject.pipe( - sharpness_barrier(self.config.fps) - ).subscribe(on_frame) + frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(on_frame) ) unsub_image = self.color_image.subscribe(frame_subject.on_next) From 5f1116b8ed81f583af35e115734a28f0807e2073 Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Mon, 12 Jan 2026 18:36:40 -0800 Subject: [PATCH 10/21] initial graph database implementation --- dimos/perception/entity_graph_db.py | 793 ++++++++++++++++++++ dimos/perception/temporal_memory.py | 444 +++++++++-- dimos/perception/temporal_memory_example.py | 16 + dimos/perception/videorag_utils.py | 279 ++++++- 4 files changed, 1466 insertions(+), 66 deletions(-) create mode 100644 dimos/perception/entity_graph_db.py diff --git a/dimos/perception/entity_graph_db.py b/dimos/perception/entity_graph_db.py new file mode 100644 index 0000000000..41fde5554a --- /dev/null +++ b/dimos/perception/entity_graph_db.py @@ -0,0 +1,793 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Entity Graph Database for storing and querying entity relationships. + +Maintains three types of graphs: +1. Relations Graph: Interactions between entities (holds, looks_at, talks_to, etc.) +2. Distance Graph: Spatial distances between entities +3. Semantic Graph: Conceptual relationships (goes_with, part_of, used_for, etc.) + +All graphs share the same entity nodes but have different edge types. +""" + +import json +from pathlib import Path +import sqlite3 +import threading +from typing import Any + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class EntityGraphDB: + """ + SQLite-based graph database for entity relationships. + + Thread-safe implementation using connection-per-thread pattern. + All graphs share the same entity nodes but maintain separate edge tables. + """ + + def __init__(self, db_path: str | Path) -> None: + """ + Initialize the entity graph database. + + Args: + db_path: Path to the SQLite database file + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + # Thread-local storage for connections + self._local = threading.local() + + # Initialize schema + self._init_schema() + + logger.info(f"EntityGraphDB initialized at {self.db_path}") + + def _get_connection(self) -> sqlite3.Connection: + """Get thread-local database connection.""" + if not hasattr(self._local, "conn"): + self._local.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._local.conn.row_factory = sqlite3.Row + return self._local.conn + + def _init_schema(self) -> None: + """Initialize database schema.""" + conn = self._get_connection() + cursor = conn.cursor() + + # Entities table (shared nodes) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS entities ( + entity_id TEXT PRIMARY KEY, + entity_type TEXT NOT NULL, + descriptor TEXT, + first_seen_ts REAL NOT NULL, + last_seen_ts REAL NOT NULL, + metadata TEXT + ) + """) + + # Relations table (Graph 1: Interactions) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + relation_type TEXT NOT NULL, + subject_id TEXT NOT NULL, + object_id TEXT NOT NULL, + confidence REAL DEFAULT 1.0, + timestamp_s REAL NOT NULL, + evidence TEXT, + notes TEXT, + FOREIGN KEY (subject_id) REFERENCES entities(entity_id), + FOREIGN KEY (object_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_subject ON relations(subject_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_object ON relations(object_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_type ON relations(relation_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_time ON relations(timestamp_s)") + + # Distances table (Graph 2: Spatial) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS distances ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entity_a_id TEXT NOT NULL, + entity_b_id TEXT NOT NULL, + distance_meters REAL, + distance_category TEXT, + confidence REAL DEFAULT 1.0, + timestamp_s REAL NOT NULL, + method TEXT, + FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id), + FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_distances_pair ON distances(entity_a_id, entity_b_id)" + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_distances_time ON distances(timestamp_s)") + + # Semantic relations table (Graph 3: Knowledge) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS semantic_relations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + relation_type TEXT NOT NULL, + entity_a_id TEXT NOT NULL, + entity_b_id TEXT NOT NULL, + confidence REAL DEFAULT 1.0, + learned_from TEXT, + first_observed_ts REAL NOT NULL, + last_observed_ts REAL NOT NULL, + observation_count INTEGER DEFAULT 1, + FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id), + FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_semantic_pair ON semantic_relations(entity_a_id, entity_b_id)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_semantic_type ON semantic_relations(relation_type)" + ) + + conn.commit() + + def upsert_entity( + self, + entity_id: str, + entity_type: str, + descriptor: str, + timestamp_s: float, + metadata: dict[str, Any] | None = None, + ) -> None: + """ + Insert or update an entity. + + Args: + entity_id: Unique entity identifier (e.g., "E1") + entity_type: Type of entity (person, object, location, etc.) + descriptor: Text description of the entity + timestamp_s: Timestamp when entity was observed + metadata: Optional additional metadata + """ + conn = self._get_connection() + cursor = conn.cursor() + + metadata_json = json.dumps(metadata) if metadata else None + + cursor.execute( + """ + INSERT INTO entities (entity_id, entity_type, descriptor, first_seen_ts, last_seen_ts, metadata) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(entity_id) DO UPDATE SET + last_seen_ts = ?, + descriptor = COALESCE(excluded.descriptor, descriptor), + metadata = COALESCE(excluded.metadata, metadata) + """, + ( + entity_id, + entity_type, + descriptor, + timestamp_s, + timestamp_s, + metadata_json, + timestamp_s, + ), + ) + + conn.commit() + logger.debug(f"Upserted entity {entity_id} (type={entity_type})") + + def get_entity(self, entity_id: str) -> dict[str, Any] | None: + """ + Get an entity by ID. + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT * FROM entities WHERE entity_id = ?", (entity_id,)) + row = cursor.fetchone() + + if row is None: + return None + + return { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + + def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any]]: + """ + Get all entities, optionally filtered by type. + """ + conn = self._get_connection() + cursor = conn.cursor() + + if entity_type: + cursor.execute( + "SELECT * FROM entities WHERE entity_type = ? ORDER BY last_seen_ts DESC", + (entity_type,), + ) + else: + cursor.execute("SELECT * FROM entities ORDER BY last_seen_ts DESC") + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + for row in rows + ] + + def add_relation( + self, + relation_type: str, + subject_id: str, + object_id: str, + confidence: float, + timestamp_s: float, + evidence: list[str] | None = None, + notes: str | None = None, + ) -> None: + """ + Add a relation between two entities. + + Args: + relation_type: Type of relation (holds, looks_at, talks_to, etc.) + subject_id: Subject entity ID + object_id: Object entity ID + confidence: Confidence score (0.0 to 1.0) + timestamp_s: Timestamp when relation was observed + evidence: Optional list of evidence strings + notes: Optional notes + """ + conn = self._get_connection() + cursor = conn.cursor() + + evidence_json = json.dumps(evidence) if evidence else None + + cursor.execute( + """ + INSERT INTO relations (relation_type, subject_id, object_id, confidence, timestamp_s, evidence, notes) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + (relation_type, subject_id, object_id, confidence, timestamp_s, evidence_json, notes), + ) + + conn.commit() + logger.debug(f"Added relation: {subject_id} --{relation_type}--> {object_id}") + + def get_relations_for_entity( + self, + entity_id: str, + relation_type: str | None = None, + time_window: tuple[float, float] | None = None, + ) -> list[dict[str, Any]]: + """ + Get all relations involving an entity. + + Args: + entity_id: Entity ID + relation_type: Optional filter by relation type + time_window: Optional (start_ts, end_ts) tuple + + Returns: + List of relation dicts + """ + conn = self._get_connection() + cursor = conn.cursor() + + query = """ + SELECT * FROM relations + WHERE (subject_id = ? OR object_id = ?) + """ + params: list[Any] = [entity_id, entity_id] + + if relation_type: + query += " AND relation_type = ?" + params.append(relation_type) + + if time_window: + query += " AND timestamp_s BETWEEN ? AND ?" + params.extend(time_window) + + query += " ORDER BY timestamp_s DESC" + + cursor.execute(query, params) + rows = cursor.fetchall() + + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "subject_id": row["subject_id"], + "object_id": row["object_id"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "evidence": json.loads(row["evidence"]) if row["evidence"] else None, + "notes": row["notes"], + } + for row in rows + ] + + def get_recent_relations(self, limit: int = 50) -> list[dict[str, Any]]: + """Get most recent relations.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM relations + ORDER BY timestamp_s DESC + LIMIT ? + """, + (limit,), + ) + + rows = cursor.fetchall() + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "subject_id": row["subject_id"], + "object_id": row["object_id"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "evidence": json.loads(row["evidence"]) if row["evidence"] else None, + "notes": row["notes"], + } + for row in rows + ] + + # ==================== Distance Operations (Graph 2) ==================== + + def add_distance( + self, + entity_a_id: str, + entity_b_id: str, + distance_meters: float | None, + distance_category: str | None, + confidence: float, + timestamp_s: float, + method: str, + ) -> None: + """ + Add distance measurement between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + distance_meters: Distance in meters (can be None if only categorical) + distance_category: Category (near/medium/far) + confidence: Confidence score + timestamp_s: Timestamp of measurement + method: Method used (vlm, depth_estimation, bbox) + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order to avoid duplicates (store alphabetically) + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + cursor.execute( + """ + INSERT INTO distances (entity_a_id, entity_b_id, distance_meters, distance_category, + confidence, timestamp_s, method) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + entity_a_id, + entity_b_id, + distance_meters, + distance_category, + confidence, + timestamp_s, + method, + ), + ) + + conn.commit() + logger.debug( + f"Added distance: {entity_a_id} <--> {entity_b_id}: {distance_meters}m ({distance_category})" + ) + + def get_distance( + self, + entity_a_id: str, + entity_b_id: str, + latest_only: bool = True, + ) -> dict[str, Any] | None: + """ + Get distance between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + latest_only: If True, return only the most recent measurement + + Returns: + Distance dict or None + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + if latest_only: + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + LIMIT 1 + """, + (entity_a_id, entity_b_id), + ) + else: + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + """, + (entity_a_id, entity_b_id), + ) + + row = cursor.fetchone() + if row is None: + return None + + return { + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "method": row["method"], + } + + def get_nearby_entities( + self, + entity_id: str, + max_distance: float, + latest_only: bool = True, + ) -> list[dict[str, Any]]: + """ + Find entities within a distance threshold. + + Args: + entity_id: Reference entity ID + max_distance: Maximum distance in meters + latest_only: If True, use only latest measurements + + Returns: + List of nearby entities with distances + """ + conn = self._get_connection() + cursor = conn.cursor() + + if latest_only: + # Get latest distance for each pair + query = """ + SELECT d.*, e.entity_type, e.descriptor + FROM distances d + INNER JOIN entities e ON ( + CASE + WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id + WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id + END + ) + WHERE (d.entity_a_id = ? OR d.entity_b_id = ?) + AND d.distance_meters IS NOT NULL + AND d.distance_meters <= ? + AND d.id IN ( + SELECT MAX(id) FROM distances + WHERE (entity_a_id = d.entity_a_id AND entity_b_id = d.entity_b_id) + GROUP BY entity_a_id, entity_b_id + ) + ORDER BY d.distance_meters ASC + """ + cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance)) + else: + query = """ + SELECT d.*, e.entity_type, e.descriptor + FROM distances d + INNER JOIN entities e ON ( + CASE + WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id + WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id + END + ) + WHERE (d.entity_a_id = ? OR d.entity_b_id = ?) + AND d.distance_meters IS NOT NULL + AND d.distance_meters <= ? + ORDER BY d.distance_meters ASC + """ + cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance)) + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_b_id"] + if row["entity_a_id"] == entity_id + else row["entity_a_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + } + for row in rows + ] + + def add_semantic_relation( + self, + relation_type: str, + entity_a_id: str, + entity_b_id: str, + confidence: float, + learned_from: str, + timestamp_s: float, + ) -> None: + """ + Add or update a semantic relation. + + Args: + relation_type: Relation type (goes_with, opposite_of, part_of, used_for) + entity_a_id: First entity ID + entity_b_id: Second entity ID + confidence: Confidence score + learned_from: Source (llm, knowledge_base, observation) + timestamp_s: Timestamp when learned + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order for symmetric relations + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + # Check if relation exists + cursor.execute( + """ + SELECT id, observation_count, confidence FROM semantic_relations + WHERE relation_type = ? AND entity_a_id = ? AND entity_b_id = ? + """, + (relation_type, entity_a_id, entity_b_id), + ) + + existing = cursor.fetchone() + + if existing: + # Update existing relation (increase confidence, increment count) + new_count = existing["observation_count"] + 1 + new_confidence = min( + 1.0, existing["confidence"] + 0.1 + ) # Increase confidence with observations + + cursor.execute( + """ + UPDATE semantic_relations + SET last_observed_ts = ?, + observation_count = ?, + confidence = ? + WHERE id = ? + """, + (timestamp_s, new_count, new_confidence, existing["id"]), + ) + else: + # Insert new relation + cursor.execute( + """ + INSERT INTO semantic_relations + (relation_type, entity_a_id, entity_b_id, confidence, learned_from, + first_observed_ts, last_observed_ts, observation_count) + VALUES (?, ?, ?, ?, ?, ?, ?, 1) + """, + ( + relation_type, + entity_a_id, + entity_b_id, + confidence, + learned_from, + timestamp_s, + timestamp_s, + ), + ) + + conn.commit() + logger.debug(f"Added semantic relation: {entity_a_id} --{relation_type}--> {entity_b_id}") + + def get_semantic_relations( + self, + entity_id: str | None = None, + relation_type: str | None = None, + ) -> list[dict[str, Any]]: + """ + Get semantic relations, optionally filtered. + + Args: + entity_id: Optional filter by entity + relation_type: Optional filter by relation type + + Returns: + List of semantic relation dicts + """ + conn = self._get_connection() + cursor = conn.cursor() + + query = "SELECT * FROM semantic_relations WHERE 1=1" + params: list[Any] = [] + + if entity_id: + query += " AND (entity_a_id = ? OR entity_b_id = ?)" + params.extend([entity_id, entity_id]) + + if relation_type: + query += " AND relation_type = ?" + params.append(relation_type) + + query += " ORDER BY confidence DESC, observation_count DESC" + + cursor.execute(query, params) + rows = cursor.fetchall() + + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "confidence": row["confidence"], + "learned_from": row["learned_from"], + "first_observed_ts": row["first_observed_ts"], + "last_observed_ts": row["last_observed_ts"], + "observation_count": row["observation_count"], + } + for row in rows + ] + + # querying + + def get_entity_neighborhood( + self, + entity_id: str, + max_hops: int = 2, + include_distances: bool = True, + include_semantics: bool = True, + ) -> dict[str, Any]: + """ + Get entity neighborhood (BFS traversal). + + Args: + entity_id: Starting entity ID + max_hops: Maximum number of hops to traverse + include_distances: Include distance graph + include_semantics: Include semantic graph + + Returns: + Dict with entities, relations, distances, and semantics + """ + visited_entities = {entity_id} + current_level = {entity_id} + all_relations = [] + all_distances = [] + all_semantics = [] + + for _ in range(max_hops): + next_level = set() + + for ent_id in current_level: + # Get relations + relations = self.get_relations_for_entity(ent_id) + all_relations.extend(relations) + + for rel in relations: + other_id = ( + rel["object_id"] if rel["subject_id"] == ent_id else rel["subject_id"] + ) + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + # Get distances + if include_distances: + distances = self.get_nearby_entities(ent_id, max_distance=10.0) + all_distances.extend(distances) + for dist in distances: + other_id = dist["entity_id"] + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + # Get semantic relations + if include_semantics: + semantics = self.get_semantic_relations(entity_id=ent_id) + all_semantics.extend(semantics) + for sem in semantics: + other_id = ( + sem["entity_b_id"] + if sem["entity_a_id"] == ent_id + else sem["entity_a_id"] + ) + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + current_level = next_level + if not current_level: + break + + # Get all entity details + entities = [self.get_entity(ent_id) for ent_id in visited_entities] + entities = [e for e in entities if e is not None] + + return { + "center_entity": entity_id, + "entities": entities, + "relations": all_relations, + "distances": all_distances, + "semantic_relations": all_semantics, + "num_hops": max_hops, + } + + def get_stats(self) -> dict[str, Any]: + """Get database statistics.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT COUNT(*) as count FROM entities") + entity_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM relations") + relation_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM distances") + distance_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM semantic_relations") + semantic_count = cursor.fetchone()["count"] + + return { + "entities": entity_count, + "relations": relation_count, + "distances": distance_count, + "semantic_relations": semantic_count, + } + + def close(self) -> None: + """Close database connection.""" + if hasattr(self._local, "conn"): + self._local.conn.close() + del self._local.conn diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 1479b118b5..f77de9193f 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -46,12 +46,18 @@ CLIPFrameFilter, select_diverse_frames_simple, ) +from dimos.perception.entity_graph_db import EntityGraphDB from dimos.perception.videorag_utils import ( apply_summary_update, + build_batch_distance_estimation_prompt, build_query_prompt, build_summary_prompt, build_window_prompt, + default_state, + extract_time_window, + format_timestamp, get_structured_output_format, + parse_batch_distance_response, parse_window_response, update_state_from_window, ) @@ -59,6 +65,126 @@ logger = setup_logger() +# Constants +STALE_SCENE_THRESHOLD = 5.0 # Skip window if scene hasn't changed (perceptual hash distance) +MAX_DISTANCE_PAIRS = 5 # Max entity pairs to estimate distances for per window +MAX_RELATIONS_PER_ENTITY = 10 # Max relations to include in graph context +NEARBY_DISTANCE_METERS = 5.0 # Distance threshold for "nearby" entities +MAX_RECENT_WINDOWS = 50 # Max recent windows to keep in memory + + +# Pure functions +def is_scene_stale(frames: list["Frame"]) -> bool: + """Check if scene hasn't changed meaningfully between first and last frame. + + Args: + frames: List of frames to check + + Returns: + True if scene is stale (hasn't changed enough), False otherwise + """ + if len(frames) < 2: + return False + first_img = frames[0].image + last_img = frames[-1].image + if first_img is None or last_img is None: + return False + diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) + return diff.mean() < STALE_SCENE_THRESHOLD + + +def build_graph_context( + graph_db: "EntityGraphDB", entity_ids: list[str], time_window_s: float | None = None +) -> dict[str, Any]: + """Build enriched context from graph database for given entities. + + Args: + graph_db: Entity graph database instance + entity_ids: List of entity IDs to get context for + time_window_s: Optional time window in seconds (e.g., 3600 for last hour) + + Returns: + Dictionary with graph context including relationships, distances, and semantics + """ + if not graph_db or not entity_ids: + return {} + + try: + graph_context: dict[str, Any] = { + "relationships": [], + "spatial_info": [], + "semantic_knowledge": [], + } + + # Convert time_window_s to a (start_ts, end_ts) tuple if provided + time_window_tuple = None + if time_window_s is not None: + current_time = time.time() + time_window_tuple = (current_time - time_window_s, current_time) + + # Get recent relationships for each entity + for entity_id in entity_ids: + # Get relationships (Graph 1: interactions) + relations = graph_db.get_relations_for_entity( + entity_id=entity_id, + relation_type=None, # all types + time_window=time_window_tuple, + ) + for rel in relations[-MAX_RELATIONS_PER_ENTITY:]: + graph_context["relationships"].append( + { + "subject": rel["subject_id"], + "relation": rel["relation_type"], + "object": rel["object_id"], + "confidence": rel["confidence"], + "when": rel["timestamp_s"], + } + ) + + # Get spatial relationships (Graph 2: distances) + # Find entities nearby this one + nearby = graph_db.get_nearby_entities( + entity_id=entity_id, max_distance=NEARBY_DISTANCE_METERS, latest_only=True + ) + for dist in nearby: + graph_context["spatial_info"].append( + { + "entity_a": dist["entity_a"], + "entity_b": dist["entity_b"], + "distance": dist.get("distance_m"), + "category": dist.get("distance_category"), + "confidence": dist["confidence"], + } + ) + + # Get semantic knowledge (Graph 3: conceptual relations) + semantic_rels = graph_db.get_semantic_relations( + entity_id=entity_id, + relation_type=None, # all types + ) + for sem in semantic_rels: + graph_context["semantic_knowledge"].append( + { + "entity_a": sem["entity_a"], + "relation": sem["relation_type"], + "entity_b": sem["entity_b"], + "confidence": sem["confidence"], + "observations": sem["observation_count"], + } + ) + + # Get graph statistics for context + if entity_ids: + stats = graph_db.get_stats() + graph_context["total_entities"] = stats.get("total_entities", 0) + graph_context["total_relations"] = stats.get("total_relations", 0) + + return graph_context + + except Exception as e: + logger.warning(f"failed to build graph context: {e}") + return {} + @dataclass class Frame: @@ -82,16 +208,6 @@ class TemporalMemoryConfig(ModuleConfig): clip_model: str = "ViT-B/32" -def default_state() -> dict[str, Any]: - return { - "entity_roster": [], - "rolling_summary": "", - "chunk_buffer": [], - "next_summary_at_s": 0.0, - "last_present": [], - } - - class TemporalMemory(SkillModule): """ builds temporal understanding of video streams using vlms. @@ -113,12 +229,13 @@ def __init__( # single lock protects all state self._state_lock = threading.Lock() + self._stopped = False # protected state self._state = default_state() self._state["next_summary_at_s"] = float(self.config.summary_interval_s) self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size) - self._recent_windows: deque[dict[str, Any]] = deque(maxlen=50) + self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) self._frame_count = 0 self._last_analysis_time = 0.0 self._video_start_wall_time: float | None = None @@ -144,13 +261,13 @@ def __init__( self._state_file = self._output_path / "state.json" self._entities_file = self._output_path / "entities.json" self._frames_index_file = self._output_path / "frames_index.jsonl" - logger.info(f"artifacts save to: {self._output_path}") - # else: - # self._output_path = None - # # frames directory for saving images - # self._frames_dir = Path("temporal_memory_frames") - # self._frames_dir.mkdir(parents=True, exist_ok=True) + # Initialize entity graph database + self._graph_db = EntityGraphDB(db_path=self._output_path / "entity_graph.db") + + logger.info(f"artifacts save to: {self._output_path}") + else: + self._graph_db = None logger.info( f"temporalmemory init: fps={self.config.fps}, " @@ -179,6 +296,7 @@ def start(self) -> None: super().start() with self._state_lock: + self._stopped = False if self._video_start_wall_time is None: self._video_start_wall_time = time.time() @@ -197,14 +315,6 @@ def on_frame(image: Image) -> None: ) self._frame_buffer.append(frame) - # Save image to frames directory - # frame_filename = f"frame_{self._frame_count:06d}_{image.frame_id or 'unknown'}.jpg" - # frame_path = self._frames_dir / frame_filename - # try: - # image.save(str(frame_path)) - # except Exception as e: - # logger.warning(f"Failed to save frame {self._frame_count}: {e}") - self._frame_count += 1 # pipe through sharpness filter before buffering @@ -224,10 +334,32 @@ def on_frame(image: Image) -> None: @rpc def stop(self) -> None: - self.save_state() - self.save_entities() + # Save state before clearing (bypass _stopped check by saving directly) + if self.config.output_dir: + try: + with self._state_lock: + state_copy = self._state.copy() + entity_roster = list(self._state.get("entity_roster", [])) + with open(self._state_file, "w") as f: + json.dump(state_copy, f, indent=2, ensure_ascii=False) + logger.info(f"saved state to {self._state_file}") + with open(self._entities_file, "w") as f: + json.dump(entity_roster, f, indent=2, ensure_ascii=False) + logger.info(f"saved {len(entity_roster)} entities") + except Exception as e: + logger.error(f"save failed during stop: {e}", exc_info=True) + self.save_frames_index() + # Set stopped flag and clear state + with self._state_lock: + self._stopped = True + + # Close graph database + if self._graph_db: + self._graph_db.close() + self._graph_db = None + if self._clip_filter: self._clip_filter.close() self._clip_filter = None @@ -248,22 +380,6 @@ def stop(self) -> None: logger.info("temporalmemory stopped") - def _format_timestamp(self, seconds: float) -> str: - m = int(seconds // 60) - s = seconds - 60 * m - return f"{m:02d}:{s:06.3f}" - - def _is_scene_stale(self, frames: list[Frame]) -> bool: - """skip if scene hasn't changed meaningfully""" - if len(frames) < 2: - return False - first_img = frames[0].image - last_img = frames[-1].image - if first_img is None or last_img is None: - return False - diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) - return diff.mean() < 5.0 # tune this threshold - def _analyze_window(self) -> None: try: # get snapshot @@ -281,16 +397,16 @@ def _analyze_window(self) -> None: window_frames = list(self._frame_buffer)[-frames_needed:] state_snapshot = self._state.copy() + w_start = window_frames[0].timestamp_s + w_end = window_frames[-1].timestamp_s + # add this check early, before any filtering or VLM calls - if self._is_scene_stale(window_frames): + if is_scene_stale(window_frames): logger.debug(f"skipping stale window [{w_start:.1f}-{w_end:.1f}s]") with self._state_lock: self._last_analysis_time = w_end return - w_start = window_frames[0].timestamp_s - w_end = window_frames[-1].timestamp_s - # filter frames # NOTE: no longer using clip filter for now (alternative: sharpness barrier and stale scene check) # if len(window_frames) > self.config.max_frames_per_window: @@ -326,10 +442,6 @@ def _analyze_window(self) -> None: ) # query_batch returns list[str] with one response for all images response_text = responses[0] if responses else "" - - # TODO: clear image data from analyzed frames & only keep metadata if the frame_buffer is still too big - for frame in window_frames: - frame.image = None else: # Single frame - use regular query response_text = self.vlm.query( @@ -349,6 +461,14 @@ def _analyze_window(self) -> None: else: logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") + # estimate distances between entities + # note: batched into single VLM call for efficiency + if self._graph_db and len(window_frames) > 0: + # use the middle frame from the window for distance estimation + mid_frame = window_frames[len(window_frames) // 2] + if mid_frame.image: + self._estimate_distances(parsed, mid_frame.image, w_end) + # update state with self._state_lock: needs_summary = update_state_from_window( @@ -357,6 +477,10 @@ def _analyze_window(self) -> None: self._recent_windows.append(parsed) self._last_analysis_time = w_end + # save to graph database + if self._graph_db: + self._save_to_graph_db(parsed, w_end) + # save evidence if self.config.output_dir: self._append_evidence(parsed) @@ -381,6 +505,8 @@ def _update_rolling_summary(self, w_end: float) -> None: try: # get state with self._state_lock: + if self._stopped: + return rolling_summary = str(self._state.get("rolling_summary", "")) chunk_buffer = list(self._state.get("chunk_buffer", [])) if self._frame_buffer: @@ -402,35 +528,115 @@ def _update_rolling_summary(self, w_end: float) -> None: summary_text = self.vlm.query(latest_frame, prompt) if summary_text and summary_text.strip(): with self._state_lock: + if self._stopped: + return apply_summary_update( self._state, summary_text, w_end, self.config.summary_interval_s ) logger.info(f"updated summary: {summary_text[:100]}...") + # Save state after summary update to persist entities + if self.config.output_dir and not self._stopped: + self.save_state() + self.save_entities() except Exception as e: logger.error(f"summary update failed: {e}", exc_info=True) except Exception as e: logger.error(f"error updating summary: {e}", exc_info=True) + def _estimate_distances(self, parsed: dict[str, Any], frame_image, timestamp_s: float) -> None: + """Estimate distances between entities using VLM and save to graph database. + + Batches all entity pairs into a single VLM call for efficiency. + + Args: + parsed: Parsed window data containing entities + frame_image: Frame image to analyze + timestamp_s: Timestamp for the distance observations + """ + if not self._graph_db or not frame_image: + return + + # Collect all entities present in this window + all_entities = [] + for entity in parsed.get("new_entities", []): + all_entities.append(entity) + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + all_entities.append(entity) + + # Need at least 2 entities to estimate distances + if len(all_entities) < 2: + return + + # Generate entity pairs (avoid duplicates by only doing i < j) + entity_pairs = [] + for i in range(len(all_entities)): + for j in range(i + 1, len(all_entities)): + entity_pairs.append((all_entities[i], all_entities[j])) + + # Limit to avoid excessive prompt length + entity_pairs = entity_pairs[:MAX_DISTANCE_PAIRS] + + if not entity_pairs: + return + + try: + # Build batched prompt for all pairs + prompt = build_batch_distance_estimation_prompt(entity_pairs) + + # Single VLM call for all pairs + response = self.vlm.query(frame_image, prompt) + + # Parse all distance estimates + results = parse_batch_distance_response(response, entity_pairs) + + # Save all valid distances to database + for result in results: + if result["category"] in ["near", "medium", "far"]: + self._graph_db.add_distance( + entity_a=result["entity_a_id"], + entity_b=result["entity_b_id"], + distance_m=result.get("distance_m"), + distance_category=result["category"], + confidence=result.get("confidence", 0.5), + timestamp_s=timestamp_s, + method="vlm_estimation_batch", + ) + logger.debug( + f"estimated distance {result['entity_a_id']}-{result['entity_b_id']}: " + f"{result['category']} ({result.get('distance_m')}m)" + ) + + except Exception as e: + logger.warning(f"failed to estimate distances: {e}") + @skill() def query(self, question: str) -> str: - """Answer a question about the video stream using temporal memory. + """Answer a question about the video stream using temporal memory and graph knowledge. This skill analyzes the current video stream and temporal memory state to answer questions about what is happening, what entities are present, - and recent events. + recent events, spatial relationships, and conceptual knowledge. + + The system automatically accesses three knowledge graphs: + - Interactions: relationships between entities (holds, looks_at, talks_to) + - Spatial: distance and proximity information + - Semantic: conceptual relationships (goes_with, used_for, etc.) Example: query("What entities are currently visible?") - query("Do you see a wall in the video stream?") + query("What did I do last week?") + query("Where did I leave my keys?") + query("What objects are near the person?") Args: question (str): The question to ask about the video stream. Examples: "What entities are visible?", "What happened recently?", - "Is there a person in the scene?" + "Is there a person in the scene?", "What am I holding?" Returns: - str: Answer to the question based on temporal memory and current video frame. + str: Answer based on temporal memory, graph knowledge, and current frame. """ # read state with self._state_lock: @@ -443,7 +649,7 @@ def query(self, question: str) -> str: if not latest_frame: return "no frames available" - # build context + # build context from temporal state currently_present = {e["id"] for e in last_present if isinstance(e, dict) and "id" in e} for window in recent_windows[-3:]: for entity in window.get("entities_present", []): @@ -458,6 +664,18 @@ def query(self, question: str) -> str: "timestamp": time.time(), } + # enhance context with graph database knowledge + if self._graph_db and currently_present: + # Extract time window from question using VLM + time_window_s = extract_time_window(question, self.vlm, latest_frame) + + graph_context = build_graph_context( + graph_db=self._graph_db, + entity_ids=list(currently_present), + time_window_s=time_window_s, + ) + context["graph_knowledge"] = graph_context + # build query prompt using videorag utils prompt = build_query_prompt(question=question, context=context) @@ -501,12 +719,36 @@ def get_rolling_summary(self) -> str: with self._state_lock: return str(self._state.get("rolling_summary", "")) + @rpc + def get_graph_db_stats(self) -> dict[str, Any]: + """Get statistics and sample data from the graph database.""" + if not self._graph_db: + logger.warning("graph database not initialized") + return {"stats": {}, "entities": [], "recent_relations": []} + + try: + stats = self._graph_db.get_stats() + all_entities = self._graph_db.get_all_entities() + recent_relations = self._graph_db.get_recent_relations(limit=10) + + return { + "stats": stats, + "entities": all_entities, + "recent_relations": recent_relations[:5], # just show top 5 + } + except Exception as e: + logger.error(f"failed to get graph db stats: {e}", exc_info=True) + return {"stats": {}, "entities": [], "recent_relations": []} + @rpc def save_state(self) -> bool: if not self.config.output_dir: return False try: with self._state_lock: + # Don't save if stopped (state has been cleared) + if self._stopped: + return False state_copy = self._state.copy() with open(self._state_file, "w") as f: json.dump(state_copy, f, indent=2, ensure_ascii=False) @@ -523,11 +765,85 @@ def _append_evidence(self, evidence: dict[str, Any]) -> None: except Exception as e: logger.error(f"append evidence failed: {e}") + def _save_to_graph_db(self, parsed: dict[str, Any], timestamp_s: float) -> None: + """Save parsed window data to the entity graph database.""" + if not self._graph_db: + return + + try: + # Save new entities + for entity in parsed.get("new_entities", []): + self._graph_db.upsert_entity( + entity_id=entity["id"], + entity_type=entity["type"], + descriptor=entity["descriptor"], + timestamp_s=timestamp_s, + ) + + # Save existing entities (update last_seen) + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + # Only update with descriptor if we have one, otherwise pass empty to preserve existing + descriptor = entity.get("descriptor") + if descriptor: + self._graph_db.upsert_entity( + entity_id=entity["id"], + entity_type=entity.get("type", "unknown"), + descriptor=descriptor, + timestamp_s=timestamp_s, + ) + else: + # Just update last_seen without changing descriptor + # Get existing entity to preserve its descriptor + existing = self._graph_db.get_entity(entity["id"]) + if existing: + self._graph_db.upsert_entity( + entity_id=entity["id"], + entity_type=existing["entity_type"], + descriptor=existing["descriptor"], + timestamp_s=timestamp_s, + ) + + # Save relations + for relation in parsed.get("relations", []): + # Parse subject/object (format: "E1|person" or just "E1") + subject_id = ( + relation["subject"].split("|")[0] + if "|" in relation["subject"] + else relation["subject"] + ) + object_id = ( + relation["object"].split("|")[0] + if "|" in relation["object"] + else relation["object"] + ) + + self._graph_db.add_relation( + relation_type=relation["type"], + subject_id=subject_id, + object_id=object_id, + confidence=relation.get("confidence", 1.0), + timestamp_s=timestamp_s, + evidence=relation.get("evidence"), + notes=relation.get("notes"), + ) + + logger.debug( + f"Saved window data to graph DB: {len(parsed.get('new_entities', []))} new entities, " + f"{len(parsed.get('relations', []))} relations" + ) + + except Exception as e: + logger.error(f"Failed to save to graph DB: {e}", exc_info=True) + def save_entities(self) -> bool: if not self.config.output_dir: return False try: with self._state_lock: + # Don't save if stopped (state has been cleared) + if self._stopped: + return False entity_roster = list(self._state.get("entity_roster", [])) with open(self._entities_file, "w") as f: json.dump(entity_roster, f, indent=2, ensure_ascii=False) @@ -548,7 +864,7 @@ def save_frames_index(self) -> bool: { "frame_index": f.frame_index, "timestamp_s": f.timestamp_s, - "timestamp": self._format_timestamp(f.timestamp_s), + "timestamp": format_timestamp(f.timestamp_s), } for f in frames ] @@ -564,19 +880,25 @@ def save_frames_index(self) -> bool: return False +# Backwards compatibility: import deploy from separate module +# from dimos.perception.temporal_memory_deploy import deploy def deploy( dimos: DimosCluster, camera: spec.Camera, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None, ) -> TemporalMemory: - """ - Deploy TemporalMemory with a camera. + """Deploy TemporalMemory with a camera. + + Args: + dimos: Dimos cluster instance + camera: Camera module to connect to + vlm: Optional VLM instance (creates OpenAI VLM if None) + config: Optional temporal memory configuration """ if vlm is None: from dimos.models.vl.openai import OpenAIVlModel - # Load API key from environment api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY environment variable not set") diff --git a/dimos/perception/temporal_memory_example.py b/dimos/perception/temporal_memory_example.py index 59f669e758..a455b7d7c8 100644 --- a/dimos/perception/temporal_memory_example.py +++ b/dimos/perception/temporal_memory_example.py @@ -105,6 +105,22 @@ def example_usage(): for entity in entities: print(f" {entity['id']}: {entity['descriptor']}") + # Check graph database stats + graph_stats = temporal_memory.get_graph_db_stats() + print("\n=== Graph Database Stats ===") + if "error" in graph_stats: + print(f"Error: {graph_stats['error']}") + else: + print(f"Stats: {graph_stats['stats']}") + print(f"\nEntities in DB ({len(graph_stats['entities'])}):") + for entity in graph_stats["entities"]: + print(f" {entity['entity_id']} ({entity['entity_type']}): {entity['descriptor']}") + print(f"\nRecent relations ({len(graph_stats['recent_relations'])}):") + for rel in graph_stats["recent_relations"]: + print( + f" {rel['subject_id']} --{rel['relation_type']}--> {rel['object_id']} (confidence: {rel['confidence']:.2f})" + ) + # Stop when done temporal_memory.stop() camera.stop() diff --git a/dimos/perception/videorag_utils.py b/dimos/perception/videorag_utils.py index 393e79cb3e..ab6df13025 100644 --- a/dimos/perception/videorag_utils.py +++ b/dimos/perception/videorag_utils.py @@ -20,9 +20,26 @@ """ import json +import re from typing import Any +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image from dimos.utils.llm_utils import extract_json +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +def default_state() -> dict[str, Any]: + """Create default temporal memory state dictionary.""" + return { + "entity_roster": [], + "rolling_summary": "", + "chunk_buffer": [], + "next_summary_at_s": 0.0, + "last_present": [], + } def next_entity_id_hint(roster: Any) -> str: @@ -48,6 +65,13 @@ def clamp_text(text: str, max_chars: int) -> str: return text[:max_chars] + "..." +def format_timestamp(seconds: float) -> str: + """Format seconds as MM:SS.mmm timestamp string.""" + m = int(seconds // 60) + s = seconds - 60 * m + return f"{m:02d}:{s:06.3f}" + + def build_window_prompt( *, w_start: float, @@ -249,6 +273,13 @@ def build_query_prompt( Returns: Formatted prompt string """ + currently_present = context.get("currently_present_entities", []) + currently_present_str = ( + f"Entities recently detected in recent windows: {currently_present}" + if currently_present + else "No entities were detected in recent windows (list is empty)" + ) + prompt = f"""Answer the following question about the video stream using the provided context. **Question:** {question} @@ -256,18 +287,252 @@ def build_query_prompt( **Context:** {json.dumps(context, indent=2, ensure_ascii=False)} -**Instructions:** +**Important Notes:** - Entities have stable IDs like E1, E2, etc. -- The 'currently_present_entities' list shows which entities are visible now -- If an entity is NOT in 'currently_present_entities', it is no longer visible -- Answer based ONLY on the provided context -- If information isn't available, say so clearly +- The 'currently_present_entities' list contains entity IDs that were detected in recent video windows (not necessarily in the current frame you're viewing) +- {currently_present_str} +- The 'entity_roster' contains all known entities with their descriptions +- The 'rolling_summary' describes what has happened over time +- If 'currently_present_entities' is empty, it means no entities were detected in recent windows, but entities may still exist in the roster from earlier +- Answer based on the provided context (entity_roster, rolling_summary, currently_present_entities) AND what you see in the current frame +- If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see Provide a concise answer. """ return prompt +def extract_time_window( + question: str, + vlm: VlModel, + latest_frame: Image | None = None, +) -> float | None: + """Extract time window from question using VLM with example-based learning. + + Uses a few example keywords as patterns, then asks VLM to extrapolate + similar time references and return seconds. + + Args: + question: User's question + vlm: VLM instance to use for extraction + latest_frame: Optional frame (required for VLM call, but image is ignored) + + Returns: + Time window in seconds, or None if no time reference found + """ + question_lower = question.lower() + + # Quick check for common patterns (fast path) + if "last week" in question_lower or "past week" in question_lower: + return 7 * 24 * 3600 + if "today" in question_lower or "last hour" in question_lower: + return 3600 + if "recently" in question_lower or "recent" in question_lower: + return 600 + + # Use VLM to extract time reference from question + # Provide examples and let VLM extrapolate similar patterns + # Note: latest_frame is required by VLM interface but image content is ignored + if not latest_frame: + return None + + extraction_prompt = f"""Extract any time reference from this question and convert it to seconds. + +Question: {question} + +Examples of time references and their conversions: +- "last week" or "past week" -> 604800 seconds (7 days) +- "yesterday" -> 86400 seconds (1 day) +- "today" or "last hour" -> 3600 seconds (1 hour) +- "recently" or "recent" -> 600 seconds (10 minutes) +- "few minutes ago" -> 300 seconds (5 minutes) +- "just now" -> 60 seconds (1 minute) + +Extrapolate similar patterns (e.g., "2 days ago", "this morning", "last month", etc.) +and convert to seconds. If no time reference is found, return "none". + +Return ONLY a number (seconds) or the word "none". Do not include any explanation.""" + + try: + response = vlm.query(latest_frame, extraction_prompt) + response = response.strip().lower() + + if "none" in response or not response: + return None + + # Extract number from response + numbers = re.findall(r"\d+(?:\.\d+)?", response) + if numbers: + seconds = float(numbers[0]) + # Sanity check: reasonable time windows (1 second to 1 year) + if 1 <= seconds <= 365 * 24 * 3600: + return seconds + except Exception as e: + logger.debug(f"Time extraction failed: {e}") + + return None + + +def build_distance_estimation_prompt( + *, + entity_a_descriptor: str, + entity_a_id: str, + entity_b_descriptor: str, + entity_b_id: str, +) -> str: + """ + Build prompt for estimating distance between two entities. + + Args: + entity_a_descriptor: Description of first entity + entity_a_id: ID of first entity + entity_b_descriptor: Description of second entity + entity_b_id: ID of second entity + + Returns: + Formatted prompt string for distance estimation + """ + prompt = f"""Look at this image and estimate the distance between these two entities: + +Entity A: {entity_a_descriptor} (ID: {entity_a_id}) +Entity B: {entity_b_descriptor} (ID: {entity_b_id}) + +Provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] +reasoning: [brief explanation]""" + return prompt + + +def build_batch_distance_estimation_prompt( + entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]], +) -> str: + """ + Build prompt for estimating distances between multiple entity pairs in one call. + + Args: + entity_pairs: List of (entity_a, entity_b) tuples, each entity is a dict with 'id' and 'descriptor' + + Returns: + Formatted prompt string for batched distance estimation + """ + pairs_text = [] + for i, (entity_a, entity_b) in enumerate(entity_pairs, 1): + pairs_text.append( + f"Pair {i}:\n" + f" Entity A: {entity_a['descriptor']} (ID: {entity_a['id']})\n" + f" Entity B: {entity_b['descriptor']} (ID: {entity_b['id']})" + ) + + prompt = f"""Look at this image and estimate the distances between the following entity pairs: + +{chr(10).join(pairs_text)} + +For each pair, provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format (one block per pair): +Pair 1: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +Pair 2: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +(etc.)""" + return prompt + + +def parse_batch_distance_response( + response: str, entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]] +) -> list[dict[str, Any]]: + """ + Parse batched distance estimation response. + + Args: + response: VLM response text + entity_pairs: Original entity pairs used in the prompt + + Returns: + List of dicts with keys: entity_a_id, entity_b_id, category, distance_m, confidence + """ + results = [] + lines = response.strip().split("\n") + + current_pair_idx = None + category = None + distance_m = None + confidence = 0.5 + + for line in lines: + line = line.strip() + + # Check for pair marker + if line.startswith("Pair "): + # Save previous pair if exists + if current_pair_idx is not None and category: + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + # Start new pair + try: + pair_num = int(line.split()[1].rstrip(":")) + current_pair_idx = pair_num - 1 # Convert to 0-indexed + category = None + distance_m = None + confidence = 0.5 + except (IndexError, ValueError): + continue + + # Parse distance fields + elif line.startswith("category:"): + category = line.split(":", 1)[1].strip().lower() + elif line.startswith("distance_m:"): + try: + distance_m = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + elif line.startswith("confidence:"): + try: + confidence = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + + # Save last pair + if current_pair_idx is not None and category and current_pair_idx < len(entity_pairs): + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + return results + + def parse_window_response( response_text: str, w_start: float, w_end: float, frame_count: int ) -> dict[str, Any]: @@ -446,10 +711,14 @@ def get_structured_output_format() -> dict[str, Any]: __all__ = [ "WINDOW_RESPONSE_SCHEMA", "apply_summary_update", + "build_distance_estimation_prompt", "build_query_prompt", "build_summary_prompt", "build_window_prompt", "clamp_text", + "default_state", + "extract_time_window", + "format_timestamp", "get_structured_output_format", "next_entity_id_hint", "parse_window_response", From 3a7003942e9cb46e523532ae28d637700c140845 Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Tue, 13 Jan 2026 18:44:36 -0800 Subject: [PATCH 11/21] db implementation, working and stylized, best reply is unitree_go2_office_walk2 --- dimos/models/vl/qwen.py | 6 +- dimos/perception/clip_filter.py | 65 +- dimos/perception/entity_graph_db.py | 279 ++++++- dimos/perception/temporal_memory.py | 669 +++++----------- dimos/perception/temporal_memory_deploy.py | 59 ++ dimos/perception/temporal_memory_example.py | 12 +- dimos/perception/temporal_utils/__init__.py | 60 ++ .../perception/temporal_utils/graph_utils.py | 206 +++++ dimos/perception/temporal_utils/helpers.py | 72 ++ dimos/perception/temporal_utils/parsers.py | 156 ++++ dimos/perception/temporal_utils/prompts.py | 353 +++++++++ dimos/perception/temporal_utils/state.py | 139 ++++ dimos/perception/videorag_utils.py | 726 ------------------ dimos/robot/unitree/connection/go2.py | 3 +- 14 files changed, 1557 insertions(+), 1248 deletions(-) create mode 100644 dimos/perception/temporal_memory_deploy.py create mode 100644 dimos/perception/temporal_utils/__init__.py create mode 100644 dimos/perception/temporal_utils/graph_utils.py create mode 100644 dimos/perception/temporal_utils/helpers.py create mode 100644 dimos/perception/temporal_utils/parsers.py create mode 100644 dimos/perception/temporal_utils/prompts.py create mode 100644 dimos/perception/temporal_utils/state.py delete mode 100644 dimos/perception/videorag_utils.py diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index c508b8cc72..2b3808211b 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -88,10 +88,10 @@ def query_batch( return [] content: list[dict[str, Any]] = [ - { - "type": "image_url", + { + "type": "image_url", "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, - } + } for img in images ] content.append({"type": "text", "text": query}) diff --git a/dimos/perception/clip_filter.py b/dimos/perception/clip_filter.py index 24e43df1e2..34f9da3912 100644 --- a/dimos/perception/clip_filter.py +++ b/dimos/perception/clip_filter.py @@ -22,6 +22,8 @@ import logging from typing import Any +import numpy as np + from dimos.msgs.sensor_msgs import Image from dimos.utils.logging_config import setup_logger @@ -175,4 +177,65 @@ def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list return [frames[i] for i in indices] -__all__ = ["CLIP_AVAILABLE", "CLIPFrameFilter", "select_diverse_frames_simple"] +def adaptive_keyframes( + frames: list, + min_frames: int = 3, + max_frames: int = 5, + change_threshold: float = 15.0, +) -> list: + """select frames based on visual change, adaptive count.""" + if len(frames) <= min_frames: + return frames + + # compute frame-to-frame differences + diffs = [] + for i in range(1, len(frames)): + prev = frames[i - 1].image.data.astype(float) + curr = frames[i].image.data.astype(float) + diffs.append(np.abs(curr - prev).mean()) + + total_motion = sum(diffs) + + # adaptive N: more motion → more frames + n_frames = int(np.clip(total_motion / change_threshold, min_frames, max_frames)) + + # pick frames at change peaks (local maxima) + # always include first and last + keyframe_indices = [0, len(frames) - 1] # always + + # find peaks in diff signal + for i in range(1, len(diffs) - 1): + if ( + diffs[i] > diffs[i - 1] + and diffs[i] > diffs[i + 1] + and diffs[i] > change_threshold * 0.5 + ): + keyframe_indices.append(i + 1) # +1 bc diff[i] is between frame i and i+1 + + # if too many peaks, subsample; if too few, add uniform samples + if len(keyframe_indices) > n_frames: + # keep first, last, and highest-diff peaks + middle_indices = [i for i in keyframe_indices if i not in (0, len(frames) - 1)] + middle_diffs = [diffs[i - 1] for i in middle_indices] + sorted_by_diff = sorted(zip(middle_diffs, middle_indices, strict=False), reverse=True) + keep = [idx for _, idx in sorted_by_diff[: n_frames - 2]] + keyframe_indices = sorted([0, *keep, len(frames) - 1]) + elif len(keyframe_indices) < n_frames: + # fill in uniformly from remaining candidates + needed = n_frames - len(keyframe_indices) + candidates = sorted(set(range(len(frames))) - set(keyframe_indices)) + if candidates: + # Calculate step, ensuring it's at least 1 + step = max(1, len(candidates) // (needed + 1)) + uniform_fill = candidates[::step][:needed] + keyframe_indices = sorted(set(keyframe_indices) | set(uniform_fill)) + + return [frames[i] for i in keyframe_indices] + + +__all__ = [ + "CLIP_AVAILABLE", + "CLIPFrameFilter", + "adaptive_keyframes", + "select_diverse_frames_simple", +] diff --git a/dimos/perception/entity_graph_db.py b/dimos/perception/entity_graph_db.py index 41fde5554a..414a2d66b7 100644 --- a/dimos/perception/entity_graph_db.py +++ b/dimos/perception/entity_graph_db.py @@ -27,10 +27,14 @@ from pathlib import Path import sqlite3 import threading -from typing import Any +from typing import TYPE_CHECKING, Any from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + from dimos.msgs.sensor_msgs import Image + logger = setup_logger() @@ -83,6 +87,12 @@ def _init_schema(self) -> None: metadata TEXT ) """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_entities_first_seen ON entities(first_seen_ts)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_entities_last_seen ON entities(last_seen_ts)" + ) # Relations table (Graph 1: Interactions) cursor.execute(""" @@ -218,9 +228,7 @@ def get_entity(self, entity_id: str) -> dict[str, Any] | None: } def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any]]: - """ - Get all entities, optionally filtered by type. - """ + """Get all entities, optionally filtered by type.""" conn = self._get_connection() cursor = conn.cursor() @@ -245,6 +253,42 @@ def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any for row in rows ] + def get_entities_by_time( + self, + time_window: tuple[float, float], + first_seen: bool = True, + ) -> list[dict[str, Any]]: + """Get entities first/last seen within a time window. + + Args: + time_window: (start_ts, end_ts) tuple in seconds + first_seen: If True, filter by first_seen_ts. If False, filter by last_seen_ts. + + Returns: + List of entities seen within the time window + """ + conn = self._get_connection() + cursor = conn.cursor() + + ts_field = "first_seen_ts" if first_seen else "last_seen_ts" + cursor.execute( + f"SELECT * FROM entities WHERE {ts_field} BETWEEN ? AND ? ORDER BY {ts_field} DESC", + time_window, + ) + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + for row in rows + ] + def add_relation( self, relation_type: str, @@ -422,15 +466,12 @@ def get_distance( self, entity_a_id: str, entity_b_id: str, - latest_only: bool = True, ) -> dict[str, Any] | None: - """ - Get distance between two entities. + """Get most recent distance between two entities. Args: entity_a_id: First entity ID entity_b_id: Second entity ID - latest_only: If True, return only the most recent measurement Returns: Distance dict or None @@ -442,25 +483,15 @@ def get_distance( if entity_a_id > entity_b_id: entity_a_id, entity_b_id = entity_b_id, entity_a_id - if latest_only: - cursor.execute( - """ - SELECT * FROM distances - WHERE entity_a_id = ? AND entity_b_id = ? - ORDER BY timestamp_s DESC - LIMIT 1 - """, - (entity_a_id, entity_b_id), - ) - else: - cursor.execute( - """ - SELECT * FROM distances - WHERE entity_a_id = ? AND entity_b_id = ? - ORDER BY timestamp_s DESC - """, - (entity_a_id, entity_b_id), - ) + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + LIMIT 1 + """, + (entity_a_id, entity_b_id), + ) row = cursor.fetchone() if row is None: @@ -476,6 +507,49 @@ def get_distance( "method": row["method"], } + def get_distance_history( + self, + entity_a_id: str, + entity_b_id: str, + ) -> list[dict[str, Any]]: + """Get all distance measurements between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + + Returns: + List of distance dicts, most recent first + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + """, + (entity_a_id, entity_b_id), + ) + + return [ + { + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "method": row["method"], + } + for row in cursor.fetchall() + ] + def get_nearby_entities( self, entity_id: str, @@ -786,6 +860,157 @@ def get_stats(self) -> dict[str, Any]: "semantic_relations": semantic_count, } + def get_summary(self, recent_relations_limit: int = 5) -> dict[str, Any]: + """Get stats, all entities, and recent relations.""" + return { + "stats": self.get_stats(), + "entities": self.get_all_entities(), + "recent_relations": self.get_recent_relations(limit=recent_relations_limit), + } + + def save_window_data(self, parsed: dict[str, Any], timestamp_s: float) -> None: + """Save parsed window data (entities and relations) to the graph database.""" + try: + # Save new entities + for entity in parsed.get("new_entities", []): + self.upsert_entity( + entity_id=entity["id"], + entity_type=entity["type"], + descriptor=entity.get("descriptor", "unknown"), + timestamp_s=timestamp_s, + ) + + # Save existing entities (update last_seen) + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + descriptor = entity.get("descriptor") + if descriptor: + self.upsert_entity( + entity_id=entity["id"], + entity_type=entity.get("type", "unknown"), + descriptor=descriptor, + timestamp_s=timestamp_s, + ) + else: + existing = self.get_entity(entity["id"]) + if existing: + self.upsert_entity( + entity_id=entity["id"], + entity_type=existing["entity_type"], + descriptor=existing["descriptor"], + timestamp_s=timestamp_s, + ) + + # Save relations + for relation in parsed.get("relations", []): + subject_id = ( + relation["subject"].split("|")[0] + if "|" in relation["subject"] + else relation["subject"] + ) + object_id = ( + relation["object"].split("|")[0] + if "|" in relation["object"] + else relation["object"] + ) + + self.add_relation( + relation_type=relation["type"], + subject_id=subject_id, + object_id=object_id, + confidence=relation.get("confidence", 1.0), + timestamp_s=timestamp_s, + evidence=relation.get("evidence"), + notes=relation.get("notes"), + ) + + except Exception as e: + logger.error(f"Failed to save window data to graph DB: {e}", exc_info=True) + + def estimate_and_save_distances( + self, + parsed: dict[str, Any], + frame_image: "Image", + vlm: "VlModel", + timestamp_s: float, + max_distance_pairs: int = 5, + ) -> None: + """Estimate distances between entities using VLM and save to database. + + Args: + parsed: Parsed window data containing entities + frame_image: Frame image to analyze + vlm: VLM instance for distance estimation + timestamp_s: Timestamp for the distance measurements + max_distance_pairs: Maximum number of entity pairs to estimate + """ + if not frame_image: + return + + # Import here to avoid circular dependency + from dimos.perception import temporal_utils as tu + + # Collect entities with descriptors + # new_entities have descriptors from VLM + enriched_entities = [] + for entity in parsed.get("new_entities", []): + if isinstance(entity, dict) and "id" in entity: + enriched_entities.append( + {"id": entity["id"], "descriptor": entity.get("descriptor", "unknown")} + ) + + # entities_present only have IDs - need to fetch descriptors from DB + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + entity_id = entity["id"] + # Fetch descriptor from DB + db_entity = self.get_entity(entity_id) + if db_entity: + enriched_entities.append( + {"id": entity_id, "descriptor": db_entity.get("descriptor", "unknown")} + ) + + if len(enriched_entities) < 2: + return + + # Generate pairs without existing distances + pairs = [ + (enriched_entities[i], enriched_entities[j]) + for i in range(len(enriched_entities)) + for j in range(i + 1, len(enriched_entities)) + if not self.get_distance(enriched_entities[i]["id"], enriched_entities[j]["id"]) + ][:max_distance_pairs] + + if not pairs: + return + + try: + response = vlm.query(frame_image, tu.build_batch_distance_estimation_prompt(pairs)) + for r in tu.parse_batch_distance_response(response, pairs): + if r["category"] in ("near", "medium", "far"): + self.add_distance( + entity_a_id=r["entity_a_id"], + entity_b_id=r["entity_b_id"], + distance_meters=r.get("distance_m"), + distance_category=r["category"], + confidence=r.get("confidence", 0.5), + timestamp_s=timestamp_s, + method="vlm", + ) + except Exception as e: + logger.warning(f"Failed to estimate distances: {e}", exc_info=True) + + def commit(self) -> None: + """Commit all pending transactions and ensure data is flushed to disk.""" + if hasattr(self._local, "conn"): + conn = self._local.conn + conn.commit() + # Force checkpoint to ensure WAL data is written to main database file + try: + conn.execute("PRAGMA wal_checkpoint(FULL)") + except Exception: + pass # Ignore if WAL is not enabled + def close(self) -> None: """Close database connection.""" if hasattr(self._local, "conn"): diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index f77de9193f..2534854102 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -29,163 +29,32 @@ import time from typing import Any -import numpy as np from reactivex import Subject, interval from reactivex.disposable import Disposable -from dimos import spec from dimos.agents import skill -from dimos.core import DimosCluster, In, rpc +from dimos.core import In, rpc from dimos.core.module import ModuleConfig from dimos.core.skill_module import SkillModule from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier +from dimos.perception import temporal_utils as tu from dimos.perception.clip_filter import ( CLIP_AVAILABLE, CLIPFrameFilter, + adaptive_keyframes, select_diverse_frames_simple, ) from dimos.perception.entity_graph_db import EntityGraphDB -from dimos.perception.videorag_utils import ( - apply_summary_update, - build_batch_distance_estimation_prompt, - build_query_prompt, - build_summary_prompt, - build_window_prompt, - default_state, - extract_time_window, - format_timestamp, - get_structured_output_format, - parse_batch_distance_response, - parse_window_response, - update_state_from_window, -) from dimos.utils.logging_config import setup_logger logger = setup_logger() # Constants -STALE_SCENE_THRESHOLD = 5.0 # Skip window if scene hasn't changed (perceptual hash distance) -MAX_DISTANCE_PAIRS = 5 # Max entity pairs to estimate distances for per window -MAX_RELATIONS_PER_ENTITY = 10 # Max relations to include in graph context -NEARBY_DISTANCE_METERS = 5.0 # Distance threshold for "nearby" entities MAX_RECENT_WINDOWS = 50 # Max recent windows to keep in memory -# Pure functions -def is_scene_stale(frames: list["Frame"]) -> bool: - """Check if scene hasn't changed meaningfully between first and last frame. - - Args: - frames: List of frames to check - - Returns: - True if scene is stale (hasn't changed enough), False otherwise - """ - if len(frames) < 2: - return False - first_img = frames[0].image - last_img = frames[-1].image - if first_img is None or last_img is None: - return False - diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) - return diff.mean() < STALE_SCENE_THRESHOLD - - -def build_graph_context( - graph_db: "EntityGraphDB", entity_ids: list[str], time_window_s: float | None = None -) -> dict[str, Any]: - """Build enriched context from graph database for given entities. - - Args: - graph_db: Entity graph database instance - entity_ids: List of entity IDs to get context for - time_window_s: Optional time window in seconds (e.g., 3600 for last hour) - - Returns: - Dictionary with graph context including relationships, distances, and semantics - """ - if not graph_db or not entity_ids: - return {} - - try: - graph_context: dict[str, Any] = { - "relationships": [], - "spatial_info": [], - "semantic_knowledge": [], - } - - # Convert time_window_s to a (start_ts, end_ts) tuple if provided - time_window_tuple = None - if time_window_s is not None: - current_time = time.time() - time_window_tuple = (current_time - time_window_s, current_time) - - # Get recent relationships for each entity - for entity_id in entity_ids: - # Get relationships (Graph 1: interactions) - relations = graph_db.get_relations_for_entity( - entity_id=entity_id, - relation_type=None, # all types - time_window=time_window_tuple, - ) - for rel in relations[-MAX_RELATIONS_PER_ENTITY:]: - graph_context["relationships"].append( - { - "subject": rel["subject_id"], - "relation": rel["relation_type"], - "object": rel["object_id"], - "confidence": rel["confidence"], - "when": rel["timestamp_s"], - } - ) - - # Get spatial relationships (Graph 2: distances) - # Find entities nearby this one - nearby = graph_db.get_nearby_entities( - entity_id=entity_id, max_distance=NEARBY_DISTANCE_METERS, latest_only=True - ) - for dist in nearby: - graph_context["spatial_info"].append( - { - "entity_a": dist["entity_a"], - "entity_b": dist["entity_b"], - "distance": dist.get("distance_m"), - "category": dist.get("distance_category"), - "confidence": dist["confidence"], - } - ) - - # Get semantic knowledge (Graph 3: conceptual relations) - semantic_rels = graph_db.get_semantic_relations( - entity_id=entity_id, - relation_type=None, # all types - ) - for sem in semantic_rels: - graph_context["semantic_knowledge"].append( - { - "entity_a": sem["entity_a"], - "relation": sem["relation_type"], - "entity_b": sem["entity_b"], - "confidence": sem["confidence"], - "observations": sem["observation_count"], - } - ) - - # Get graph statistics for context - if entity_ids: - stats = graph_db.get_stats() - graph_context["total_entities"] = stats.get("total_entities", 0) - graph_context["total_relations"] = stats.get("total_relations", 0) - - return graph_context - - except Exception as e: - logger.warning(f"failed to build graph context: {e}") - return {} - - @dataclass class Frame: frame_index: int @@ -195,17 +64,35 @@ class Frame: @dataclass class TemporalMemoryConfig(ModuleConfig): + # Frame processing fps: float = 1.0 window_s: float = 2.0 stride_s: float = 2.0 summary_interval_s: float = 10.0 max_frames_per_window: int = 3 frame_buffer_size: int = 50 + + # Output output_dir: str | Path | None = None + + # VLM parameters max_tokens: int = 900 temperature: float = 0.2 + + # Frame filtering use_clip_filtering: bool = True clip_model: str = "ViT-B/32" + stale_scene_threshold: float = 5.0 + + # Graph database + persistent_memory: bool = True # Keep graph across sessions + clear_memory_on_start: bool = False # Wipe DB on startup + enable_distance_estimation: bool = True # Estimate entity distances + max_distance_pairs: int = 5 # Max entity pairs per window + + # Graph context + max_relations_per_entity: int = 10 # Max relations in query context + nearby_distance_meters: float = 5.0 # "Nearby" threshold class TemporalMemory(SkillModule): @@ -232,12 +119,12 @@ def __init__( self._stopped = False # protected state - self._state = default_state() + self._state = tu.default_state() self._state["next_summary_at_s"] = float(self.config.summary_interval_s) self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size) self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) self._frame_count = 0 - self._last_analysis_time = 0.0 + self._last_analysis_time = -float("inf") # Allow first analysis immediately self._video_start_wall_time: float | None = None # clip filter @@ -254,6 +141,7 @@ def __init__( self.config.use_clip_filtering = False # output directory + self._graph_db: EntityGraphDB | None if self.config.output_dir: self._output_path = Path(self.config.output_dir) self._output_path.mkdir(parents=True, exist_ok=True) @@ -303,6 +191,8 @@ def start(self) -> None: def on_frame(image: Image) -> None: with self._state_lock: video_start = self._video_start_wall_time + if video_start is None: + return # Not started yet if image.ts is not None: timestamp_s = image.ts - video_start else: @@ -314,11 +204,10 @@ def on_frame(image: Image) -> None: image=image, ) self._frame_buffer.append(frame) - self._frame_count += 1 # pipe through sharpness filter before buffering - frame_subject = Subject() + frame_subject: Subject[Image] = Subject() self._disposables.add( frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(on_frame) ) @@ -326,6 +215,7 @@ def on_frame(image: Image) -> None: unsub_image = self.color_image.subscribe(frame_subject.on_next) self._disposables.add(Disposable(unsub_image)) + # Schedule window analysis every stride_s seconds self._disposables.add( interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) ) @@ -355,8 +245,9 @@ def stop(self) -> None: with self._state_lock: self._stopped = True - # Close graph database + # Save and close graph database if self._graph_db: + self._graph_db.commit() # save all pending transactions self._graph_db.close() self._graph_db = None @@ -368,7 +259,7 @@ def stop(self) -> None: with self._state_lock: self._frame_buffer.clear() self._recent_windows.clear() - self._state = default_state() + self._state = tu.default_state() super().stop() @@ -380,236 +271,151 @@ def stop(self) -> None: logger.info("temporalmemory stopped") - def _analyze_window(self) -> None: + def _get_window_frames(self) -> tuple[list[Frame], dict[str, Any]] | None: + """Extract window frames from buffer with guards.""" + with self._state_lock: + if not self._frame_buffer: + return None + current_time = self._frame_buffer[-1].timestamp_s + if current_time - self._last_analysis_time < self.config.stride_s: + return None + frames_needed = max(1, int(self.config.fps * self.config.window_s)) + if len(self._frame_buffer) < frames_needed: + return None + window_frames = list(self._frame_buffer)[-frames_needed:] + state_snapshot = self._state.copy() + return window_frames, state_snapshot + + def _query_vlm_for_window( + self, + window_frames: list[Frame], + state_snapshot: dict[str, Any], + w_start: float, + w_end: float, + ) -> str | None: + """Query VLM for window analysis.""" + query = tu.build_window_prompt( + w_start=w_start, w_end=w_end, frame_count=len(window_frames), state=state_snapshot + ) try: - # get snapshot - with self._state_lock: - if not self._frame_buffer: - return - current_time = self._frame_buffer[-1].timestamp_s - if current_time - self._last_analysis_time < self.config.stride_s: - return - - frames_needed = max(1, int(self.config.fps * self.config.window_s)) - if len(self._frame_buffer) < frames_needed: - return + fmt = tu.get_structured_output_format() + if len(window_frames) > 1: + responses = self.vlm.query_batch( + [f.image for f in window_frames], query, response_format=fmt + ) + return responses[0] if responses else "" + else: + return self.vlm.query(window_frames[0].image, query, response_format=fmt) + except Exception as e: + logger.error(f"vlm query failed [{w_start:.1f}-{w_end:.1f}s]: {e}", exc_info=True) + return None - window_frames = list(self._frame_buffer)[-frames_needed:] - state_snapshot = self._state.copy() + def _save_window_artifacts(self, parsed: dict[str, Any], w_end: float) -> None: + """Save window data to graph DB and evidence file.""" + if self._graph_db: + self._graph_db.save_window_data(parsed, w_end) + if self.config.output_dir: + self._append_evidence(parsed) - w_start = window_frames[0].timestamp_s - w_end = window_frames[-1].timestamp_s + def _analyze_window(self) -> None: + """Analyze a temporal window of frames using VLM.""" + # Extract window frames with guards + result = self._get_window_frames() + if result is None: + return + window_frames, state_snapshot = result + w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s - # add this check early, before any filtering or VLM calls - if is_scene_stale(window_frames): - logger.debug(f"skipping stale window [{w_start:.1f}-{w_end:.1f}s]") - with self._state_lock: - self._last_analysis_time = w_end - return + # Skip if scene hasn't changed + if tu.is_scene_stale(window_frames, self.config.stale_scene_threshold): + with self._state_lock: + self._last_analysis_time = w_end + return - # filter frames - # NOTE: no longer using clip filter for now (alternative: sharpness barrier and stale scene check) - # if len(window_frames) > self.config.max_frames_per_window: - # if self._clip_filter: - # window_frames = self._clip_filter.select_diverse_frames( - # window_frames, max_frames=self.config.max_frames_per_window - # ) - # else: - window_frames = select_diverse_frames_simple( + # Select diverse frames for analysis + window_frames = ( + adaptive_keyframes( # TODO: unclear if clip vs. diverse vs. this solution is best window_frames, max_frames=self.config.max_frames_per_window ) + ) + logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") - logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") - - # build prompt - query = build_window_prompt( - w_start=w_start, - w_end=w_end, - frame_count=len(window_frames), - state=state_snapshot, - ) - - # query vlm (slow, outside lock) - # use query_batch for multiple frames to send all filtered frames in one API call - try: - response_format = get_structured_output_format() - if len(window_frames) > 1: - # Use query_batch to send all filtered frames in one API call - # This gives the model more temporal context - frame_images = [frame.image for frame in window_frames] - responses = self.vlm.query_batch( - frame_images, query, response_format=response_format - ) - # query_batch returns list[str] with one response for all images - response_text = responses[0] if responses else "" - else: - # Single frame - use regular query - response_text = self.vlm.query( - window_frames[0].image, query, response_format=response_format - ) - except Exception as e: - logger.error(f"vlm query failed [{w_start:.1f}-{w_end:.1f}s]: {e}") - with self._state_lock: - self._last_analysis_time = w_end - return - - # parse response - parsed = parse_window_response(response_text, w_start, w_end, len(window_frames)) - - if "_error" in parsed: - logger.error(f"parse error: {parsed['_error']}") - else: - logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") - - # estimate distances between entities - # note: batched into single VLM call for efficiency - if self._graph_db and len(window_frames) > 0: - # use the middle frame from the window for distance estimation - mid_frame = window_frames[len(window_frames) // 2] - if mid_frame.image: - self._estimate_distances(parsed, mid_frame.image, w_end) - - # update state + # Query VLM and parse response + response_text = self._query_vlm_for_window(window_frames, state_snapshot, w_start, w_end) + if response_text is None: with self._state_lock: - needs_summary = update_state_from_window( - self._state, parsed, w_end, self.config.summary_interval_s - ) - self._recent_windows.append(parsed) self._last_analysis_time = w_end + return - # save to graph database - if self._graph_db: - self._save_to_graph_db(parsed, w_end) - - # save evidence - if self.config.output_dir: - self._append_evidence(parsed) - - # update summary if needed - if needs_summary: - logger.info(f"updating summary at t≈{w_end:.1f}s") - self._update_rolling_summary(w_end) + parsed = tu.parse_window_response(response_text, w_start, w_end, len(window_frames)) + if "_error" in parsed: + logger.error(f"parse error: {parsed['_error']}") + else: + logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") + + # Start distance estimation in background + if self._graph_db and window_frames and self.config.enable_distance_estimation: + mid_frame = window_frames[len(window_frames) // 2] + if mid_frame.image: + threading.Thread( + target=self._graph_db.estimate_and_save_distances, + args=(parsed, mid_frame.image, self.vlm, w_end, self.config.max_distance_pairs), + daemon=True, + ).start() + + # Update temporal state + with self._state_lock: + needs_summary = tu.update_state_from_window( + self._state, parsed, w_end, self.config.summary_interval_s + ) + self._recent_windows.append(parsed) + self._last_analysis_time = w_end - # periodic save - with self._state_lock: - window_count = len(self._recent_windows) + # Save artifacts + self._save_window_artifacts(parsed, w_end) - if window_count % 10 == 0: - self.save_state() - self.save_entities() + # Trigger summary update if needed + if needs_summary: + logger.info(f"updating summary at t≈{w_end:.1f}s") + self._update_rolling_summary(w_end) - except Exception as e: - logger.error(f"error analyzing window: {e}", exc_info=True) + # Periodic state saves + with self._state_lock: + window_count = len(self._recent_windows) + if window_count % 10 == 0: + self.save_state() + self.save_entities() def _update_rolling_summary(self, w_end: float) -> None: - try: - # get state - with self._state_lock: - if self._stopped: - return - rolling_summary = str(self._state.get("rolling_summary", "")) - chunk_buffer = list(self._state.get("chunk_buffer", [])) - if self._frame_buffer: - latest_frame = self._frame_buffer[-1].image - else: - latest_frame = None - - if not chunk_buffer or not latest_frame: + with self._state_lock: + if self._stopped: return + rolling_summary = str(self._state.get("rolling_summary", "")) + chunk_buffer = list(self._state.get("chunk_buffer", [])) + latest_frame = self._frame_buffer[-1].image if self._frame_buffer else None - # build prompt - prompt = build_summary_prompt( - rolling_summary=rolling_summary, - chunk_windows=chunk_buffer, - ) - - # query vlm (slow, outside lock) - try: - summary_text = self.vlm.query(latest_frame, prompt) - if summary_text and summary_text.strip(): - with self._state_lock: - if self._stopped: - return - apply_summary_update( - self._state, summary_text, w_end, self.config.summary_interval_s - ) - logger.info(f"updated summary: {summary_text[:100]}...") - # Save state after summary update to persist entities - if self.config.output_dir and not self._stopped: - self.save_state() - self.save_entities() - except Exception as e: - logger.error(f"summary update failed: {e}", exc_info=True) - - except Exception as e: - logger.error(f"error updating summary: {e}", exc_info=True) - - def _estimate_distances(self, parsed: dict[str, Any], frame_image, timestamp_s: float) -> None: - """Estimate distances between entities using VLM and save to graph database. - - Batches all entity pairs into a single VLM call for efficiency. - - Args: - parsed: Parsed window data containing entities - frame_image: Frame image to analyze - timestamp_s: Timestamp for the distance observations - """ - if not self._graph_db or not frame_image: - return - - # Collect all entities present in this window - all_entities = [] - for entity in parsed.get("new_entities", []): - all_entities.append(entity) - for entity in parsed.get("entities_present", []): - if isinstance(entity, dict) and "id" in entity: - all_entities.append(entity) - - # Need at least 2 entities to estimate distances - if len(all_entities) < 2: + if not chunk_buffer or not latest_frame: return - # Generate entity pairs (avoid duplicates by only doing i < j) - entity_pairs = [] - for i in range(len(all_entities)): - for j in range(i + 1, len(all_entities)): - entity_pairs.append((all_entities[i], all_entities[j])) - - # Limit to avoid excessive prompt length - entity_pairs = entity_pairs[:MAX_DISTANCE_PAIRS] - - if not entity_pairs: - return + prompt = tu.build_summary_prompt( + rolling_summary=rolling_summary, chunk_windows=chunk_buffer + ) try: - # Build batched prompt for all pairs - prompt = build_batch_distance_estimation_prompt(entity_pairs) - - # Single VLM call for all pairs - response = self.vlm.query(frame_image, prompt) - - # Parse all distance estimates - results = parse_batch_distance_response(response, entity_pairs) - - # Save all valid distances to database - for result in results: - if result["category"] in ["near", "medium", "far"]: - self._graph_db.add_distance( - entity_a=result["entity_a_id"], - entity_b=result["entity_b_id"], - distance_m=result.get("distance_m"), - distance_category=result["category"], - confidence=result.get("confidence", 0.5), - timestamp_s=timestamp_s, - method="vlm_estimation_batch", - ) - logger.debug( - f"estimated distance {result['entity_a_id']}-{result['entity_b_id']}: " - f"{result['category']} ({result.get('distance_m')}m)" + summary_text = self.vlm.query(latest_frame, prompt) + if summary_text and summary_text.strip(): + with self._state_lock: + if self._stopped: + return + tu.apply_summary_update( + self._state, summary_text, w_end, self.config.summary_interval_s ) - + logger.info(f"updated summary: {summary_text[:100]}...") + if self.config.output_dir and not self._stopped: + self.save_state() + self.save_entities() except Exception as e: - logger.warning(f"failed to estimate distances: {e}") + logger.error(f"summary update failed: {e}", exc_info=True) @skill() def query(self, question: str) -> str: @@ -644,17 +450,28 @@ def query(self, question: str) -> str: rolling_summary = str(self._state.get("rolling_summary", "")) last_present = list(self._state.get("last_present", [])) recent_windows = list(self._recent_windows) - latest_frame = self._frame_buffer[-1].image if self._frame_buffer else None + if self._frame_buffer: + latest_frame = self._frame_buffer[-1].image + current_video_time_s = self._frame_buffer[-1].timestamp_s + else: + latest_frame = None + current_video_time_s = 0.0 if not latest_frame: return "no frames available" # build context from temporal state + # Include entities from last_present and recent windows (both entities_present and new_entities) currently_present = {e["id"] for e in last_present if isinstance(e, dict) and "id" in e} for window in recent_windows[-3:]: + # Add entities that were present for entity in window.get("entities_present", []): if isinstance(entity, dict) and isinstance(entity.get("id"), str): currently_present.add(entity["id"]) + # Also include newly detected entities (they're present now) + for entity in window.get("new_entities", []): + if isinstance(entity, dict) and isinstance(entity.get("id"), str): + currently_present.add(entity["id"]) context = { "entity_roster": entity_roster, @@ -667,17 +484,20 @@ def query(self, question: str) -> str: # enhance context with graph database knowledge if self._graph_db and currently_present: # Extract time window from question using VLM - time_window_s = extract_time_window(question, self.vlm, latest_frame) + time_window_s = tu.extract_time_window(question, self.vlm, latest_frame) - graph_context = build_graph_context( + graph_context = tu.build_graph_context( graph_db=self._graph_db, entity_ids=list(currently_present), time_window_s=time_window_s, + max_relations_per_entity=self.config.max_relations_per_entity, + nearby_distance_meters=self.config.nearby_distance_meters, + current_video_time_s=current_video_time_s, ) context["graph_knowledge"] = graph_context - # build query prompt using videorag utils - prompt = build_query_prompt(question=question, context=context) + # build query prompt using temporal utils + prompt = tu.build_query_prompt(question=question, context=context) # query vlm (slow, outside lock) try: @@ -688,13 +508,18 @@ def query(self, question: str) -> str: return f"error: {e}" @rpc - def clear_history(self) -> None: + def clear_history(self) -> bool: """Clear temporal memory state.""" - with self._state_lock: - self._state = default_state() - self._state["next_summary_at_s"] = float(self.config.summary_interval_s) - self._recent_windows.clear() - logger.info("cleared history") + try: + with self._state_lock: + self._state = tu.default_state() + self._state["next_summary_at_s"] = float(self.config.summary_interval_s) + self._recent_windows.clear() + logger.info("cleared history") + return True + except Exception as e: + logger.error(f"clear_history failed: {e}", exc_info=True) + return False @rpc def get_state(self) -> dict[str, Any]: @@ -723,22 +548,8 @@ def get_rolling_summary(self) -> str: def get_graph_db_stats(self) -> dict[str, Any]: """Get statistics and sample data from the graph database.""" if not self._graph_db: - logger.warning("graph database not initialized") - return {"stats": {}, "entities": [], "recent_relations": []} - - try: - stats = self._graph_db.get_stats() - all_entities = self._graph_db.get_all_entities() - recent_relations = self._graph_db.get_recent_relations(limit=10) - - return { - "stats": stats, - "entities": all_entities, - "recent_relations": recent_relations[:5], # just show top 5 - } - except Exception as e: - logger.error(f"failed to get graph db stats: {e}", exc_info=True) return {"stats": {}, "entities": [], "recent_relations": []} + return self._graph_db.get_summary() @rpc def save_state(self) -> bool: @@ -763,78 +574,7 @@ def _append_evidence(self, evidence: dict[str, Any]) -> None: with open(self._evidence_file, "a") as f: f.write(json.dumps(evidence, ensure_ascii=False) + "\n") except Exception as e: - logger.error(f"append evidence failed: {e}") - - def _save_to_graph_db(self, parsed: dict[str, Any], timestamp_s: float) -> None: - """Save parsed window data to the entity graph database.""" - if not self._graph_db: - return - - try: - # Save new entities - for entity in parsed.get("new_entities", []): - self._graph_db.upsert_entity( - entity_id=entity["id"], - entity_type=entity["type"], - descriptor=entity["descriptor"], - timestamp_s=timestamp_s, - ) - - # Save existing entities (update last_seen) - for entity in parsed.get("entities_present", []): - if isinstance(entity, dict) and "id" in entity: - # Only update with descriptor if we have one, otherwise pass empty to preserve existing - descriptor = entity.get("descriptor") - if descriptor: - self._graph_db.upsert_entity( - entity_id=entity["id"], - entity_type=entity.get("type", "unknown"), - descriptor=descriptor, - timestamp_s=timestamp_s, - ) - else: - # Just update last_seen without changing descriptor - # Get existing entity to preserve its descriptor - existing = self._graph_db.get_entity(entity["id"]) - if existing: - self._graph_db.upsert_entity( - entity_id=entity["id"], - entity_type=existing["entity_type"], - descriptor=existing["descriptor"], - timestamp_s=timestamp_s, - ) - - # Save relations - for relation in parsed.get("relations", []): - # Parse subject/object (format: "E1|person" or just "E1") - subject_id = ( - relation["subject"].split("|")[0] - if "|" in relation["subject"] - else relation["subject"] - ) - object_id = ( - relation["object"].split("|")[0] - if "|" in relation["object"] - else relation["object"] - ) - - self._graph_db.add_relation( - relation_type=relation["type"], - subject_id=subject_id, - object_id=object_id, - confidence=relation.get("confidence", 1.0), - timestamp_s=timestamp_s, - evidence=relation.get("evidence"), - notes=relation.get("notes"), - ) - - logger.debug( - f"Saved window data to graph DB: {len(parsed.get('new_entities', []))} new entities, " - f"{len(parsed.get('relations', []))} relations" - ) - - except Exception as e: - logger.error(f"Failed to save to graph DB: {e}", exc_info=True) + logger.error(f"append evidence failed: {e}", exc_info=True) def save_entities(self) -> bool: if not self.config.output_dir: @@ -864,7 +604,7 @@ def save_frames_index(self) -> bool: { "frame_index": f.frame_index, "timestamp_s": f.timestamp_s, - "timestamp": format_timestamp(f.timestamp_s), + "timestamp": tu.format_timestamp(f.timestamp_s), } for f in frames ] @@ -880,43 +620,6 @@ def save_frames_index(self) -> bool: return False -# Backwards compatibility: import deploy from separate module -# from dimos.perception.temporal_memory_deploy import deploy -def deploy( - dimos: DimosCluster, - camera: spec.Camera, - vlm: VlModel | None = None, - config: TemporalMemoryConfig | None = None, -) -> TemporalMemory: - """Deploy TemporalMemory with a camera. - - Args: - dimos: Dimos cluster instance - camera: Camera module to connect to - vlm: Optional VLM instance (creates OpenAI VLM if None) - config: Optional temporal memory configuration - """ - if vlm is None: - from dimos.models.vl.openai import OpenAIVlModel - - api_key = os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError("OPENAI_API_KEY environment variable not set") - vlm = OpenAIVlModel(api_key=api_key) - - temporal_memory = dimos.deploy(TemporalMemory, vlm=vlm, config=config) # type: ignore[attr-defined] - - if camera.color_image.transport is None: - from dimos.core.transport import JpegShmTransport - - transport = JpegShmTransport("/temporal_memory/color_image") - camera.color_image.transport = transport - - temporal_memory.color_image.connect(camera.color_image) - temporal_memory.start() - return temporal_memory - - temporal_memory = TemporalMemory.blueprint -__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "deploy", "temporal_memory"] +__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "temporal_memory"] diff --git a/dimos/perception/temporal_memory_deploy.py b/dimos/perception/temporal_memory_deploy.py new file mode 100644 index 0000000000..bbb8c8ea0a --- /dev/null +++ b/dimos/perception/temporal_memory_deploy.py @@ -0,0 +1,59 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Deployment helpers for TemporalMemory module. +""" + +import os + +from dimos import spec +from dimos.core import DimosCluster +from dimos.models.vl.base import VlModel +from dimos.perception.temporal_memory import TemporalMemory, TemporalMemoryConfig + + +def deploy( + dimos: DimosCluster, + camera: spec.Camera, + vlm: VlModel | None = None, + config: TemporalMemoryConfig | None = None, +) -> TemporalMemory: + """Deploy TemporalMemory with a camera. + + Args: + dimos: Dimos cluster instance + camera: Camera module to connect to + vlm: Optional VLM instance (creates OpenAI VLM if None) + config: Optional temporal memory configuration + """ + if vlm is None: + from dimos.models.vl.openai import OpenAIVlModel + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + vlm = OpenAIVlModel(api_key=api_key) + + temporal_memory = dimos.deploy(TemporalMemory, vlm=vlm, config=config) # type: ignore[attr-defined] + + if camera.color_image.transport is None: + from dimos.core.transport import JpegShmTransport + + transport = JpegShmTransport("/temporal_memory/color_image") + camera.color_image.transport = transport + + temporal_memory.color_image.connect(camera.color_image) + temporal_memory.start() + return temporal_memory diff --git a/dimos/perception/temporal_memory_example.py b/dimos/perception/temporal_memory_example.py index a455b7d7c8..13deca3a59 100644 --- a/dimos/perception/temporal_memory_example.py +++ b/dimos/perception/temporal_memory_example.py @@ -30,10 +30,8 @@ from dimos import core from dimos.hardware.sensors.camera.module import CameraModule from dimos.hardware.sensors.camera.webcam import Webcam -from dimos.perception.temporal_memory import ( - TemporalMemoryConfig, - deploy, -) +from dimos.perception.temporal_memory import TemporalMemoryConfig +from dimos.perception.temporal_memory_deploy import deploy # Load environment variables load_dotenv() @@ -76,14 +74,14 @@ def example_usage(): print("Building temporal context... (wait ~15 seconds)") import time - time.sleep(15) + time.sleep(20) # Query the temporal memory questions = [ - "What entities are currently visible?", - "What has happened in the last few seconds?", "Are there any people in the scene?", "Describe the main activity happening now", + "What has happened in the last few seconds?", + "What entities are currently visible?", ] for question in questions: diff --git a/dimos/perception/temporal_utils/__init__.py b/dimos/perception/temporal_utils/__init__.py new file mode 100644 index 0000000000..64950bee8a --- /dev/null +++ b/dimos/perception/temporal_utils/__init__.py @@ -0,0 +1,60 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Temporal memory utilities for temporal memory. includes helper functions +and prompts that are used to build the prompt for the VLM. +""" + +# Re-export everything from submodules +from .graph_utils import build_graph_context, extract_time_window +from .helpers import clamp_text, format_timestamp, is_scene_stale, next_entity_id_hint +from .parsers import parse_batch_distance_response, parse_window_response +from .prompts import ( + WINDOW_RESPONSE_SCHEMA, + build_batch_distance_estimation_prompt, + build_distance_estimation_prompt, + build_query_prompt, + build_summary_prompt, + build_window_prompt, + get_structured_output_format, +) +from .state import apply_summary_update, default_state, update_state_from_window + +__all__ = [ + # Schema + "WINDOW_RESPONSE_SCHEMA", + # State management + "apply_summary_update", + # Prompts + "build_batch_distance_estimation_prompt", + "build_distance_estimation_prompt", + # Graph utils + "build_graph_context", + "build_query_prompt", + "build_summary_prompt", + "build_window_prompt", + # Helpers + "clamp_text", + "default_state", + "extract_time_window", + "format_timestamp", + "get_structured_output_format", + "is_scene_stale", + "next_entity_id_hint", + # Parsers + "parse_batch_distance_response", + "parse_window_response", + "update_state_from_window", +] diff --git a/dimos/perception/temporal_utils/graph_utils.py b/dimos/perception/temporal_utils/graph_utils.py new file mode 100644 index 0000000000..075641021b --- /dev/null +++ b/dimos/perception/temporal_utils/graph_utils.py @@ -0,0 +1,206 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph database utility functions for temporal memory.""" + +import re +import time +from typing import TYPE_CHECKING, Any + +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + from dimos.msgs.sensor_msgs import Image + from dimos.perception.entity_graph_db import EntityGraphDB + +logger = setup_logger() + + +def extract_time_window( + question: str, + vlm: "VlModel", + latest_frame: "Image | None" = None, +) -> float | None: + """Extract time window from question using VLM with example-based learning. + + Uses a few example keywords as patterns, then asks VLM to extrapolate + similar time references and return seconds. + + Args: + question: User's question + vlm: VLM instance to use for extraction + latest_frame: Optional frame (required for VLM call, but image is ignored) + + Returns: + Time window in seconds, or None if no time reference found + """ + question_lower = question.lower() + + # Quick check for common patterns (fast path) + if "last week" in question_lower or "past week" in question_lower: + return 7 * 24 * 3600 + if "today" in question_lower or "last hour" in question_lower: + return 3600 + if "recently" in question_lower or "recent" in question_lower: + return 600 + + # Use VLM to extract time reference from question + # Provide examples and let VLM extrapolate similar patterns + # Note: latest_frame is required by VLM interface but image content is ignored + if not latest_frame: + return None + + extraction_prompt = f"""Extract any time reference from this question and convert it to seconds. + +Question: {question} + +Examples of time references and their conversions: +- "last week" or "past week" -> 604800 seconds (7 days) +- "yesterday" -> 86400 seconds (1 day) +- "today" or "last hour" -> 3600 seconds (1 hour) +- "recently" or "recent" -> 600 seconds (10 minutes) +- "few minutes ago" -> 300 seconds (5 minutes) +- "just now" -> 60 seconds (1 minute) + +Extrapolate similar patterns (e.g., "2 days ago", "this morning", "last month", etc.) +and convert to seconds. If no time reference is found, return "none". + +Return ONLY a number (seconds) or the word "none". Do not include any explanation.""" + + try: + response = vlm.query(latest_frame, extraction_prompt) + response = response.strip().lower() + + if "none" in response or not response: + return None + + # Extract number from response + numbers = re.findall(r"\d+(?:\.\d+)?", response) + if numbers: + seconds = float(numbers[0]) + # Sanity check: reasonable time windows (1 second to 1 year) + if 1 <= seconds <= 365 * 24 * 3600: + return seconds + except Exception as e: + logger.debug(f"Time extraction failed: {e}") + + return None + + +def build_graph_context( + graph_db: "EntityGraphDB", + entity_ids: list[str], + time_window_s: float | None = None, + max_relations_per_entity: int = 10, + nearby_distance_meters: float = 5.0, + current_video_time_s: float | None = None, +) -> dict[str, Any]: + """Build enriched context from graph database for given entities. + + Args: + graph_db: Entity graph database instance + entity_ids: List of entity IDs to get context for + time_window_s: Optional time window in seconds (e.g., 3600 for last hour) + max_relations_per_entity: Maximum relations to include per entity (default: 10) + nearby_distance_meters: Distance threshold for "nearby" entities (default: 5.0) + current_video_time_s: Current video timestamp in seconds (for time window queries). + If None, uses latest entity timestamp from DB as reference. + + Returns: + Dictionary with graph context including relationships, distances, and semantics + """ + if not graph_db or not entity_ids: + return {} + + try: + graph_context: dict[str, Any] = { + "relationships": [], + "spatial_info": [], + "semantic_knowledge": [], + } + + # Convert time_window_s to a (start_ts, end_ts) tuple if provided + # Use video-relative timestamps, not wall-clock time + time_window_tuple = None + if time_window_s is not None: + if current_video_time_s is not None: + ref_time = current_video_time_s + else: + # Fallback: get the latest timestamp from entities in DB + all_entities = graph_db.get_all_entities() + ref_time = max((e.get("last_seen_ts", 0) for e in all_entities), default=0) + time_window_tuple = (max(0, ref_time - time_window_s), ref_time) + + # Get recent relationships for each entity + for entity_id in entity_ids: + # Get relationships (Graph 1: interactions) + relations = graph_db.get_relations_for_entity( + entity_id=entity_id, + relation_type=None, # all types + time_window=time_window_tuple, + ) + for rel in relations[-max_relations_per_entity:]: + graph_context["relationships"].append( + { + "subject": rel["subject_id"], + "relation": rel["relation_type"], + "object": rel["object_id"], + "confidence": rel["confidence"], + "when": rel["timestamp_s"], + } + ) + + # Get spatial relationships (Graph 2: distances) + nearby = graph_db.get_nearby_entities( + entity_id=entity_id, max_distance=nearby_distance_meters, latest_only=True + ) + for dist in nearby: + graph_context["spatial_info"].append( + { + "entity_a": entity_id, + "entity_b": dist["entity_id"], + "distance": dist.get("distance_meters"), + "category": dist.get("distance_category"), + "confidence": dist["confidence"], + } + ) + + # Get semantic knowledge (Graph 3: conceptual relations) + semantic_rels = graph_db.get_semantic_relations( + entity_id=entity_id, + relation_type=None, + ) + for sem in semantic_rels: + graph_context["semantic_knowledge"].append( + { + "entity_a": sem["entity_a_id"], + "relation": sem["relation_type"], + "entity_b": sem["entity_b_id"], + "confidence": sem["confidence"], + "observations": sem["observation_count"], + } + ) + + # Get graph statistics for context + if entity_ids: + stats = graph_db.get_stats() + graph_context["total_entities"] = stats.get("entities", 0) + graph_context["total_relations"] = stats.get("relations", 0) + + return graph_context + + except Exception as e: + logger.warning(f"failed to build graph context: {e}") + return {} diff --git a/dimos/perception/temporal_utils/helpers.py b/dimos/perception/temporal_utils/helpers.py new file mode 100644 index 0000000000..ecaa4cc2d3 --- /dev/null +++ b/dimos/perception/temporal_utils/helpers.py @@ -0,0 +1,72 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper utility functions for temporal memory.""" + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from dimos.perception.temporal_memory import Frame + + +def next_entity_id_hint(roster: Any) -> str: + """Generate next entity ID based on existing roster (e.g., E1, E2, E3...).""" + if not isinstance(roster, list): + return "E1" + max_n = 0 + for e in roster: + if not isinstance(e, dict): + continue + eid = e.get("id") + if isinstance(eid, str) and eid.startswith("E"): + tail = eid[1:] + if tail.isdigit(): + max_n = max(max_n, int(tail)) + return f"E{max_n + 1}" + + +def clamp_text(text: str, max_chars: int) -> str: + """Clamp text to maximum characters.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + "..." + + +def format_timestamp(seconds: float) -> str: + """Format seconds as MM:SS.mmm timestamp string.""" + m = int(seconds // 60) + s = seconds - 60 * m + return f"{m:02d}:{s:06.3f}" + + +def is_scene_stale(frames: list["Frame"], stale_threshold: float = 5.0) -> bool: + """Check if scene hasn't changed meaningfully between first and last frame. + + Args: + frames: List of frames to check + stale_threshold: Threshold for mean pixel difference (default: 5.0) + + Returns: + True if scene is stale (hasn't changed enough), False otherwise + """ + if len(frames) < 2: + return False + first_img = frames[0].image + last_img = frames[-1].image + if first_img is None or last_img is None: + return False + diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) + return bool(diff.mean() < stale_threshold) diff --git a/dimos/perception/temporal_utils/parsers.py b/dimos/perception/temporal_utils/parsers.py new file mode 100644 index 0000000000..a9b1a05d9f --- /dev/null +++ b/dimos/perception/temporal_utils/parsers.py @@ -0,0 +1,156 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Response parsing functions for VLM outputs.""" + +from typing import Any + +from dimos.utils.llm_utils import extract_json + + +def parse_batch_distance_response( + response: str, entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]] +) -> list[dict[str, Any]]: + """ + Parse batched distance estimation response. + + Args: + response: VLM response text + entity_pairs: Original entity pairs used in the prompt + + Returns: + List of dicts with keys: entity_a_id, entity_b_id, category, distance_m, confidence + """ + results = [] + lines = response.strip().split("\n") + + current_pair_idx = None + category = None + distance_m = None + confidence = 0.5 + + for line in lines: + line = line.strip() + + # Check for pair marker + if line.startswith("Pair "): + # Save previous pair if exists + if current_pair_idx is not None and category: + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + # Start new pair + try: + pair_num = int(line.split()[1].rstrip(":")) + current_pair_idx = pair_num - 1 # Convert to 0-indexed + category = None + distance_m = None + confidence = 0.5 + except (IndexError, ValueError): + continue + + # Parse distance fields + elif line.startswith("category:"): + category = line.split(":", 1)[1].strip().lower() + elif line.startswith("distance_m:"): + try: + distance_m = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + elif line.startswith("confidence:"): + try: + confidence = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + + # Save last pair + if current_pair_idx is not None and category and current_pair_idx < len(entity_pairs): + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + return results + + +def parse_window_response( + response_text: str, w_start: float, w_end: float, frame_count: int +) -> dict[str, Any]: + """ + Parse VLM response for a window analysis. + + Args: + response_text: Raw text response from VLM + w_start: Window start time + w_end: Window end time + frame_count: Number of frames in window + + Returns: + Parsed dictionary with defaults filled in. If parsing fails, returns + a dict with "_error" key instead of raising. + """ + # Try to extract JSON (handles code fences) + parsed = extract_json(response_text) + if parsed is None: + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Failed to parse JSON from response: {response_text[:200]}...", + } + + # Ensure we return a dict (extract_json can return a list) + if isinstance(parsed, list): + # If we got a list, wrap it in a dict with a default structure + # This shouldn't happen with proper structured output, but handle gracefully + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Unexpected list response: {parsed}", + } + + # Ensure it's a dict + if not isinstance(parsed, dict): + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Expected dict or list, got {type(parsed)}: {parsed}", + } + + return parsed diff --git a/dimos/perception/temporal_utils/prompts.py b/dimos/perception/temporal_utils/prompts.py new file mode 100644 index 0000000000..61399fd3f1 --- /dev/null +++ b/dimos/perception/temporal_utils/prompts.py @@ -0,0 +1,353 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prompt building functions for VLM queries.""" + +import json +from typing import Any + +from .helpers import clamp_text, next_entity_id_hint + +# JSON schema for window responses (from VideoRAG) +WINDOW_RESPONSE_SCHEMA = { + "type": "object", + "properties": { + "window": { + "type": "object", + "properties": {"start_s": {"type": "number"}, "end_s": {"type": "number"}}, + "required": ["start_s", "end_s"], + }, + "caption": {"type": "string"}, + "entities_present": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["id"], + }, + }, + "new_entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "type": { + "type": "string", + "enum": ["person", "object", "screen", "text", "location", "other"], + }, + "descriptor": {"type": "string"}, + }, + "required": ["id", "type"], + }, + }, + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"type": "string"}, + "subject": {"type": "string"}, + "object": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + "evidence": {"type": "array", "items": {"type": "string"}}, + "notes": {"type": "string"}, + }, + "required": ["type", "subject", "object"], + }, + }, + "on_screen_text": {"type": "array", "items": {"type": "string"}}, + "uncertainties": {"type": "array", "items": {"type": "string"}}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["window", "caption"], +} + + +def build_window_prompt( + *, + w_start: float, + w_end: float, + frame_count: int, + state: dict[str, Any], +) -> str: + """ + Build comprehensive VLM prompt for analyzing a video window. + + This is adapted from videorag's build_window_messages() but formatted + as a single text prompt for VlModel.query() instead of OpenAI's messages format. + + Args: + w_start: Window start time in seconds + w_end: Window end time in seconds + frame_count: Number of frames in this window + state: Current temporal memory state (entity_roster, rolling_summary, etc.) + + Returns: + Formatted prompt string + """ + roster = state.get("entity_roster", []) + rolling_summary = state.get("rolling_summary", "") + next_id = next_entity_id_hint(roster) + + # System instructions (from VideoRAG) + system_context = """You analyze short sequences of video frames. +You must stay grounded in what is visible. +Do not identify real people or guess names/identities; describe people anonymously. +Extract general entities (people, objects, screens, text, locations) and relations between them. +Use stable entity IDs like E1, E2 based on the provided roster.""" + + # Main prompt (from VideoRAG's build_window_messages) + prompt = f"""{system_context} + +Time window: [{w_start:.3f}, {w_end:.3f}) seconds +Number of frames: {frame_count} + +Existing entity roster (may be empty): +{json.dumps(roster, ensure_ascii=False)} + +Rolling summary so far (may be empty): +{clamp_text(str(rolling_summary), 1500)} + +Task: +1) Write a dense, grounded caption describing what is visible across the frames in this time window. +2) Identify which existing roster entities appear in these frames. +3) Add any new salient entities (people/objects/screens/text/locations) with a short grounded descriptor. +4) Extract grounded relations/events between entities (e.g., looks_at, holds, uses, walks_past, speaks_to (inferred)). + +New entity IDs must start at: {next_id} + +Rules (important): +- You MUST stay grounded in what is visible in the provided frames. +- You MUST NOT mention any entity ID unless it appears in the provided roster OR you include it in new_entities in this same output. +- If the roster is empty, introduce any salient entities you reference (start with E1, E2, ...). +- Do not invent on-screen text: only include text you can read. +- If a relation is inferred (e.g., speaks_to without audio), include it but lower confidence and explain the visual cues. + +Output JSON ONLY with this schema: +{{ + "window": {{"start_s": {w_start:.3f}, "end_s": {w_end:.3f}}}, + "caption": "dense grounded description", + "entities_present": [{{"id": "E1", "confidence": 0.0-1.0}}], + "new_entities": [{{"id": "E3", "type": "person|object|screen|text|location|other", "descriptor": "..."}}], + "relations": [ + {{ + "type": "speaks_to|looks_at|holds|uses|moves|gesture|scene_change|other", + "subject": "E1|unknown", + "object": "E2|unknown", + "confidence": 0.0-1.0, + "evidence": ["describe which frames show this"], + "notes": "short, grounded" + }} + ], + "on_screen_text": ["verbatim snippets"], + "uncertainties": ["things that are unclear"], + "confidence": 0.0-1.0 +}} +""" + return prompt + + +def build_summary_prompt( + *, + rolling_summary: str, + chunk_windows: list[dict[str, Any]], +) -> str: + """ + Build prompt for updating rolling summary. + + This is adapted from videorag's build_summary_messages() but formatted + as a single text prompt for VlModel.query(). + + Args: + rolling_summary: Current rolling summary text + chunk_windows: List of recent window results to incorporate + + Returns: + Formatted prompt string + """ + # System context (from VideoRAG) + system_context = """You summarize timestamped video-window logs into a concise rolling summary. +Stay grounded in the provided window captions/relations. +Do not invent entities or rename entity IDs; preserve IDs like E1, E2 exactly. +You MAY incorporate new entity IDs if they appear in the provided chunk windows (e.g., in new_entities). +Be concise, but keep relevant entity continuity and key relations.""" + + prompt = f"""{system_context} + +Update the rolling summary using the newest chunk. + +Previous rolling summary (may be empty): +{clamp_text(rolling_summary, 2500)} + +New chunk windows (JSON): +{json.dumps(chunk_windows, ensure_ascii=False)} + +Output a concise summary as PLAIN TEXT (no JSON, no code fences). +Length constraints (important): +- Target <= 120 words total. +- Hard cap <= 900 characters. +""" + return prompt + + +def build_query_prompt( + *, + question: str, + context: dict[str, Any], +) -> str: + """ + Build prompt for querying temporal memory. + + Args: + question: User's question about the video stream + context: Context dict containing entity_roster, rolling_summary, etc. + + Returns: + Formatted prompt string + """ + currently_present = context.get("currently_present_entities", []) + currently_present_str = ( + f"Entities recently detected in recent windows: {currently_present}" + if currently_present + else "No entities were detected in recent windows (list is empty)" + ) + + prompt = f"""Answer the following question about the video stream using the provided context. + +**Question:** {question} + +**Context:** +{json.dumps(context, indent=2, ensure_ascii=False)} + +**Important Notes:** +- Entities have stable IDs like E1, E2, etc. +- The 'currently_present_entities' list contains entity IDs that were detected in recent video windows (not necessarily in the current frame you're viewing) +- {currently_present_str} +- The 'entity_roster' contains all known entities with their descriptions +- The 'rolling_summary' describes what has happened over time +- If 'currently_present_entities' is empty, it means no entities were detected in recent windows, but entities may still exist in the roster from earlier +- Answer based on the provided context (entity_roster, rolling_summary, currently_present_entities) AND what you see in the current frame +- If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see + +Provide a concise answer. +""" + return prompt + + +def build_distance_estimation_prompt( + *, + entity_a_descriptor: str, + entity_a_id: str, + entity_b_descriptor: str, + entity_b_id: str, +) -> str: + """ + Build prompt for estimating distance between two entities. + + Args: + entity_a_descriptor: Description of first entity + entity_a_id: ID of first entity + entity_b_descriptor: Description of second entity + entity_b_id: ID of second entity + + Returns: + Formatted prompt string for distance estimation + """ + prompt = f"""Look at this image and estimate the distance between these two entities: + +Entity A: {entity_a_descriptor} (ID: {entity_a_id}) +Entity B: {entity_b_descriptor} (ID: {entity_b_id}) + +Provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] +reasoning: [brief explanation]""" + return prompt + + +def build_batch_distance_estimation_prompt( + entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]], +) -> str: + """ + Build prompt for estimating distances between multiple entity pairs in one call. + + Args: + entity_pairs: List of (entity_a, entity_b) tuples, each entity is a dict with 'id' and 'descriptor' + + Returns: + Formatted prompt string for batched distance estimation + """ + pairs_text = [] + for i, (entity_a, entity_b) in enumerate(entity_pairs, 1): + pairs_text.append( + f"Pair {i}:\n" + f" Entity A: {entity_a['descriptor']} (ID: {entity_a['id']})\n" + f" Entity B: {entity_b['descriptor']} (ID: {entity_b['id']})" + ) + + prompt = f"""Look at this image and estimate the distances between the following entity pairs: + +{chr(10).join(pairs_text)} + +For each pair, provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format (one block per pair): +Pair 1: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +Pair 2: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +(etc.)""" + return prompt + + +def get_structured_output_format() -> dict[str, Any]: + """ + Get OpenAI-compatible structured output format for window responses. + + This uses the json_schema mode available in OpenAI API (GPT-4o mini) to enforce + the VideoRAG response schema. + + Returns: + Dictionary for response_format parameter: + {"type": "json_schema", "json_schema": {...}} + """ + + return { + "type": "json_schema", + "json_schema": { + "name": "video_window_analysis", + "description": "Analysis of a video window with entities and relations", + "schema": WINDOW_RESPONSE_SCHEMA, + "strict": False, # Allow additional fields + }, + } diff --git a/dimos/perception/temporal_utils/state.py b/dimos/perception/temporal_utils/state.py new file mode 100644 index 0000000000..9cdfbe4931 --- /dev/null +++ b/dimos/perception/temporal_utils/state.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State management functions for temporal memory.""" + +from typing import Any + + +def default_state() -> dict[str, Any]: + """Create default temporal memory state dictionary.""" + return { + "entity_roster": [], + "rolling_summary": "", + "chunk_buffer": [], + "next_summary_at_s": 0.0, + "last_present": [], + } + + +def update_state_from_window( + state: dict[str, Any], + parsed: dict[str, Any], + w_end: float, + summary_interval_s: float, +) -> bool: + """ + Update temporal memory state from a parsed window result. + + This implements the state update logic from VideoRAG's generate_evidence(). + + Args: + state: Current state dictionary (modified in place) + parsed: Parsed window result + w_end: Window end time + summary_interval_s: How often to trigger summary updates + + Returns: + True if summary update is needed, False otherwise + """ + # Skip if there was an error + if "_error" in parsed: + return False + + new_entities = parsed.get("new_entities", []) + present = parsed.get("entities_present", []) + + # Handle new entities + if new_entities: + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + for e in new_entities: + if isinstance(e, dict) and e.get("id") not in known: + roster.append(e) + known.add(e.get("id")) + state["entity_roster"] = roster + + # Handle referenced entities (auto-add if mentioned but not in roster) + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + referenced: set[str] = set() + for p in present or []: + if isinstance(p, dict) and isinstance(p.get("id"), str): + referenced.add(p["id"]) + for rel in parsed.get("relations") or []: + if isinstance(rel, dict): + for k in ("subject", "object"): + v = rel.get(k) + if isinstance(v, str) and v != "unknown": + referenced.add(v) + for rid in sorted(referenced): + if rid not in known: + roster.append( + { + "id": rid, + "type": "other", + "descriptor": "unknown (auto-added; rerun recommended)", + } + ) + known.add(rid) + state["entity_roster"] = roster + state["last_present"] = present + + # Add to chunk buffer + chunk_buffer = state.get("chunk_buffer", []) + if not isinstance(chunk_buffer, list): + chunk_buffer = [] + chunk_buffer.append( + { + "window": parsed.get("window"), + "caption": parsed.get("caption", ""), + "entities_present": parsed.get("entities_present", []), + "new_entities": parsed.get("new_entities", []), + "relations": parsed.get("relations", []), + "on_screen_text": parsed.get("on_screen_text", []), + } + ) + state["chunk_buffer"] = chunk_buffer + + # Check if summary update is needed + if summary_interval_s > 0: + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + if w_end + 1e-6 >= next_at and chunk_buffer: + return True # Need to update summary + + return False + + +def apply_summary_update( + state: dict[str, Any], summary_text: str, w_end: float, summary_interval_s: float +) -> None: + """ + Apply a summary update to the state. + + Args: + state: State dictionary (modified in place) + summary_text: New summary text + w_end: Current window end time + summary_interval_s: Summary update interval + """ + if summary_text and summary_text.strip(): + state["rolling_summary"] = summary_text.strip() + state["chunk_buffer"] = [] + + # Advance next_summary_at_s + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + while next_at <= w_end + 1e-6: + next_at += float(summary_interval_s) + state["next_summary_at_s"] = next_at diff --git a/dimos/perception/videorag_utils.py b/dimos/perception/videorag_utils.py deleted file mode 100644 index ab6df13025..0000000000 --- a/dimos/perception/videorag_utils.py +++ /dev/null @@ -1,726 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -VideoRAG utilities for temporal memory - adapted from videorag/evidence.py - -This module ports the sophisticated prompts and logic from VideoRAG for use -with dimos's VlModel abstraction instead of OpenAI API directly. -""" - -import json -import re -from typing import Any - -from dimos.models.vl.base import VlModel -from dimos.msgs.sensor_msgs import Image -from dimos.utils.llm_utils import extract_json -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - - -def default_state() -> dict[str, Any]: - """Create default temporal memory state dictionary.""" - return { - "entity_roster": [], - "rolling_summary": "", - "chunk_buffer": [], - "next_summary_at_s": 0.0, - "last_present": [], - } - - -def next_entity_id_hint(roster: Any) -> str: - """Generate next entity ID based on existing roster (e.g., E1, E2, E3...).""" - if not isinstance(roster, list): - return "E1" - max_n = 0 - for e in roster: - if not isinstance(e, dict): - continue - eid = e.get("id") - if isinstance(eid, str) and eid.startswith("E"): - tail = eid[1:] - if tail.isdigit(): - max_n = max(max_n, int(tail)) - return f"E{max_n + 1}" - - -def clamp_text(text: str, max_chars: int) -> str: - """Clamp text to maximum characters.""" - if len(text) <= max_chars: - return text - return text[:max_chars] + "..." - - -def format_timestamp(seconds: float) -> str: - """Format seconds as MM:SS.mmm timestamp string.""" - m = int(seconds // 60) - s = seconds - 60 * m - return f"{m:02d}:{s:06.3f}" - - -def build_window_prompt( - *, - w_start: float, - w_end: float, - frame_count: int, - state: dict[str, Any], -) -> str: - """ - Build comprehensive VLM prompt for analyzing a video window. - - This is adapted from videorag's build_window_messages() but formatted - as a single text prompt for VlModel.query() instead of OpenAI's messages format. - - Args: - w_start: Window start time in seconds - w_end: Window end time in seconds - frame_count: Number of frames in this window - state: Current temporal memory state (entity_roster, rolling_summary, etc.) - - Returns: - Formatted prompt string - """ - roster = state.get("entity_roster", []) - rolling_summary = state.get("rolling_summary", "") - next_id = next_entity_id_hint(roster) - - # System instructions (from VideoRAG) - system_context = """You analyze short sequences of video frames. -You must stay grounded in what is visible. -Do not identify real people or guess names/identities; describe people anonymously. -Extract general entities (people, objects, screens, text, locations) and relations between them. -Use stable entity IDs like E1, E2 based on the provided roster.""" - - # Main prompt (from VideoRAG's build_window_messages) - prompt = f"""{system_context} - -Time window: [{w_start:.3f}, {w_end:.3f}) seconds -Number of frames: {frame_count} - -Existing entity roster (may be empty): -{json.dumps(roster, ensure_ascii=False)} - -Rolling summary so far (may be empty): -{clamp_text(str(rolling_summary), 1500)} - -Task: -1) Write a dense, grounded caption describing what is visible across the frames in this time window. -2) Identify which existing roster entities appear in these frames. -3) Add any new salient entities (people/objects/screens/text/locations) with a short grounded descriptor. -4) Extract grounded relations/events between entities (e.g., looks_at, holds, uses, walks_past, speaks_to (inferred)). - -New entity IDs must start at: {next_id} - -Rules (important): -- You MUST stay grounded in what is visible in the provided frames. -- You MUST NOT mention any entity ID unless it appears in the provided roster OR you include it in new_entities in this same output. -- If the roster is empty, introduce any salient entities you reference (start with E1, E2, ...). -- Do not invent on-screen text: only include text you can read. -- If a relation is inferred (e.g., speaks_to without audio), include it but lower confidence and explain the visual cues. - -Output JSON ONLY with this schema: -{{ - "window": {{"start_s": {w_start:.3f}, "end_s": {w_end:.3f}}}, - "caption": "dense grounded description", - "entities_present": [{{"id": "E1", "confidence": 0.0-1.0}}], - "new_entities": [{{"id": "E3", "type": "person|object|screen|text|location|other", "descriptor": "..."}}], - "relations": [ - {{ - "type": "speaks_to|looks_at|holds|uses|moves|gesture|scene_change|other", - "subject": "E1|unknown", - "object": "E2|unknown", - "confidence": 0.0-1.0, - "evidence": ["describe which frames show this"], - "notes": "short, grounded" - }} - ], - "on_screen_text": ["verbatim snippets"], - "uncertainties": ["things that are unclear"], - "confidence": 0.0-1.0 -}} -""" - return prompt - - -# JSON schema for window responses (from VideoRAG) -WINDOW_RESPONSE_SCHEMA = { - "type": "object", - "properties": { - "window": { - "type": "object", - "properties": {"start_s": {"type": "number"}, "end_s": {"type": "number"}}, - "required": ["start_s", "end_s"], - }, - "caption": {"type": "string"}, - "entities_present": { - "type": "array", - "items": { - "type": "object", - "properties": { - "id": {"type": "string"}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["id"], - }, - }, - "new_entities": { - "type": "array", - "items": { - "type": "object", - "properties": { - "id": {"type": "string"}, - "type": { - "type": "string", - "enum": ["person", "object", "screen", "text", "location", "other"], - }, - "descriptor": {"type": "string"}, - }, - "required": ["id", "type"], - }, - }, - "relations": { - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"type": "string"}, - "subject": {"type": "string"}, - "object": {"type": "string"}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - "evidence": {"type": "array", "items": {"type": "string"}}, - "notes": {"type": "string"}, - }, - "required": ["type", "subject", "object"], - }, - }, - "on_screen_text": {"type": "array", "items": {"type": "string"}}, - "uncertainties": {"type": "array", "items": {"type": "string"}}, - "confidence": {"type": "number", "minimum": 0, "maximum": 1}, - }, - "required": ["window", "caption"], -} - - -def build_summary_prompt( - *, - rolling_summary: str, - chunk_windows: list[dict[str, Any]], -) -> str: - """ - Build prompt for updating rolling summary. - - This is adapted from videorag's build_summary_messages() but formatted - as a single text prompt for VlModel.query(). - - Args: - rolling_summary: Current rolling summary text - chunk_windows: List of recent window results to incorporate - - Returns: - Formatted prompt string - """ - # System context (from VideoRAG) - system_context = """You summarize timestamped video-window logs into a concise rolling summary. -Stay grounded in the provided window captions/relations. -Do not invent entities or rename entity IDs; preserve IDs like E1, E2 exactly. -You MAY incorporate new entity IDs if they appear in the provided chunk windows (e.g., in new_entities). -Be concise, but keep relevant entity continuity and key relations.""" - - prompt = f"""{system_context} - -Update the rolling summary using the newest chunk. - -Previous rolling summary (may be empty): -{clamp_text(rolling_summary, 2500)} - -New chunk windows (JSON): -{json.dumps(chunk_windows, ensure_ascii=False)} - -Output a concise summary as PLAIN TEXT (no JSON, no code fences). -Length constraints (important): -- Target <= 120 words total. -- Hard cap <= 900 characters. -""" - return prompt - - -def build_query_prompt( - *, - question: str, - context: dict[str, Any], -) -> str: - """ - Build prompt for querying temporal memory. - - Args: - question: User's question about the video stream - context: Context dict containing entity_roster, rolling_summary, etc. - - Returns: - Formatted prompt string - """ - currently_present = context.get("currently_present_entities", []) - currently_present_str = ( - f"Entities recently detected in recent windows: {currently_present}" - if currently_present - else "No entities were detected in recent windows (list is empty)" - ) - - prompt = f"""Answer the following question about the video stream using the provided context. - -**Question:** {question} - -**Context:** -{json.dumps(context, indent=2, ensure_ascii=False)} - -**Important Notes:** -- Entities have stable IDs like E1, E2, etc. -- The 'currently_present_entities' list contains entity IDs that were detected in recent video windows (not necessarily in the current frame you're viewing) -- {currently_present_str} -- The 'entity_roster' contains all known entities with their descriptions -- The 'rolling_summary' describes what has happened over time -- If 'currently_present_entities' is empty, it means no entities were detected in recent windows, but entities may still exist in the roster from earlier -- Answer based on the provided context (entity_roster, rolling_summary, currently_present_entities) AND what you see in the current frame -- If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see - -Provide a concise answer. -""" - return prompt - - -def extract_time_window( - question: str, - vlm: VlModel, - latest_frame: Image | None = None, -) -> float | None: - """Extract time window from question using VLM with example-based learning. - - Uses a few example keywords as patterns, then asks VLM to extrapolate - similar time references and return seconds. - - Args: - question: User's question - vlm: VLM instance to use for extraction - latest_frame: Optional frame (required for VLM call, but image is ignored) - - Returns: - Time window in seconds, or None if no time reference found - """ - question_lower = question.lower() - - # Quick check for common patterns (fast path) - if "last week" in question_lower or "past week" in question_lower: - return 7 * 24 * 3600 - if "today" in question_lower or "last hour" in question_lower: - return 3600 - if "recently" in question_lower or "recent" in question_lower: - return 600 - - # Use VLM to extract time reference from question - # Provide examples and let VLM extrapolate similar patterns - # Note: latest_frame is required by VLM interface but image content is ignored - if not latest_frame: - return None - - extraction_prompt = f"""Extract any time reference from this question and convert it to seconds. - -Question: {question} - -Examples of time references and their conversions: -- "last week" or "past week" -> 604800 seconds (7 days) -- "yesterday" -> 86400 seconds (1 day) -- "today" or "last hour" -> 3600 seconds (1 hour) -- "recently" or "recent" -> 600 seconds (10 minutes) -- "few minutes ago" -> 300 seconds (5 minutes) -- "just now" -> 60 seconds (1 minute) - -Extrapolate similar patterns (e.g., "2 days ago", "this morning", "last month", etc.) -and convert to seconds. If no time reference is found, return "none". - -Return ONLY a number (seconds) or the word "none". Do not include any explanation.""" - - try: - response = vlm.query(latest_frame, extraction_prompt) - response = response.strip().lower() - - if "none" in response or not response: - return None - - # Extract number from response - numbers = re.findall(r"\d+(?:\.\d+)?", response) - if numbers: - seconds = float(numbers[0]) - # Sanity check: reasonable time windows (1 second to 1 year) - if 1 <= seconds <= 365 * 24 * 3600: - return seconds - except Exception as e: - logger.debug(f"Time extraction failed: {e}") - - return None - - -def build_distance_estimation_prompt( - *, - entity_a_descriptor: str, - entity_a_id: str, - entity_b_descriptor: str, - entity_b_id: str, -) -> str: - """ - Build prompt for estimating distance between two entities. - - Args: - entity_a_descriptor: Description of first entity - entity_a_id: ID of first entity - entity_b_descriptor: Description of second entity - entity_b_id: ID of second entity - - Returns: - Formatted prompt string for distance estimation - """ - prompt = f"""Look at this image and estimate the distance between these two entities: - -Entity A: {entity_a_descriptor} (ID: {entity_a_id}) -Entity B: {entity_b_descriptor} (ID: {entity_b_id}) - -Provide: -1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) -2. Approximate distance in meters (best guess) -3. Confidence: 0.0-1.0 (how certain are you?) - -Respond in this format: -category: [near/medium/far] -distance_m: [number] -confidence: [0.0-1.0] -reasoning: [brief explanation]""" - return prompt - - -def build_batch_distance_estimation_prompt( - entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]], -) -> str: - """ - Build prompt for estimating distances between multiple entity pairs in one call. - - Args: - entity_pairs: List of (entity_a, entity_b) tuples, each entity is a dict with 'id' and 'descriptor' - - Returns: - Formatted prompt string for batched distance estimation - """ - pairs_text = [] - for i, (entity_a, entity_b) in enumerate(entity_pairs, 1): - pairs_text.append( - f"Pair {i}:\n" - f" Entity A: {entity_a['descriptor']} (ID: {entity_a['id']})\n" - f" Entity B: {entity_b['descriptor']} (ID: {entity_b['id']})" - ) - - prompt = f"""Look at this image and estimate the distances between the following entity pairs: - -{chr(10).join(pairs_text)} - -For each pair, provide: -1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) -2. Approximate distance in meters (best guess) -3. Confidence: 0.0-1.0 (how certain are you?) - -Respond in this format (one block per pair): -Pair 1: -category: [near/medium/far] -distance_m: [number] -confidence: [0.0-1.0] - -Pair 2: -category: [near/medium/far] -distance_m: [number] -confidence: [0.0-1.0] - -(etc.)""" - return prompt - - -def parse_batch_distance_response( - response: str, entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]] -) -> list[dict[str, Any]]: - """ - Parse batched distance estimation response. - - Args: - response: VLM response text - entity_pairs: Original entity pairs used in the prompt - - Returns: - List of dicts with keys: entity_a_id, entity_b_id, category, distance_m, confidence - """ - results = [] - lines = response.strip().split("\n") - - current_pair_idx = None - category = None - distance_m = None - confidence = 0.5 - - for line in lines: - line = line.strip() - - # Check for pair marker - if line.startswith("Pair "): - # Save previous pair if exists - if current_pair_idx is not None and category: - entity_a, entity_b = entity_pairs[current_pair_idx] - results.append( - { - "entity_a_id": entity_a["id"], - "entity_b_id": entity_b["id"], - "category": category, - "distance_m": distance_m, - "confidence": confidence, - } - ) - - # Start new pair - try: - pair_num = int(line.split()[1].rstrip(":")) - current_pair_idx = pair_num - 1 # Convert to 0-indexed - category = None - distance_m = None - confidence = 0.5 - except (IndexError, ValueError): - continue - - # Parse distance fields - elif line.startswith("category:"): - category = line.split(":", 1)[1].strip().lower() - elif line.startswith("distance_m:"): - try: - distance_m = float(line.split(":", 1)[1].strip()) - except (ValueError, IndexError): - pass - elif line.startswith("confidence:"): - try: - confidence = float(line.split(":", 1)[1].strip()) - except (ValueError, IndexError): - pass - - # Save last pair - if current_pair_idx is not None and category and current_pair_idx < len(entity_pairs): - entity_a, entity_b = entity_pairs[current_pair_idx] - results.append( - { - "entity_a_id": entity_a["id"], - "entity_b_id": entity_b["id"], - "category": category, - "distance_m": distance_m, - "confidence": confidence, - } - ) - - return results - - -def parse_window_response( - response_text: str, w_start: float, w_end: float, frame_count: int -) -> dict[str, Any]: - """ - Parse VLM response for a window analysis. - - Args: - response_text: Raw text response from VLM - w_start: Window start time - w_end: Window end time - frame_count: Number of frames in window - - Returns: - Parsed dictionary with defaults filled in - """ - # Try to extract JSON (handles code fences) - parsed = extract_json(response_text) - if parsed is None: - raise ValueError(f"Failed to parse response: {response_text}") - - # Ensure we return a dict (extract_json can return a list) - if isinstance(parsed, list): - # If we got a list, wrap it in a dict with a default structure - # This shouldn't happen with proper structured output, but handle gracefully - return { - "window": {"start": w_start, "end": w_end}, - "caption": "", - "entities_present": [], - "new_entities": [], - "relations": [], - "on_screen_text": [], - "_error": f"Unexpected list response: {parsed}", - } - - # Ensure it's a dict - if not isinstance(parsed, dict): - raise ValueError(f"Expected dict or list, got {type(parsed)}: {parsed}") - - return parsed - - -def update_state_from_window( - state: dict[str, Any], - parsed: dict[str, Any], - w_end: float, - summary_interval_s: float, -) -> bool: - """ - Update temporal memory state from a parsed window result. - - This implements the state update logic from VideoRAG's generate_evidence(). - - Args: - state: Current state dictionary (modified in place) - parsed: Parsed window result - w_end: Window end time - summary_interval_s: How often to trigger summary updates - - Returns: - True if summary update is needed, False otherwise - """ - # Skip if there was an error - if "_error" in parsed: - return False - - new_entities = parsed.get("new_entities", []) - present = parsed.get("entities_present", []) - - # Handle new entities - if new_entities: - roster = list(state.get("entity_roster", [])) - known = {e.get("id") for e in roster if isinstance(e, dict)} - for e in new_entities: - if isinstance(e, dict) and e.get("id") not in known: - roster.append(e) - known.add(e.get("id")) - state["entity_roster"] = roster - - # Handle referenced entities (auto-add if mentioned but not in roster) - roster = list(state.get("entity_roster", [])) - known = {e.get("id") for e in roster if isinstance(e, dict)} - referenced: set[str] = set() - for p in present or []: - if isinstance(p, dict) and isinstance(p.get("id"), str): - referenced.add(p["id"]) - for rel in parsed.get("relations") or []: - if isinstance(rel, dict): - for k in ("subject", "object"): - v = rel.get(k) - if isinstance(v, str) and v != "unknown": - referenced.add(v) - for rid in sorted(referenced): - if rid not in known: - roster.append( - { - "id": rid, - "type": "other", - "descriptor": "unknown (auto-added; rerun recommended)", - } - ) - known.add(rid) - state["entity_roster"] = roster - state["last_present"] = present - - # Add to chunk buffer - chunk_buffer = state.get("chunk_buffer", []) - if not isinstance(chunk_buffer, list): - chunk_buffer = [] - chunk_buffer.append( - { - "window": parsed.get("window"), - "caption": parsed.get("caption", ""), - "entities_present": parsed.get("entities_present", []), - "new_entities": parsed.get("new_entities", []), - "relations": parsed.get("relations", []), - "on_screen_text": parsed.get("on_screen_text", []), - } - ) - state["chunk_buffer"] = chunk_buffer - - # Check if summary update is needed - if summary_interval_s > 0: - next_at = float(state.get("next_summary_at_s", summary_interval_s)) - if w_end + 1e-6 >= next_at and chunk_buffer: - return True # Need to update summary - - return False - - -def apply_summary_update( - state: dict[str, Any], summary_text: str, w_end: float, summary_interval_s: float -) -> None: - """ - Apply a summary update to the state. - - Args: - state: State dictionary (modified in place) - summary_text: New summary text - w_end: Current window end time - summary_interval_s: Summary update interval - """ - if summary_text and summary_text.strip(): - state["rolling_summary"] = summary_text.strip() - state["chunk_buffer"] = [] - - # Advance next_summary_at_s - next_at = float(state.get("next_summary_at_s", summary_interval_s)) - while next_at <= w_end + 1e-6: - next_at += float(summary_interval_s) - state["next_summary_at_s"] = next_at - - -def get_structured_output_format() -> dict[str, Any]: - """ - Get OpenAI-compatible structured output format for window responses. - - This uses the json_schema mode available in OpenAI API (GPT-4o mini) to enforce - the VideoRAG response schema. - - Returns: - Dictionary for response_format parameter: - {"type": "json_schema", "json_schema": {...}} - """ - - return { - "type": "json_schema", - "json_schema": { - "name": "video_window_analysis", - "description": "Analysis of a video window with entities and relations", - "schema": WINDOW_RESPONSE_SCHEMA, - "strict": False, # Allow additional fields - }, - } - - -__all__ = [ - "WINDOW_RESPONSE_SCHEMA", - "apply_summary_update", - "build_distance_estimation_prompt", - "build_query_prompt", - "build_summary_prompt", - "build_window_prompt", - "clamp_text", - "default_state", - "extract_time_window", - "format_timestamp", - "get_structured_output_format", - "next_entity_id_hint", - "parse_window_response", - "update_state_from_window", -] diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index 34f81e2bbf..084adba8da 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -81,7 +81,8 @@ def _camera_info_static() -> CameraInfo: class ReplayConnection(UnitreeWebRTCConnection): - dir_name = "unitree_go2_bigoffice" + # dir_name = "unitree_go2_bigoffice" + dir_name = "unitree_go2_office_walk2" # we don't want UnitreeWebRTCConnection to init def __init__( # type: ignore[no-untyped-def] From 581471f2f707083dceaebbc962c7af7a067c8acb Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Tue, 13 Jan 2026 19:11:57 -0800 Subject: [PATCH 12/21] type checking issues --- dimos/perception/clip_filter.py | 189 ++++++++++---------- dimos/perception/entity_graph_db.py | 2 +- dimos/perception/temporal_memory.py | 10 +- dimos/perception/temporal_memory_deploy.py | 2 +- dimos/perception/temporal_memory_example.py | 6 +- 5 files changed, 111 insertions(+), 98 deletions(-) diff --git a/dimos/perception/clip_filter.py b/dimos/perception/clip_filter.py index 34f9da3912..3996dbbdb8 100644 --- a/dimos/perception/clip_filter.py +++ b/dimos/perception/clip_filter.py @@ -31,129 +31,135 @@ # Optional CLIP imports try: - import clip - from PIL import Image as PILImage - import torch + import clip # type: ignore[import-untyped] + from PIL import Image as PILImage # type: ignore[import-untyped] + import torch # type: ignore[import-untyped] CLIP_AVAILABLE = True -except ImportError: +except (ImportError, RuntimeError) as e: CLIP_AVAILABLE = False logger.warning( - "CLIP not available. Install with: pip install torch torchvision openai-clip. " + f"CLIP not available: {e}. Install with: pip install torch torchvision openai-clip. " "Frame filtering will fall back to simple sampling." ) + # Define stub for type annotations when PIL is not available + class PILImage: + Image = None -class CLIPFrameFilter: - """Filter video frames using CLIP embeddings to select diverse frames.""" - def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): - """ - Initialize CLIP frame filter. +if CLIP_AVAILABLE: - Args: - model_name: CLIP model name (e.g., "ViT-B/32", "ViT-L/14") - device: Device to use ("cuda", "cpu", or None for auto-detect) - """ - if not CLIP_AVAILABLE: - raise ImportError( - "CLIP is not available. Install with: pip install torch torchvision openai-clip" - ) + class CLIPFrameFilter: + """Filter video frames using CLIP embeddings to select diverse frames.""" - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Loading CLIP model {model_name} on {self.device}") - self.model, self.preprocess = clip.load(model_name, device=self.device) - logger.info("CLIP model loaded successfully") + def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): + """ + Initialize CLIP frame filter. - def _image_to_pil(self, image: Image) -> PILImage.Image: - """Convert dimos Image to PIL Image.""" - # Get numpy array from dimos Image - img_array = image.data # Assumes Image has .data attribute with numpy array + Args: + model_name: CLIP model name (e.g., "ViT-B/32", "ViT-L/14") + device: Device to use ("cuda", "cpu", or None for auto-detect) + """ + if not CLIP_AVAILABLE: + raise ImportError( + "CLIP is not available. Install with: pip install torch torchvision openai-clip" + ) - # Convert to PIL - return PILImage.fromarray(img_array) + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Loading CLIP model {model_name} on {self.device}") + self.model, self.preprocess = clip.load(model_name, device=self.device) + logger.info("CLIP model loaded successfully") - def _encode_images(self, images: list[Image]) -> torch.Tensor: - """Encode images using CLIP. + def _image_to_pil(self, image: Image) -> "PILImage.Image": + """Convert dimos Image to PIL Image.""" + # Get numpy array from dimos Image + img_array = image.data # Assumes Image has .data attribute with numpy array - Args: - images: List of dimos Images + # Convert to PIL + return PILImage.fromarray(img_array) - Returns: - Tensor of normalized CLIP embeddings, shape (N, embedding_dim) - """ - # Convert to PIL and preprocess - pil_images = [self._image_to_pil(img) for img in images] - preprocessed = [self.preprocess(img) for img in pil_images] + def _encode_images(self, images: list[Image]) -> "torch.Tensor": + """Encode images using CLIP. - # Stack and encode - image_tensor = torch.stack(preprocessed).to(self.device) - with torch.no_grad(): - embeddings = self.model.encode_image(image_tensor) - # Normalize embeddings - embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) + Args: + images: List of dimos Images - return embeddings + Returns: + Tensor of normalized CLIP embeddings, shape (N, embedding_dim) + """ + # Convert to PIL and preprocess + pil_images = [self._image_to_pil(img) for img in images] + preprocessed = [self.preprocess(img) for img in pil_images] - def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: - """ - Select diverse frames using greedy farthest-point sampling. + # Stack and encode + image_tensor = torch.stack(preprocessed).to(self.device) + with torch.no_grad(): + embeddings = self.model.encode_image(image_tensor) + # Normalize embeddings + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) - This selects frames that are maximally different from each other in CLIP - embedding space, ensuring good visual coverage of the window. + return embeddings - Algorithm: - 1. Always include first frame (temporal anchor) - 2. Iteratively select frame most different from already-selected frames - 3. Continue until we have max_frames frames + def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: + """ + Select diverse frames using greedy farthest-point sampling. - Args: - frames: List of Frame objects with .image attribute - max_frames: Maximum number of frames to select + This selects frames that are maximally different from each other in CLIP + embedding space, ensuring good visual coverage of the window. - Returns: - List of selected Frame objects (subset of input frames) - """ - if len(frames) <= max_frames: - return frames + Algorithm: + 1. Always include first frame (temporal anchor) + 2. Iteratively select frame most different from already-selected frames + 3. Continue until we have max_frames frames - # Extract images from frames - images = [f.image for f in frames] + Args: + frames: List of Frame objects with .image attribute + max_frames: Maximum number of frames to select - # Encode all images - embeddings = self._encode_images(images) + Returns: + List of selected Frame objects (subset of input frames) + """ + if len(frames) <= max_frames: + return frames - # Greedy farthest-point sampling - selected_indices = [0] # Always include first frame - remaining_indices = list(range(1, len(frames))) + # Extract images from frames + images = [f.image for f in frames] - while len(selected_indices) < max_frames and remaining_indices: - selected_embs = embeddings[selected_indices] - remaining_embs = embeddings[remaining_indices] + # Encode all images + embeddings = self._encode_images(images) - # Compute similarities between remaining and selected - # Shape: (num_remaining, num_selected) - similarities = remaining_embs @ selected_embs.T + # Greedy farthest-point sampling + selected_indices = [0] # Always include first frame + remaining_indices = list(range(1, len(frames))) - # For each remaining frame, find its max similarity to any selected frame - # Shape: (num_remaining,) - max_similarities = similarities.max(dim=1)[0] + while len(selected_indices) < max_frames and remaining_indices: + selected_embs = embeddings[selected_indices] + remaining_embs = embeddings[remaining_indices] + + # Compute similarities between remaining and selected + # Shape: (num_remaining, num_selected) + similarities = remaining_embs @ selected_embs.T + + # For each remaining frame, find its max similarity to any selected frame + # Shape: (num_remaining,) + max_similarities = similarities.max(dim=1)[0] # Select frame with minimum max similarity (most different from all selected) - best_idx = max_similarities.argmin().item() + best_idx = int(max_similarities.argmin().item()) selected_indices.append(remaining_indices[best_idx]) remaining_indices.pop(best_idx) - # Return frames in temporal order (sorted by index) - return [frames[i] for i in sorted(selected_indices)] + # Return frames in temporal order (sorted by index) + return [frames[i] for i in sorted(selected_indices)] - def close(self) -> None: - """Clean up CLIP model.""" - if hasattr(self, "model"): - del self.model - if hasattr(self, "preprocess"): - del self.preprocess + def close(self) -> None: + """Clean up CLIP model.""" + if hasattr(self, "model"): + del self.model + if hasattr(self, "preprocess"): + del self.preprocess def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: @@ -178,11 +184,11 @@ def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list def adaptive_keyframes( - frames: list, + frames: list[Any], min_frames: int = 3, max_frames: int = 5, change_threshold: float = 15.0, -) -> list: +) -> list[Any]: """select frames based on visual change, adaptive count.""" if len(frames) <= min_frames: return frames @@ -235,7 +241,10 @@ def adaptive_keyframes( __all__ = [ "CLIP_AVAILABLE", - "CLIPFrameFilter", + "CLIPFrameFilter" if CLIP_AVAILABLE else None, "adaptive_keyframes", "select_diverse_frames_simple", ] + +# Filter out None values +__all__ = [item for item in __all__ if item is not None] diff --git a/dimos/perception/entity_graph_db.py b/dimos/perception/entity_graph_db.py index 414a2d66b7..06d9fbc76a 100644 --- a/dimos/perception/entity_graph_db.py +++ b/dimos/perception/entity_graph_db.py @@ -69,7 +69,7 @@ def _get_connection(self) -> sqlite3.Connection: if not hasattr(self._local, "conn"): self._local.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) self._local.conn.row_factory = sqlite3.Row - return self._local.conn + return self._local.conn # type: ignore def _init_schema(self) -> None: """Initialize database schema.""" diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/temporal_memory.py index 2534854102..61906498e0 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/temporal_memory.py @@ -42,10 +42,14 @@ from dimos.perception import temporal_utils as tu from dimos.perception.clip_filter import ( CLIP_AVAILABLE, - CLIPFrameFilter, adaptive_keyframes, select_diverse_frames_simple, ) + +try: + from dimos.perception.clip_filter import CLIPFrameFilter +except ImportError: + CLIPFrameFilter = type(None) # type: ignore[misc,assignment] from dimos.perception.entity_graph_db import EntityGraphDB from dimos.utils.logging_config import setup_logger @@ -112,7 +116,7 @@ def __init__( super().__init__() self._vlm = vlm # Can be None for blueprint usage - self.config = config or TemporalMemoryConfig() + self.config: TemporalMemoryConfig = config or TemporalMemoryConfig() # single lock protects all state self._state_lock = threading.Lock() @@ -267,7 +271,7 @@ def stop(self) -> None: for stream in list(self.inputs.values()) + list(self.outputs.values()): if stream.transport is not None and hasattr(stream.transport, "stop"): stream.transport.stop() - stream._transport = None + stream._transport = None # type: ignore[attr-defined,assignment] logger.info("temporalmemory stopped") diff --git a/dimos/perception/temporal_memory_deploy.py b/dimos/perception/temporal_memory_deploy.py index bbb8c8ea0a..668977d8b7 100644 --- a/dimos/perception/temporal_memory_deploy.py +++ b/dimos/perception/temporal_memory_deploy.py @@ -56,4 +56,4 @@ def deploy( temporal_memory.color_image.connect(camera.color_image) temporal_memory.start() - return temporal_memory + return temporal_memory # type: ignore[return-value,no-any-return] diff --git a/dimos/perception/temporal_memory_example.py b/dimos/perception/temporal_memory_example.py index 13deca3a59..791c96c3e1 100644 --- a/dimos/perception/temporal_memory_example.py +++ b/dimos/perception/temporal_memory_example.py @@ -37,7 +37,7 @@ load_dotenv() -def example_usage(): +def example_usage() -> None: """Example of how to use TemporalMemory.""" # Initialize variables to None for cleanup temporal_memory = None @@ -48,7 +48,7 @@ def example_usage(): # Create Dimos cluster dimos = core.start(1) # Deploy camera module - camera = dimos.deploy(CameraModule, hardware=lambda: Webcam(camera_index=0)) + camera = dimos.deploy(CameraModule, hardware=lambda: Webcam(camera_index=0)) # type: ignore[attr-defined] camera.start() # Deploy temporal memory using the deploy function @@ -130,7 +130,7 @@ def example_usage(): if camera is not None: camera.stop() if dimos is not None: - dimos.close_all() + dimos.close_all() # type: ignore[attr-defined] if __name__ == "__main__": From e0112e2855205a41bfffaad05b000d7709b2f3a8 Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Wed, 14 Jan 2026 12:48:18 -0800 Subject: [PATCH 13/21] final edits, move into experimental, revert non-memory code edits, typechecking --- dimos/agents/vlm_agent.py | 16 +- dimos/core/test_blueprints.py | 12 +- dimos/models/vl/base.py | 9 - dimos/models/vl/moondream.py | 2 +- dimos/models/vl/openai.py | 16 +- dimos/models/vl/qwen.py | 14 +- dimos/perception/clip_filter.py | 250 ------------------ dimos/perception/experimental/clip_filter.py | 172 ++++++++++++ .../{ => experimental}/entity_graph_db.py | 4 +- .../{ => experimental}/temporal_memory.py | 14 +- .../temporal_memory_deploy.py | 2 +- .../temporal_memory_example.py | 4 +- .../temporal_utils/__init__.py | 0 .../temporal_utils/graph_utils.py | 2 +- .../temporal_utils/helpers.py | 4 +- .../temporal_utils/parsers.py | 0 .../temporal_utils/prompts.py | 0 .../temporal_utils/state.py | 0 .../test_temporal_memory_module.py | 2 +- dimos/robot/unitree/connection/go2.py | 3 +- .../unitree_webrtc/unitree_go2_blueprints.py | 2 +- 21 files changed, 217 insertions(+), 311 deletions(-) delete mode 100644 dimos/perception/clip_filter.py create mode 100644 dimos/perception/experimental/clip_filter.py rename dimos/perception/{ => experimental}/entity_graph_db.py (99%) rename dimos/perception/{ => experimental}/temporal_memory.py (98%) rename dimos/perception/{ => experimental}/temporal_memory_deploy.py (95%) rename dimos/perception/{ => experimental}/temporal_memory_example.py (96%) rename dimos/perception/{ => experimental}/temporal_utils/__init__.py (100%) rename dimos/perception/{ => experimental}/temporal_utils/graph_utils.py (99%) rename dimos/perception/{ => experimental}/temporal_utils/helpers.py (93%) rename dimos/perception/{ => experimental}/temporal_utils/parsers.py (100%) rename dimos/perception/{ => experimental}/temporal_utils/prompts.py (100%) rename dimos/perception/{ => experimental}/temporal_utils/state.py (100%) rename dimos/perception/{ => experimental}/test_temporal_memory_module.py (98%) diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 542fb4a180..d591d8ed0b 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -33,7 +33,7 @@ class VLMAgent(AgentSpec): query_stream: In[HumanMessage] answer_stream: Out[AIMessage] - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._llm = build_llm(self.config) self._latest_image: Image | None = None @@ -80,7 +80,7 @@ def _invoke(self, msg: HumanMessage, **kwargs: Any) -> AIMessage: return response # type: ignore[return-value] def _invoke_image( - self, image: Image, query: str, response_format: dict | None = None + self, image: Image, query: str, response_format: dict[str, Any] | None = None ) -> AIMessage: content = [{"type": "text", "text": query}, *image.agent_encode()] kwargs: dict[str, Any] = {} @@ -89,7 +89,7 @@ def _invoke_image( return self._invoke(HumanMessage(content=content), **kwargs) @rpc - def clear_history(self): # type: ignore[no-untyped-def] + def clear_history(self) -> None: self._history.clear() def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None: @@ -102,9 +102,7 @@ def history(self) -> list[AnyMessage]: return [self._system_message, *self._history] @rpc - def register_skills( # type: ignore[no-untyped-def] - self, container, run_implicit_name: str | None = None - ) -> None: + def register_skills(self, container: Any, run_implicit_name: str | None = None) -> None: logger.warning( "VLMAgent does not manage skills; register_skills is a no-op", container=str(container), @@ -112,12 +110,14 @@ def register_skills( # type: ignore[no-untyped-def] ) @rpc - def query(self, query: str): # type: ignore[no-untyped-def] + def query(self, query: str) -> str: response = self._invoke(HumanMessage(query)) return response.content @rpc - def query_image(self, image: Image, query: str, response_format: dict | None = None): # type: ignore[no-untyped-def] + def query_image( + self, image: Image, query: str, response_format: dict[str, Any] | None = None + ) -> str: response = self._invoke_image(image, query, response_format=response_format) return response.content diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index 54313f1a84..7a99a23abe 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -27,6 +27,7 @@ autoconnect, ) from dimos.core.core import rpc +from dimos.core.global_config import GlobalConfig from dimos.core.module import Module from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.rpc_client import RpcCall @@ -34,6 +35,11 @@ from dimos.core.transport import LCMTransport from dimos.protocol import pubsub +# Disable Rerun for tests (prevents viewer spawn and gRPC flush errors) +_BUILD_WITHOUT_RERUN = { + "global_config": GlobalConfig(rerun_enabled=False, viewer_backend="foxglove"), +} + class Scratch: pass @@ -161,7 +167,7 @@ def test_build_happy_path() -> None: blueprint_set = autoconnect(module_a(), module_b(), module_c()) - coordinator = blueprint_set.build() + coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) try: assert isinstance(coordinator, ModuleCoordinator) @@ -297,7 +303,7 @@ class TargetModule(Module): assert ("color_image", Data1) not in blueprint_set._all_name_types # Build and verify connections work - coordinator = blueprint_set.build() + coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) try: source_instance = coordinator.get_instance(SourceModule) @@ -350,7 +356,7 @@ def test_future_annotations_autoconnect() -> None: blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint()) - coordinator = blueprint_set.build() + coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN) try: out_instance = coordinator.get_instance(FutureModuleOut) diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index a7faae84a3..91b411f3b7 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -174,15 +174,6 @@ class VlModel(Captioner, Resource, Configurable[VlModelConfig]): default_config = VlModelConfig config: VlModelConfig - @abstractmethod - def is_set_up(self) -> None: - """Verify that the VLM is properly configured (e.g., API key is set). - - Raises: - ValueError: If the VLM is not properly configured - """ - ... - def _prepare_image(self, image: Image) -> tuple[Image, float]: """Prepare image for inference, applying any configured transformations. diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index f31611e867..cb74a59484 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -68,7 +68,7 @@ def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: # type return str(result) - def query_batch(self, images: list[Image], query: str, **kwargs) -> list[str]: # type: ignore[no-untyped-def] + def query_batch(self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs) -> list[str]: # type: ignore[no-untyped-def,override] """Query multiple images with the same question. Note: moondream2's batch_answer is not truly batched - it processes diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 4607e29bbe..69a53a2c3d 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -23,16 +23,6 @@ class OpenAIVlModel(VlModel): default_config = OpenAIVlModelConfig config: OpenAIVlModelConfig - def is_set_up(self) -> None: - """ - Verify that OpenAI API key is configured. - """ - api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") - if not api_key: - raise ValueError( - "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable" - ) - @cached_property def _client(self) -> OpenAI: api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") @@ -81,7 +71,7 @@ def query(self, image: Image | np.ndarray, query: str, response_format: dict | N response = self._client.chat.completions.create(**api_kwargs) - return response.choices[0].message.content # type: ignore[return-value] + return response.choices[0].message.content # type: ignore[return-value,no-any-return] def query_batch( self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any @@ -105,7 +95,9 @@ def query_batch( api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - return [response.choices[0].message.content] # type: ignore[list-item] + content = response.choices[0].message.content or "" + # Return one response per image (same response since API analyzes all images together) + return [content] * len(images) def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 2b3808211b..62fa2a0dab 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -22,16 +22,6 @@ class QwenVlModel(VlModel): default_config = QwenVlModelConfig config: QwenVlModelConfig - def is_set_up(self) -> None: - """ - Verify that Alibaba API key is configured. - """ - api_key = self.config.api_key or os.getenv("ALIBABA_API_KEY") - if not api_key: - raise ValueError( - "Alibaba API key must be provided or set in ALIBABA_API_KEY environment variable" - ) - @cached_property def _client(self) -> OpenAI: api_key = self.config.api_key or os.getenv("ALIBABA_API_KEY") @@ -102,7 +92,9 @@ def query_batch( api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - return [response.choices[0].message.content] # type: ignore[list-item] + content = response.choices[0].message.content or "" + # Return one response per image (same response since API analyzes all images together) + return [content] * len(images) def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/perception/clip_filter.py b/dimos/perception/clip_filter.py deleted file mode 100644 index 3996dbbdb8..0000000000 --- a/dimos/perception/clip_filter.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2025-2026 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -CLIP-based frame filtering for selecting diverse frames from video windows. - -Adapted from videorag/clip_filter.py - uses CLIP embeddings to select the most -visually diverse frames from a window, reducing VLM costs while maintaining coverage. -""" - -import logging -from typing import Any - -import numpy as np - -from dimos.msgs.sensor_msgs import Image -from dimos.utils.logging_config import setup_logger - -logger = setup_logger() - -# Optional CLIP imports -try: - import clip # type: ignore[import-untyped] - from PIL import Image as PILImage # type: ignore[import-untyped] - import torch # type: ignore[import-untyped] - - CLIP_AVAILABLE = True -except (ImportError, RuntimeError) as e: - CLIP_AVAILABLE = False - logger.warning( - f"CLIP not available: {e}. Install with: pip install torch torchvision openai-clip. " - "Frame filtering will fall back to simple sampling." - ) - - # Define stub for type annotations when PIL is not available - class PILImage: - Image = None - - -if CLIP_AVAILABLE: - - class CLIPFrameFilter: - """Filter video frames using CLIP embeddings to select diverse frames.""" - - def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): - """ - Initialize CLIP frame filter. - - Args: - model_name: CLIP model name (e.g., "ViT-B/32", "ViT-L/14") - device: Device to use ("cuda", "cpu", or None for auto-detect) - """ - if not CLIP_AVAILABLE: - raise ImportError( - "CLIP is not available. Install with: pip install torch torchvision openai-clip" - ) - - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Loading CLIP model {model_name} on {self.device}") - self.model, self.preprocess = clip.load(model_name, device=self.device) - logger.info("CLIP model loaded successfully") - - def _image_to_pil(self, image: Image) -> "PILImage.Image": - """Convert dimos Image to PIL Image.""" - # Get numpy array from dimos Image - img_array = image.data # Assumes Image has .data attribute with numpy array - - # Convert to PIL - return PILImage.fromarray(img_array) - - def _encode_images(self, images: list[Image]) -> "torch.Tensor": - """Encode images using CLIP. - - Args: - images: List of dimos Images - - Returns: - Tensor of normalized CLIP embeddings, shape (N, embedding_dim) - """ - # Convert to PIL and preprocess - pil_images = [self._image_to_pil(img) for img in images] - preprocessed = [self.preprocess(img) for img in pil_images] - - # Stack and encode - image_tensor = torch.stack(preprocessed).to(self.device) - with torch.no_grad(): - embeddings = self.model.encode_image(image_tensor) - # Normalize embeddings - embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) - - return embeddings - - def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: - """ - Select diverse frames using greedy farthest-point sampling. - - This selects frames that are maximally different from each other in CLIP - embedding space, ensuring good visual coverage of the window. - - Algorithm: - 1. Always include first frame (temporal anchor) - 2. Iteratively select frame most different from already-selected frames - 3. Continue until we have max_frames frames - - Args: - frames: List of Frame objects with .image attribute - max_frames: Maximum number of frames to select - - Returns: - List of selected Frame objects (subset of input frames) - """ - if len(frames) <= max_frames: - return frames - - # Extract images from frames - images = [f.image for f in frames] - - # Encode all images - embeddings = self._encode_images(images) - - # Greedy farthest-point sampling - selected_indices = [0] # Always include first frame - remaining_indices = list(range(1, len(frames))) - - while len(selected_indices) < max_frames and remaining_indices: - selected_embs = embeddings[selected_indices] - remaining_embs = embeddings[remaining_indices] - - # Compute similarities between remaining and selected - # Shape: (num_remaining, num_selected) - similarities = remaining_embs @ selected_embs.T - - # For each remaining frame, find its max similarity to any selected frame - # Shape: (num_remaining,) - max_similarities = similarities.max(dim=1)[0] - - # Select frame with minimum max similarity (most different from all selected) - best_idx = int(max_similarities.argmin().item()) - - selected_indices.append(remaining_indices[best_idx]) - remaining_indices.pop(best_idx) - - # Return frames in temporal order (sorted by index) - return [frames[i] for i in sorted(selected_indices)] - - def close(self) -> None: - """Clean up CLIP model.""" - if hasattr(self, "model"): - del self.model - if hasattr(self, "preprocess"): - del self.preprocess - - -def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: - """ - Fallback frame selection when CLIP is not available. - - Uses simple uniform sampling across the window. - - Args: - frames: List of Frame objects - max_frames: Maximum number of frames to select - - Returns: - List of selected Frame objects - """ - if len(frames) <= max_frames: - return frames - - # Sample uniformly across window - indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] - return [frames[i] for i in indices] - - -def adaptive_keyframes( - frames: list[Any], - min_frames: int = 3, - max_frames: int = 5, - change_threshold: float = 15.0, -) -> list[Any]: - """select frames based on visual change, adaptive count.""" - if len(frames) <= min_frames: - return frames - - # compute frame-to-frame differences - diffs = [] - for i in range(1, len(frames)): - prev = frames[i - 1].image.data.astype(float) - curr = frames[i].image.data.astype(float) - diffs.append(np.abs(curr - prev).mean()) - - total_motion = sum(diffs) - - # adaptive N: more motion → more frames - n_frames = int(np.clip(total_motion / change_threshold, min_frames, max_frames)) - - # pick frames at change peaks (local maxima) - # always include first and last - keyframe_indices = [0, len(frames) - 1] # always - - # find peaks in diff signal - for i in range(1, len(diffs) - 1): - if ( - diffs[i] > diffs[i - 1] - and diffs[i] > diffs[i + 1] - and diffs[i] > change_threshold * 0.5 - ): - keyframe_indices.append(i + 1) # +1 bc diff[i] is between frame i and i+1 - - # if too many peaks, subsample; if too few, add uniform samples - if len(keyframe_indices) > n_frames: - # keep first, last, and highest-diff peaks - middle_indices = [i for i in keyframe_indices if i not in (0, len(frames) - 1)] - middle_diffs = [diffs[i - 1] for i in middle_indices] - sorted_by_diff = sorted(zip(middle_diffs, middle_indices, strict=False), reverse=True) - keep = [idx for _, idx in sorted_by_diff[: n_frames - 2]] - keyframe_indices = sorted([0, *keep, len(frames) - 1]) - elif len(keyframe_indices) < n_frames: - # fill in uniformly from remaining candidates - needed = n_frames - len(keyframe_indices) - candidates = sorted(set(range(len(frames))) - set(keyframe_indices)) - if candidates: - # Calculate step, ensuring it's at least 1 - step = max(1, len(candidates) // (needed + 1)) - uniform_fill = candidates[::step][:needed] - keyframe_indices = sorted(set(keyframe_indices) | set(uniform_fill)) - - return [frames[i] for i in keyframe_indices] - - -__all__ = [ - "CLIP_AVAILABLE", - "CLIPFrameFilter" if CLIP_AVAILABLE else None, - "adaptive_keyframes", - "select_diverse_frames_simple", -] - -# Filter out None values -__all__ = [item for item in __all__ if item is not None] diff --git a/dimos/perception/experimental/clip_filter.py b/dimos/perception/experimental/clip_filter.py new file mode 100644 index 0000000000..b9b17a0dc0 --- /dev/null +++ b/dimos/perception/experimental/clip_filter.py @@ -0,0 +1,172 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLIP-based frame filtering for selecting diverse frames from video windows.""" + +from typing import Any + +import numpy as np + +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +try: + import clip # type: ignore + from PIL import Image as PILImage # type: ignore + import torch # type: ignore + + CLIP_AVAILABLE = True +except ImportError as e: + CLIP_AVAILABLE = False + logger.info(f"CLIP unavailable ({e}), using simple frame sampling") + + +def _get_image_data(image: Image) -> np.ndarray: + """Extract numpy array from Image.""" + if not hasattr(image, "data"): + raise AttributeError(f"Image missing .data attribute: {type(image)}") + return image.data + + +if CLIP_AVAILABLE: + + class CLIPFrameFilter: + """Filter video frames using CLIP embeddings for diversity.""" + + def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): + if not CLIP_AVAILABLE: + raise ImportError( + "CLIP not available. Install: pip install torch torchvision openai-clip" + ) + + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Loading CLIP {model_name} on {self.device}") + self.model, self.preprocess = clip.load(model_name, device=self.device) + + def _encode_images(self, images: list[Image]) -> "torch.Tensor": + """Encode images using CLIP.""" + pil_images = [PILImage.fromarray(_get_image_data(img)) for img in images] + preprocessed = torch.stack([self.preprocess(img) for img in pil_images]).to(self.device) + + with torch.no_grad(): + embeddings = self.model.encode_image(preprocessed) + embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) + + return embeddings # type: ignore + + def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: + """Select diverse frames using greedy farthest-point sampling in CLIP space.""" + if len(frames) <= max_frames: + return frames + + embeddings = self._encode_images([f.image for f in frames]) + + # Greedy farthest-point sampling + selected_indices = [0] # Always include first frame + remaining_indices = list(range(1, len(frames))) + + while len(selected_indices) < max_frames and remaining_indices: + # Compute similarities: (num_remaining, num_selected) + similarities = embeddings[remaining_indices] @ embeddings[selected_indices].T + # Find max similarity for each remaining frame + max_similarities = similarities.max(dim=1)[0] + # Select frame most different from all selected + best_idx = int(max_similarities.argmin().item()) + + selected_indices.append(remaining_indices[best_idx]) + remaining_indices.pop(best_idx) + + return [frames[i] for i in sorted(selected_indices)] + + def close(self) -> None: + """Clean up CLIP model.""" + if hasattr(self, "model"): + del self.model + if hasattr(self, "preprocess"): + del self.preprocess + + +def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: + """Fallback frame selection: uniform sampling across window.""" + if len(frames) <= max_frames: + return frames + indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] + return [frames[i] for i in indices] + + +def adaptive_keyframes( + frames: list[Any], + min_frames: int = 3, + max_frames: int = 5, + change_threshold: float = 15.0, +) -> list[Any]: + """Select frames based on visual change, adaptive count.""" + if len(frames) <= min_frames: + return frames + + # Compute frame-to-frame differences + try: + diffs = [ + np.abs( + _get_image_data(frames[i].image).astype(float) + - _get_image_data(frames[i - 1].image).astype(float) + ).mean() + for i in range(1, len(frames)) + ] + except (AttributeError, ValueError) as e: + logger.warning(f"Failed to compute frame diffs: {e}. Falling back to uniform sampling.") + return select_diverse_frames_simple(frames, max_frames) + + total_motion = sum(diffs) + n_frames = int(np.clip(total_motion / change_threshold, min_frames, max_frames)) + + # Always include first and last + keyframe_indices = {0, len(frames) - 1} + + # Add peaks in diff signal + for i in range(1, len(diffs) - 1): + if ( + diffs[i] > diffs[i - 1] + and diffs[i] > diffs[i + 1] + and diffs[i] > change_threshold * 0.5 + ): + keyframe_indices.add(i + 1) + + # Adjust count + if len(keyframe_indices) > n_frames: + # Keep first, last, and highest-diff peaks + middle = [i for i in keyframe_indices if i not in (0, len(frames) - 1)] + middle_by_diff = sorted(middle, key=lambda i: diffs[i - 1], reverse=True) + keyframe_indices = {0, len(frames) - 1, *middle_by_diff[: n_frames - 2]} + elif len(keyframe_indices) < n_frames: + # Fill uniformly from remaining + needed = n_frames - len(keyframe_indices) + candidates = sorted(set(range(len(frames))) - keyframe_indices) + if candidates: + step = max(1, len(candidates) // (needed + 1)) + keyframe_indices.update(candidates[::step][:needed]) + + return [frames[i] for i in sorted(keyframe_indices)] + + +__all__ = [ + "CLIP_AVAILABLE", + "adaptive_keyframes", + "select_diverse_frames_simple", +] + +if CLIP_AVAILABLE: + __all__.append("CLIPFrameFilter") diff --git a/dimos/perception/entity_graph_db.py b/dimos/perception/experimental/entity_graph_db.py similarity index 99% rename from dimos/perception/entity_graph_db.py rename to dimos/perception/experimental/entity_graph_db.py index 06d9fbc76a..f1e6b75938 100644 --- a/dimos/perception/entity_graph_db.py +++ b/dimos/perception/experimental/entity_graph_db.py @@ -67,7 +67,7 @@ def __init__(self, db_path: str | Path) -> None: def _get_connection(self) -> sqlite3.Connection: """Get thread-local database connection.""" if not hasattr(self._local, "conn"): - self._local.conn = sqlite3.connect(str(self.db_path), check_same_thread=False) + self._local.conn = sqlite3.connect(str(self.db_path)) self._local.conn.row_factory = sqlite3.Row return self._local.conn # type: ignore @@ -948,7 +948,7 @@ def estimate_and_save_distances( return # Import here to avoid circular dependency - from dimos.perception import temporal_utils as tu + from dimos.perception.experimental import temporal_utils as tu # Collect entities with descriptors # new_entities have descriptors from VLM diff --git a/dimos/perception/temporal_memory.py b/dimos/perception/experimental/temporal_memory.py similarity index 98% rename from dimos/perception/temporal_memory.py rename to dimos/perception/experimental/temporal_memory.py index 61906498e0..44653f4492 100644 --- a/dimos/perception/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory.py @@ -39,18 +39,18 @@ from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from dimos.perception import temporal_utils as tu -from dimos.perception.clip_filter import ( +from dimos.perception.experimental import temporal_utils as tu +from dimos.perception.experimental.clip_filter import ( CLIP_AVAILABLE, adaptive_keyframes, select_diverse_frames_simple, ) try: - from dimos.perception.clip_filter import CLIPFrameFilter + from dimos.perception.experimental.clip_filter import CLIPFrameFilter except ImportError: CLIPFrameFilter = type(None) # type: ignore[misc,assignment] -from dimos.perception.entity_graph_db import EntityGraphDB +from dimos.perception.experimental.entity_graph_db import EntityGraphDB from dimos.utils.logging_config import setup_logger logger = setup_logger() @@ -270,8 +270,10 @@ def stop(self) -> None: # Stop all stream transports to clean up LCM/shared memory threads for stream in list(self.inputs.values()) + list(self.outputs.values()): if stream.transport is not None and hasattr(stream.transport, "stop"): - stream.transport.stop() - stream._transport = None # type: ignore[attr-defined,assignment] + try: + stream.transport.stop() + except Exception as e: + logger.warning(f"Failed to stop stream transport: {e}") logger.info("temporalmemory stopped") diff --git a/dimos/perception/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory_deploy.py similarity index 95% rename from dimos/perception/temporal_memory_deploy.py rename to dimos/perception/experimental/temporal_memory_deploy.py index 668977d8b7..2d58fc4e55 100644 --- a/dimos/perception/temporal_memory_deploy.py +++ b/dimos/perception/experimental/temporal_memory_deploy.py @@ -21,7 +21,7 @@ from dimos import spec from dimos.core import DimosCluster from dimos.models.vl.base import VlModel -from dimos.perception.temporal_memory import TemporalMemory, TemporalMemoryConfig +from dimos.perception.experimental.temporal_memory import TemporalMemory, TemporalMemoryConfig def deploy( diff --git a/dimos/perception/temporal_memory_example.py b/dimos/perception/experimental/temporal_memory_example.py similarity index 96% rename from dimos/perception/temporal_memory_example.py rename to dimos/perception/experimental/temporal_memory_example.py index 791c96c3e1..c553f0de00 100644 --- a/dimos/perception/temporal_memory_example.py +++ b/dimos/perception/experimental/temporal_memory_example.py @@ -30,8 +30,8 @@ from dimos import core from dimos.hardware.sensors.camera.module import CameraModule from dimos.hardware.sensors.camera.webcam import Webcam -from dimos.perception.temporal_memory import TemporalMemoryConfig -from dimos.perception.temporal_memory_deploy import deploy +from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig +from dimos.perception.experimental.temporal_memory_deploy import deploy # Load environment variables load_dotenv() diff --git a/dimos/perception/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_utils/__init__.py similarity index 100% rename from dimos/perception/temporal_utils/__init__.py rename to dimos/perception/experimental/temporal_utils/__init__.py diff --git a/dimos/perception/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_utils/graph_utils.py similarity index 99% rename from dimos/perception/temporal_utils/graph_utils.py rename to dimos/perception/experimental/temporal_utils/graph_utils.py index 075641021b..500516b51d 100644 --- a/dimos/perception/temporal_utils/graph_utils.py +++ b/dimos/perception/experimental/temporal_utils/graph_utils.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image - from dimos.perception.entity_graph_db import EntityGraphDB + from dimos.perception.experimental.entity_graph_db import EntityGraphDB logger = setup_logger() diff --git a/dimos/perception/temporal_utils/helpers.py b/dimos/perception/experimental/temporal_utils/helpers.py similarity index 93% rename from dimos/perception/temporal_utils/helpers.py rename to dimos/perception/experimental/temporal_utils/helpers.py index ecaa4cc2d3..dccdd55d46 100644 --- a/dimos/perception/temporal_utils/helpers.py +++ b/dimos/perception/experimental/temporal_utils/helpers.py @@ -19,7 +19,7 @@ import numpy as np if TYPE_CHECKING: - from dimos.perception.temporal_memory import Frame + from dimos.perception.experimental.temporal_memory import Frame def next_entity_id_hint(roster: Any) -> str: @@ -68,5 +68,7 @@ def is_scene_stale(frames: list["Frame"], stale_threshold: float = 5.0) -> bool: last_img = frames[-1].image if first_img is None or last_img is None: return False + if not hasattr(first_img, "data") or not hasattr(last_img, "data"): + return False diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) return bool(diff.mean() < stale_threshold) diff --git a/dimos/perception/temporal_utils/parsers.py b/dimos/perception/experimental/temporal_utils/parsers.py similarity index 100% rename from dimos/perception/temporal_utils/parsers.py rename to dimos/perception/experimental/temporal_utils/parsers.py diff --git a/dimos/perception/temporal_utils/prompts.py b/dimos/perception/experimental/temporal_utils/prompts.py similarity index 100% rename from dimos/perception/temporal_utils/prompts.py rename to dimos/perception/experimental/temporal_utils/prompts.py diff --git a/dimos/perception/temporal_utils/state.py b/dimos/perception/experimental/temporal_utils/state.py similarity index 100% rename from dimos/perception/temporal_utils/state.py rename to dimos/perception/experimental/temporal_utils/state.py diff --git a/dimos/perception/test_temporal_memory_module.py b/dimos/perception/experimental/test_temporal_memory_module.py similarity index 98% rename from dimos/perception/test_temporal_memory_module.py rename to dimos/perception/experimental/test_temporal_memory_module.py index 45750fd139..08232cb11a 100644 --- a/dimos/perception/test_temporal_memory_module.py +++ b/dimos/perception/experimental/test_temporal_memory_module.py @@ -26,7 +26,7 @@ from dimos.core import Module, Out, rpc from dimos.models.vl.openai import OpenAIVlModel from dimos.msgs.sensor_msgs import Image -from dimos.perception.temporal_memory import TemporalMemory, TemporalMemoryConfig +from dimos.perception.experimental.temporal_memory import TemporalMemory, TemporalMemoryConfig from dimos.protocol import pubsub from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index 0e4c9a5d7a..29792c5203 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -77,8 +77,7 @@ def _camera_info_static() -> CameraInfo: class ReplayConnection(UnitreeWebRTCConnection): - # dir_name = "unitree_go2_bigoffice" - dir_name = "unitree_go2_office_walk2" + dir_name = "unitree_go2_bigoffice" # we don't want UnitreeWebRTCConnection to init def __init__( # type: ignore[no-untyped-def] diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index 2dcd1cf812..be53702412 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -46,8 +46,8 @@ replanning_a_star_planner, ) from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.experimental.temporal_memory import temporal_memory from dimos.perception.spatial_perception import spatial_memory -from dimos.perception.temporal_memory import temporal_memory from dimos.protocol.mcp.mcp import MCPModule from dimos.robot.foxglove_bridge import foxglove_bridge import dimos.robot.unitree.connection.go2 as _go2_mod From 70032887b3cb2b0b0172e93a1a95444f6b81ff4e Mon Sep 17 00:00:00 2001 From: clairebookworm Date: Wed, 14 Jan 2026 19:15:58 -0800 Subject: [PATCH 14/21] persistent db flag enabled in config --- dimos/agents/vlm_agent.py | 6 +- dimos/models/vl/openai.py | 4 +- dimos/models/vl/qwen.py | 4 +- dimos/perception/experimental/clip_filter.py | 6 +- .../experimental/temporal_memory.py | 93 +++++++++++++------ 5 files changed, 74 insertions(+), 39 deletions(-) diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index d591d8ed0b..b523cfbaf8 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -112,14 +112,16 @@ def register_skills(self, container: Any, run_implicit_name: str | None = None) @rpc def query(self, query: str) -> str: response = self._invoke(HumanMessage(query)) - return response.content + content = response.content + return content if isinstance(content, str) else str(content) @rpc def query_image( self, image: Image, query: str, response_format: dict[str, Any] | None = None ) -> str: response = self._invoke_image(image, query, response_format=response_format) - return response.content + content = response.content + return content if isinstance(content, str) else str(content) vlm_agent = VLMAgent.blueprint diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py index 69a53a2c3d..f596f1ee1e 100644 --- a/dimos/models/vl/openai.py +++ b/dimos/models/vl/openai.py @@ -95,9 +95,9 @@ def query_batch( api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - content = response.choices[0].message.content or "" + response_text = response.choices[0].message.content or "" # Return one response per image (same response since API analyzes all images together) - return [content] * len(images) + return [response_text] * len(images) def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index 62fa2a0dab..93b31bf74c 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -92,9 +92,9 @@ def query_batch( api_kwargs["response_format"] = response_format response = self._client.chat.completions.create(**api_kwargs) - content = response.choices[0].message.content or "" + response_text = response.choices[0].message.content or "" # Return one response per image (same response since API analyzes all images together) - return [content] * len(images) + return [response_text] * len(images) def stop(self) -> None: """Release the OpenAI client.""" diff --git a/dimos/perception/experimental/clip_filter.py b/dimos/perception/experimental/clip_filter.py index b9b17a0dc0..4273278865 100644 --- a/dimos/perception/experimental/clip_filter.py +++ b/dimos/perception/experimental/clip_filter.py @@ -14,7 +14,7 @@ """CLIP-based frame filtering for selecting diverse frames from video windows.""" -from typing import Any +from typing import Any, cast import numpy as np @@ -34,11 +34,11 @@ logger.info(f"CLIP unavailable ({e}), using simple frame sampling") -def _get_image_data(image: Image) -> np.ndarray: +def _get_image_data(image: Image) -> np.ndarray[Any, Any]: """Extract numpy array from Image.""" if not hasattr(image, "data"): raise AttributeError(f"Image missing .data attribute: {type(image)}") - return image.data + return cast("np.ndarray[Any, Any]", image.data) if CLIP_AVAILABLE: diff --git a/dimos/perception/experimental/temporal_memory.py b/dimos/perception/experimental/temporal_memory.py index 44653f4492..efc85b1608 100644 --- a/dimos/perception/experimental/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory.py @@ -128,21 +128,26 @@ def __init__( self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size) self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) self._frame_count = 0 - self._last_analysis_time = -float("inf") # Allow first analysis immediately + # Start at -inf so first analysis passes stride_s check regardless of elapsed time + self._last_analysis_time = -float("inf") self._video_start_wall_time: float | None = None - # clip filter + # Track background distance estimation threads + self._distance_threads: list[threading.Thread] = [] + + # clip filter - use instance state to avoid mutating shared config self._clip_filter: CLIPFrameFilter | None = None - if self.config.use_clip_filtering and CLIP_AVAILABLE: + self._use_clip_filtering = self.config.use_clip_filtering + if self._use_clip_filtering and CLIP_AVAILABLE: try: self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model) logger.info("clip filtering enabled") except Exception as e: logger.warning(f"clip init failed: {e}") - self.config.use_clip_filtering = False - elif self.config.use_clip_filtering: + self._use_clip_filtering = False + elif self._use_clip_filtering: logger.warning("clip not available") - self.config.use_clip_filtering = False + self._use_clip_filtering = False # output directory self._graph_db: EntityGraphDB | None @@ -154,8 +159,18 @@ def __init__( self._entities_file = self._output_path / "entities.json" self._frames_index_file = self._output_path / "frames_index.jsonl" - # Initialize entity graph database - self._graph_db = EntityGraphDB(db_path=self._output_path / "entity_graph.db") + db_path = self._output_path / "entity_graph.db" + if not self.config.persistent_memory or self.config.clear_memory_on_start: + if db_path.exists(): + db_path.unlink() + reason = ( + "non-persistent mode" + if not self.config.persistent_memory + else "clear_memory_on_start=True" + ) + logger.info(f"Deleted existing database: {reason}") + + self._graph_db = EntityGraphDB(db_path=db_path) logger.info(f"artifacts save to: {self._output_path}") else: @@ -172,7 +187,6 @@ def vlm(self) -> VlModel: if self._vlm is None: from dimos.models.vl.openai import OpenAIVlModel - # Load API key from environment api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError( @@ -210,7 +224,6 @@ def on_frame(image: Image) -> None: self._frame_buffer.append(frame) self._frame_count += 1 - # pipe through sharpness filter before buffering frame_subject: Subject[Image] = Subject() self._disposables.add( frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(on_frame) @@ -244,22 +257,30 @@ def stop(self) -> None: logger.error(f"save failed during stop: {e}", exc_info=True) self.save_frames_index() - - # Set stopped flag and clear state with self._state_lock: self._stopped = True - # Save and close graph database + # Wait for background distance estimation threads to complete before closing DB + if self._distance_threads: + logger.info(f"Waiting for {len(self._distance_threads)} distance estimation threads...") + for thread in self._distance_threads: + thread.join(timeout=10.0) # Wait max 10s per thread + self._distance_threads.clear() + if self._graph_db: + db_path = self._graph_db.db_path self._graph_db.commit() # save all pending transactions self._graph_db.close() self._graph_db = None + if not self.config.persistent_memory and db_path.exists(): + db_path.unlink() + logger.info("Deleted non-persistent database") + if self._clip_filter: self._clip_filter.close() self._clip_filter = None - # Clear buffers to release image references with self._state_lock: self._frame_buffer.clear() self._recent_windows.clear() @@ -268,6 +289,7 @@ def stop(self) -> None: super().stop() # Stop all stream transports to clean up LCM/shared memory threads + # Note: We use public stream.transport API and rely on transport.stop() to clean up for stream in list(self.inputs.values()) + list(self.outputs.values()): if stream.transport is not None and hasattr(stream.transport, "stop"): try: @@ -356,18 +378,21 @@ def _analyze_window(self) -> None: parsed = tu.parse_window_response(response_text, w_start, w_end, len(window_frames)) if "_error" in parsed: logger.error(f"parse error: {parsed['_error']}") - else: - logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") + # else: + # logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") # Start distance estimation in background if self._graph_db and window_frames and self.config.enable_distance_estimation: mid_frame = window_frames[len(window_frames) // 2] if mid_frame.image: - threading.Thread( + thread = threading.Thread( target=self._graph_db.estimate_and_save_distances, args=(parsed, mid_frame.image, self.vlm, w_end, self.config.max_distance_pairs), daemon=True, - ).start() + ) + thread.start() + self._distance_threads = [t for t in self._distance_threads if t.is_alive()] + self._distance_threads.append(thread) # Update temporal state with self._state_lock: @@ -488,19 +513,24 @@ def query(self, question: str) -> str: } # enhance context with graph database knowledge - if self._graph_db and currently_present: + if self._graph_db: # Extract time window from question using VLM time_window_s = tu.extract_time_window(question, self.vlm, latest_frame) - graph_context = tu.build_graph_context( - graph_db=self._graph_db, - entity_ids=list(currently_present), - time_window_s=time_window_s, - max_relations_per_entity=self.config.max_relations_per_entity, - nearby_distance_meters=self.config.nearby_distance_meters, - current_video_time_s=current_video_time_s, - ) - context["graph_knowledge"] = graph_context + # Query graph for ALL entities in roster (not just currently present) + # This allows queries about entities that disappeared or were seen in the past + all_entity_ids = [e["id"] for e in entity_roster if isinstance(e, dict) and "id" in e] + + if all_entity_ids: + graph_context = tu.build_graph_context( + graph_db=self._graph_db, + entity_ids=all_entity_ids, + time_window_s=time_window_s, + max_relations_per_entity=self.config.max_relations_per_entity, + nearby_distance_meters=self.config.nearby_distance_meters, + current_video_time_s=current_video_time_s, + ) + context["graph_knowledge"] = graph_context # build query prompt using temporal utils prompt = tu.build_query_prompt(question=question, context=context) @@ -552,7 +582,10 @@ def get_rolling_summary(self) -> str: @rpc def get_graph_db_stats(self) -> dict[str, Any]: - """Get statistics and sample data from the graph database.""" + """Get statistics and sample data from the graph database. + + Returns empty structures when no database is available (no-error pattern). + """ if not self._graph_db: return {"stats": {}, "entities": [], "recent_relations": []} return self._graph_db.get_summary() @@ -619,7 +652,7 @@ def save_frames_index(self) -> bool: with open(self._frames_index_file, "w", encoding="utf-8") as f: for rec in frames_index: f.write(json.dumps(rec, ensure_ascii=False) + "\n") - logger.info(f"saved {len(frames_index)} frames") + logger.info(f"saved {len(frames_index)} frames") return True except Exception as e: logger.error(f"save frames failed: {e}", exc_info=True) From fef34b78f0acb828a20654845dfaaf075ca5a66f Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 03:00:34 -0800 Subject: [PATCH 15/21] Fix test to not run in CI due to LFS pull --- dimos/models/vl/base.py | 32 ++++--------------- .../test_temporal_memory_module.py | 4 +-- 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index 91b411f3b7..93caba4de7 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -2,16 +2,11 @@ from dataclasses import dataclass import json import logging -from typing import Any import warnings from dimos.core.resource import Resource from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.type import ( - Detection2DBBox, - Detection2DPoint, - ImageDetections2D, -) +from dimos.perception.detection.type import Detection2DBBox, Detection2DPoint, ImageDetections2D from dimos.protocol.service import Configurable # type: ignore[attr-defined] from dimos.utils.data import get_data from dimos.utils.decorators import retry @@ -185,37 +180,28 @@ def _prepare_image(self, image: Image) -> tuple[Image, float]: return image.resize_to_fit(max_w, max_h) return image, 1.0 - # Note: No custom pickle methods needed. In practice, VlModel instances - # are only stored in SkillModules, which use empty-shell pickling - # (SkillModule.__getstate__ returns None). Therefore VlModel is never - # actually pickled and doesn't need to handle unpicklable _client attributes. - @abstractmethod def query(self, image: Image, query: str, **kwargs) -> str: ... # type: ignore[no-untyped-def] - def query_batch( - self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any - ) -> list[str]: # type: ignore[no-untyped-def] + def query_batch(self, images: list[Image], query: str, **kwargs) -> list[str]: # type: ignore[no-untyped-def] """Query multiple images with the same question. Default implementation calls query() for each image sequentially. - Subclasses may override for efficient batched inference. + Subclasses may override for more efficient batched inference. Args: images: List of input images - query: Question to ask about all images - response_format: Optional response format for structured output - **kwargs: Additional arguments + query: Question to ask about each image Returns: List of responses, one per image """ warnings.warn( - f"{self.__class__.__name__}.query_batch() using sequential implementation. " + f"{self.__class__.__name__}.query_batch() is using default sequential implementation. " "Override for efficient batched inference.", stacklevel=2, ) - return [self.query(image, query, response_format=response_format, **kwargs) for image in images] + return [self.query(image, query, **kwargs) for image in images] def query_multi(self, image: Image, queries: list[str], **kwargs) -> list[str]: # type: ignore[no-untyped-def] """Query a single image with multiple different questions. @@ -343,11 +329,7 @@ def query_points( for track_id, point_tuple in enumerate(point_tuples): # Scale coordinates back to original image size if resized - if ( - scale != 1.0 - and isinstance(point_tuple, (list, tuple)) - and len(point_tuple) == 3 - ): + if scale != 1.0 and isinstance(point_tuple, (list, tuple)) and len(point_tuple) == 3: point_tuple = [ point_tuple[0], # label point_tuple[1] / scale, # x diff --git a/dimos/perception/experimental/test_temporal_memory_module.py b/dimos/perception/experimental/test_temporal_memory_module.py index 08232cb11a..b4af1b4020 100644 --- a/dimos/perception/experimental/test_temporal_memory_module.py +++ b/dimos/perception/experimental/test_temporal_memory_module.py @@ -79,9 +79,7 @@ def stop(self) -> None: logger.info("VideoReplayModule stopped") -@pytest.mark.lcm -@pytest.mark.gpu -@pytest.mark.module +@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM replay + dataset not CI-safe.") @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") class TestTemporalMemoryModule: @pytest.fixture(scope="function") From 73476932ca2bbbf7466a814805c5bbf5369ed5de Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 03:37:00 -0800 Subject: [PATCH 16/21] Fix CLIP filter to use dimensional clip --- dimos/perception/experimental/clip_filter.py | 39 ++++++++++---------- 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/dimos/perception/experimental/clip_filter.py b/dimos/perception/experimental/clip_filter.py index 4273278865..8faac3fad8 100644 --- a/dimos/perception/experimental/clip_filter.py +++ b/dimos/perception/experimental/clip_filter.py @@ -24,10 +24,10 @@ logger = setup_logger() try: - import clip # type: ignore - from PIL import Image as PILImage # type: ignore import torch # type: ignore + from dimos.models.embedding.clip import CLIPModel # type: ignore + CLIP_AVAILABLE = True except ImportError as e: CLIP_AVAILABLE = False @@ -48,24 +48,24 @@ class CLIPFrameFilter: def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): if not CLIP_AVAILABLE: - raise ImportError( - "CLIP not available. Install: pip install torch torchvision openai-clip" - ) + raise ImportError("CLIP not available. Install transformers[torch].") - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - logger.info(f"Loading CLIP {model_name} on {self.device}") - self.model, self.preprocess = clip.load(model_name, device=self.device) + resolved_name = ( + "openai/clip-vit-base-patch32" if model_name == "ViT-B/32" else model_name + ) + if device is None: + self._model = CLIPModel(model_name=resolved_name) + else: + self._model = CLIPModel(model_name=resolved_name, device=device) + logger.info(f"Loading CLIP {resolved_name} on {self._model.device}") def _encode_images(self, images: list[Image]) -> "torch.Tensor": """Encode images using CLIP.""" - pil_images = [PILImage.fromarray(_get_image_data(img)) for img in images] - preprocessed = torch.stack([self.preprocess(img) for img in pil_images]).to(self.device) - - with torch.no_grad(): - embeddings = self.model.encode_image(preprocessed) - embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) - - return embeddings # type: ignore + embeddings = self._model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + vectors = [e.to_torch(self._model.device) for e in embeddings] + return torch.stack(vectors) def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: """Select diverse frames using greedy farthest-point sampling in CLIP space.""" @@ -93,10 +93,9 @@ def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[ def close(self) -> None: """Clean up CLIP model.""" - if hasattr(self, "model"): - del self.model - if hasattr(self, "preprocess"): - del self.preprocess + if hasattr(self, "_model"): + self._model.stop() + del self._model def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: From 52c2fd86a31419dead1301bd63ac9b0adf2d9345 Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 03:39:26 -0800 Subject: [PATCH 17/21] Add path to temporal memory --- dimos/perception/experimental/temporal_memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/perception/experimental/temporal_memory.py b/dimos/perception/experimental/temporal_memory.py index efc85b1608..4f21e56aaf 100644 --- a/dimos/perception/experimental/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory.py @@ -77,7 +77,7 @@ class TemporalMemoryConfig(ModuleConfig): frame_buffer_size: int = 50 # Output - output_dir: str | Path | None = None + output_dir: str | Path | None = "assets/temporal_memory" # VLM parameters max_tokens: int = 900 From 0aa85458fb0201eaf5a11ef2f2b112f6fe00c291 Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 03:41:20 -0800 Subject: [PATCH 18/21] revert video operators --- dimos/stream/video_operators.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/stream/video_operators.py b/dimos/stream/video_operators.py index 558972e155..548bba7598 100644 --- a/dimos/stream/video_operators.py +++ b/dimos/stream/video_operators.py @@ -16,7 +16,7 @@ from collections.abc import Callable from datetime import datetime, timedelta from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import cv2 import numpy as np From 6a01797d5d0840fc071a15bd15c6f2e8039eae1c Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 03:42:53 -0800 Subject: [PATCH 19/21] Revert moondream --- dimos/models/vl/moondream.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index cb74a59484..f31611e867 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -68,7 +68,7 @@ def query(self, image: Image | np.ndarray, query: str, **kwargs) -> str: # type return str(result) - def query_batch(self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs) -> list[str]: # type: ignore[no-untyped-def,override] + def query_batch(self, images: list[Image], query: str, **kwargs) -> list[str]: # type: ignore[no-untyped-def] """Query multiple images with the same question. Note: moondream2's batch_answer is not truly batched - it processes From 6b2175dfbb9242bf26a9a96c0e2a207076d78f48 Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 04:18:11 -0800 Subject: [PATCH 20/21] added temporal memory docs --- dimos/perception/experimental/README.md | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 dimos/perception/experimental/README.md diff --git a/dimos/perception/experimental/README.md b/dimos/perception/experimental/README.md new file mode 100644 index 0000000000..9ef5f6cb22 --- /dev/null +++ b/dimos/perception/experimental/README.md @@ -0,0 +1,32 @@ +Temporal memory runs "Temporal/Spatial RAG" on streamed videos building an continuous entity-based +memory over time. It uses a VLM to extract evidence in sliding windows, tracks +entities across windows, maintains a rolling summary, and stores relations in a graph network. + +Methodology +1) Sample frames at a target FPS and analyze them in sliding windows. +2) Extract dense evidence with a VLM (caption + entities + relations). +3) Update rolling summary for global context. +4) Persist per-window evidence + entity graph for query-time context. + +Setup +- Put your OpenAI key in `.env`: + `OPENAI_API_KEY=...` +- Install dimensional dependencies + +Quickstart +To run: `dimos --replay run unitree-go2-temporal-memory` + +In another terminal: `humancli` to chat with the agent and run memory queries. + +Artifacts +By default, artifacts are written under `assets/temporal_memory`: +- `evidence.jsonl` (window evidence: captions, entities, relations) +- `state.json` (rolling summary + roster state) +- `entities.json` (current entity roster) +- `frames_index.jsonl` (timestamps for saved frames; written on stop) +- `entity_graph.db` (SQLite graph of relations/distances) + +Notes +- Evidence is extracted in sliding windows, so queries can refer to recent or past entities. +- Distance estimation can run in the background to enrich graph relations. +- If you want a different output directory, set `TemporalMemoryConfig(output_dir=...)`. From 0d156557c2b26adcdb5a2bd22590b758ee46dc70 Mon Sep 17 00:00:00 2001 From: stash Date: Thu, 15 Jan 2026 04:37:33 -0800 Subject: [PATCH 21/21] Refactor move to /experimental/temporal_memory --- dimos/perception/experimental/__init__.py | 15 +++++++ .../{ => temporal_memory}/README.md | 0 .../experimental/temporal_memory/__init__.py | 24 ++++++++++++ .../{ => temporal_memory}/clip_filter.py | 0 .../{ => temporal_memory}/entity_graph_db.py | 2 +- .../temporal_memory/temporal_memory.md | 39 +++++++++++++++++++ .../{ => temporal_memory}/temporal_memory.py | 10 +++-- .../temporal_memory_deploy.py | 3 +- .../temporal_memory_example.py | 5 ++- .../temporal_utils/__init__.py | 0 .../temporal_utils/graph_utils.py | 3 +- .../temporal_utils/helpers.py | 2 +- .../temporal_utils/parsers.py | 0 .../temporal_utils/prompts.py | 0 .../temporal_utils/state.py | 0 .../test_temporal_memory_module.py | 0 16 files changed, 93 insertions(+), 10 deletions(-) create mode 100644 dimos/perception/experimental/__init__.py rename dimos/perception/experimental/{ => temporal_memory}/README.md (100%) create mode 100644 dimos/perception/experimental/temporal_memory/__init__.py rename dimos/perception/experimental/{ => temporal_memory}/clip_filter.py (100%) rename dimos/perception/experimental/{ => temporal_memory}/entity_graph_db.py (99%) create mode 100644 dimos/perception/experimental/temporal_memory/temporal_memory.md rename dimos/perception/experimental/{ => temporal_memory}/temporal_memory.py (98%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_memory_deploy.py (95%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_memory_example.py (96%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/__init__.py (100%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/graph_utils.py (99%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/helpers.py (97%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/parsers.py (100%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/prompts.py (100%) rename dimos/perception/experimental/{ => temporal_memory}/temporal_utils/state.py (100%) rename dimos/perception/experimental/{ => temporal_memory}/test_temporal_memory_module.py (100%) diff --git a/dimos/perception/experimental/__init__.py b/dimos/perception/experimental/__init__.py new file mode 100644 index 0000000000..39ef33521d --- /dev/null +++ b/dimos/perception/experimental/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental perception modules.""" diff --git a/dimos/perception/experimental/README.md b/dimos/perception/experimental/temporal_memory/README.md similarity index 100% rename from dimos/perception/experimental/README.md rename to dimos/perception/experimental/temporal_memory/README.md diff --git a/dimos/perception/experimental/temporal_memory/__init__.py b/dimos/perception/experimental/temporal_memory/__init__.py new file mode 100644 index 0000000000..3cc61601ce --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Temporal memory package.""" + +from .temporal_memory import Frame, TemporalMemory, TemporalMemoryConfig, temporal_memory + +__all__ = [ + "Frame", + "TemporalMemory", + "TemporalMemoryConfig", + "temporal_memory", +] diff --git a/dimos/perception/experimental/clip_filter.py b/dimos/perception/experimental/temporal_memory/clip_filter.py similarity index 100% rename from dimos/perception/experimental/clip_filter.py rename to dimos/perception/experimental/temporal_memory/clip_filter.py diff --git a/dimos/perception/experimental/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py similarity index 99% rename from dimos/perception/experimental/entity_graph_db.py rename to dimos/perception/experimental/temporal_memory/entity_graph_db.py index f1e6b75938..7109459f40 100644 --- a/dimos/perception/experimental/entity_graph_db.py +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -948,7 +948,7 @@ def estimate_and_save_distances( return # Import here to avoid circular dependency - from dimos.perception.experimental import temporal_utils as tu + from . import temporal_utils as tu # Collect entities with descriptors # new_entities have descriptors from VLM diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.md b/dimos/perception/experimental/temporal_memory/temporal_memory.md new file mode 100644 index 0000000000..0eaa3df893 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.md @@ -0,0 +1,39 @@ +Dimensional Temporal Memory is a lightweight Video RAG pipeline for building +entity-centric memory over live or replayed video streams. It uses a VLM to +extract evidence in sliding windows, tracks entities across time, maintains a +rolling summary, and persists relations in a compact graph for query-time context. + +How It Works +1) Sample frames at a target FPS and analyze them in sliding windows. +2) Extract dense evidence with a VLM (caption + entities + relations). +3) Update a rolling summary for global context. +4) Persist per-window evidence and the entity graph for fast queries. + +Setup +- Add your OpenAI key to `.env`: + `OPENAI_API_KEY=...` +- Install dependencies (recommended set from repo install guide): + `uv sync --extra dev --extra cpu --extra sim --extra drone` + +`uv sync` installs the locked dependency set from `uv.lock` to match the repo's +known-good environment. `uv pip install ...` behaves like pip (ad-hoc installs) +and can drift from the lockfile. + +Quickstart +- Run Temporal Memory on a replay: + `dimos --replay run unitree-go2-temporal-memory` +- In another terminal, open a chat session: + `humancli` + +Artifacts +By default, artifacts are written under `assets/temporal_memory`: +- `evidence.jsonl` (window evidence: captions, entities, relations) +- `state.json` (rolling summary + roster state) +- `entities.json` (current entity roster) +- `frames_index.jsonl` (timestamps for saved frames; written on stop) +- `entity_graph.db` (SQLite graph of relations/distances) + +Notes +- Evidence is extracted in sliding windows; queries can reference recent or past entities. +- Distance estimation can run in the background to enrich graph relations. +- Change the output location via `TemporalMemoryConfig(output_dir=...)`. diff --git a/dimos/perception/experimental/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py similarity index 98% rename from dimos/perception/experimental/temporal_memory.py rename to dimos/perception/experimental/temporal_memory/temporal_memory.py index 4f21e56aaf..a328186173 100644 --- a/dimos/perception/experimental/temporal_memory.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -39,20 +39,22 @@ from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier -from dimos.perception.experimental import temporal_utils as tu -from dimos.perception.experimental.clip_filter import ( + +from . import temporal_utils as tu +from .clip_filter import ( CLIP_AVAILABLE, adaptive_keyframes, select_diverse_frames_simple, ) try: - from dimos.perception.experimental.clip_filter import CLIPFrameFilter + from .clip_filter import CLIPFrameFilter except ImportError: CLIPFrameFilter = type(None) # type: ignore[misc,assignment] -from dimos.perception.experimental.entity_graph_db import EntityGraphDB from dimos.utils.logging_config import setup_logger +from .entity_graph_db import EntityGraphDB + logger = setup_logger() # Constants diff --git a/dimos/perception/experimental/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py similarity index 95% rename from dimos/perception/experimental/temporal_memory_deploy.py rename to dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py index 2d58fc4e55..611385630e 100644 --- a/dimos/perception/experimental/temporal_memory_deploy.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py @@ -21,7 +21,8 @@ from dimos import spec from dimos.core import DimosCluster from dimos.models.vl.base import VlModel -from dimos.perception.experimental.temporal_memory import TemporalMemory, TemporalMemoryConfig + +from .temporal_memory import TemporalMemory, TemporalMemoryConfig def deploy( diff --git a/dimos/perception/experimental/temporal_memory_example.py b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py similarity index 96% rename from dimos/perception/experimental/temporal_memory_example.py rename to dimos/perception/experimental/temporal_memory/temporal_memory_example.py index c553f0de00..df435df3cc 100644 --- a/dimos/perception/experimental/temporal_memory_example.py +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py @@ -30,8 +30,9 @@ from dimos import core from dimos.hardware.sensors.camera.module import CameraModule from dimos.hardware.sensors.camera.webcam import Webcam -from dimos.perception.experimental.temporal_memory import TemporalMemoryConfig -from dimos.perception.experimental.temporal_memory_deploy import deploy + +from .temporal_memory import TemporalMemoryConfig +from .temporal_memory_deploy import deploy # Load environment variables load_dotenv() diff --git a/dimos/perception/experimental/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py similarity index 100% rename from dimos/perception/experimental/temporal_utils/__init__.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py diff --git a/dimos/perception/experimental/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py similarity index 99% rename from dimos/perception/experimental/temporal_utils/graph_utils.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py index 500516b51d..bc55f7c65c 100644 --- a/dimos/perception/experimental/temporal_utils/graph_utils.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py @@ -23,7 +23,8 @@ if TYPE_CHECKING: from dimos.models.vl.base import VlModel from dimos.msgs.sensor_msgs import Image - from dimos.perception.experimental.entity_graph_db import EntityGraphDB + + from ..entity_graph_db import EntityGraphDB logger = setup_logger() diff --git a/dimos/perception/experimental/temporal_utils/helpers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py similarity index 97% rename from dimos/perception/experimental/temporal_utils/helpers.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py index dccdd55d46..513feb65a4 100644 --- a/dimos/perception/experimental/temporal_utils/helpers.py +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py @@ -19,7 +19,7 @@ import numpy as np if TYPE_CHECKING: - from dimos.perception.experimental.temporal_memory import Frame + from ..temporal_memory import Frame def next_entity_id_hint(roster: Any) -> str: diff --git a/dimos/perception/experimental/temporal_utils/parsers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py similarity index 100% rename from dimos/perception/experimental/temporal_utils/parsers.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py diff --git a/dimos/perception/experimental/temporal_utils/prompts.py b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py similarity index 100% rename from dimos/perception/experimental/temporal_utils/prompts.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py diff --git a/dimos/perception/experimental/temporal_utils/state.py b/dimos/perception/experimental/temporal_memory/temporal_utils/state.py similarity index 100% rename from dimos/perception/experimental/temporal_utils/state.py rename to dimos/perception/experimental/temporal_memory/temporal_utils/state.py diff --git a/dimos/perception/experimental/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py similarity index 100% rename from dimos/perception/experimental/test_temporal_memory_module.py rename to dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py