Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion api/core/agent/cot_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
Expand All @@ -213,6 +212,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response

# detect direct return
direct_flag = (tool_invoke_meta.extra or {}).get("return_direct", False)

self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
Expand All @@ -229,6 +231,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)

if direct_flag:
final_answer = str(tool_invoke_response or "")
# keep function_call_state as False to end iterations
else:
function_call_state = True

# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
Expand Down
73 changes: 60 additions & 13 deletions api/core/agent/fc_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU
}

tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
# check direct return flag
direct_flag = (tool_invoke_meta.extra or {}).get("return_direct", False)

if tool_response["tool_response"] is not None and not direct_flag:
self._current_thoughts.append(
ToolPromptMessage(
content=str(tool_response["tool_response"]),
Expand All @@ -274,6 +277,28 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU
)
)

if direct_flag:
# save agent thought for this tool call
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=tool_call_name,
tool_input=tool_call_args,
thought=llm_result.message.content or "",
tool_invoke_meta={tool_call_name: tool_invoke_meta.to_dict()},
observation={tool_call_name: tool_invoke_response},
answer=str(tool_invoke_response or ""),
messages_ids=message_file_ids,
)
Comment thread
Cursx marked this conversation as resolved.
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)

# publish end event immediately and return
final_answer = str(tool_invoke_response or "")
llm_final_usage = llm_usage.get("usage") or LLMUsage.empty_usage()
yield from self._yield_final_answer(prompt_messages, final_answer, llm_final_usage)
return
Comment on lines +280 to +300
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if direct_flag: block is quite large and contains complex logic for handling the direct return. This makes the for loop harder to read and understand. To improve readability and maintainability, I suggest extracting this block of code into a separate private helper method, for example, _handle_direct_return(...). This helper method could encapsulate the logic for saving the agent thought, publishing events, and yielding the final result chunk.

Comment on lines +280 to +300
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block introduces a large amount of duplicated logic for finalizing the agent run. The logic for publishing the QueueMessageEndEvent and yielding the final LLMResultChunk is very similar to what should happen at the end of the run method.

Currently, the original code path at the end of the method (lines 352-363) seems to be missing a yield for the final answer, which is a bug. This new direct_flag path correctly yields the result.

To improve maintainability and fix the bug in the original path, I recommend extracting this finalization logic into a private helper method. This helper could then be called here, and also at the end of the run method to ensure consistent behavior and a single source of truth for ending the agent execution.

A refactoring would look something like this:

  1. Create a private helper method _yield_final_answer(...) that contains the logic from lines 296-321.
  2. Call it here: yield from self._yield_final_answer(...) and then return.
  3. Replace lines 352-363 with a call to the same helper method: yield from self._yield_final_answer(...).

This would fix the bug and remove the code duplication.


if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
Expand Down Expand Up @@ -301,18 +326,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU

iteration_step += 1

# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
yield from self._yield_final_answer(
prompt_messages,
final_answer,
llm_usage["usage"] or LLMUsage.empty_usage(),
)

def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
Expand Down Expand Up @@ -377,6 +394,36 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str,

return tool_calls

def _yield_final_answer(
self,
prompt_messages: list,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better type safety and code clarity, it's recommended to use a more specific type hint for prompt_messages. Instead of list, please use list[PromptMessage].

Suggested change
prompt_messages: list,
prompt_messages: list[PromptMessage],

Comment thread
Cursx marked this conversation as resolved.
final_answer: str,
usage: LLMUsage,
) -> Generator[LLMResultChunk, None, None]:
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)

yield LLMResultChunk(
model=self.model_instance.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
),
)

def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
Expand Down
2 changes: 2 additions & 0 deletions api/core/tools/entities/tool_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ class ToolInvokeMeta(BaseModel):
time_cost: float = Field(..., description="The time cost of the tool invoke")
error: str | None = None
tool_config: dict | None = None
extra: dict | None = None

@classmethod
def empty(cls) -> "ToolInvokeMeta":
Expand All @@ -447,6 +448,7 @@ def to_dict(self):
"time_cost": self.time_cost,
"error": self.error,
"tool_config": self.tool_config,
"extra": self.extra,
}


Expand Down
13 changes: 13 additions & 0 deletions api/core/tools/tool_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,19 @@ def message_callback(
tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id
)

# detect return_direct signal from variable messages (short-circuit)
return_direct = False
for m in message_list:
if m.type == ToolInvokeMessage.MessageType.VARIABLE:
variable = cast(ToolInvokeMessage.VariableMessage, m.message)
if variable.variable_name == "return_direct" and bool(variable.variable_value):
return_direct = True
break

plain_text = ToolEngine._convert_tool_response_to_str(message_list)

meta = invocation_meta_dict["meta"]
meta.extra = {"return_direct": return_direct}

# hit the callback handler
agent_tool_callback.on_tool_end(
Expand Down Expand Up @@ -254,6 +264,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str
ensure_ascii=False,
)
)
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# internal variable messages should not be surfaced into plain text
continue
else:
parts.append(str(response.message))

Expand Down