Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
1 change: 0 additions & 1 deletion bin/mypy-strict
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ run_mypy() {
export MYPYPATH=/opt/ros/jazzy/lib/python3.12/site-packages

mypy_args=(
--config-file mypy_strict.ini
--show-error-codes
--hide-error-context
--no-pretty
Expand Down
165 changes: 86 additions & 79 deletions dimos/agents/agent.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion dimos/agents/agent_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def add_image(self, image: Image | AgentImage) -> None:
if isinstance(image, Image):
# Convert to AgentImage
agent_image = AgentImage(
base64_jpeg=image.agent_encode(),
base64_jpeg=image.agent_encode(), # type: ignore[arg-type]
width=image.width,
height=image.height,
metadata={"format": image.format.value, "frame_id": image.frame_id},
Expand Down
6 changes: 3 additions & 3 deletions dimos/agents/agent_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def to_openai_format(self) -> dict[str, Any]:
msg["content"] = self.content
else:
# Content is already a list of content blocks
msg["content"] = self.content
msg["content"] = self.content # type: ignore[assignment]

# Add tool calls if present
if self.tool_calls:
# Handle both ToolCall objects and dicts
if isinstance(self.tool_calls[0], dict):
msg["tool_calls"] = self.tool_calls
msg["tool_calls"] = self.tool_calls # type: ignore[assignment]
else:
msg["tool_calls"] = [
msg["tool_calls"] = [ # type: ignore[assignment]
{
"id": tc.id,
"type": "function",
Expand Down
44 changes: 22 additions & 22 deletions dimos/agents/claude_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

# Response object compatible with LLMAgent
class ResponseMessage:
def __init__(self, content: str = "", tool_calls=None, thinking_blocks=None) -> None:
def __init__(self, content: str = "", tool_calls=None, thinking_blocks=None) -> None: # type: ignore[no-untyped-def]
self.content = content
self.tool_calls = tool_calls or []
self.thinking_blocks = thinking_blocks or []
Expand Down Expand Up @@ -85,9 +85,9 @@ def __init__(
dev_name: str,
agent_type: str = "Vision",
query: str = "What do you see?",
input_query_stream: Observable | None = None,
input_video_stream: Observable | None = None,
input_data_stream: Observable | None = None,
input_query_stream: Observable | None = None, # type: ignore[type-arg]
input_video_stream: Observable | None = None, # type: ignore[type-arg]
input_data_stream: Observable | None = None, # type: ignore[type-arg]
output_dir: str = os.path.join(os.getcwd(), "assets", "agent"),
agent_memory: AbstractAgentSemanticMemory | None = None,
system_query: str | None = None,
Expand Down Expand Up @@ -158,7 +158,7 @@ def __init__(

# Claude-specific parameters
self.thinking_budget_tokens = thinking_budget_tokens
self.claude_api_params = {} # Will store params for Claude API calls
self.claude_api_params = {} # type: ignore[var-annotated] # Will store params for Claude API calls

# Configure skills
self.skills = skills
Expand Down Expand Up @@ -217,7 +217,7 @@ def _add_context_to_memory(self) -> None:
),
]
for doc_id, text in context_data:
self.agent_memory.add_vector(doc_id, text)
self.agent_memory.add_vector(doc_id, text) # type: ignore[no-untyped-call]

def _convert_tools_to_claude_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Expand Down Expand Up @@ -258,15 +258,15 @@ def _convert_tools_to_claude_format(self, tools: list[dict[str, Any]]) -> list[d

return claude_tools

def _build_prompt(
def _build_prompt( # type: ignore[override]
self,
messages: list,
messages: list, # type: ignore[type-arg]
base64_image: str | list[str] | None = None,
dimensions: tuple[int, int] | None = None,
override_token_limit: bool = False,
rag_results: str = "",
thinking_budget_tokens: int | None = None,
) -> list:
) -> list: # type: ignore[type-arg]
"""Builds a prompt message specifically for Claude API, using local messages copy."""
"""Builds a prompt message specifically for Claude API.

Expand Down Expand Up @@ -347,9 +347,9 @@ def _build_prompt(

# Store the parameters for use in _send_query and return them
self.claude_api_params = claude_params.copy()
return messages, claude_params
return messages, claude_params # type: ignore[return-value]

def _send_query(self, messages: list, claude_params: dict) -> Any:
def _send_query(self, messages: list, claude_params: dict) -> Any: # type: ignore[override, type-arg]
"""Sends the query to Anthropic's API using streaming for better thinking visualization.

Args:
Expand Down Expand Up @@ -397,7 +397,7 @@ def _send_query(self, messages: list, claude_params: dict) -> Any:
block_type = event.content_block.type
current_block = {
"type": block_type,
"id": event.index,
"id": event.index, # type: ignore[dict-item]
"content": "",
"signature": None,
}
Expand All @@ -413,7 +413,7 @@ def _send_query(self, messages: list, claude_params: dict) -> Any:
elif event.delta.type == "text_delta":
# Accumulate text content
text_content += event.delta.text
current_block["content"] += event.delta.text
current_block["content"] += event.delta.text # type: ignore[operator]
memory_file.write(f"{event.delta.text}")
memory_file.flush()

Expand Down Expand Up @@ -463,9 +463,9 @@ def _send_query(self, messages: list, claude_params: dict) -> Any:
# Process tool use blocks when they're complete
if hasattr(event, "content_block"):
tool_block = event.content_block
tool_id = tool_block.id
tool_name = tool_block.name
tool_input = tool_block.input
tool_id = tool_block.id # type: ignore[union-attr]
tool_name = tool_block.name # type: ignore[union-attr]
tool_input = tool_block.input # type: ignore[union-attr]

# Create a tool call object for LLMAgent compatibility
tool_call_obj = type(
Expand Down Expand Up @@ -537,7 +537,7 @@ def _send_query(self, messages: list, claude_params: dict) -> Any:

def _observable_query(
self,
observer: Observer,
observer: Observer, # type: ignore[name-defined]
base64_image: str | None = None,
dimensions: tuple[int, int] | None = None,
override_token_limit: bool = False,
Expand Down Expand Up @@ -605,7 +605,7 @@ def _observable_query(

# Handle tool calls if present
if response_message.tool_calls:
self._handle_tooling(response_message, messages)
self._handle_tooling(response_message, messages) # type: ignore[no-untyped-call]

# At the end, append only new messages (including tool-use/results) to the global conversation history under a lock
import threading
Expand Down Expand Up @@ -633,7 +633,7 @@ def _observable_query(
self.response_subject.on_next(error_message)
observer.on_completed()

def _handle_tooling(self, response_message, messages):
def _handle_tooling(self, response_message, messages): # type: ignore[no-untyped-def]
"""Executes tools and appends tool-use/result blocks to messages."""
if not hasattr(response_message, "tool_calls") or not response_message.tool_calls:
logger.info("No tool calls found in response message")
Expand All @@ -658,7 +658,7 @@ def _handle_tooling(self, response_message, messages):
try:
# Execute the tool
args = json.loads(tool_call.function.arguments)
tool_result = self.skills.call(tool_call.function.name, **args)
tool_result = self.skills.call(tool_call.function.name, **args) # type: ignore[union-attr]

# Check if the result is an error message
if isinstance(tool_result, str) and (
Expand Down Expand Up @@ -698,7 +698,7 @@ def _handle_tooling(self, response_message, messages):
}
)

def _tooling_callback(self, response_message) -> None:
def _tooling_callback(self, response_message) -> None: # type: ignore[no-untyped-def]
"""Runs the observable query for each tool call in the current response_message"""
if not hasattr(response_message, "tool_calls") or not response_message.tool_calls:
return
Expand All @@ -716,7 +716,7 @@ def _tooling_callback(self, response_message) -> None:
# Continue processing even if the callback fails
pass

def _debug_api_call(self, claude_params: dict):
def _debug_api_call(self, claude_params: dict): # type: ignore[no-untyped-def, type-arg]
"""Debugging function to log API calls with truncated base64 data."""
# Remove tools to reduce verbosity
import copy
Expand Down
20 changes: 10 additions & 10 deletions dimos/agents/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


class AbstractAgentSemanticMemory: # AbstractAgentMemory):
def __init__(self, connection_type: str = "local", **kwargs) -> None:
def __init__(self, connection_type: str = "local", **kwargs) -> None: # type: ignore[no-untyped-def]
"""
Initialize with dynamic connection parameters.
Args:
Expand All @@ -53,26 +53,26 @@ def __init__(self, connection_type: str = "local", **kwargs) -> None:

try:
if connection_type == "remote":
self.connect()
self.connect() # type: ignore[no-untyped-call]
elif connection_type == "local":
self.create()
self.create() # type: ignore[no-untyped-call]
except Exception as e:
self.logger.error("Failed to initialize database connection: %s", str(e), exc_info=True)
raise AgentMemoryConnectionError(
"Initialization failed due to an unexpected error.", cause=e
) from e

@abstractmethod
def connect(self):
def connect(self): # type: ignore[no-untyped-def]
"""Establish a connection to the data store using dynamic parameters specified during initialization."""

@abstractmethod
def create(self):
def create(self): # type: ignore[no-untyped-def]
"""Create a local instance of the data store tailored to specific requirements."""

## Create ##
@abstractmethod
def add_vector(self, vector_id, vector_data):
def add_vector(self, vector_id, vector_data): # type: ignore[no-untyped-def]
"""Add a vector to the database.
Args:
vector_id (any): Unique identifier for the vector.
Expand All @@ -81,14 +81,14 @@ def add_vector(self, vector_id, vector_data):

## Read ##
@abstractmethod
def get_vector(self, vector_id):
def get_vector(self, vector_id): # type: ignore[no-untyped-def]
"""Retrieve a vector from the database by its identifier.
Args:
vector_id (any): The identifier of the vector to retrieve.
"""

@abstractmethod
def query(self, query_texts, n_results: int = 4, similarity_threshold=None):
def query(self, query_texts, n_results: int = 4, similarity_threshold=None): # type: ignore[no-untyped-def]
"""Performs a semantic search in the vector database.

Args:
Expand All @@ -109,7 +109,7 @@ def query(self, query_texts, n_results: int = 4, similarity_threshold=None):

## Update ##
@abstractmethod
def update_vector(self, vector_id, new_vector_data):
def update_vector(self, vector_id, new_vector_data): # type: ignore[no-untyped-def]
"""Update an existing vector in the database.
Args:
vector_id (any): The identifier of the vector to update.
Expand All @@ -118,7 +118,7 @@ def update_vector(self, vector_id, new_vector_data):

## Delete ##
@abstractmethod
def delete_vector(self, vector_id):
def delete_vector(self, vector_id): # type: ignore[no-untyped-def]
"""Delete a vector from the database using its identifier.
Args:
vector_id (any): The identifier of the vector to delete.
Expand Down
40 changes: 20 additions & 20 deletions dimos/agents/memory/chroma_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ def __init__(self, collection_name: str = "my_collection") -> None:
self.embeddings = None
super().__init__(connection_type="local")

def connect(self):
def connect(self): # type: ignore[no-untyped-def]
# Stub
return super().connect()
return super().connect() # type: ignore[no-untyped-call, safe-super]

def create(self):
def create(self): # type: ignore[no-untyped-def]
"""Create the embedding function and initialize the Chroma database.
This method must be implemented by child classes."""
raise NotImplementedError("Child classes must implement this method")

def add_vector(self, vector_id, vector_data):
def add_vector(self, vector_id, vector_data): # type: ignore[no-untyped-def]
"""Add a vector to the ChromaDB collection."""
if not self.db_connection:
raise Exception("Collection not initialized. Call connect() first.")
Expand All @@ -51,12 +51,12 @@ def add_vector(self, vector_id, vector_data):
metadatas=[{"name": vector_id}],
)

def get_vector(self, vector_id):
def get_vector(self, vector_id): # type: ignore[no-untyped-def]
"""Retrieve a vector from the ChromaDB by its identifier."""
result = self.db_connection.get(include=["embeddings"], ids=[vector_id])
result = self.db_connection.get(include=["embeddings"], ids=[vector_id]) # type: ignore[attr-defined]
return result

def query(self, query_texts, n_results: int = 4, similarity_threshold=None):
def query(self, query_texts, n_results: int = 4, similarity_threshold=None): # type: ignore[no-untyped-def]
"""Query the collection with a specific text and return up to n results."""
if not self.db_connection:
raise Exception("Collection not initialized. Call connect() first.")
Expand All @@ -71,11 +71,11 @@ def query(self, query_texts, n_results: int = 4, similarity_threshold=None):
documents = self.db_connection.similarity_search(query=query_texts, k=n_results)
return [(doc, None) for doc in documents]

def update_vector(self, vector_id, new_vector_data):
def update_vector(self, vector_id, new_vector_data): # type: ignore[no-untyped-def]
# TODO
return super().connect()
return super().connect() # type: ignore[no-untyped-call, safe-super]

def delete_vector(self, vector_id):
def delete_vector(self, vector_id): # type: ignore[no-untyped-def]
"""Delete a vector from the ChromaDB using its identifier."""
if not self.db_connection:
raise Exception("Collection not initialized. Call connect() first.")
Expand All @@ -102,22 +102,22 @@ def __init__(
self.dimensions = dimensions
super().__init__(collection_name=collection_name)

def create(self):
def create(self): # type: ignore[no-untyped-def]
"""Connect to OpenAI API and create the ChromaDB client."""
# Get OpenAI key
self.OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
if not self.OPENAI_API_KEY:
raise Exception("OpenAI key was not specified.")

# Set embeddings
self.embeddings = OpenAIEmbeddings(
self.embeddings = OpenAIEmbeddings( # type: ignore[assignment]
model=self.model,
dimensions=self.dimensions,
api_key=self.OPENAI_API_KEY,
api_key=self.OPENAI_API_KEY, # type: ignore[arg-type]
)

# Create the database
self.db_connection = Chroma(
self.db_connection = Chroma( # type: ignore[assignment]
collection_name=self.collection_name,
embedding_function=self.embeddings,
collection_metadata={"hnsw:space": "cosine"},
Expand Down Expand Up @@ -148,26 +148,26 @@ def create(self) -> None:
# Use CUDA if available, otherwise fall back to CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
self.model = SentenceTransformer(self.model_name, device=device)
self.model = SentenceTransformer(self.model_name, device=device) # type: ignore[name-defined]

# Create a custom embedding class that implements the embed_query method
class SentenceTransformerEmbeddings:
def __init__(self, model) -> None:
def __init__(self, model) -> None: # type: ignore[no-untyped-def]
self.model = model

def embed_query(self, text: str):
def embed_query(self, text: str): # type: ignore[no-untyped-def]
"""Embed a single query text."""
return self.model.encode(text, normalize_embeddings=True).tolist()

def embed_documents(self, texts: Sequence[str]):
def embed_documents(self, texts: Sequence[str]): # type: ignore[no-untyped-def]
"""Embed multiple documents/texts."""
return self.model.encode(texts, normalize_embeddings=True).tolist()

# Create an instance of our custom embeddings class
self.embeddings = SentenceTransformerEmbeddings(self.model)
self.embeddings = SentenceTransformerEmbeddings(self.model) # type: ignore[assignment]

# Create the database
self.db_connection = Chroma(
self.db_connection = Chroma( # type: ignore[assignment]
collection_name=self.collection_name,
embedding_function=self.embeddings,
collection_metadata={"hnsw:space": "cosine"},
Expand Down
Loading
Loading