diff --git a/dimos/agents/__init__.py b/dimos/agents/__init__.py index 9e1dd2df77..2bac584249 100644 --- a/dimos/agents/__init__.py +++ b/dimos/agents/__init__.py @@ -9,7 +9,19 @@ from dimos.agents.agent import Agent, deploy from dimos.agents.spec import AgentSpec +from dimos.agents.vlm_agent import VLMAgent +from dimos.agents.vlm_stream_tester import VlmStreamTester from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Output, Reducer, Stream -__all__ = ["Agent", "AgentSpec", "Output", "Reducer", "Stream", "deploy", "skill"] +__all__ = [ + "Agent", + "AgentSpec", + "Output", + "Reducer", + "Stream", + "VLMAgent", + "VlmStreamTester", + "deploy", + "skill", +] diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index bf5ded4f00..e9c7c5d7b9 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -19,7 +19,6 @@ from typing import Any, TypedDict import uuid -from langchain.chat_models import init_chat_model from langchain_core.messages import ( AIMessage, HumanMessage, @@ -27,11 +26,9 @@ ToolCall, ToolMessage, ) -from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline -from dimos.agents.ollama_agent import ensure_ollama_model +from dimos.agents.llm_init import build_llm, build_system_message from dimos.agents.spec import AgentSpec, Model, Provider -from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.core import DimosCluster, rpc from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict from dimos.protocol.skill.skill import SkillContainer @@ -175,40 +172,9 @@ def __init__( # type: ignore[no-untyped-def] self._agent_id = str(uuid.uuid4()) self._agent_stopped = False - if self.config.system_prompt: - if isinstance(self.config.system_prompt, str): - self.system_message = SystemMessage(self.config.system_prompt + SYSTEM_MSG_APPEND) - else: - self.config.system_prompt.content += SYSTEM_MSG_APPEND # type: ignore[operator] - self.system_message = self.config.system_prompt - else: - self.system_message = SystemMessage(SYSTEM_PROMPT + SYSTEM_MSG_APPEND) - + self.system_message = build_system_message(self.config, append=SYSTEM_MSG_APPEND) self.publish(self.system_message) - - # Use provided model instance if available, otherwise initialize from config - if self.config.model_instance: - self._llm = self.config.model_instance - else: - # For Ollama provider, ensure the model is available before initializing - if self.config.provider.value.lower() == "ollama": - ensure_ollama_model(self.config.model) - - # For HuggingFace, we need to create a pipeline and wrap it in ChatHuggingFace - if self.config.provider.value.lower() == "huggingface": - llm = HuggingFacePipeline.from_model_id( - model_id=self.config.model, - task="text-generation", - pipeline_kwargs={ - "max_new_tokens": 512, - "temperature": 0.7, - }, - ) - self._llm = ChatHuggingFace(llm=llm, model_id=self.config.model) - else: - self._llm = init_chat_model( # type: ignore[call-overload] - model_provider=self.config.provider, model=self.config.model - ) + self._llm = build_llm(self.config) @rpc def get_agent_id(self) -> str: diff --git a/dimos/agents/llm_init.py b/dimos/agents/llm_init.py new file mode 100644 index 0000000000..eb8c33c631 --- /dev/null +++ b/dimos/agents/llm_init.py @@ -0,0 +1,62 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import cast + +from langchain.chat_models import init_chat_model +from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import SystemMessage +from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline + +from dimos.agents.ollama_agent import ensure_ollama_model +from dimos.agents.spec import AgentConfig +from dimos.agents.system_prompt import SYSTEM_PROMPT + + +def build_llm(config: AgentConfig) -> BaseChatModel: + if config.model_instance: + return config.model_instance + + if config.provider.value.lower() == "ollama": + ensure_ollama_model(config.model) + + if config.provider.value.lower() == "huggingface": + llm = HuggingFacePipeline.from_model_id( + model_id=config.model, + task="text-generation", + pipeline_kwargs={ + "max_new_tokens": 512, + "temperature": 0.7, + }, + ) + return ChatHuggingFace(llm=llm, model_id=config.model) + + return cast( + "BaseChatModel", + init_chat_model( # type: ignore[call-overload] + model_provider=config.provider, + model=config.model, + ), + ) + + +def build_system_message(config: AgentConfig, *, append: str = "") -> SystemMessage: + if config.system_prompt: + if isinstance(config.system_prompt, str): + return SystemMessage(config.system_prompt + append) + if append: + config.system_prompt.content += append # type: ignore[operator] + return config.system_prompt + + return SystemMessage(SYSTEM_PROMPT + append) diff --git a/dimos/agents/vlm_agent.py b/dimos/agents/vlm_agent.py new file mode 100644 index 0000000000..0757a59d22 --- /dev/null +++ b/dimos/agents/vlm_agent.py @@ -0,0 +1,120 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage + +from dimos.agents.llm_init import build_llm, build_system_message +from dimos.agents.spec import AgentSpec, AnyMessage +from dimos.core import rpc +from dimos.core.stream import In, Out +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class VLMAgent(AgentSpec): + """Stream-first agent for vision queries with optional RPC access.""" + + color_image: In[Image] + query_stream: In[HumanMessage] + answer_stream: Out[AIMessage] + + def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] + super().__init__(*args, **kwargs) + self._llm = build_llm(self.config) + self._latest_image: Image | None = None + self._history: list[AIMessage | HumanMessage] = [] + self._system_message = build_system_message(self.config) + self.publish(self._system_message) + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] + self._disposables.add(self.query_stream.subscribe(self._on_query)) # type: ignore[arg-type] + + @rpc + def stop(self) -> None: + super().stop() + + def _on_image(self, image: Image) -> None: + self._latest_image = image + + def _on_query(self, msg: HumanMessage) -> None: + if not self._latest_image: + self.answer_stream.publish(AIMessage(content="No image available yet.")) + return + + query_text = self._extract_text(msg) + response = self._invoke_image(self._latest_image, query_text) + self.answer_stream.publish(response) + + def _extract_text(self, msg: HumanMessage) -> str: + content = msg.content + if isinstance(content, str): + return content + if isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + return str(part.get("text", "")) + return str(content) + + def _invoke(self, msg: HumanMessage) -> AIMessage: + messages = [self._system_message, msg] + response = self._llm.invoke(messages) + self.append_history([msg, response]) # type: ignore[arg-type] + return response # type: ignore[return-value] + + def _invoke_image(self, image: Image, query: str) -> AIMessage: + content = [{"type": "text", "text": query}, *image.agent_encode()] + return self._invoke(HumanMessage(content=content)) + + @rpc + def clear_history(self): # type: ignore[no-untyped-def] + self._history.clear() + + def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None: + for msg_list in msgs: + for msg in msg_list: + self.publish(msg) # type: ignore[arg-type] + self._history.extend(msg_list) + + def history(self) -> list[AnyMessage]: + return [self._system_message, *self._history] + + @rpc + def register_skills( # type: ignore[no-untyped-def] + self, container, run_implicit_name: str | None = None + ) -> None: + logger.warning( + "VLMAgent does not manage skills; register_skills is a no-op", + container=str(container), + run_implicit_name=run_implicit_name, + ) + + @rpc + def query(self, query: str): # type: ignore[no-untyped-def] + response = self._invoke(HumanMessage(query)) + return response.content + + @rpc + def query_image(self, image: Image, query: str): # type: ignore[no-untyped-def] + response = self._invoke_image(image, query) + return response.content + + +vlm_agent = VLMAgent.blueprint + +__all__ = ["VLMAgent", "vlm_agent"] diff --git a/dimos/agents/vlm_stream_tester.py b/dimos/agents/vlm_stream_tester.py new file mode 100644 index 0000000000..79bb802a03 --- /dev/null +++ b/dimos/agents/vlm_stream_tester.py @@ -0,0 +1,179 @@ +# Copyright 2025-2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time + +from langchain_core.messages import AIMessage, HumanMessage + +from dimos.core import Module, rpc +from dimos.core.stream import In, Out +from dimos.msgs.sensor_msgs import Image +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + + +class VlmStreamTester(Module): + """Smoke-test VLMAgent with replayed images and stream queries.""" + + color_image: In[Image] + query_stream: Out[HumanMessage] + answer_stream: In[AIMessage] + + rpc_calls: list[str] = [ + "VLMAgent.query_image", + ] + + def __init__( # type: ignore[no-untyped-def] + self, + prompt: str = "What do you see?", + num_queries: int = 10, + query_interval_s: float = 2.0, + max_image_age_s: float = 1.5, + max_image_gap_s: float = 1.5, + ) -> None: + super().__init__() + self._prompt = prompt + self._num_queries = num_queries + self._query_interval_s = query_interval_s + self._max_image_age_s = max_image_age_s + self._max_image_gap_s = max_image_gap_s + self._latest_image: Image | None = None + self._latest_image_wall_ts: float | None = None + self._last_image_wall_ts: float | None = None + self._max_gap_seen_s = 0.0 + self._answer_count = 0 + self._stop_event = threading.Event() + self._worker: threading.Thread | None = None + + @rpc + def start(self) -> None: + super().start() + self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] + self._disposables.add(self.answer_stream.subscribe(self._on_answer)) # type: ignore[arg-type] + self._worker = threading.Thread(target=self._run_queries, daemon=True) + self._worker.start() + + @rpc + def stop(self) -> None: + self._stop_event.set() + if self._worker and self._worker.is_alive(): + self._worker.join(timeout=1.0) + super().stop() + + def _on_image(self, image: Image) -> None: + now = time.time() + if self._last_image_wall_ts is not None: + gap = now - self._last_image_wall_ts + if gap > self._max_gap_seen_s: + self._max_gap_seen_s = gap + self._last_image_wall_ts = now + self._latest_image_wall_ts = now + self._latest_image = image + + def _on_answer(self, msg: AIMessage) -> None: + self._answer_count += 1 + logger.info( + "VLMAgent stream answer", + count=self._answer_count, + content=msg.content, + ) + + def _run_queries(self) -> None: + try: + while not self._stop_event.is_set() and self._latest_image is None: + time.sleep(0.05) + + self._run_stream_queries() + self._run_rpc_queries() + except Exception as exc: + logger.exception("VlmStreamTester query loop failed", error=str(exc)) + finally: + if self._max_gap_seen_s > self._max_image_gap_s: + logger.warning( + "Image stream gap exceeded threshold", + max_gap_s=self._max_gap_seen_s, + threshold_s=self._max_image_gap_s, + ) + + def _run_stream_queries(self) -> None: + for idx in range(self._num_queries): + if self._stop_event.is_set(): + break + if self._latest_image is None: + logger.warning("No image available for stream query.") + break + + image_age = None + if self._latest_image_wall_ts is not None: + image_age = time.time() - self._latest_image_wall_ts + if image_age > self._max_image_age_s: + logger.warning( + "Latest image is stale", + age_s=image_age, + max_age_s=self._max_image_age_s, + ) + + logger.info("Sending stream query", index=idx + 1, total=self._num_queries) + self.query_stream.publish( + HumanMessage(content=f"{self._prompt} (stream query {idx + 1}/{self._num_queries})") + ) + time.sleep(self._query_interval_s) + + def _run_rpc_queries(self) -> None: + rpc_query = None + try: + rpc_query = self.get_rpc_calls("VLMAgent.query_image") + except Exception as exc: + logger.warning("RPC query_image lookup failed", error=str(exc)) + return + + for idx in range(self._num_queries): + if self._stop_event.is_set(): + break + if self._latest_image is None: + logger.warning("No image available for RPC query.") + break + + image_age = None + if self._latest_image_wall_ts is not None: + image_age = time.time() - self._latest_image_wall_ts + if image_age > self._max_image_age_s: + logger.warning( + "Latest image is stale", + age_s=image_age, + max_age_s=self._max_image_age_s, + ) + + logger.info("Sending RPC query", index=idx + 1, total=self._num_queries) + try: + response = rpc_query( + self._latest_image, + f"{self._prompt} (rpc query {idx + 1}/{self._num_queries})", + ) + logger.info( + "VLMAgent RPC answer", + query_index=idx + 1, + image_age_s=image_age, + content=response, + ) + except Exception as exc: + logger.warning("RPC query_image failed", error=str(exc)) + time.sleep(self._query_interval_s) + + +vlm_stream_tester = VlmStreamTester.blueprint + +__all__ = ["VlmStreamTester", "vlm_stream_tester"] diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 9b118cbd60..f989098f05 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -25,6 +25,7 @@ "unitree-go2-agentic-mcp": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_mcp", "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", "unitree-go2-agentic-huggingface": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_huggingface", + "unitree-go2-vlm-stream-test": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:vlm_stream_test", "unitree-g1": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard", "unitree-g1-sim": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_sim", "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_ros", diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index 46d951650c..7629644ed6 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -27,6 +27,8 @@ from dimos.agents.skills.navigation import navigation_skill from dimos.agents.skills.speak_skill import speak_skill from dimos.agents.spec import Provider +from dimos.agents.vlm_agent import vlm_agent +from dimos.agents.vlm_stream_tester import vlm_stream_tester from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import JpegLcmTransport, JpegShmTransport, LCMTransport, pSHMTransport @@ -190,3 +192,9 @@ ), _common_agentic, ) + +vlm_stream_test = autoconnect( + basic, + vlm_agent(), + vlm_stream_tester(), +)