diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 25ad6dc06017db..eea40821282c28 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -202,7 +202,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU except TypeError: final_answer = f"{scratchpad.action.action_input}" else: - function_call_state = True # action is tool call, invoke tool tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( action=scratchpad.action, @@ -213,6 +212,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU scratchpad.observation = tool_invoke_response scratchpad.agent_response = tool_invoke_response + # detect direct return + direct_flag = (tool_invoke_meta.extra or {}).get("return_direct", False) + self.save_agent_thought( agent_thought_id=agent_thought_id, tool_name=scratchpad.action.action_name, @@ -229,6 +231,12 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) + if direct_flag: + final_answer = str(tool_invoke_response or "") + # keep function_call_state as False to end iterations + else: + function_call_state = True + # update prompt tool message for prompt_tool in self._prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index dcc1326b33bc34..fe24a27491d39e 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -265,7 +265,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU } tool_responses.append(tool_response) - if tool_response["tool_response"] is not None: + # check direct return flag + direct_flag = (tool_invoke_meta.extra or {}).get("return_direct", False) + + if tool_response["tool_response"] is not None and not direct_flag: self._current_thoughts.append( ToolPromptMessage( content=str(tool_response["tool_response"]), @@ -274,6 +277,28 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU ) ) + if direct_flag: + # save agent thought for this tool call + self.save_agent_thought( + agent_thought_id=agent_thought_id, + tool_name=tool_call_name, + tool_input=tool_call_args, + thought=llm_result.message.content or "", + tool_invoke_meta={tool_call_name: tool_invoke_meta.to_dict()}, + observation={tool_call_name: tool_invoke_response}, + answer=str(tool_invoke_response or ""), + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER + ) + + # publish end event immediately and return + final_answer = str(tool_invoke_response or "") + llm_final_usage = llm_usage.get("usage") or LLMUsage.empty_usage() + yield from self._yield_final_answer(prompt_messages, final_answer, llm_final_usage) + return + if len(tool_responses) > 0: # save agent thought self.save_agent_thought( @@ -301,18 +326,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU iteration_step += 1 - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, + yield from self._yield_final_answer( + prompt_messages, + final_answer, + llm_usage["usage"] or LLMUsage.empty_usage(), ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: @@ -377,6 +394,36 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, return tool_calls + def _yield_final_answer( + self, + prompt_messages: list, + final_answer: str, + usage: LLMUsage, + ) -> Generator[LLMResultChunk, None, None]: + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + yield LLMResultChunk( + model=self.model_instance.model, + prompt_messages=prompt_messages, + system_fingerprint="", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + ), + ) + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Initialize system message diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 353f3a646a9542..23e6c0c5e11140 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -427,6 +427,7 @@ class ToolInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None tool_config: dict | None = None + extra: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -447,6 +448,7 @@ def to_dict(self): "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, + "extra": self.extra, } diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 13fd579e20aa01..0c49ffda2e8440 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -107,9 +107,19 @@ def message_callback( tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) + # detect return_direct signal from variable messages (short-circuit) + return_direct = False + for m in message_list: + if m.type == ToolInvokeMessage.MessageType.VARIABLE: + variable = cast(ToolInvokeMessage.VariableMessage, m.message) + if variable.variable_name == "return_direct" and bool(variable.variable_value): + return_direct = True + break + plain_text = ToolEngine._convert_tool_response_to_str(message_list) meta = invocation_meta_dict["meta"] + meta.extra = {"return_direct": return_direct} # hit the callback handler agent_tool_callback.on_tool_end( @@ -254,6 +264,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str ensure_ascii=False, ) ) + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # internal variable messages should not be surfaced into plain text + continue else: parts.append(str(response.message))