-
Notifications
You must be signed in to change notification settings - Fork 159
class VLMAgent(AgentSpec, Module) for streamed VLM queries over Transport #960
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1ec549c
5c6ee9f
1486f48
7ef72b7
a704e20
671c022
8f21f22
e70b8a1
f2502f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| # Copyright 2025-2026 Dimensional Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from typing import cast | ||
|
|
||
| from langchain.chat_models import init_chat_model | ||
| from langchain_core.language_models.chat_models import BaseChatModel | ||
| from langchain_core.messages import SystemMessage | ||
| from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline | ||
|
|
||
| from dimos.agents.ollama_agent import ensure_ollama_model | ||
| from dimos.agents.spec import AgentConfig | ||
| from dimos.agents.system_prompt import SYSTEM_PROMPT | ||
|
|
||
|
|
||
| def build_llm(config: AgentConfig) -> BaseChatModel: | ||
| if config.model_instance: | ||
| return config.model_instance | ||
|
|
||
| if config.provider.value.lower() == "ollama": | ||
| ensure_ollama_model(config.model) | ||
|
|
||
| if config.provider.value.lower() == "huggingface": | ||
| llm = HuggingFacePipeline.from_model_id( | ||
| model_id=config.model, | ||
| task="text-generation", | ||
| pipeline_kwargs={ | ||
| "max_new_tokens": 512, | ||
| "temperature": 0.7, | ||
| }, | ||
| ) | ||
| return ChatHuggingFace(llm=llm, model_id=config.model) | ||
|
|
||
| return cast( | ||
| "BaseChatModel", | ||
| init_chat_model( # type: ignore[call-overload] | ||
| model_provider=config.provider, | ||
| model=config.model, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def build_system_message(config: AgentConfig, *, append: str = "") -> SystemMessage: | ||
| if config.system_prompt: | ||
| if isinstance(config.system_prompt, str): | ||
| return SystemMessage(config.system_prompt + append) | ||
| if append: | ||
| config.system_prompt.content += append # type: ignore[operator] | ||
| return config.system_prompt | ||
|
|
||
| return SystemMessage(SYSTEM_PROMPT + append) |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,120 @@ | ||||||
| # Copyright 2025-2026 Dimensional Inc. | ||||||
| # | ||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| # you may not use this file except in compliance with the License. | ||||||
| # You may obtain a copy of the License at | ||||||
| # | ||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| # | ||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| # See the License for the specific language governing permissions and | ||||||
| # limitations under the License. | ||||||
|
|
||||||
| from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | ||||||
|
|
||||||
| from dimos.agents.llm_init import build_llm, build_system_message | ||||||
| from dimos.agents.spec import AgentSpec, AnyMessage | ||||||
| from dimos.core import rpc | ||||||
| from dimos.core.stream import In, Out | ||||||
| from dimos.msgs.sensor_msgs import Image | ||||||
| from dimos.utils.logging_config import setup_logger | ||||||
|
|
||||||
| logger = setup_logger() | ||||||
|
|
||||||
|
|
||||||
| class VLMAgent(AgentSpec): | ||||||
| """Stream-first agent for vision queries with optional RPC access.""" | ||||||
|
|
||||||
| color_image: In[Image] | ||||||
| query_stream: In[HumanMessage] | ||||||
| answer_stream: Out[AIMessage] | ||||||
|
|
||||||
| def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] | ||||||
| super().__init__(*args, **kwargs) | ||||||
| self._llm = build_llm(self.config) | ||||||
| self._latest_image: Image | None = None | ||||||
| self._history: list[AIMessage | HumanMessage] = [] | ||||||
| self._system_message = build_system_message(self.config) | ||||||
| self.publish(self._system_message) | ||||||
|
|
||||||
| @rpc | ||||||
| def start(self) -> None: | ||||||
| super().start() | ||||||
| self._disposables.add(self.color_image.subscribe(self._on_image)) # type: ignore[arg-type] | ||||||
| self._disposables.add(self.query_stream.subscribe(self._on_query)) # type: ignore[arg-type] | ||||||
|
|
||||||
| @rpc | ||||||
| def stop(self) -> None: | ||||||
| super().stop() | ||||||
|
|
||||||
| def _on_image(self, image: Image) -> None: | ||||||
| self._latest_image = image | ||||||
|
|
||||||
| def _on_query(self, msg: HumanMessage) -> None: | ||||||
| if not self._latest_image: | ||||||
| self.answer_stream.publish(AIMessage(content="No image available yet.")) | ||||||
| return | ||||||
|
|
||||||
| query_text = self._extract_text(msg) | ||||||
| response = self._invoke_image(self._latest_image, query_text) | ||||||
| self.answer_stream.publish(response) | ||||||
|
|
||||||
| def _extract_text(self, msg: HumanMessage) -> str: | ||||||
| content = msg.content | ||||||
| if isinstance(content, str): | ||||||
| return content | ||||||
| if isinstance(content, list): | ||||||
| for part in content: | ||||||
| if isinstance(part, dict) and part.get("type") == "text": | ||||||
| return str(part.get("text", "")) | ||||||
| return str(content) | ||||||
|
|
||||||
| def _invoke(self, msg: HumanMessage) -> AIMessage: | ||||||
| messages = [self._system_message, msg] | ||||||
| response = self._llm.invoke(messages) | ||||||
| self.append_history([msg, response]) # type: ignore[arg-type] | ||||||
| return response # type: ignore[return-value] | ||||||
|
|
||||||
| def _invoke_image(self, image: Image, query: str) -> AIMessage: | ||||||
| content = [{"type": "text", "text": query}, *image.agent_encode()] | ||||||
| return self._invoke(HumanMessage(content=content)) | ||||||
|
|
||||||
| @rpc | ||||||
| def clear_history(self): # type: ignore[no-untyped-def] | ||||||
| self._history.clear() | ||||||
|
|
||||||
| def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. syntax: Parameter type annotation incorrect -
Suggested change
|
||||||
| for msg_list in msgs: | ||||||
| for msg in msg_list: | ||||||
| self.publish(msg) # type: ignore[arg-type] | ||||||
| self._history.extend(msg_list) | ||||||
|
|
||||||
| def history(self) -> list[AnyMessage]: | ||||||
| return [self._system_message, *self._history] | ||||||
|
|
||||||
| @rpc | ||||||
| def register_skills( # type: ignore[no-untyped-def] | ||||||
| self, container, run_implicit_name: str | None = None | ||||||
| ) -> None: | ||||||
| logger.warning( | ||||||
| "VLMAgent does not manage skills; register_skills is a no-op", | ||||||
| container=str(container), | ||||||
| run_implicit_name=run_implicit_name, | ||||||
| ) | ||||||
|
Comment on lines
+98
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, so VLMAgent isn't a replacement for Agent? In that case would we be running two agent loops?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VLMAgent in this case is meant to be run one-off yeah. But i guess we could alos change VLMAgent to inherit from Agent instead of AgentSpec, so then i would run a agent loop / skill coordinator
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah so we need to do this because agent runs its own agent loop but listens to /skill. We haven't solved multi-agent yet. So if VLMAgent also inherited from Agent it would get bogged down by tool responses in parallel.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Soon we will fix this so each agent also has its own skill coordinator topics and then we can run 5 agents that don't clash |
||||||
|
|
||||||
| @rpc | ||||||
| def query(self, query: str): # type: ignore[no-untyped-def] | ||||||
| response = self._invoke(HumanMessage(query)) | ||||||
| return response.content | ||||||
|
|
||||||
| @rpc | ||||||
| def query_image(self, image: Image, query: str): # type: ignore[no-untyped-def] | ||||||
| response = self._invoke_image(image, query) | ||||||
| return response.content | ||||||
|
|
||||||
|
|
||||||
| vlm_agent = VLMAgent.blueprint | ||||||
|
|
||||||
| __all__ = ["VLMAgent", "vlm_agent"] | ||||||
Uh oh!
There was an error while loading. Please reload this page.