diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index d902dba3da..7c37febec0 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -27,7 +27,9 @@ ToolCall, ToolMessage, ) +from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline +from dimos.agents2.ollama_agent import ensure_ollama_model from dimos.agents2.spec import AgentSpec, Model, Provider from dimos.agents2.system_prompt import get_system_prompt from dimos.core import DimosCluster, rpc @@ -192,9 +194,25 @@ def __init__( # type: ignore[no-untyped-def] if self.config.model_instance: self._llm = self.config.model_instance else: - self._llm = init_chat_model( # type: ignore[call-overload] - model_provider=self.config.provider, model=self.config.model - ) + # For Ollama provider, ensure the model is available before initializing + if self.config.provider.value.lower() == "ollama": + ensure_ollama_model(self.config.model) + + # For HuggingFace, we need to create a pipeline and wrap it in ChatHuggingFace + if self.config.provider.value.lower() == "huggingface": + llm = HuggingFacePipeline.from_model_id( + model_id=self.config.model, + task="text-generation", + pipeline_kwargs={ + "max_new_tokens": 512, + "temperature": 0.7, + }, + ) + self._llm = ChatHuggingFace(llm=llm, model_id=self.config.model) + else: + self._llm = init_chat_model( # type: ignore[call-overload] + model_provider=self.config.provider, model=self.config.model + ) @rpc def get_agent_id(self) -> str: @@ -278,7 +296,19 @@ def _get_state() -> str: # history() builds our message history dynamically # ensures we include latest system state, but not old ones. - msg = self._llm.invoke(self.history()) # type: ignore[no-untyped-call] + messages = self.history() # type: ignore[no-untyped-call] + + # Some LLMs don't work without any human messages. Add an initial one. + if len(messages) == 1 and isinstance(messages[0], SystemMessage): + messages.append( + HumanMessage( + "Everything is initialized. I'll let you know when you should act." + ) + ) + self.append_history(messages[-1]) + + msg = self._llm.invoke(messages) + self.append_history(msg) # type: ignore[arg-type] logger.info(f"Agent response: {msg.content}") diff --git a/dimos/agents2/ollama_agent.py b/dimos/agents2/ollama_agent.py new file mode 100644 index 0000000000..26179d7418 --- /dev/null +++ b/dimos/agents2/ollama_agent.py @@ -0,0 +1,39 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ollama + +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__file__) + + +def ensure_ollama_model(model_name: str) -> None: + available_models = ollama.list() + model_exists = any(model_name == m.model for m in available_models.models) + if not model_exists: + logger.info(f"Ollama model '{model_name}' not found. Pulling...") + ollama.pull(model_name) + + +def ollama_installed() -> str | None: + try: + ollama.list() + return None + except Exception: + return ( + "Cannot connect to Ollama daemon. Please ensure Ollama is installed and running.\n" + "\n" + " For installation instructions, visit https://ollama.com/download" + ) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 2512da4a9d..525c2de42a 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -14,11 +14,12 @@ from abc import ABC from collections import defaultdict -from collections.abc import Mapping +from collections.abc import Callable, Mapping from dataclasses import dataclass, field from functools import cached_property, reduce import inspect import operator +import sys from types import MappingProxyType from typing import Any, Literal, get_args, get_origin @@ -56,6 +57,7 @@ class ModuleBlueprintSet: remapping_map: Mapping[tuple[type[Module], str], str] = field( default_factory=lambda: MappingProxyType({}) ) + requirement_checks: tuple[Callable[[], str | None], ...] = field(default_factory=tuple) def transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprintSet": return ModuleBlueprintSet( @@ -63,6 +65,7 @@ def transports(self, transports: dict[tuple[str, type], Any]) -> "ModuleBlueprin transport_map=MappingProxyType({**self.transport_map, **transports}), global_config_overrides=self.global_config_overrides, remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks, ) def global_config(self, **kwargs: Any) -> "ModuleBlueprintSet": @@ -71,6 +74,7 @@ def global_config(self, **kwargs: Any) -> "ModuleBlueprintSet": transport_map=self.transport_map, global_config_overrides=MappingProxyType({**self.global_config_overrides, **kwargs}), remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks, ) def remappings(self, remappings: list[tuple[type[Module], str, str]]) -> "ModuleBlueprintSet": @@ -83,6 +87,16 @@ def remappings(self, remappings: list[tuple[type[Module], str, str]]) -> "Module transport_map=self.transport_map, global_config_overrides=self.global_config_overrides, remapping_map=MappingProxyType(remappings_dict), + requirement_checks=self.requirement_checks, + ) + + def requirements(self, *checks: Callable[[], str | None]) -> "ModuleBlueprintSet": + return ModuleBlueprintSet( + blueprints=self.blueprints, + transport_map=self.transport_map, + global_config_overrides=self.global_config_overrides, + remapping_map=self.remapping_map, + requirement_checks=self.requirement_checks + tuple(checks), ) def _get_transport_for(self, name: str, type: type) -> Any: @@ -110,6 +124,21 @@ def _all_name_types(self) -> set[tuple[str, type]]: def _is_name_unique(self, name: str) -> bool: return sum(1 for n, _ in self._all_name_types if n == name) == 1 + def _check_requirements(self) -> None: + errors = [] + red = "\033[31m" + reset = "\033[0m" + + for check in self.requirement_checks: + error = check() + if error: + errors.append(error) + + if errors: + for error in errors: + print(f"{red}Error: {error}{reset}", file=sys.stderr) + sys.exit(1) + def _verify_no_name_conflicts(self) -> None: name_to_types = defaultdict(set) name_to_modules = defaultdict(list) @@ -244,6 +273,7 @@ def build(self, global_config: GlobalConfig | None = None) -> ModuleCoordinator: global_config = GlobalConfig() global_config = global_config.model_copy(update=self.global_config_overrides) + self._check_requirements() self._verify_no_name_conflicts() module_coordinator = ModuleCoordinator(global_config=global_config) @@ -295,12 +325,14 @@ def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet: all_remappings = dict( # type: ignore[var-annotated] reduce(operator.iadd, [list(x.remapping_map.items()) for x in blueprints], []) ) + all_requirement_checks = tuple(check for bs in blueprints for check in bs.requirement_checks) return ModuleBlueprintSet( blueprints=all_blueprints, transport_map=MappingProxyType(all_transports), global_config_overrides=MappingProxyType(all_config_overrides), remapping_map=MappingProxyType(all_remappings), + requirement_checks=all_requirement_checks, ) diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 0838ce66fc..89f5c0e5d3 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -22,6 +22,8 @@ "unitree-go2-jpegshm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_jpegshm", "unitree-go2-jpeglcm": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard_with_jpeglcm", "unitree-go2-agentic": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic", + "unitree-go2-agentic-ollama": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_ollama", + "unitree-go2-agentic-huggingface": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:agentic_huggingface", "unitree-g1": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard", "unitree-g1-bt-nav": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:standard_bt_nav", "unitree-g1-basic": "dimos.robot.unitree_webrtc.unitree_g1_blueprints:basic_ros", diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index d2f006848a..e1383ae0f5 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -18,7 +18,9 @@ from dimos.agents2.agent import llm_agent from dimos.agents2.cli.human import human_input +from dimos.agents2.ollama_agent import ollama_installed from dimos.agents2.skills.navigation import navigation_skill +from dimos.agents2.spec import Provider from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import JpegLcmTransport, JpegShmTransport, LCMTransport, pSHMTransport @@ -107,10 +109,34 @@ ), ) -agentic = autoconnect( - standard, - llm_agent(), +_common_agentic = autoconnect( human_input(), navigation_skill(), unitree_skills(), ) + +agentic = autoconnect( + standard, + llm_agent(), + _common_agentic, +) + +agentic_ollama = autoconnect( + standard, + llm_agent( + model="qwen3:8b", + provider=Provider.OLLAMA, # type: ignore[attr-defined] + ), + _common_agentic, +).requirements( + ollama_installed, +) + +agentic_huggingface = autoconnect( + standard, + llm_agent( + model="Qwen/Qwen2.5-1.5B-Instruct", + provider=Provider.HUGGINGFACE, # type: ignore[attr-defined] + ), + _common_agentic, +) diff --git a/pyproject.toml b/pyproject.toml index f43da29e61..2a5eb6f5ff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,10 @@ dependencies = [ "langchain-core>=0.3.72", "langchain-openai>=0.3.28", "langchain-text-splitters>=0.3.9", + "langchain-huggingface>=0.3.1", + "langchain-ollama>=0.3.10", + "bitsandbytes>=0.48.2,<1.0", + "ollama>=0.6.0", # Class Extraction "pydantic",