Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 76 additions & 21 deletions dimos/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,9 @@ def __init__(
self.image_detail: str = "low"
self.max_input_tokens_per_request: int = max_input_tokens_per_request
self.max_output_tokens_per_request: int = max_output_tokens_per_request
self.max_tokens_per_request: int = self.max_input_tokens_per_request + self.max_output_tokens_per_request
self.max_tokens_per_request: int = (
self.max_input_tokens_per_request + self.max_output_tokens_per_request
)
self.rag_query_n: int = 4
self.rag_similarity_threshold: float = 0.45
self.frame_processor: Optional[FrameProcessor] = None
Expand All @@ -200,10 +202,14 @@ def __init__(
RxOps.map(
lambda combined: {
"query": combined[0],
"objects": combined[1] if len(combined) > 1 else "No object data available",
"objects": combined[1]
if len(combined) > 1
else "No object data available",
}
),
RxOps.map(lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"),
RxOps.map(
lambda data: f"{data['query']}\n\nCurrent objects detected:\n{data['objects']}"
),
RxOps.do_action(
lambda x: print(f"\033[34mEnriched query: {x.split(chr(10))[0]}\033[0m")
or [print(f"\033[34m{line}\033[0m") for line in x.split(chr(10))[1:]]
Expand All @@ -222,7 +228,9 @@ def __init__(
# Define a query extractor for the merged stream
query_extractor = lambda emission: (emission[0], emission[1][0])
self.disposables.add(
self.subscribe_to_image_processing(self.merged_stream, query_extractor=query_extractor)
self.subscribe_to_image_processing(
self.merged_stream, query_extractor=query_extractor
)
)
else:
# If no merged stream, fall back to individual streams
Expand Down Expand Up @@ -250,7 +258,9 @@ def _get_rag_context(self) -> Tuple[str, str]:
and condensed results (for use in the prompt).
"""
results = self.agent_memory.query(
query_texts=self.query, n_results=self.rag_query_n, similarity_threshold=self.rag_similarity_threshold
query_texts=self.query,
n_results=self.rag_query_n,
similarity_threshold=self.rag_similarity_threshold,
)
formatted_results = "\n".join(
f"Document ID: {doc.id}\nMetadata: {doc.metadata}\nContent: {doc.page_content}\nScore: {score}\n"
Expand Down Expand Up @@ -334,7 +344,12 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
result = skill_library.call(name, **args)
logger.info(f"Function Call Results: {result}")
new_messages.append(
{"role": "tool", "tool_call_id": tool_call.id, "content": str(result), "name": name}
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": str(result),
"name": name,
}
)
if has_called_tools:
logger.info("Sending Another Query.")
Expand All @@ -347,7 +362,9 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL
return None

if response_message.tool_calls is not None:
return _tooling_callback(response_message, messages, response_message, self.skill_library)
return _tooling_callback(
response_message, messages, response_message, self.skill_library
)
return None

def _observable_query(
Expand All @@ -373,7 +390,9 @@ def _observable_query(
try:
self._update_query(incoming_query)
_, condensed_results = self._get_rag_context()
messages = self._build_prompt(base64_image, dimensions, override_token_limit, condensed_results)
messages = self._build_prompt(
base64_image, dimensions, override_token_limit, condensed_results
)
# logger.debug(f"Sending Query: {messages}")
logger.info("Sending Query.")
response_message = self._send_query(messages)
Expand All @@ -391,13 +410,19 @@ def _observable_query(
final_msg = (
response_message.parsed
if hasattr(response_message, "parsed") and response_message.parsed
else (response_message.content if hasattr(response_message, "content") else response_message)
else (
response_message.content
if hasattr(response_message, "content")
else response_message
)
)
observer.on_next(final_msg)
self.response_subject.on_next(final_msg)
else:
response_message_2 = self._handle_tooling(response_message, messages)
final_msg = response_message_2 if response_message_2 is not None else response_message
final_msg = (
response_message_2 if response_message_2 is not None else response_message
)
if isinstance(final_msg, BaseModel): # TODO: Test
final_msg = str(final_msg.content)
observer.on_next(final_msg)
Expand Down Expand Up @@ -440,7 +465,9 @@ def _log_response_to_file(self, response, output_dir: str = None):
file.write(f"{self.dev_name}: {response}\n")
logger.info(f"LLM Response [{self.dev_name}]: {response}")

def subscribe_to_image_processing(self, frame_observable: Observable, query_extractor=None) -> Disposable:
def subscribe_to_image_processing(
self, frame_observable: Observable, query_extractor=None
) -> Disposable:
"""Subscribes to a stream of video frames for processing.

This method sets up a subscription to process incoming video frames.
Expand Down Expand Up @@ -480,7 +507,9 @@ def _process_frame(emission) -> Observable:
RxOps.subscribe_on(self.pool_scheduler),
MyOps.print_emission(id="D", **print_emission_args),
MyVidOps.with_jpeg_export(
self.frame_processor, suffix=f"{self.dev_name}_frame_", save_limit=_MAX_SAVED_FRAMES
self.frame_processor,
suffix=f"{self.dev_name}_frame_",
save_limit=_MAX_SAVED_FRAMES,
),
MyOps.print_emission(id="E", **print_emission_args),
MyVidOps.encode_image(),
Expand Down Expand Up @@ -562,7 +591,9 @@ def _process_query(query) -> Observable:
return just(query).pipe(
MyOps.print_emission(id="Pr A", **print_emission_args),
RxOps.flat_map(
lambda query: create(lambda observer, _: self._observable_query(observer, incoming_query=query))
lambda query: create(
lambda observer, _: self._observable_query(observer, incoming_query=query)
)
),
MyOps.print_emission(id="Pr B", **print_emission_args),
)
Expand Down Expand Up @@ -612,7 +643,9 @@ def get_response_observable(self) -> Observable:
Observable: An observable that emits string responses from the agent.
"""
return self.response_subject.pipe(
RxOps.observe_on(self.pool_scheduler), RxOps.subscribe_on(self.pool_scheduler), RxOps.share()
RxOps.observe_on(self.pool_scheduler),
RxOps.subscribe_on(self.pool_scheduler),
RxOps.share(),
)

def run_observable_query(self, query_text: str, **kwargs) -> Observable:
Expand All @@ -631,7 +664,11 @@ def run_observable_query(self, query_text: str, **kwargs) -> Observable:
Returns:
Observable: An observable that emits the response as a string.
"""
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text, **kwargs))
return create(
lambda observer, _: self._observable_query(
observer, incoming_query=query_text, **kwargs
)
)

def dispose_all(self):
"""Disposes of all active subscriptions managed by this agent."""
Expand Down Expand Up @@ -749,7 +786,9 @@ def __init__(
self.response_model = response_model if response_model is not None else NOT_GIVEN
self.model_name = model_name
self.tokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name)
self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
self.prompt_builder = prompt_builder or PromptBuilder(
self.model_name, tokenizer=self.tokenizer
)
self.rag_query_n = rag_query_n
self.rag_similarity_threshold = rag_similarity_threshold
self.image_detail = image_detail
Expand All @@ -767,8 +806,14 @@ def __init__(
def _add_context_to_memory(self):
"""Adds initial context to the agent's memory."""
context_data = [
("id0", "Optical Flow is a technique used to track the movement of objects in a video sequence."),
("id1", "Edge Detection is a technique used to identify the boundaries of objects in an image."),
(
"id0",
"Optical Flow is a technique used to track the movement of objects in a video sequence.",
),
(
"id1",
"Edge Detection is a technique used to identify the boundaries of objects in an image.",
),
("id2", "Video is a sequence of frames captured at regular intervals."),
(
"id3",
Expand Down Expand Up @@ -805,15 +850,23 @@ def _send_query(self, messages: list) -> Any:
model=self.model_name,
messages=messages,
response_format=self.response_model,
tools=(self.skill_library.get_tools() if self.skill_library is not None else NOT_GIVEN),
tools=(
self.skill_library.get_tools()
if self.skill_library is not None
else NOT_GIVEN
),
max_tokens=self.max_output_tokens_per_request,
)
else:
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=self.max_output_tokens_per_request,
tools=(self.skill_library.get_tools() if self.skill_library is not None else NOT_GIVEN),
tools=(
self.skill_library.get_tools()
if self.skill_library is not None
else NOT_GIVEN
),
)
response_message = response.choices[0].message
if response_message is None:
Expand Down Expand Up @@ -843,7 +896,9 @@ def stream_query(self, query_text: str) -> Observable:
Returns:
Observable: An observable that emits the response as a string.
"""
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
return create(
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
)


# endregion OpenAIAgent Subclass (OpenAI-Specific Implementation)
12 changes: 9 additions & 3 deletions dimos/agents/agent_ctransformers_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ def __init__(

self.tokenizer = CTransformersTokenizerAdapter(self.model)

self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
self.prompt_builder = prompt_builder or PromptBuilder(
self.model_name, tokenizer=self.tokenizer
)

self.max_output_tokens_per_request = max_output_tokens_per_request

Expand All @@ -152,7 +154,9 @@ def __init__(

# Ensure only one input stream is provided.
if self.input_video_stream is not None and self.input_query_stream is not None:
raise ValueError("More than one input stream provided. Please provide only one input stream.")
raise ValueError(
"More than one input stream provided. Please provide only one input stream."
)

if self.input_video_stream is not None:
logger.info("Subscribing to input video stream...")
Expand Down Expand Up @@ -198,7 +202,9 @@ def stream_query(self, query_text: str) -> Subject:
"""
Creates an observable that processes a text query and emits the response.
"""
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
return create(
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
)


# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)
42 changes: 32 additions & 10 deletions dimos/agents/agent_huggingface_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,14 @@ def __init__(

self.tokenizer = tokenizer or HuggingFaceTokenizer(self.model_name)

self.prompt_builder = prompt_builder or PromptBuilder(self.model_name, tokenizer=self.tokenizer)
self.prompt_builder = prompt_builder or PromptBuilder(
self.model_name, tokenizer=self.tokenizer
)

self.model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map=self.device
model_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map=self.device,
)

self.max_output_tokens_per_request = max_output_tokens_per_request
Expand All @@ -113,7 +117,9 @@ def __init__(

# Ensure only one input stream is provided.
if self.input_video_stream is not None and self.input_query_stream is not None:
raise ValueError("More than one input stream provided. Please provide only one input stream.")
raise ValueError(
"More than one input stream provided. Please provide only one input stream."
)

if self.input_video_stream is not None:
logger.info("Subscribing to input video stream...")
Expand Down Expand Up @@ -142,21 +148,28 @@ def _send_query(self, messages: list) -> Any:

# Tokenize the prompt
print("Preparing model inputs...")
model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to(self.model.device)
model_inputs = self.tokenizer.tokenizer([prompt_text], return_tensors="pt").to(
self.model.device
)
print("Model inputs prepared.")

# Generate the response
print("Generating response...")
generated_ids = self.model.generate(**model_inputs, max_new_tokens=self.max_output_tokens_per_request)
generated_ids = self.model.generate(
**model_inputs, max_new_tokens=self.max_output_tokens_per_request
)

# Extract the generated tokens (excluding the input prompt tokens)
print("Processing generated output...")
generated_ids = [
output_ids[len(input_ids) :] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]

# Convert tokens back to text
response = self.tokenizer.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = self.tokenizer.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True
)[0]
print("Response successfully generated.")

return response
Expand All @@ -168,14 +181,21 @@ def _send_query(self, messages: list) -> Any:

except Exception as e:
# Log any other errors but continue execution
logger.warning(f"Error in chat template processing: {e}. Falling back to simple format.")
logger.warning(
f"Error in chat template processing: {e}. Falling back to simple format."
)

# Fallback approach for models without chat template support
# This code runs if the try block above raises an exception
print("Using simple prompt format...")

# Convert messages to a simple text format
if isinstance(messages, list) and messages and isinstance(messages[0], dict) and "content" in messages[0]:
if (
isinstance(messages, list)
and messages
and isinstance(messages[0], dict)
and "content" in messages[0]
):
prompt_text = messages[0]["content"]
else:
prompt_text = str(messages)
Expand Down Expand Up @@ -207,7 +227,9 @@ def stream_query(self, query_text: str) -> Subject:
"""
Creates an observable that processes a text query and emits the response.
"""
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
return create(
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
)


# endregion HuggingFaceLLMAgent Subclass (HuggingFace-Specific Implementation)
8 changes: 6 additions & 2 deletions dimos/agents/agent_huggingface_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ def __init__(

# Ensure only one input stream is provided.
if self.input_video_stream is not None and self.input_query_stream is not None:
raise ValueError("More than one input stream provided. Please provide only one input stream.")
raise ValueError(
"More than one input stream provided. Please provide only one input stream."
)

if self.input_video_stream is not None:
logger.info("Subscribing to input video stream...")
Expand All @@ -136,4 +138,6 @@ def stream_query(self, query_text: str) -> Subject:
"""
Creates an observable that processes a text query and emits the response.
"""
return create(lambda observer, _: self._observable_query(observer, incoming_query=query_text))
return create(
lambda observer, _: self._observable_query(observer, incoming_query=query_text)
)
Loading