diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py index 2600a7ab50..b523cfbaf8 100644 --- a/dimos/agents/vlm_agent.py +++ b/dimos/agents/vlm_agent.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_core.messages import AIMessage, HumanMessage +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from dimos.agents.llm_init import build_llm, build_system_message from dimos.agents.spec import AgentSpec, AnyMessage @@ -31,7 +33,7 @@ class VLMAgent(AgentSpec): query_stream: In[HumanMessage] answer_stream: Out[AIMessage] - def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._llm = build_llm(self.config) self._latest_image: Image | None = None @@ -71,18 +73,23 @@ def _extract_text(self, msg: HumanMessage) -> str: return str(part.get("text", "")) return str(content) - def _invoke(self, msg: HumanMessage) -> AIMessage: + def _invoke(self, msg: HumanMessage, **kwargs: Any) -> AIMessage: messages = [self._system_message, msg] - response = self._llm.invoke(messages) + response = self._llm.invoke(messages, **kwargs) self.append_history([msg, response]) # type: ignore[arg-type] return response # type: ignore[return-value] - def _invoke_image(self, image: Image, query: str) -> AIMessage: + def _invoke_image( + self, image: Image, query: str, response_format: dict[str, Any] | None = None + ) -> AIMessage: content = [{"type": "text", "text": query}, *image.agent_encode()] - return self._invoke(HumanMessage(content=content)) + kwargs: dict[str, Any] = {} + if response_format: + kwargs["response_format"] = response_format + return self._invoke(HumanMessage(content=content), **kwargs) @rpc - def clear_history(self): # type: ignore[no-untyped-def] + def clear_history(self) -> None: self._history.clear() def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None: @@ -95,9 +102,7 @@ def history(self) -> list[AnyMessage]: return [self._system_message, *self._history] @rpc - def register_skills( # type: ignore[no-untyped-def] - self, container, run_implicit_name: str | None = None - ) -> None: + def register_skills(self, container: Any, run_implicit_name: str | None = None) -> None: logger.warning( "VLMAgent does not manage skills; register_skills is a no-op", container=str(container), @@ -105,14 +110,18 @@ def register_skills( # type: ignore[no-untyped-def] ) @rpc - def query(self, query: str): # type: ignore[no-untyped-def] + def query(self, query: str) -> str: response = self._invoke(HumanMessage(query)) - return response.content + content = response.content + return content if isinstance(content, str) else str(content) @rpc - def query_image(self, image: Image, query: str): # type: ignore[no-untyped-def] - response = self._invoke_image(image, query) - return response.content + def query_image( + self, image: Image, query: str, response_format: dict[str, Any] | None = None + ) -> str: + response = self._invoke_image(image, query, response_format=response_format) + content = response.content + return content if isinstance(content, str) else str(content) vlm_agent = VLMAgent.blueprint diff --git a/dimos/models/vl/__init__.py b/dimos/models/vl/__init__.py index 6f120f9141..e4bb68e03c 100644 --- a/dimos/models/vl/__init__.py +++ b/dimos/models/vl/__init__.py @@ -2,6 +2,7 @@ from dimos.models.vl.florence import Florence2Model from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.moondream_hosted import MoondreamHostedVlModel +from dimos.models.vl.openai import OpenAIVlModel from dimos.models.vl.qwen import QwenVlModel __all__ = [ @@ -9,6 +10,7 @@ "Florence2Model", "MoondreamHostedVlModel", "MoondreamVlModel", + "OpenAIVlModel", "QwenVlModel", "VlModel", ] diff --git a/dimos/models/vl/openai.py b/dimos/models/vl/openai.py new file mode 100644 index 0000000000..f596f1ee1e --- /dev/null +++ b/dimos/models/vl/openai.py @@ -0,0 +1,106 @@ +from dataclasses import dataclass +from functools import cached_property +import os +from typing import Any + +import numpy as np +from openai import OpenAI + +from dimos.models.vl.base import VlModel, VlModelConfig +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +@dataclass +class OpenAIVlModelConfig(VlModelConfig): + model_name: str = "gpt-4o-mini" + api_key: str | None = None + + +class OpenAIVlModel(VlModel): + default_config = OpenAIVlModelConfig + config: OpenAIVlModelConfig + + @cached_property + def _client(self) -> OpenAI: + api_key = self.config.api_key or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key must be provided or set in OPENAI_API_KEY environment variable" + ) + + return OpenAI(api_key=api_key) + + def query(self, image: Image | np.ndarray, query: str, response_format: dict | None = None, **kwargs) -> str: # type: ignore[override, type-arg, no-untyped-def] + if isinstance(image, np.ndarray): + import warnings + + warnings.warn( + "OpenAIVlModel.query should receive standard dimos Image type, not a numpy array", + DeprecationWarning, + stacklevel=2, + ) + + image = Image.from_numpy(image) + + # Apply auto_resize if configured + image, _ = self._prepare_image(image) + + img_base64 = image.to_base64() + + api_kwargs: dict[str, Any] = { + "model": self.config.model_name, + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{img_base64}"}, + }, + {"type": "text", "text": query}, + ], + } + ], + } + + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + + return response.choices[0].message.content # type: ignore[return-value,no-any-return] + + def query_batch( + self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + ) -> list[str]: # type: ignore[override] + """Query VLM with multiple images using a single API call.""" + if not images: + return [] + + content: list[dict[str, Any]] = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + } + for img in images + ] + content.append({"type": "text", "text": query}) + + messages = [{"role": "user", "content": content}] + api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages} + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + response_text = response.choices[0].message.content or "" + # Return one response per image (same response since API analyzes all images together) + return [response_text] * len(images) + + def stop(self) -> None: + """Release the OpenAI client.""" + if "_client" in self.__dict__: + del self.__dict__["_client"] + diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index b1d3d6f036..93b31bf74c 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from functools import cached_property import os +from typing import Any import numpy as np from openai import OpenAI @@ -69,6 +70,32 @@ def query(self, image: Image | np.ndarray, query: str) -> str: # type: ignore[o return response.choices[0].message.content # type: ignore[return-value] + def query_batch( + self, images: list[Image], query: str, response_format: dict[str, Any] | None = None, **kwargs: Any + ) -> list[str]: # type: ignore[override] + """Query VLM with multiple images using a single API call.""" + if not images: + return [] + + content: list[dict[str, Any]] = [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{self._prepare_image(img)[0].to_base64()}"}, + } + for img in images + ] + content.append({"type": "text", "text": query}) + + messages = [{"role": "user", "content": content}] + api_kwargs: dict[str, Any] = {"model": self.config.model_name, "messages": messages} + if response_format: + api_kwargs["response_format"] = response_format + + response = self._client.chat.completions.create(**api_kwargs) + response_text = response.choices[0].message.content or "" + # Return one response per image (same response since API analyzes all images together) + return [response_text] * len(images) + def stop(self) -> None: """Release the OpenAI client.""" if "_client" in self.__dict__: diff --git a/dimos/perception/experimental/__init__.py b/dimos/perception/experimental/__init__.py new file mode 100644 index 0000000000..39ef33521d --- /dev/null +++ b/dimos/perception/experimental/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental perception modules.""" diff --git a/dimos/perception/experimental/temporal_memory/README.md b/dimos/perception/experimental/temporal_memory/README.md new file mode 100644 index 0000000000..9ef5f6cb22 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/README.md @@ -0,0 +1,32 @@ +Temporal memory runs "Temporal/Spatial RAG" on streamed videos building an continuous entity-based +memory over time. It uses a VLM to extract evidence in sliding windows, tracks +entities across windows, maintains a rolling summary, and stores relations in a graph network. + +Methodology +1) Sample frames at a target FPS and analyze them in sliding windows. +2) Extract dense evidence with a VLM (caption + entities + relations). +3) Update rolling summary for global context. +4) Persist per-window evidence + entity graph for query-time context. + +Setup +- Put your OpenAI key in `.env`: + `OPENAI_API_KEY=...` +- Install dimensional dependencies + +Quickstart +To run: `dimos --replay run unitree-go2-temporal-memory` + +In another terminal: `humancli` to chat with the agent and run memory queries. + +Artifacts +By default, artifacts are written under `assets/temporal_memory`: +- `evidence.jsonl` (window evidence: captions, entities, relations) +- `state.json` (rolling summary + roster state) +- `entities.json` (current entity roster) +- `frames_index.jsonl` (timestamps for saved frames; written on stop) +- `entity_graph.db` (SQLite graph of relations/distances) + +Notes +- Evidence is extracted in sliding windows, so queries can refer to recent or past entities. +- Distance estimation can run in the background to enrich graph relations. +- If you want a different output directory, set `TemporalMemoryConfig(output_dir=...)`. diff --git a/dimos/perception/experimental/temporal_memory/__init__.py b/dimos/perception/experimental/temporal_memory/__init__.py new file mode 100644 index 0000000000..3cc61601ce --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Temporal memory package.""" + +from .temporal_memory import Frame, TemporalMemory, TemporalMemoryConfig, temporal_memory + +__all__ = [ + "Frame", + "TemporalMemory", + "TemporalMemoryConfig", + "temporal_memory", +] diff --git a/dimos/perception/experimental/temporal_memory/clip_filter.py b/dimos/perception/experimental/temporal_memory/clip_filter.py new file mode 100644 index 0000000000..8faac3fad8 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/clip_filter.py @@ -0,0 +1,171 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLIP-based frame filtering for selecting diverse frames from video windows.""" + +from typing import Any, cast + +import numpy as np + +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +try: + import torch # type: ignore + + from dimos.models.embedding.clip import CLIPModel # type: ignore + + CLIP_AVAILABLE = True +except ImportError as e: + CLIP_AVAILABLE = False + logger.info(f"CLIP unavailable ({e}), using simple frame sampling") + + +def _get_image_data(image: Image) -> np.ndarray[Any, Any]: + """Extract numpy array from Image.""" + if not hasattr(image, "data"): + raise AttributeError(f"Image missing .data attribute: {type(image)}") + return cast("np.ndarray[Any, Any]", image.data) + + +if CLIP_AVAILABLE: + + class CLIPFrameFilter: + """Filter video frames using CLIP embeddings for diversity.""" + + def __init__(self, model_name: str = "ViT-B/32", device: str | None = None): + if not CLIP_AVAILABLE: + raise ImportError("CLIP not available. Install transformers[torch].") + + resolved_name = ( + "openai/clip-vit-base-patch32" if model_name == "ViT-B/32" else model_name + ) + if device is None: + self._model = CLIPModel(model_name=resolved_name) + else: + self._model = CLIPModel(model_name=resolved_name, device=device) + logger.info(f"Loading CLIP {resolved_name} on {self._model.device}") + + def _encode_images(self, images: list[Image]) -> "torch.Tensor": + """Encode images using CLIP.""" + embeddings = self._model.embed(*images) + if not isinstance(embeddings, list): + embeddings = [embeddings] + vectors = [e.to_torch(self._model.device) for e in embeddings] + return torch.stack(vectors) + + def select_diverse_frames(self, frames: list[Any], max_frames: int = 3) -> list[Any]: + """Select diverse frames using greedy farthest-point sampling in CLIP space.""" + if len(frames) <= max_frames: + return frames + + embeddings = self._encode_images([f.image for f in frames]) + + # Greedy farthest-point sampling + selected_indices = [0] # Always include first frame + remaining_indices = list(range(1, len(frames))) + + while len(selected_indices) < max_frames and remaining_indices: + # Compute similarities: (num_remaining, num_selected) + similarities = embeddings[remaining_indices] @ embeddings[selected_indices].T + # Find max similarity for each remaining frame + max_similarities = similarities.max(dim=1)[0] + # Select frame most different from all selected + best_idx = int(max_similarities.argmin().item()) + + selected_indices.append(remaining_indices[best_idx]) + remaining_indices.pop(best_idx) + + return [frames[i] for i in sorted(selected_indices)] + + def close(self) -> None: + """Clean up CLIP model.""" + if hasattr(self, "_model"): + self._model.stop() + del self._model + + +def select_diverse_frames_simple(frames: list[Any], max_frames: int = 3) -> list[Any]: + """Fallback frame selection: uniform sampling across window.""" + if len(frames) <= max_frames: + return frames + indices = [int(i * len(frames) / max_frames) for i in range(max_frames)] + return [frames[i] for i in indices] + + +def adaptive_keyframes( + frames: list[Any], + min_frames: int = 3, + max_frames: int = 5, + change_threshold: float = 15.0, +) -> list[Any]: + """Select frames based on visual change, adaptive count.""" + if len(frames) <= min_frames: + return frames + + # Compute frame-to-frame differences + try: + diffs = [ + np.abs( + _get_image_data(frames[i].image).astype(float) + - _get_image_data(frames[i - 1].image).astype(float) + ).mean() + for i in range(1, len(frames)) + ] + except (AttributeError, ValueError) as e: + logger.warning(f"Failed to compute frame diffs: {e}. Falling back to uniform sampling.") + return select_diverse_frames_simple(frames, max_frames) + + total_motion = sum(diffs) + n_frames = int(np.clip(total_motion / change_threshold, min_frames, max_frames)) + + # Always include first and last + keyframe_indices = {0, len(frames) - 1} + + # Add peaks in diff signal + for i in range(1, len(diffs) - 1): + if ( + diffs[i] > diffs[i - 1] + and diffs[i] > diffs[i + 1] + and diffs[i] > change_threshold * 0.5 + ): + keyframe_indices.add(i + 1) + + # Adjust count + if len(keyframe_indices) > n_frames: + # Keep first, last, and highest-diff peaks + middle = [i for i in keyframe_indices if i not in (0, len(frames) - 1)] + middle_by_diff = sorted(middle, key=lambda i: diffs[i - 1], reverse=True) + keyframe_indices = {0, len(frames) - 1, *middle_by_diff[: n_frames - 2]} + elif len(keyframe_indices) < n_frames: + # Fill uniformly from remaining + needed = n_frames - len(keyframe_indices) + candidates = sorted(set(range(len(frames))) - keyframe_indices) + if candidates: + step = max(1, len(candidates) // (needed + 1)) + keyframe_indices.update(candidates[::step][:needed]) + + return [frames[i] for i in sorted(keyframe_indices)] + + +__all__ = [ + "CLIP_AVAILABLE", + "adaptive_keyframes", + "select_diverse_frames_simple", +] + +if CLIP_AVAILABLE: + __all__.append("CLIPFrameFilter") diff --git a/dimos/perception/experimental/temporal_memory/entity_graph_db.py b/dimos/perception/experimental/temporal_memory/entity_graph_db.py new file mode 100644 index 0000000000..7109459f40 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/entity_graph_db.py @@ -0,0 +1,1018 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Entity Graph Database for storing and querying entity relationships. + +Maintains three types of graphs: +1. Relations Graph: Interactions between entities (holds, looks_at, talks_to, etc.) +2. Distance Graph: Spatial distances between entities +3. Semantic Graph: Conceptual relationships (goes_with, part_of, used_for, etc.) + +All graphs share the same entity nodes but have different edge types. +""" + +import json +from pathlib import Path +import sqlite3 +import threading +from typing import TYPE_CHECKING, Any + +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + from dimos.msgs.sensor_msgs import Image + +logger = setup_logger() + + +class EntityGraphDB: + """ + SQLite-based graph database for entity relationships. + + Thread-safe implementation using connection-per-thread pattern. + All graphs share the same entity nodes but maintain separate edge tables. + """ + + def __init__(self, db_path: str | Path) -> None: + """ + Initialize the entity graph database. + + Args: + db_path: Path to the SQLite database file + """ + self.db_path = Path(db_path) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + # Thread-local storage for connections + self._local = threading.local() + + # Initialize schema + self._init_schema() + + logger.info(f"EntityGraphDB initialized at {self.db_path}") + + def _get_connection(self) -> sqlite3.Connection: + """Get thread-local database connection.""" + if not hasattr(self._local, "conn"): + self._local.conn = sqlite3.connect(str(self.db_path)) + self._local.conn.row_factory = sqlite3.Row + return self._local.conn # type: ignore + + def _init_schema(self) -> None: + """Initialize database schema.""" + conn = self._get_connection() + cursor = conn.cursor() + + # Entities table (shared nodes) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS entities ( + entity_id TEXT PRIMARY KEY, + entity_type TEXT NOT NULL, + descriptor TEXT, + first_seen_ts REAL NOT NULL, + last_seen_ts REAL NOT NULL, + metadata TEXT + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_entities_first_seen ON entities(first_seen_ts)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_entities_last_seen ON entities(last_seen_ts)" + ) + + # Relations table (Graph 1: Interactions) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS relations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + relation_type TEXT NOT NULL, + subject_id TEXT NOT NULL, + object_id TEXT NOT NULL, + confidence REAL DEFAULT 1.0, + timestamp_s REAL NOT NULL, + evidence TEXT, + notes TEXT, + FOREIGN KEY (subject_id) REFERENCES entities(entity_id), + FOREIGN KEY (object_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_subject ON relations(subject_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_object ON relations(object_id)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_type ON relations(relation_type)") + cursor.execute("CREATE INDEX IF NOT EXISTS idx_relations_time ON relations(timestamp_s)") + + # Distances table (Graph 2: Spatial) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS distances ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + entity_a_id TEXT NOT NULL, + entity_b_id TEXT NOT NULL, + distance_meters REAL, + distance_category TEXT, + confidence REAL DEFAULT 1.0, + timestamp_s REAL NOT NULL, + method TEXT, + FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id), + FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_distances_pair ON distances(entity_a_id, entity_b_id)" + ) + cursor.execute("CREATE INDEX IF NOT EXISTS idx_distances_time ON distances(timestamp_s)") + + # Semantic relations table (Graph 3: Knowledge) + cursor.execute(""" + CREATE TABLE IF NOT EXISTS semantic_relations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + relation_type TEXT NOT NULL, + entity_a_id TEXT NOT NULL, + entity_b_id TEXT NOT NULL, + confidence REAL DEFAULT 1.0, + learned_from TEXT, + first_observed_ts REAL NOT NULL, + last_observed_ts REAL NOT NULL, + observation_count INTEGER DEFAULT 1, + FOREIGN KEY (entity_a_id) REFERENCES entities(entity_id), + FOREIGN KEY (entity_b_id) REFERENCES entities(entity_id) + ) + """) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_semantic_pair ON semantic_relations(entity_a_id, entity_b_id)" + ) + cursor.execute( + "CREATE INDEX IF NOT EXISTS idx_semantic_type ON semantic_relations(relation_type)" + ) + + conn.commit() + + def upsert_entity( + self, + entity_id: str, + entity_type: str, + descriptor: str, + timestamp_s: float, + metadata: dict[str, Any] | None = None, + ) -> None: + """ + Insert or update an entity. + + Args: + entity_id: Unique entity identifier (e.g., "E1") + entity_type: Type of entity (person, object, location, etc.) + descriptor: Text description of the entity + timestamp_s: Timestamp when entity was observed + metadata: Optional additional metadata + """ + conn = self._get_connection() + cursor = conn.cursor() + + metadata_json = json.dumps(metadata) if metadata else None + + cursor.execute( + """ + INSERT INTO entities (entity_id, entity_type, descriptor, first_seen_ts, last_seen_ts, metadata) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(entity_id) DO UPDATE SET + last_seen_ts = ?, + descriptor = COALESCE(excluded.descriptor, descriptor), + metadata = COALESCE(excluded.metadata, metadata) + """, + ( + entity_id, + entity_type, + descriptor, + timestamp_s, + timestamp_s, + metadata_json, + timestamp_s, + ), + ) + + conn.commit() + logger.debug(f"Upserted entity {entity_id} (type={entity_type})") + + def get_entity(self, entity_id: str) -> dict[str, Any] | None: + """ + Get an entity by ID. + """ + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT * FROM entities WHERE entity_id = ?", (entity_id,)) + row = cursor.fetchone() + + if row is None: + return None + + return { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + + def get_all_entities(self, entity_type: str | None = None) -> list[dict[str, Any]]: + """Get all entities, optionally filtered by type.""" + conn = self._get_connection() + cursor = conn.cursor() + + if entity_type: + cursor.execute( + "SELECT * FROM entities WHERE entity_type = ? ORDER BY last_seen_ts DESC", + (entity_type,), + ) + else: + cursor.execute("SELECT * FROM entities ORDER BY last_seen_ts DESC") + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + for row in rows + ] + + def get_entities_by_time( + self, + time_window: tuple[float, float], + first_seen: bool = True, + ) -> list[dict[str, Any]]: + """Get entities first/last seen within a time window. + + Args: + time_window: (start_ts, end_ts) tuple in seconds + first_seen: If True, filter by first_seen_ts. If False, filter by last_seen_ts. + + Returns: + List of entities seen within the time window + """ + conn = self._get_connection() + cursor = conn.cursor() + + ts_field = "first_seen_ts" if first_seen else "last_seen_ts" + cursor.execute( + f"SELECT * FROM entities WHERE {ts_field} BETWEEN ? AND ? ORDER BY {ts_field} DESC", + time_window, + ) + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "first_seen_ts": row["first_seen_ts"], + "last_seen_ts": row["last_seen_ts"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else None, + } + for row in rows + ] + + def add_relation( + self, + relation_type: str, + subject_id: str, + object_id: str, + confidence: float, + timestamp_s: float, + evidence: list[str] | None = None, + notes: str | None = None, + ) -> None: + """ + Add a relation between two entities. + + Args: + relation_type: Type of relation (holds, looks_at, talks_to, etc.) + subject_id: Subject entity ID + object_id: Object entity ID + confidence: Confidence score (0.0 to 1.0) + timestamp_s: Timestamp when relation was observed + evidence: Optional list of evidence strings + notes: Optional notes + """ + conn = self._get_connection() + cursor = conn.cursor() + + evidence_json = json.dumps(evidence) if evidence else None + + cursor.execute( + """ + INSERT INTO relations (relation_type, subject_id, object_id, confidence, timestamp_s, evidence, notes) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + (relation_type, subject_id, object_id, confidence, timestamp_s, evidence_json, notes), + ) + + conn.commit() + logger.debug(f"Added relation: {subject_id} --{relation_type}--> {object_id}") + + def get_relations_for_entity( + self, + entity_id: str, + relation_type: str | None = None, + time_window: tuple[float, float] | None = None, + ) -> list[dict[str, Any]]: + """ + Get all relations involving an entity. + + Args: + entity_id: Entity ID + relation_type: Optional filter by relation type + time_window: Optional (start_ts, end_ts) tuple + + Returns: + List of relation dicts + """ + conn = self._get_connection() + cursor = conn.cursor() + + query = """ + SELECT * FROM relations + WHERE (subject_id = ? OR object_id = ?) + """ + params: list[Any] = [entity_id, entity_id] + + if relation_type: + query += " AND relation_type = ?" + params.append(relation_type) + + if time_window: + query += " AND timestamp_s BETWEEN ? AND ?" + params.extend(time_window) + + query += " ORDER BY timestamp_s DESC" + + cursor.execute(query, params) + rows = cursor.fetchall() + + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "subject_id": row["subject_id"], + "object_id": row["object_id"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "evidence": json.loads(row["evidence"]) if row["evidence"] else None, + "notes": row["notes"], + } + for row in rows + ] + + def get_recent_relations(self, limit: int = 50) -> list[dict[str, Any]]: + """Get most recent relations.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute( + """ + SELECT * FROM relations + ORDER BY timestamp_s DESC + LIMIT ? + """, + (limit,), + ) + + rows = cursor.fetchall() + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "subject_id": row["subject_id"], + "object_id": row["object_id"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "evidence": json.loads(row["evidence"]) if row["evidence"] else None, + "notes": row["notes"], + } + for row in rows + ] + + # ==================== Distance Operations (Graph 2) ==================== + + def add_distance( + self, + entity_a_id: str, + entity_b_id: str, + distance_meters: float | None, + distance_category: str | None, + confidence: float, + timestamp_s: float, + method: str, + ) -> None: + """ + Add distance measurement between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + distance_meters: Distance in meters (can be None if only categorical) + distance_category: Category (near/medium/far) + confidence: Confidence score + timestamp_s: Timestamp of measurement + method: Method used (vlm, depth_estimation, bbox) + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order to avoid duplicates (store alphabetically) + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + cursor.execute( + """ + INSERT INTO distances (entity_a_id, entity_b_id, distance_meters, distance_category, + confidence, timestamp_s, method) + VALUES (?, ?, ?, ?, ?, ?, ?) + """, + ( + entity_a_id, + entity_b_id, + distance_meters, + distance_category, + confidence, + timestamp_s, + method, + ), + ) + + conn.commit() + logger.debug( + f"Added distance: {entity_a_id} <--> {entity_b_id}: {distance_meters}m ({distance_category})" + ) + + def get_distance( + self, + entity_a_id: str, + entity_b_id: str, + ) -> dict[str, Any] | None: + """Get most recent distance between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + + Returns: + Distance dict or None + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + LIMIT 1 + """, + (entity_a_id, entity_b_id), + ) + + row = cursor.fetchone() + if row is None: + return None + + return { + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "method": row["method"], + } + + def get_distance_history( + self, + entity_a_id: str, + entity_b_id: str, + ) -> list[dict[str, Any]]: + """Get all distance measurements between two entities. + + Args: + entity_a_id: First entity ID + entity_b_id: Second entity ID + + Returns: + List of distance dicts, most recent first + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + cursor.execute( + """ + SELECT * FROM distances + WHERE entity_a_id = ? AND entity_b_id = ? + ORDER BY timestamp_s DESC + """, + (entity_a_id, entity_b_id), + ) + + return [ + { + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + "method": row["method"], + } + for row in cursor.fetchall() + ] + + def get_nearby_entities( + self, + entity_id: str, + max_distance: float, + latest_only: bool = True, + ) -> list[dict[str, Any]]: + """ + Find entities within a distance threshold. + + Args: + entity_id: Reference entity ID + max_distance: Maximum distance in meters + latest_only: If True, use only latest measurements + + Returns: + List of nearby entities with distances + """ + conn = self._get_connection() + cursor = conn.cursor() + + if latest_only: + # Get latest distance for each pair + query = """ + SELECT d.*, e.entity_type, e.descriptor + FROM distances d + INNER JOIN entities e ON ( + CASE + WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id + WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id + END + ) + WHERE (d.entity_a_id = ? OR d.entity_b_id = ?) + AND d.distance_meters IS NOT NULL + AND d.distance_meters <= ? + AND d.id IN ( + SELECT MAX(id) FROM distances + WHERE (entity_a_id = d.entity_a_id AND entity_b_id = d.entity_b_id) + GROUP BY entity_a_id, entity_b_id + ) + ORDER BY d.distance_meters ASC + """ + cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance)) + else: + query = """ + SELECT d.*, e.entity_type, e.descriptor + FROM distances d + INNER JOIN entities e ON ( + CASE + WHEN d.entity_a_id = ? THEN e.entity_id = d.entity_b_id + WHEN d.entity_b_id = ? THEN e.entity_id = d.entity_a_id + END + ) + WHERE (d.entity_a_id = ? OR d.entity_b_id = ?) + AND d.distance_meters IS NOT NULL + AND d.distance_meters <= ? + ORDER BY d.distance_meters ASC + """ + cursor.execute(query, (entity_id, entity_id, entity_id, entity_id, max_distance)) + + rows = cursor.fetchall() + return [ + { + "entity_id": row["entity_b_id"] + if row["entity_a_id"] == entity_id + else row["entity_a_id"], + "entity_type": row["entity_type"], + "descriptor": row["descriptor"], + "distance_meters": row["distance_meters"], + "distance_category": row["distance_category"], + "confidence": row["confidence"], + "timestamp_s": row["timestamp_s"], + } + for row in rows + ] + + def add_semantic_relation( + self, + relation_type: str, + entity_a_id: str, + entity_b_id: str, + confidence: float, + learned_from: str, + timestamp_s: float, + ) -> None: + """ + Add or update a semantic relation. + + Args: + relation_type: Relation type (goes_with, opposite_of, part_of, used_for) + entity_a_id: First entity ID + entity_b_id: Second entity ID + confidence: Confidence score + learned_from: Source (llm, knowledge_base, observation) + timestamp_s: Timestamp when learned + """ + conn = self._get_connection() + cursor = conn.cursor() + + # Normalize order for symmetric relations + if entity_a_id > entity_b_id: + entity_a_id, entity_b_id = entity_b_id, entity_a_id + + # Check if relation exists + cursor.execute( + """ + SELECT id, observation_count, confidence FROM semantic_relations + WHERE relation_type = ? AND entity_a_id = ? AND entity_b_id = ? + """, + (relation_type, entity_a_id, entity_b_id), + ) + + existing = cursor.fetchone() + + if existing: + # Update existing relation (increase confidence, increment count) + new_count = existing["observation_count"] + 1 + new_confidence = min( + 1.0, existing["confidence"] + 0.1 + ) # Increase confidence with observations + + cursor.execute( + """ + UPDATE semantic_relations + SET last_observed_ts = ?, + observation_count = ?, + confidence = ? + WHERE id = ? + """, + (timestamp_s, new_count, new_confidence, existing["id"]), + ) + else: + # Insert new relation + cursor.execute( + """ + INSERT INTO semantic_relations + (relation_type, entity_a_id, entity_b_id, confidence, learned_from, + first_observed_ts, last_observed_ts, observation_count) + VALUES (?, ?, ?, ?, ?, ?, ?, 1) + """, + ( + relation_type, + entity_a_id, + entity_b_id, + confidence, + learned_from, + timestamp_s, + timestamp_s, + ), + ) + + conn.commit() + logger.debug(f"Added semantic relation: {entity_a_id} --{relation_type}--> {entity_b_id}") + + def get_semantic_relations( + self, + entity_id: str | None = None, + relation_type: str | None = None, + ) -> list[dict[str, Any]]: + """ + Get semantic relations, optionally filtered. + + Args: + entity_id: Optional filter by entity + relation_type: Optional filter by relation type + + Returns: + List of semantic relation dicts + """ + conn = self._get_connection() + cursor = conn.cursor() + + query = "SELECT * FROM semantic_relations WHERE 1=1" + params: list[Any] = [] + + if entity_id: + query += " AND (entity_a_id = ? OR entity_b_id = ?)" + params.extend([entity_id, entity_id]) + + if relation_type: + query += " AND relation_type = ?" + params.append(relation_type) + + query += " ORDER BY confidence DESC, observation_count DESC" + + cursor.execute(query, params) + rows = cursor.fetchall() + + return [ + { + "id": row["id"], + "relation_type": row["relation_type"], + "entity_a_id": row["entity_a_id"], + "entity_b_id": row["entity_b_id"], + "confidence": row["confidence"], + "learned_from": row["learned_from"], + "first_observed_ts": row["first_observed_ts"], + "last_observed_ts": row["last_observed_ts"], + "observation_count": row["observation_count"], + } + for row in rows + ] + + # querying + + def get_entity_neighborhood( + self, + entity_id: str, + max_hops: int = 2, + include_distances: bool = True, + include_semantics: bool = True, + ) -> dict[str, Any]: + """ + Get entity neighborhood (BFS traversal). + + Args: + entity_id: Starting entity ID + max_hops: Maximum number of hops to traverse + include_distances: Include distance graph + include_semantics: Include semantic graph + + Returns: + Dict with entities, relations, distances, and semantics + """ + visited_entities = {entity_id} + current_level = {entity_id} + all_relations = [] + all_distances = [] + all_semantics = [] + + for _ in range(max_hops): + next_level = set() + + for ent_id in current_level: + # Get relations + relations = self.get_relations_for_entity(ent_id) + all_relations.extend(relations) + + for rel in relations: + other_id = ( + rel["object_id"] if rel["subject_id"] == ent_id else rel["subject_id"] + ) + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + # Get distances + if include_distances: + distances = self.get_nearby_entities(ent_id, max_distance=10.0) + all_distances.extend(distances) + for dist in distances: + other_id = dist["entity_id"] + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + # Get semantic relations + if include_semantics: + semantics = self.get_semantic_relations(entity_id=ent_id) + all_semantics.extend(semantics) + for sem in semantics: + other_id = ( + sem["entity_b_id"] + if sem["entity_a_id"] == ent_id + else sem["entity_a_id"] + ) + if other_id not in visited_entities: + next_level.add(other_id) + visited_entities.add(other_id) + + current_level = next_level + if not current_level: + break + + # Get all entity details + entities = [self.get_entity(ent_id) for ent_id in visited_entities] + entities = [e for e in entities if e is not None] + + return { + "center_entity": entity_id, + "entities": entities, + "relations": all_relations, + "distances": all_distances, + "semantic_relations": all_semantics, + "num_hops": max_hops, + } + + def get_stats(self) -> dict[str, Any]: + """Get database statistics.""" + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute("SELECT COUNT(*) as count FROM entities") + entity_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM relations") + relation_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM distances") + distance_count = cursor.fetchone()["count"] + + cursor.execute("SELECT COUNT(*) as count FROM semantic_relations") + semantic_count = cursor.fetchone()["count"] + + return { + "entities": entity_count, + "relations": relation_count, + "distances": distance_count, + "semantic_relations": semantic_count, + } + + def get_summary(self, recent_relations_limit: int = 5) -> dict[str, Any]: + """Get stats, all entities, and recent relations.""" + return { + "stats": self.get_stats(), + "entities": self.get_all_entities(), + "recent_relations": self.get_recent_relations(limit=recent_relations_limit), + } + + def save_window_data(self, parsed: dict[str, Any], timestamp_s: float) -> None: + """Save parsed window data (entities and relations) to the graph database.""" + try: + # Save new entities + for entity in parsed.get("new_entities", []): + self.upsert_entity( + entity_id=entity["id"], + entity_type=entity["type"], + descriptor=entity.get("descriptor", "unknown"), + timestamp_s=timestamp_s, + ) + + # Save existing entities (update last_seen) + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + descriptor = entity.get("descriptor") + if descriptor: + self.upsert_entity( + entity_id=entity["id"], + entity_type=entity.get("type", "unknown"), + descriptor=descriptor, + timestamp_s=timestamp_s, + ) + else: + existing = self.get_entity(entity["id"]) + if existing: + self.upsert_entity( + entity_id=entity["id"], + entity_type=existing["entity_type"], + descriptor=existing["descriptor"], + timestamp_s=timestamp_s, + ) + + # Save relations + for relation in parsed.get("relations", []): + subject_id = ( + relation["subject"].split("|")[0] + if "|" in relation["subject"] + else relation["subject"] + ) + object_id = ( + relation["object"].split("|")[0] + if "|" in relation["object"] + else relation["object"] + ) + + self.add_relation( + relation_type=relation["type"], + subject_id=subject_id, + object_id=object_id, + confidence=relation.get("confidence", 1.0), + timestamp_s=timestamp_s, + evidence=relation.get("evidence"), + notes=relation.get("notes"), + ) + + except Exception as e: + logger.error(f"Failed to save window data to graph DB: {e}", exc_info=True) + + def estimate_and_save_distances( + self, + parsed: dict[str, Any], + frame_image: "Image", + vlm: "VlModel", + timestamp_s: float, + max_distance_pairs: int = 5, + ) -> None: + """Estimate distances between entities using VLM and save to database. + + Args: + parsed: Parsed window data containing entities + frame_image: Frame image to analyze + vlm: VLM instance for distance estimation + timestamp_s: Timestamp for the distance measurements + max_distance_pairs: Maximum number of entity pairs to estimate + """ + if not frame_image: + return + + # Import here to avoid circular dependency + from . import temporal_utils as tu + + # Collect entities with descriptors + # new_entities have descriptors from VLM + enriched_entities = [] + for entity in parsed.get("new_entities", []): + if isinstance(entity, dict) and "id" in entity: + enriched_entities.append( + {"id": entity["id"], "descriptor": entity.get("descriptor", "unknown")} + ) + + # entities_present only have IDs - need to fetch descriptors from DB + for entity in parsed.get("entities_present", []): + if isinstance(entity, dict) and "id" in entity: + entity_id = entity["id"] + # Fetch descriptor from DB + db_entity = self.get_entity(entity_id) + if db_entity: + enriched_entities.append( + {"id": entity_id, "descriptor": db_entity.get("descriptor", "unknown")} + ) + + if len(enriched_entities) < 2: + return + + # Generate pairs without existing distances + pairs = [ + (enriched_entities[i], enriched_entities[j]) + for i in range(len(enriched_entities)) + for j in range(i + 1, len(enriched_entities)) + if not self.get_distance(enriched_entities[i]["id"], enriched_entities[j]["id"]) + ][:max_distance_pairs] + + if not pairs: + return + + try: + response = vlm.query(frame_image, tu.build_batch_distance_estimation_prompt(pairs)) + for r in tu.parse_batch_distance_response(response, pairs): + if r["category"] in ("near", "medium", "far"): + self.add_distance( + entity_a_id=r["entity_a_id"], + entity_b_id=r["entity_b_id"], + distance_meters=r.get("distance_m"), + distance_category=r["category"], + confidence=r.get("confidence", 0.5), + timestamp_s=timestamp_s, + method="vlm", + ) + except Exception as e: + logger.warning(f"Failed to estimate distances: {e}", exc_info=True) + + def commit(self) -> None: + """Commit all pending transactions and ensure data is flushed to disk.""" + if hasattr(self._local, "conn"): + conn = self._local.conn + conn.commit() + # Force checkpoint to ensure WAL data is written to main database file + try: + conn.execute("PRAGMA wal_checkpoint(FULL)") + except Exception: + pass # Ignore if WAL is not enabled + + def close(self) -> None: + """Close database connection.""" + if hasattr(self._local, "conn"): + self._local.conn.close() + del self._local.conn diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.md b/dimos/perception/experimental/temporal_memory/temporal_memory.md new file mode 100644 index 0000000000..0eaa3df893 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.md @@ -0,0 +1,39 @@ +Dimensional Temporal Memory is a lightweight Video RAG pipeline for building +entity-centric memory over live or replayed video streams. It uses a VLM to +extract evidence in sliding windows, tracks entities across time, maintains a +rolling summary, and persists relations in a compact graph for query-time context. + +How It Works +1) Sample frames at a target FPS and analyze them in sliding windows. +2) Extract dense evidence with a VLM (caption + entities + relations). +3) Update a rolling summary for global context. +4) Persist per-window evidence and the entity graph for fast queries. + +Setup +- Add your OpenAI key to `.env`: + `OPENAI_API_KEY=...` +- Install dependencies (recommended set from repo install guide): + `uv sync --extra dev --extra cpu --extra sim --extra drone` + +`uv sync` installs the locked dependency set from `uv.lock` to match the repo's +known-good environment. `uv pip install ...` behaves like pip (ad-hoc installs) +and can drift from the lockfile. + +Quickstart +- Run Temporal Memory on a replay: + `dimos --replay run unitree-go2-temporal-memory` +- In another terminal, open a chat session: + `humancli` + +Artifacts +By default, artifacts are written under `assets/temporal_memory`: +- `evidence.jsonl` (window evidence: captions, entities, relations) +- `state.json` (rolling summary + roster state) +- `entities.json` (current entity roster) +- `frames_index.jsonl` (timestamps for saved frames; written on stop) +- `entity_graph.db` (SQLite graph of relations/distances) + +Notes +- Evidence is extracted in sliding windows; queries can reference recent or past entities. +- Distance estimation can run in the background to enrich graph relations. +- Change the output location via `TemporalMemoryConfig(output_dir=...)`. diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory.py b/dimos/perception/experimental/temporal_memory/temporal_memory.py new file mode 100644 index 0000000000..a328186173 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_memory.py @@ -0,0 +1,666 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Temporal Memory module for creating entity-based temporal understanding of video streams. + +This module implements a sophisticated temporal memory system inspired by VideoRAG, +using VLM (Vision-Language Model) API calls to maintain entity rosters, rolling summaries, +and temporal relationships across video frames. +""" + +from collections import deque +from dataclasses import dataclass +import json +import os +from pathlib import Path +import threading +import time +from typing import Any + +from reactivex import Subject, interval +from reactivex.disposable import Disposable + +from dimos.agents import skill +from dimos.core import In, rpc +from dimos.core.module import ModuleConfig +from dimos.core.skill_module import SkillModule +from dimos.models.vl.base import VlModel +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs.Image import sharpness_barrier + +from . import temporal_utils as tu +from .clip_filter import ( + CLIP_AVAILABLE, + adaptive_keyframes, + select_diverse_frames_simple, +) + +try: + from .clip_filter import CLIPFrameFilter +except ImportError: + CLIPFrameFilter = type(None) # type: ignore[misc,assignment] +from dimos.utils.logging_config import setup_logger + +from .entity_graph_db import EntityGraphDB + +logger = setup_logger() + +# Constants +MAX_RECENT_WINDOWS = 50 # Max recent windows to keep in memory + + +@dataclass +class Frame: + frame_index: int + timestamp_s: float + image: Image + + +@dataclass +class TemporalMemoryConfig(ModuleConfig): + # Frame processing + fps: float = 1.0 + window_s: float = 2.0 + stride_s: float = 2.0 + summary_interval_s: float = 10.0 + max_frames_per_window: int = 3 + frame_buffer_size: int = 50 + + # Output + output_dir: str | Path | None = "assets/temporal_memory" + + # VLM parameters + max_tokens: int = 900 + temperature: float = 0.2 + + # Frame filtering + use_clip_filtering: bool = True + clip_model: str = "ViT-B/32" + stale_scene_threshold: float = 5.0 + + # Graph database + persistent_memory: bool = True # Keep graph across sessions + clear_memory_on_start: bool = False # Wipe DB on startup + enable_distance_estimation: bool = True # Estimate entity distances + max_distance_pairs: int = 5 # Max entity pairs per window + + # Graph context + max_relations_per_entity: int = 10 # Max relations in query context + nearby_distance_meters: float = 5.0 # "Nearby" threshold + + +class TemporalMemory(SkillModule): + """ + builds temporal understanding of video streams using vlms. + + processes frames reactively, maintains entity rosters, tracks temporal + relationships, builds rolling summaries. responds to queries about current + state and recent events. + """ + + color_image: In[Image] + + def __init__( + self, vlm: VlModel | None = None, config: TemporalMemoryConfig | None = None + ) -> None: + super().__init__() + + self._vlm = vlm # Can be None for blueprint usage + self.config: TemporalMemoryConfig = config or TemporalMemoryConfig() + + # single lock protects all state + self._state_lock = threading.Lock() + self._stopped = False + + # protected state + self._state = tu.default_state() + self._state["next_summary_at_s"] = float(self.config.summary_interval_s) + self._frame_buffer: deque[Frame] = deque(maxlen=self.config.frame_buffer_size) + self._recent_windows: deque[dict[str, Any]] = deque(maxlen=MAX_RECENT_WINDOWS) + self._frame_count = 0 + # Start at -inf so first analysis passes stride_s check regardless of elapsed time + self._last_analysis_time = -float("inf") + self._video_start_wall_time: float | None = None + + # Track background distance estimation threads + self._distance_threads: list[threading.Thread] = [] + + # clip filter - use instance state to avoid mutating shared config + self._clip_filter: CLIPFrameFilter | None = None + self._use_clip_filtering = self.config.use_clip_filtering + if self._use_clip_filtering and CLIP_AVAILABLE: + try: + self._clip_filter = CLIPFrameFilter(model_name=self.config.clip_model) + logger.info("clip filtering enabled") + except Exception as e: + logger.warning(f"clip init failed: {e}") + self._use_clip_filtering = False + elif self._use_clip_filtering: + logger.warning("clip not available") + self._use_clip_filtering = False + + # output directory + self._graph_db: EntityGraphDB | None + if self.config.output_dir: + self._output_path = Path(self.config.output_dir) + self._output_path.mkdir(parents=True, exist_ok=True) + self._evidence_file = self._output_path / "evidence.jsonl" + self._state_file = self._output_path / "state.json" + self._entities_file = self._output_path / "entities.json" + self._frames_index_file = self._output_path / "frames_index.jsonl" + + db_path = self._output_path / "entity_graph.db" + if not self.config.persistent_memory or self.config.clear_memory_on_start: + if db_path.exists(): + db_path.unlink() + reason = ( + "non-persistent mode" + if not self.config.persistent_memory + else "clear_memory_on_start=True" + ) + logger.info(f"Deleted existing database: {reason}") + + self._graph_db = EntityGraphDB(db_path=db_path) + + logger.info(f"artifacts save to: {self._output_path}") + else: + self._graph_db = None + + logger.info( + f"temporalmemory init: fps={self.config.fps}, " + f"window={self.config.window_s}s, stride={self.config.stride_s}s" + ) + + @property + def vlm(self) -> VlModel: + """Get or create VLM instance lazily.""" + if self._vlm is None: + from dimos.models.vl.openai import OpenAIVlModel + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OPENAI_API_KEY environment variable not set. " + "Either set it or pass a vlm instance to TemporalMemory constructor." + ) + self._vlm = OpenAIVlModel(api_key=api_key) + logger.info("Created OpenAIVlModel from OPENAI_API_KEY environment variable") + return self._vlm + + @rpc + def start(self) -> None: + super().start() + + with self._state_lock: + self._stopped = False + if self._video_start_wall_time is None: + self._video_start_wall_time = time.time() + + def on_frame(image: Image) -> None: + with self._state_lock: + video_start = self._video_start_wall_time + if video_start is None: + return # Not started yet + if image.ts is not None: + timestamp_s = image.ts - video_start + else: + timestamp_s = time.time() - video_start + + frame = Frame( + frame_index=self._frame_count, + timestamp_s=timestamp_s, + image=image, + ) + self._frame_buffer.append(frame) + self._frame_count += 1 + + frame_subject: Subject[Image] = Subject() + self._disposables.add( + frame_subject.pipe(sharpness_barrier(self.config.fps)).subscribe(on_frame) + ) + + unsub_image = self.color_image.subscribe(frame_subject.on_next) + self._disposables.add(Disposable(unsub_image)) + + # Schedule window analysis every stride_s seconds + self._disposables.add( + interval(self.config.stride_s).subscribe(lambda _: self._analyze_window()) + ) + + logger.info("temporalmemory started") + + @rpc + def stop(self) -> None: + # Save state before clearing (bypass _stopped check by saving directly) + if self.config.output_dir: + try: + with self._state_lock: + state_copy = self._state.copy() + entity_roster = list(self._state.get("entity_roster", [])) + with open(self._state_file, "w") as f: + json.dump(state_copy, f, indent=2, ensure_ascii=False) + logger.info(f"saved state to {self._state_file}") + with open(self._entities_file, "w") as f: + json.dump(entity_roster, f, indent=2, ensure_ascii=False) + logger.info(f"saved {len(entity_roster)} entities") + except Exception as e: + logger.error(f"save failed during stop: {e}", exc_info=True) + + self.save_frames_index() + with self._state_lock: + self._stopped = True + + # Wait for background distance estimation threads to complete before closing DB + if self._distance_threads: + logger.info(f"Waiting for {len(self._distance_threads)} distance estimation threads...") + for thread in self._distance_threads: + thread.join(timeout=10.0) # Wait max 10s per thread + self._distance_threads.clear() + + if self._graph_db: + db_path = self._graph_db.db_path + self._graph_db.commit() # save all pending transactions + self._graph_db.close() + self._graph_db = None + + if not self.config.persistent_memory and db_path.exists(): + db_path.unlink() + logger.info("Deleted non-persistent database") + + if self._clip_filter: + self._clip_filter.close() + self._clip_filter = None + + with self._state_lock: + self._frame_buffer.clear() + self._recent_windows.clear() + self._state = tu.default_state() + + super().stop() + + # Stop all stream transports to clean up LCM/shared memory threads + # Note: We use public stream.transport API and rely on transport.stop() to clean up + for stream in list(self.inputs.values()) + list(self.outputs.values()): + if stream.transport is not None and hasattr(stream.transport, "stop"): + try: + stream.transport.stop() + except Exception as e: + logger.warning(f"Failed to stop stream transport: {e}") + + logger.info("temporalmemory stopped") + + def _get_window_frames(self) -> tuple[list[Frame], dict[str, Any]] | None: + """Extract window frames from buffer with guards.""" + with self._state_lock: + if not self._frame_buffer: + return None + current_time = self._frame_buffer[-1].timestamp_s + if current_time - self._last_analysis_time < self.config.stride_s: + return None + frames_needed = max(1, int(self.config.fps * self.config.window_s)) + if len(self._frame_buffer) < frames_needed: + return None + window_frames = list(self._frame_buffer)[-frames_needed:] + state_snapshot = self._state.copy() + return window_frames, state_snapshot + + def _query_vlm_for_window( + self, + window_frames: list[Frame], + state_snapshot: dict[str, Any], + w_start: float, + w_end: float, + ) -> str | None: + """Query VLM for window analysis.""" + query = tu.build_window_prompt( + w_start=w_start, w_end=w_end, frame_count=len(window_frames), state=state_snapshot + ) + try: + fmt = tu.get_structured_output_format() + if len(window_frames) > 1: + responses = self.vlm.query_batch( + [f.image for f in window_frames], query, response_format=fmt + ) + return responses[0] if responses else "" + else: + return self.vlm.query(window_frames[0].image, query, response_format=fmt) + except Exception as e: + logger.error(f"vlm query failed [{w_start:.1f}-{w_end:.1f}s]: {e}", exc_info=True) + return None + + def _save_window_artifacts(self, parsed: dict[str, Any], w_end: float) -> None: + """Save window data to graph DB and evidence file.""" + if self._graph_db: + self._graph_db.save_window_data(parsed, w_end) + if self.config.output_dir: + self._append_evidence(parsed) + + def _analyze_window(self) -> None: + """Analyze a temporal window of frames using VLM.""" + # Extract window frames with guards + result = self._get_window_frames() + if result is None: + return + window_frames, state_snapshot = result + w_start, w_end = window_frames[0].timestamp_s, window_frames[-1].timestamp_s + + # Skip if scene hasn't changed + if tu.is_scene_stale(window_frames, self.config.stale_scene_threshold): + with self._state_lock: + self._last_analysis_time = w_end + return + + # Select diverse frames for analysis + window_frames = ( + adaptive_keyframes( # TODO: unclear if clip vs. diverse vs. this solution is best + window_frames, max_frames=self.config.max_frames_per_window + ) + ) + logger.info(f"analyzing [{w_start:.1f}-{w_end:.1f}s] with {len(window_frames)} frames") + + # Query VLM and parse response + response_text = self._query_vlm_for_window(window_frames, state_snapshot, w_start, w_end) + if response_text is None: + with self._state_lock: + self._last_analysis_time = w_end + return + + parsed = tu.parse_window_response(response_text, w_start, w_end, len(window_frames)) + if "_error" in parsed: + logger.error(f"parse error: {parsed['_error']}") + # else: + # logger.info(f"parsed. caption: {parsed.get('caption', '')[:100]}") + + # Start distance estimation in background + if self._graph_db and window_frames and self.config.enable_distance_estimation: + mid_frame = window_frames[len(window_frames) // 2] + if mid_frame.image: + thread = threading.Thread( + target=self._graph_db.estimate_and_save_distances, + args=(parsed, mid_frame.image, self.vlm, w_end, self.config.max_distance_pairs), + daemon=True, + ) + thread.start() + self._distance_threads = [t for t in self._distance_threads if t.is_alive()] + self._distance_threads.append(thread) + + # Update temporal state + with self._state_lock: + needs_summary = tu.update_state_from_window( + self._state, parsed, w_end, self.config.summary_interval_s + ) + self._recent_windows.append(parsed) + self._last_analysis_time = w_end + + # Save artifacts + self._save_window_artifacts(parsed, w_end) + + # Trigger summary update if needed + if needs_summary: + logger.info(f"updating summary at t≈{w_end:.1f}s") + self._update_rolling_summary(w_end) + + # Periodic state saves + with self._state_lock: + window_count = len(self._recent_windows) + if window_count % 10 == 0: + self.save_state() + self.save_entities() + + def _update_rolling_summary(self, w_end: float) -> None: + with self._state_lock: + if self._stopped: + return + rolling_summary = str(self._state.get("rolling_summary", "")) + chunk_buffer = list(self._state.get("chunk_buffer", [])) + latest_frame = self._frame_buffer[-1].image if self._frame_buffer else None + + if not chunk_buffer or not latest_frame: + return + + prompt = tu.build_summary_prompt( + rolling_summary=rolling_summary, chunk_windows=chunk_buffer + ) + + try: + summary_text = self.vlm.query(latest_frame, prompt) + if summary_text and summary_text.strip(): + with self._state_lock: + if self._stopped: + return + tu.apply_summary_update( + self._state, summary_text, w_end, self.config.summary_interval_s + ) + logger.info(f"updated summary: {summary_text[:100]}...") + if self.config.output_dir and not self._stopped: + self.save_state() + self.save_entities() + except Exception as e: + logger.error(f"summary update failed: {e}", exc_info=True) + + @skill() + def query(self, question: str) -> str: + """Answer a question about the video stream using temporal memory and graph knowledge. + + This skill analyzes the current video stream and temporal memory state + to answer questions about what is happening, what entities are present, + recent events, spatial relationships, and conceptual knowledge. + + The system automatically accesses three knowledge graphs: + - Interactions: relationships between entities (holds, looks_at, talks_to) + - Spatial: distance and proximity information + - Semantic: conceptual relationships (goes_with, used_for, etc.) + + Example: + query("What entities are currently visible?") + query("What did I do last week?") + query("Where did I leave my keys?") + query("What objects are near the person?") + + Args: + question (str): The question to ask about the video stream. + Examples: "What entities are visible?", "What happened recently?", + "Is there a person in the scene?", "What am I holding?" + + Returns: + str: Answer based on temporal memory, graph knowledge, and current frame. + """ + # read state + with self._state_lock: + entity_roster = list(self._state.get("entity_roster", [])) + rolling_summary = str(self._state.get("rolling_summary", "")) + last_present = list(self._state.get("last_present", [])) + recent_windows = list(self._recent_windows) + if self._frame_buffer: + latest_frame = self._frame_buffer[-1].image + current_video_time_s = self._frame_buffer[-1].timestamp_s + else: + latest_frame = None + current_video_time_s = 0.0 + + if not latest_frame: + return "no frames available" + + # build context from temporal state + # Include entities from last_present and recent windows (both entities_present and new_entities) + currently_present = {e["id"] for e in last_present if isinstance(e, dict) and "id" in e} + for window in recent_windows[-3:]: + # Add entities that were present + for entity in window.get("entities_present", []): + if isinstance(entity, dict) and isinstance(entity.get("id"), str): + currently_present.add(entity["id"]) + # Also include newly detected entities (they're present now) + for entity in window.get("new_entities", []): + if isinstance(entity, dict) and isinstance(entity.get("id"), str): + currently_present.add(entity["id"]) + + context = { + "entity_roster": entity_roster, + "rolling_summary": rolling_summary, + "currently_present_entities": sorted(currently_present), + "recent_windows_count": len(recent_windows), + "timestamp": time.time(), + } + + # enhance context with graph database knowledge + if self._graph_db: + # Extract time window from question using VLM + time_window_s = tu.extract_time_window(question, self.vlm, latest_frame) + + # Query graph for ALL entities in roster (not just currently present) + # This allows queries about entities that disappeared or were seen in the past + all_entity_ids = [e["id"] for e in entity_roster if isinstance(e, dict) and "id" in e] + + if all_entity_ids: + graph_context = tu.build_graph_context( + graph_db=self._graph_db, + entity_ids=all_entity_ids, + time_window_s=time_window_s, + max_relations_per_entity=self.config.max_relations_per_entity, + nearby_distance_meters=self.config.nearby_distance_meters, + current_video_time_s=current_video_time_s, + ) + context["graph_knowledge"] = graph_context + + # build query prompt using temporal utils + prompt = tu.build_query_prompt(question=question, context=context) + + # query vlm (slow, outside lock) + try: + answer_text = self.vlm.query(latest_frame, prompt) + return answer_text.strip() + except Exception as e: + logger.error(f"query failed: {e}", exc_info=True) + return f"error: {e}" + + @rpc + def clear_history(self) -> bool: + """Clear temporal memory state.""" + try: + with self._state_lock: + self._state = tu.default_state() + self._state["next_summary_at_s"] = float(self.config.summary_interval_s) + self._recent_windows.clear() + logger.info("cleared history") + return True + except Exception as e: + logger.error(f"clear_history failed: {e}", exc_info=True) + return False + + @rpc + def get_state(self) -> dict[str, Any]: + with self._state_lock: + return { + "entity_count": len(self._state.get("entity_roster", [])), + "entities": list(self._state.get("entity_roster", [])), + "rolling_summary": str(self._state.get("rolling_summary", "")), + "frame_count": self._frame_count, + "buffer_size": len(self._frame_buffer), + "recent_windows": len(self._recent_windows), + "currently_present": list(self._state.get("last_present", [])), + } + + @rpc + def get_entity_roster(self) -> list[dict[str, Any]]: + with self._state_lock: + return list(self._state.get("entity_roster", [])) + + @rpc + def get_rolling_summary(self) -> str: + with self._state_lock: + return str(self._state.get("rolling_summary", "")) + + @rpc + def get_graph_db_stats(self) -> dict[str, Any]: + """Get statistics and sample data from the graph database. + + Returns empty structures when no database is available (no-error pattern). + """ + if not self._graph_db: + return {"stats": {}, "entities": [], "recent_relations": []} + return self._graph_db.get_summary() + + @rpc + def save_state(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + # Don't save if stopped (state has been cleared) + if self._stopped: + return False + state_copy = self._state.copy() + with open(self._state_file, "w") as f: + json.dump(state_copy, f, indent=2, ensure_ascii=False) + logger.info(f"saved state to {self._state_file}") + return True + except Exception as e: + logger.error(f"save state failed: {e}", exc_info=True) + return False + + def _append_evidence(self, evidence: dict[str, Any]) -> None: + try: + with open(self._evidence_file, "a") as f: + f.write(json.dumps(evidence, ensure_ascii=False) + "\n") + except Exception as e: + logger.error(f"append evidence failed: {e}", exc_info=True) + + def save_entities(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + # Don't save if stopped (state has been cleared) + if self._stopped: + return False + entity_roster = list(self._state.get("entity_roster", [])) + with open(self._entities_file, "w") as f: + json.dump(entity_roster, f, indent=2, ensure_ascii=False) + logger.info(f"saved {len(entity_roster)} entities") + return True + except Exception as e: + logger.error(f"save entities failed: {e}", exc_info=True) + return False + + def save_frames_index(self) -> bool: + if not self.config.output_dir: + return False + try: + with self._state_lock: + frames = list(self._frame_buffer) + + frames_index = [ + { + "frame_index": f.frame_index, + "timestamp_s": f.timestamp_s, + "timestamp": tu.format_timestamp(f.timestamp_s), + } + for f in frames + ] + + if frames_index: + with open(self._frames_index_file, "w", encoding="utf-8") as f: + for rec in frames_index: + f.write(json.dumps(rec, ensure_ascii=False) + "\n") + logger.info(f"saved {len(frames_index)} frames") + return True + except Exception as e: + logger.error(f"save frames failed: {e}", exc_info=True) + return False + + +temporal_memory = TemporalMemory.blueprint + +__all__ = ["Frame", "TemporalMemory", "TemporalMemoryConfig", "temporal_memory"] diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py new file mode 100644 index 0000000000..611385630e --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_deploy.py @@ -0,0 +1,60 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Deployment helpers for TemporalMemory module. +""" + +import os + +from dimos import spec +from dimos.core import DimosCluster +from dimos.models.vl.base import VlModel + +from .temporal_memory import TemporalMemory, TemporalMemoryConfig + + +def deploy( + dimos: DimosCluster, + camera: spec.Camera, + vlm: VlModel | None = None, + config: TemporalMemoryConfig | None = None, +) -> TemporalMemory: + """Deploy TemporalMemory with a camera. + + Args: + dimos: Dimos cluster instance + camera: Camera module to connect to + vlm: Optional VLM instance (creates OpenAI VLM if None) + config: Optional temporal memory configuration + """ + if vlm is None: + from dimos.models.vl.openai import OpenAIVlModel + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + vlm = OpenAIVlModel(api_key=api_key) + + temporal_memory = dimos.deploy(TemporalMemory, vlm=vlm, config=config) # type: ignore[attr-defined] + + if camera.color_image.transport is None: + from dimos.core.transport import JpegShmTransport + + transport = JpegShmTransport("/temporal_memory/color_image") + camera.color_image.transport = transport + + temporal_memory.color_image.connect(camera.color_image) + temporal_memory.start() + return temporal_memory # type: ignore[return-value,no-any-return] diff --git a/dimos/perception/experimental/temporal_memory/temporal_memory_example.py b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py new file mode 100644 index 0000000000..df435df3cc --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_memory_example.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example usage of TemporalMemory module with a VLM. + +This example demonstrates how to: +1. Deploy a camera module +2. Deploy TemporalMemory with the camera +3. Query the temporal memory about entities and events +""" + +import os +from pathlib import Path + +from dotenv import load_dotenv + +from dimos import core +from dimos.hardware.sensors.camera.module import CameraModule +from dimos.hardware.sensors.camera.webcam import Webcam + +from .temporal_memory import TemporalMemoryConfig +from .temporal_memory_deploy import deploy + +# Load environment variables +load_dotenv() + + +def example_usage() -> None: + """Example of how to use TemporalMemory.""" + # Initialize variables to None for cleanup + temporal_memory = None + camera = None + dimos = None + + try: + # Create Dimos cluster + dimos = core.start(1) + # Deploy camera module + camera = dimos.deploy(CameraModule, hardware=lambda: Webcam(camera_index=0)) # type: ignore[attr-defined] + camera.start() + + # Deploy temporal memory using the deploy function + output_dir = Path("./temporal_memory_output") + temporal_memory = deploy( + dimos, + camera, + vlm=None, # Will auto-create OpenAIVlModel if None + config=TemporalMemoryConfig( + fps=1.0, # Process 1 frame per second + window_s=2.0, # Analyze 2-second windows + stride_s=2.0, # New window every 2 seconds + summary_interval_s=10.0, # Update rolling summary every 10 seconds + max_frames_per_window=3, # Max 3 frames per window + output_dir=output_dir, + ), + ) + + print("TemporalMemory deployed and started!") + print(f"Artifacts will be saved to: {output_dir}") + + # Let it run for a bit to build context + print("Building temporal context... (wait ~15 seconds)") + import time + + time.sleep(20) + + # Query the temporal memory + questions = [ + "Are there any people in the scene?", + "Describe the main activity happening now", + "What has happened in the last few seconds?", + "What entities are currently visible?", + ] + + for question in questions: + print(f"\nQuestion: {question}") + answer = temporal_memory.query(question) + print(f"Answer: {answer}") + + # Get current state + state = temporal_memory.get_state() + print("\n=== Current State ===") + print(f"Entity count: {state['entity_count']}") + print(f"Frame count: {state['frame_count']}") + print(f"Rolling summary: {state['rolling_summary']}") + print(f"Entities: {state['entities']}") + + # Get entity roster + entities = temporal_memory.get_entity_roster() + print("\n=== Entity Roster ===") + for entity in entities: + print(f" {entity['id']}: {entity['descriptor']}") + + # Check graph database stats + graph_stats = temporal_memory.get_graph_db_stats() + print("\n=== Graph Database Stats ===") + if "error" in graph_stats: + print(f"Error: {graph_stats['error']}") + else: + print(f"Stats: {graph_stats['stats']}") + print(f"\nEntities in DB ({len(graph_stats['entities'])}):") + for entity in graph_stats["entities"]: + print(f" {entity['entity_id']} ({entity['entity_type']}): {entity['descriptor']}") + print(f"\nRecent relations ({len(graph_stats['recent_relations'])}):") + for rel in graph_stats["recent_relations"]: + print( + f" {rel['subject_id']} --{rel['relation_type']}--> {rel['object_id']} (confidence: {rel['confidence']:.2f})" + ) + + # Stop when done + temporal_memory.stop() + camera.stop() + print("\nTemporalMemory stopped") + + finally: + if temporal_memory is not None: + temporal_memory.stop() + if camera is not None: + camera.stop() + if dimos is not None: + dimos.close_all() # type: ignore[attr-defined] + + +if __name__ == "__main__": + example_usage() diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py new file mode 100644 index 0000000000..64950bee8a --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/__init__.py @@ -0,0 +1,60 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Temporal memory utilities for temporal memory. includes helper functions +and prompts that are used to build the prompt for the VLM. +""" + +# Re-export everything from submodules +from .graph_utils import build_graph_context, extract_time_window +from .helpers import clamp_text, format_timestamp, is_scene_stale, next_entity_id_hint +from .parsers import parse_batch_distance_response, parse_window_response +from .prompts import ( + WINDOW_RESPONSE_SCHEMA, + build_batch_distance_estimation_prompt, + build_distance_estimation_prompt, + build_query_prompt, + build_summary_prompt, + build_window_prompt, + get_structured_output_format, +) +from .state import apply_summary_update, default_state, update_state_from_window + +__all__ = [ + # Schema + "WINDOW_RESPONSE_SCHEMA", + # State management + "apply_summary_update", + # Prompts + "build_batch_distance_estimation_prompt", + "build_distance_estimation_prompt", + # Graph utils + "build_graph_context", + "build_query_prompt", + "build_summary_prompt", + "build_window_prompt", + # Helpers + "clamp_text", + "default_state", + "extract_time_window", + "format_timestamp", + "get_structured_output_format", + "is_scene_stale", + "next_entity_id_hint", + # Parsers + "parse_batch_distance_response", + "parse_window_response", + "update_state_from_window", +] diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py new file mode 100644 index 0000000000..bc55f7c65c --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/graph_utils.py @@ -0,0 +1,207 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph database utility functions for temporal memory.""" + +import re +import time +from typing import TYPE_CHECKING, Any + +from dimos.utils.logging_config import setup_logger + +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + from dimos.msgs.sensor_msgs import Image + + from ..entity_graph_db import EntityGraphDB + +logger = setup_logger() + + +def extract_time_window( + question: str, + vlm: "VlModel", + latest_frame: "Image | None" = None, +) -> float | None: + """Extract time window from question using VLM with example-based learning. + + Uses a few example keywords as patterns, then asks VLM to extrapolate + similar time references and return seconds. + + Args: + question: User's question + vlm: VLM instance to use for extraction + latest_frame: Optional frame (required for VLM call, but image is ignored) + + Returns: + Time window in seconds, or None if no time reference found + """ + question_lower = question.lower() + + # Quick check for common patterns (fast path) + if "last week" in question_lower or "past week" in question_lower: + return 7 * 24 * 3600 + if "today" in question_lower or "last hour" in question_lower: + return 3600 + if "recently" in question_lower or "recent" in question_lower: + return 600 + + # Use VLM to extract time reference from question + # Provide examples and let VLM extrapolate similar patterns + # Note: latest_frame is required by VLM interface but image content is ignored + if not latest_frame: + return None + + extraction_prompt = f"""Extract any time reference from this question and convert it to seconds. + +Question: {question} + +Examples of time references and their conversions: +- "last week" or "past week" -> 604800 seconds (7 days) +- "yesterday" -> 86400 seconds (1 day) +- "today" or "last hour" -> 3600 seconds (1 hour) +- "recently" or "recent" -> 600 seconds (10 minutes) +- "few minutes ago" -> 300 seconds (5 minutes) +- "just now" -> 60 seconds (1 minute) + +Extrapolate similar patterns (e.g., "2 days ago", "this morning", "last month", etc.) +and convert to seconds. If no time reference is found, return "none". + +Return ONLY a number (seconds) or the word "none". Do not include any explanation.""" + + try: + response = vlm.query(latest_frame, extraction_prompt) + response = response.strip().lower() + + if "none" in response or not response: + return None + + # Extract number from response + numbers = re.findall(r"\d+(?:\.\d+)?", response) + if numbers: + seconds = float(numbers[0]) + # Sanity check: reasonable time windows (1 second to 1 year) + if 1 <= seconds <= 365 * 24 * 3600: + return seconds + except Exception as e: + logger.debug(f"Time extraction failed: {e}") + + return None + + +def build_graph_context( + graph_db: "EntityGraphDB", + entity_ids: list[str], + time_window_s: float | None = None, + max_relations_per_entity: int = 10, + nearby_distance_meters: float = 5.0, + current_video_time_s: float | None = None, +) -> dict[str, Any]: + """Build enriched context from graph database for given entities. + + Args: + graph_db: Entity graph database instance + entity_ids: List of entity IDs to get context for + time_window_s: Optional time window in seconds (e.g., 3600 for last hour) + max_relations_per_entity: Maximum relations to include per entity (default: 10) + nearby_distance_meters: Distance threshold for "nearby" entities (default: 5.0) + current_video_time_s: Current video timestamp in seconds (for time window queries). + If None, uses latest entity timestamp from DB as reference. + + Returns: + Dictionary with graph context including relationships, distances, and semantics + """ + if not graph_db or not entity_ids: + return {} + + try: + graph_context: dict[str, Any] = { + "relationships": [], + "spatial_info": [], + "semantic_knowledge": [], + } + + # Convert time_window_s to a (start_ts, end_ts) tuple if provided + # Use video-relative timestamps, not wall-clock time + time_window_tuple = None + if time_window_s is not None: + if current_video_time_s is not None: + ref_time = current_video_time_s + else: + # Fallback: get the latest timestamp from entities in DB + all_entities = graph_db.get_all_entities() + ref_time = max((e.get("last_seen_ts", 0) for e in all_entities), default=0) + time_window_tuple = (max(0, ref_time - time_window_s), ref_time) + + # Get recent relationships for each entity + for entity_id in entity_ids: + # Get relationships (Graph 1: interactions) + relations = graph_db.get_relations_for_entity( + entity_id=entity_id, + relation_type=None, # all types + time_window=time_window_tuple, + ) + for rel in relations[-max_relations_per_entity:]: + graph_context["relationships"].append( + { + "subject": rel["subject_id"], + "relation": rel["relation_type"], + "object": rel["object_id"], + "confidence": rel["confidence"], + "when": rel["timestamp_s"], + } + ) + + # Get spatial relationships (Graph 2: distances) + nearby = graph_db.get_nearby_entities( + entity_id=entity_id, max_distance=nearby_distance_meters, latest_only=True + ) + for dist in nearby: + graph_context["spatial_info"].append( + { + "entity_a": entity_id, + "entity_b": dist["entity_id"], + "distance": dist.get("distance_meters"), + "category": dist.get("distance_category"), + "confidence": dist["confidence"], + } + ) + + # Get semantic knowledge (Graph 3: conceptual relations) + semantic_rels = graph_db.get_semantic_relations( + entity_id=entity_id, + relation_type=None, + ) + for sem in semantic_rels: + graph_context["semantic_knowledge"].append( + { + "entity_a": sem["entity_a_id"], + "relation": sem["relation_type"], + "entity_b": sem["entity_b_id"], + "confidence": sem["confidence"], + "observations": sem["observation_count"], + } + ) + + # Get graph statistics for context + if entity_ids: + stats = graph_db.get_stats() + graph_context["total_entities"] = stats.get("entities", 0) + graph_context["total_relations"] = stats.get("relations", 0) + + return graph_context + + except Exception as e: + logger.warning(f"failed to build graph context: {e}") + return {} diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py new file mode 100644 index 0000000000..513feb65a4 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/helpers.py @@ -0,0 +1,74 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper utility functions for temporal memory.""" + +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + from ..temporal_memory import Frame + + +def next_entity_id_hint(roster: Any) -> str: + """Generate next entity ID based on existing roster (e.g., E1, E2, E3...).""" + if not isinstance(roster, list): + return "E1" + max_n = 0 + for e in roster: + if not isinstance(e, dict): + continue + eid = e.get("id") + if isinstance(eid, str) and eid.startswith("E"): + tail = eid[1:] + if tail.isdigit(): + max_n = max(max_n, int(tail)) + return f"E{max_n + 1}" + + +def clamp_text(text: str, max_chars: int) -> str: + """Clamp text to maximum characters.""" + if len(text) <= max_chars: + return text + return text[:max_chars] + "..." + + +def format_timestamp(seconds: float) -> str: + """Format seconds as MM:SS.mmm timestamp string.""" + m = int(seconds // 60) + s = seconds - 60 * m + return f"{m:02d}:{s:06.3f}" + + +def is_scene_stale(frames: list["Frame"], stale_threshold: float = 5.0) -> bool: + """Check if scene hasn't changed meaningfully between first and last frame. + + Args: + frames: List of frames to check + stale_threshold: Threshold for mean pixel difference (default: 5.0) + + Returns: + True if scene is stale (hasn't changed enough), False otherwise + """ + if len(frames) < 2: + return False + first_img = frames[0].image + last_img = frames[-1].image + if first_img is None or last_img is None: + return False + if not hasattr(first_img, "data") or not hasattr(last_img, "data"): + return False + diff = np.abs(first_img.data.astype(float) - last_img.data.astype(float)) + return bool(diff.mean() < stale_threshold) diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py b/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py new file mode 100644 index 0000000000..a9b1a05d9f --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/parsers.py @@ -0,0 +1,156 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Response parsing functions for VLM outputs.""" + +from typing import Any + +from dimos.utils.llm_utils import extract_json + + +def parse_batch_distance_response( + response: str, entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]] +) -> list[dict[str, Any]]: + """ + Parse batched distance estimation response. + + Args: + response: VLM response text + entity_pairs: Original entity pairs used in the prompt + + Returns: + List of dicts with keys: entity_a_id, entity_b_id, category, distance_m, confidence + """ + results = [] + lines = response.strip().split("\n") + + current_pair_idx = None + category = None + distance_m = None + confidence = 0.5 + + for line in lines: + line = line.strip() + + # Check for pair marker + if line.startswith("Pair "): + # Save previous pair if exists + if current_pair_idx is not None and category: + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + # Start new pair + try: + pair_num = int(line.split()[1].rstrip(":")) + current_pair_idx = pair_num - 1 # Convert to 0-indexed + category = None + distance_m = None + confidence = 0.5 + except (IndexError, ValueError): + continue + + # Parse distance fields + elif line.startswith("category:"): + category = line.split(":", 1)[1].strip().lower() + elif line.startswith("distance_m:"): + try: + distance_m = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + elif line.startswith("confidence:"): + try: + confidence = float(line.split(":", 1)[1].strip()) + except (ValueError, IndexError): + pass + + # Save last pair + if current_pair_idx is not None and category and current_pair_idx < len(entity_pairs): + entity_a, entity_b = entity_pairs[current_pair_idx] + results.append( + { + "entity_a_id": entity_a["id"], + "entity_b_id": entity_b["id"], + "category": category, + "distance_m": distance_m, + "confidence": confidence, + } + ) + + return results + + +def parse_window_response( + response_text: str, w_start: float, w_end: float, frame_count: int +) -> dict[str, Any]: + """ + Parse VLM response for a window analysis. + + Args: + response_text: Raw text response from VLM + w_start: Window start time + w_end: Window end time + frame_count: Number of frames in window + + Returns: + Parsed dictionary with defaults filled in. If parsing fails, returns + a dict with "_error" key instead of raising. + """ + # Try to extract JSON (handles code fences) + parsed = extract_json(response_text) + if parsed is None: + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Failed to parse JSON from response: {response_text[:200]}...", + } + + # Ensure we return a dict (extract_json can return a list) + if isinstance(parsed, list): + # If we got a list, wrap it in a dict with a default structure + # This shouldn't happen with proper structured output, but handle gracefully + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Unexpected list response: {parsed}", + } + + # Ensure it's a dict + if not isinstance(parsed, dict): + return { + "window": {"start_s": w_start, "end_s": w_end}, + "caption": "", + "entities_present": [], + "new_entities": [], + "relations": [], + "on_screen_text": [], + "_error": f"Expected dict or list, got {type(parsed)}: {parsed}", + } + + return parsed diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py new file mode 100644 index 0000000000..61399fd3f1 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/prompts.py @@ -0,0 +1,353 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Prompt building functions for VLM queries.""" + +import json +from typing import Any + +from .helpers import clamp_text, next_entity_id_hint + +# JSON schema for window responses (from VideoRAG) +WINDOW_RESPONSE_SCHEMA = { + "type": "object", + "properties": { + "window": { + "type": "object", + "properties": {"start_s": {"type": "number"}, "end_s": {"type": "number"}}, + "required": ["start_s", "end_s"], + }, + "caption": {"type": "string"}, + "entities_present": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["id"], + }, + }, + "new_entities": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "type": { + "type": "string", + "enum": ["person", "object", "screen", "text", "location", "other"], + }, + "descriptor": {"type": "string"}, + }, + "required": ["id", "type"], + }, + }, + "relations": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"type": "string"}, + "subject": {"type": "string"}, + "object": {"type": "string"}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + "evidence": {"type": "array", "items": {"type": "string"}}, + "notes": {"type": "string"}, + }, + "required": ["type", "subject", "object"], + }, + }, + "on_screen_text": {"type": "array", "items": {"type": "string"}}, + "uncertainties": {"type": "array", "items": {"type": "string"}}, + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, + }, + "required": ["window", "caption"], +} + + +def build_window_prompt( + *, + w_start: float, + w_end: float, + frame_count: int, + state: dict[str, Any], +) -> str: + """ + Build comprehensive VLM prompt for analyzing a video window. + + This is adapted from videorag's build_window_messages() but formatted + as a single text prompt for VlModel.query() instead of OpenAI's messages format. + + Args: + w_start: Window start time in seconds + w_end: Window end time in seconds + frame_count: Number of frames in this window + state: Current temporal memory state (entity_roster, rolling_summary, etc.) + + Returns: + Formatted prompt string + """ + roster = state.get("entity_roster", []) + rolling_summary = state.get("rolling_summary", "") + next_id = next_entity_id_hint(roster) + + # System instructions (from VideoRAG) + system_context = """You analyze short sequences of video frames. +You must stay grounded in what is visible. +Do not identify real people or guess names/identities; describe people anonymously. +Extract general entities (people, objects, screens, text, locations) and relations between them. +Use stable entity IDs like E1, E2 based on the provided roster.""" + + # Main prompt (from VideoRAG's build_window_messages) + prompt = f"""{system_context} + +Time window: [{w_start:.3f}, {w_end:.3f}) seconds +Number of frames: {frame_count} + +Existing entity roster (may be empty): +{json.dumps(roster, ensure_ascii=False)} + +Rolling summary so far (may be empty): +{clamp_text(str(rolling_summary), 1500)} + +Task: +1) Write a dense, grounded caption describing what is visible across the frames in this time window. +2) Identify which existing roster entities appear in these frames. +3) Add any new salient entities (people/objects/screens/text/locations) with a short grounded descriptor. +4) Extract grounded relations/events between entities (e.g., looks_at, holds, uses, walks_past, speaks_to (inferred)). + +New entity IDs must start at: {next_id} + +Rules (important): +- You MUST stay grounded in what is visible in the provided frames. +- You MUST NOT mention any entity ID unless it appears in the provided roster OR you include it in new_entities in this same output. +- If the roster is empty, introduce any salient entities you reference (start with E1, E2, ...). +- Do not invent on-screen text: only include text you can read. +- If a relation is inferred (e.g., speaks_to without audio), include it but lower confidence and explain the visual cues. + +Output JSON ONLY with this schema: +{{ + "window": {{"start_s": {w_start:.3f}, "end_s": {w_end:.3f}}}, + "caption": "dense grounded description", + "entities_present": [{{"id": "E1", "confidence": 0.0-1.0}}], + "new_entities": [{{"id": "E3", "type": "person|object|screen|text|location|other", "descriptor": "..."}}], + "relations": [ + {{ + "type": "speaks_to|looks_at|holds|uses|moves|gesture|scene_change|other", + "subject": "E1|unknown", + "object": "E2|unknown", + "confidence": 0.0-1.0, + "evidence": ["describe which frames show this"], + "notes": "short, grounded" + }} + ], + "on_screen_text": ["verbatim snippets"], + "uncertainties": ["things that are unclear"], + "confidence": 0.0-1.0 +}} +""" + return prompt + + +def build_summary_prompt( + *, + rolling_summary: str, + chunk_windows: list[dict[str, Any]], +) -> str: + """ + Build prompt for updating rolling summary. + + This is adapted from videorag's build_summary_messages() but formatted + as a single text prompt for VlModel.query(). + + Args: + rolling_summary: Current rolling summary text + chunk_windows: List of recent window results to incorporate + + Returns: + Formatted prompt string + """ + # System context (from VideoRAG) + system_context = """You summarize timestamped video-window logs into a concise rolling summary. +Stay grounded in the provided window captions/relations. +Do not invent entities or rename entity IDs; preserve IDs like E1, E2 exactly. +You MAY incorporate new entity IDs if they appear in the provided chunk windows (e.g., in new_entities). +Be concise, but keep relevant entity continuity and key relations.""" + + prompt = f"""{system_context} + +Update the rolling summary using the newest chunk. + +Previous rolling summary (may be empty): +{clamp_text(rolling_summary, 2500)} + +New chunk windows (JSON): +{json.dumps(chunk_windows, ensure_ascii=False)} + +Output a concise summary as PLAIN TEXT (no JSON, no code fences). +Length constraints (important): +- Target <= 120 words total. +- Hard cap <= 900 characters. +""" + return prompt + + +def build_query_prompt( + *, + question: str, + context: dict[str, Any], +) -> str: + """ + Build prompt for querying temporal memory. + + Args: + question: User's question about the video stream + context: Context dict containing entity_roster, rolling_summary, etc. + + Returns: + Formatted prompt string + """ + currently_present = context.get("currently_present_entities", []) + currently_present_str = ( + f"Entities recently detected in recent windows: {currently_present}" + if currently_present + else "No entities were detected in recent windows (list is empty)" + ) + + prompt = f"""Answer the following question about the video stream using the provided context. + +**Question:** {question} + +**Context:** +{json.dumps(context, indent=2, ensure_ascii=False)} + +**Important Notes:** +- Entities have stable IDs like E1, E2, etc. +- The 'currently_present_entities' list contains entity IDs that were detected in recent video windows (not necessarily in the current frame you're viewing) +- {currently_present_str} +- The 'entity_roster' contains all known entities with their descriptions +- The 'rolling_summary' describes what has happened over time +- If 'currently_present_entities' is empty, it means no entities were detected in recent windows, but entities may still exist in the roster from earlier +- Answer based on the provided context (entity_roster, rolling_summary, currently_present_entities) AND what you see in the current frame +- If the context says entities were present but you don't see them in the current frame, mention both: what was recently detected AND what you currently see + +Provide a concise answer. +""" + return prompt + + +def build_distance_estimation_prompt( + *, + entity_a_descriptor: str, + entity_a_id: str, + entity_b_descriptor: str, + entity_b_id: str, +) -> str: + """ + Build prompt for estimating distance between two entities. + + Args: + entity_a_descriptor: Description of first entity + entity_a_id: ID of first entity + entity_b_descriptor: Description of second entity + entity_b_id: ID of second entity + + Returns: + Formatted prompt string for distance estimation + """ + prompt = f"""Look at this image and estimate the distance between these two entities: + +Entity A: {entity_a_descriptor} (ID: {entity_a_id}) +Entity B: {entity_b_descriptor} (ID: {entity_b_id}) + +Provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] +reasoning: [brief explanation]""" + return prompt + + +def build_batch_distance_estimation_prompt( + entity_pairs: list[tuple[dict[str, Any], dict[str, Any]]], +) -> str: + """ + Build prompt for estimating distances between multiple entity pairs in one call. + + Args: + entity_pairs: List of (entity_a, entity_b) tuples, each entity is a dict with 'id' and 'descriptor' + + Returns: + Formatted prompt string for batched distance estimation + """ + pairs_text = [] + for i, (entity_a, entity_b) in enumerate(entity_pairs, 1): + pairs_text.append( + f"Pair {i}:\n" + f" Entity A: {entity_a['descriptor']} (ID: {entity_a['id']})\n" + f" Entity B: {entity_b['descriptor']} (ID: {entity_b['id']})" + ) + + prompt = f"""Look at this image and estimate the distances between the following entity pairs: + +{chr(10).join(pairs_text)} + +For each pair, provide: +1. Distance category: "near" (< 1m), "medium" (1-3m), or "far" (> 3m) +2. Approximate distance in meters (best guess) +3. Confidence: 0.0-1.0 (how certain are you?) + +Respond in this format (one block per pair): +Pair 1: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +Pair 2: +category: [near/medium/far] +distance_m: [number] +confidence: [0.0-1.0] + +(etc.)""" + return prompt + + +def get_structured_output_format() -> dict[str, Any]: + """ + Get OpenAI-compatible structured output format for window responses. + + This uses the json_schema mode available in OpenAI API (GPT-4o mini) to enforce + the VideoRAG response schema. + + Returns: + Dictionary for response_format parameter: + {"type": "json_schema", "json_schema": {...}} + """ + + return { + "type": "json_schema", + "json_schema": { + "name": "video_window_analysis", + "description": "Analysis of a video window with entities and relations", + "schema": WINDOW_RESPONSE_SCHEMA, + "strict": False, # Allow additional fields + }, + } diff --git a/dimos/perception/experimental/temporal_memory/temporal_utils/state.py b/dimos/perception/experimental/temporal_memory/temporal_utils/state.py new file mode 100644 index 0000000000..9cdfbe4931 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/temporal_utils/state.py @@ -0,0 +1,139 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""State management functions for temporal memory.""" + +from typing import Any + + +def default_state() -> dict[str, Any]: + """Create default temporal memory state dictionary.""" + return { + "entity_roster": [], + "rolling_summary": "", + "chunk_buffer": [], + "next_summary_at_s": 0.0, + "last_present": [], + } + + +def update_state_from_window( + state: dict[str, Any], + parsed: dict[str, Any], + w_end: float, + summary_interval_s: float, +) -> bool: + """ + Update temporal memory state from a parsed window result. + + This implements the state update logic from VideoRAG's generate_evidence(). + + Args: + state: Current state dictionary (modified in place) + parsed: Parsed window result + w_end: Window end time + summary_interval_s: How often to trigger summary updates + + Returns: + True if summary update is needed, False otherwise + """ + # Skip if there was an error + if "_error" in parsed: + return False + + new_entities = parsed.get("new_entities", []) + present = parsed.get("entities_present", []) + + # Handle new entities + if new_entities: + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + for e in new_entities: + if isinstance(e, dict) and e.get("id") not in known: + roster.append(e) + known.add(e.get("id")) + state["entity_roster"] = roster + + # Handle referenced entities (auto-add if mentioned but not in roster) + roster = list(state.get("entity_roster", [])) + known = {e.get("id") for e in roster if isinstance(e, dict)} + referenced: set[str] = set() + for p in present or []: + if isinstance(p, dict) and isinstance(p.get("id"), str): + referenced.add(p["id"]) + for rel in parsed.get("relations") or []: + if isinstance(rel, dict): + for k in ("subject", "object"): + v = rel.get(k) + if isinstance(v, str) and v != "unknown": + referenced.add(v) + for rid in sorted(referenced): + if rid not in known: + roster.append( + { + "id": rid, + "type": "other", + "descriptor": "unknown (auto-added; rerun recommended)", + } + ) + known.add(rid) + state["entity_roster"] = roster + state["last_present"] = present + + # Add to chunk buffer + chunk_buffer = state.get("chunk_buffer", []) + if not isinstance(chunk_buffer, list): + chunk_buffer = [] + chunk_buffer.append( + { + "window": parsed.get("window"), + "caption": parsed.get("caption", ""), + "entities_present": parsed.get("entities_present", []), + "new_entities": parsed.get("new_entities", []), + "relations": parsed.get("relations", []), + "on_screen_text": parsed.get("on_screen_text", []), + } + ) + state["chunk_buffer"] = chunk_buffer + + # Check if summary update is needed + if summary_interval_s > 0: + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + if w_end + 1e-6 >= next_at and chunk_buffer: + return True # Need to update summary + + return False + + +def apply_summary_update( + state: dict[str, Any], summary_text: str, w_end: float, summary_interval_s: float +) -> None: + """ + Apply a summary update to the state. + + Args: + state: State dictionary (modified in place) + summary_text: New summary text + w_end: Current window end time + summary_interval_s: Summary update interval + """ + if summary_text and summary_text.strip(): + state["rolling_summary"] = summary_text.strip() + state["chunk_buffer"] = [] + + # Advance next_summary_at_s + next_at = float(state.get("next_summary_at_s", summary_interval_s)) + while next_at <= w_end + 1e-6: + next_at += float(summary_interval_s) + state["next_summary_at_s"] = next_at diff --git a/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py new file mode 100644 index 0000000000..b4af1b4020 --- /dev/null +++ b/dimos/perception/experimental/temporal_memory/test_temporal_memory_module.py @@ -0,0 +1,229 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import pathlib +import tempfile +import time + +from dotenv import load_dotenv +import pytest +from reactivex import operators as ops + +from dimos import core +from dimos.core import Module, Out, rpc +from dimos.models.vl.openai import OpenAIVlModel +from dimos.msgs.sensor_msgs import Image +from dimos.perception.experimental.temporal_memory import TemporalMemory, TemporalMemoryConfig +from dimos.protocol import pubsub +from dimos.utils.data import get_data +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +# Load environment variables +load_dotenv() + +logger = setup_logger() + +pubsub.lcm.autoconf() + + +class VideoReplayModule(Module): + """Module that replays video data from TimedSensorReplay.""" + + video_out: Out[Image] + + def __init__(self, video_path: str) -> None: + super().__init__() + self.video_path = video_path + + @rpc + def start(self) -> None: + """Start replaying video data.""" + # Use TimedSensorReplay to replay video frames + video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) + + # Subscribe to the replay stream and publish to LCM + self._disposables.add( + video_replay.stream() + .pipe( + ops.sample(1), # Sample every 1 second + ops.take(10), # Only take 10 frames total + ) + .subscribe(self.video_out.publish) + ) + + logger.info("VideoReplayModule started") + + @rpc + def stop(self) -> None: + """Stop replaying video data.""" + # Stop all stream transports to clean up LCM loop threads + for stream in list(self.outputs.values()): + if stream.transport is not None and hasattr(stream.transport, "stop"): + stream.transport.stop() + stream._transport = None + super().stop() + logger.info("VideoReplayModule stopped") + + +@pytest.mark.skipif(bool(os.getenv("CI")), reason="LCM replay + dataset not CI-safe.") +@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set.") +class TestTemporalMemoryModule: + @pytest.fixture(scope="function") + def temp_dir(self): + """Create a temporary directory for test data.""" + temp_dir = tempfile.mkdtemp(prefix="temporal_memory_test_") + yield temp_dir + + @pytest.fixture(scope="function") + def dimos_cluster(self): + """Create and cleanup Dimos cluster.""" + dimos = core.start(1) + yield dimos + dimos.close_all() + + @pytest.fixture(scope="function") + def video_module(self, dimos_cluster): + """Create and cleanup video replay module.""" + data_path = get_data("unitree_office_walk") + video_path = os.path.join(data_path, "video") + video_module = dimos_cluster.deploy(VideoReplayModule, video_path) + video_module.video_out.transport = core.LCMTransport("/test_video", Image) + yield video_module + try: + video_module.stop() + except Exception as e: + logger.warning(f"Failed to stop video_module: {e}") + + @pytest.fixture(scope="function") + def temporal_memory(self, dimos_cluster, temp_dir): + """Create and cleanup temporal memory module.""" + output_dir = os.path.join(temp_dir, "temporal_memory_output") + # Create OpenAIVlModel instance + api_key = os.getenv("OPENAI_API_KEY") + vlm = OpenAIVlModel(api_key=api_key) + + temporal_memory = dimos_cluster.deploy( + TemporalMemory, + vlm=vlm, + config=TemporalMemoryConfig( + fps=1.0, # Process 1 frame per second + window_s=2.0, # Analyze 2-second windows + stride_s=2.0, # New window every 2 seconds + summary_interval_s=10.0, # Update rolling summary every 10 seconds + max_frames_per_window=3, # Max 3 frames per window + output_dir=output_dir, + ), + ) + yield temporal_memory + try: + temporal_memory.stop() + except Exception as e: + logger.warning(f"Failed to stop temporal_memory: {e}") + + @pytest.mark.asyncio + async def test_temporal_memory_module_with_replay( + self, dimos_cluster, video_module, temporal_memory, temp_dir + ): + """Test TemporalMemory module with TimedSensorReplay inputs.""" + # Connect streams + temporal_memory.color_image.connect(video_module.video_out) + + # Start all modules + video_module.start() + temporal_memory.start() + logger.info("All modules started, processing in background...") + + # Wait for frames to be processed with timeout + timeout = 15.0 # 15 second timeout + start_time = time.time() + + # Keep checking state while modules are running + while (time.time() - start_time) < timeout: + state = temporal_memory.get_state() + if state["frame_count"] > 0: + logger.info( + f"Frames processing - Frame count: {state['frame_count']}, " + f"Buffer size: {state['buffer_size']}, " + f"Entity count: {state['entity_count']}" + ) + if state["frame_count"] >= 3: # Wait for at least 3 frames + break + await asyncio.sleep(0.5) + else: + # Timeout reached + state = temporal_memory.get_state() + logger.error( + f"Timeout after {timeout}s - Frame count: {state['frame_count']}, " + f"Buffer size: {state['buffer_size']}" + ) + raise AssertionError(f"No frames processed within {timeout} seconds") + + await asyncio.sleep(3) # Wait for more processing + + # Test get_state() RPC method + mid_state = temporal_memory.get_state() + logger.info( + f"Mid-test state - Frame count: {mid_state['frame_count']}, " + f"Entity count: {mid_state['entity_count']}, " + f"Recent windows: {mid_state['recent_windows']}" + ) + assert mid_state["frame_count"] >= state["frame_count"], ( + "Frame count should increase or stay same" + ) + + # Test query() RPC method + answer = temporal_memory.query("What entities are currently visible?") + logger.info(f"Query result: {answer[:200]}...") + assert len(answer) > 0, "Query should return a non-empty answer" + + # Test get_entity_roster() RPC method + entities = temporal_memory.get_entity_roster() + logger.info(f"Entity roster has {len(entities)} entities") + assert isinstance(entities, list), "Entity roster should be a list" + + # Test get_rolling_summary() RPC method + summary = temporal_memory.get_rolling_summary() + logger.info(f"Rolling summary: {summary[:200] if summary else 'empty'}...") + assert isinstance(summary, str), "Rolling summary should be a string" + + final_state = temporal_memory.get_state() + logger.info( + f"Final state - Frame count: {final_state['frame_count']}, " + f"Entity count: {final_state['entity_count']}, " + f"Recent windows: {final_state['recent_windows']}" + ) + + video_module.stop() + temporal_memory.stop() + logger.info("Stopped modules") + + # Wait a bit for file operations to complete + await asyncio.sleep(0.5) + + # Verify files were created - stop() already saved them + output_dir = os.path.join(temp_dir, "temporal_memory_output") + output_path = pathlib.Path(output_dir) + assert output_path.exists(), f"Output directory should exist: {output_dir}" + assert (output_path / "state.json").exists(), "state.json should exist" + assert (output_path / "entities.json").exists(), "entities.json should exist" + assert (output_path / "frames_index.jsonl").exists(), "frames_index.jsonl should exist" + + logger.info("All temporal memory module tests passed!") + + +if __name__ == "__main__": + pytest.main(["-v", "-s", __file__]) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 05658e11c2..eebc0002a2 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -21,6 +21,7 @@ "unitree-go2-nav": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:nav", "unitree-go2-detection": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:detection", "unitree-go2-spatial": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:spatial", + "unitree-go2-temporal-memory": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:temporal_memory", "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", "unitree-go2-agentic-mcp": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_mcp", "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index 75211ad3a9..be53702412 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -46,6 +46,7 @@ replanning_a_star_planner, ) from dimos.perception.detection.module3D import Detection3DModule, detection3d_module +from dimos.perception.experimental.temporal_memory import temporal_memory from dimos.perception.spatial_perception import spatial_memory from dimos.protocol.mcp.mcp import MCPModule from dimos.robot.foxglove_bridge import foxglove_bridge @@ -208,3 +209,8 @@ vlm_agent(), vlm_stream_tester(), ) + +temporal_memory = autoconnect( + agentic, + temporal_memory(), +)