diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 25ad6dc06017db..caf61c4af83ec7 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -74,6 +74,7 @@ def run( final_answer = "" prompt_messages: list = [] # Initialize prompt_messages agent_thought_id = "" # Initialize agent_thought_id + is_final_answer_from_tool = False def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage): if not final_llm_usage_dict["usage"]: @@ -202,7 +203,6 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU except TypeError: final_answer = f"{scratchpad.action.action_input}" else: - function_call_state = True # action is tool call, invoke tool tool_invoke_response, tool_invoke_meta = self._handle_invoke_action( action=scratchpad.action, @@ -213,6 +213,9 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU scratchpad.observation = tool_invoke_response scratchpad.agent_response = tool_invoke_response + # detect direct return + direct_flag = (tool_invoke_meta.extra or {}).get("return_direct", False) + self.save_agent_thought( agent_thought_id=agent_thought_id, tool_name=scratchpad.action.action_name, @@ -229,40 +232,69 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER ) + if direct_flag: + final_answer = str(tool_invoke_response or "") + is_final_answer_from_tool = True + + yield from self._yield_final_answer( + prompt_messages=prompt_messages, + final_answer=final_answer, + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + ) + return + else: + function_call_state = True + # update prompt tool message for prompt_tool in self._prompt_messages_tools: self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool) iteration_step += 1 + yield from self._yield_final_answer( + prompt_messages=prompt_messages, + final_answer=final_answer, + usage=llm_usage["usage"] or LLMUsage.empty_usage(), + ) + + # save agent thought only when final answer is NOT directly from tool + if not is_final_answer_from_tool: + self.save_agent_thought( + agent_thought_id=agent_thought_id, + tool_name="", + tool_input={}, + tool_invoke_meta={}, + thought=final_answer, + observation={}, + answer=final_answer, + messages_ids=[], + ) + + def _yield_final_answer( + self, + prompt_messages: list, + final_answer: str, + usage: LLMUsage | None, + ) -> Generator[LLMResultChunk, None, None]: + """Yields the final answer chunk and publishes the end event.""" yield LLMResultChunk( - model=model_instance.model, + model=self.model_instance.model, prompt_messages=prompt_messages, delta=LLMResultChunkDelta( - index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"] + index=0, + message=AssistantPromptMessage(content=final_answer), + usage=usage, ), system_fingerprint="", ) - # save agent thought - self.save_agent_thought( - agent_thought_id=agent_thought_id, - tool_name="", - tool_input={}, - tool_invoke_meta={}, - thought=final_answer, - observation={}, - answer=final_answer, - messages_ids=[], - ) - # publish end event self.queue_manager.publish( QueueMessageEndEvent( llm_result=LLMResult( - model=model_instance.model, + model=self.model_instance.model, prompt_messages=prompt_messages, message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), + usage=usage or LLMUsage.empty_usage(), system_fingerprint="", ) ), diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index dcc1326b33bc34..a3a7a030dd06df 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -227,12 +227,17 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU for tool_call_id, tool_call_name, tool_call_args in tool_calls: tool_instance = tool_instances.get(tool_call_name) if not tool_instance: - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": f"there is not a tool named {tool_call_name}", - "meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(), - } + error_message = f"there is not a tool named {tool_call_name}" + tool_invoke_meta = ToolInvokeMeta.error_instance(error_message) + tool_response = self._create_tool_response( + tool_call_id, + tool_call_name, + tool_call_args, + error_message, + tool_invoke_meta, + False, + ) + tool_responses.append(tool_response) else: # invoke tool tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke( @@ -257,24 +262,40 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU # add message file ids message_file_ids.append(message_file_id) - tool_response = { - "tool_call_id": tool_call_id, - "tool_call_name": tool_call_name, - "tool_response": tool_invoke_response, - "meta": tool_invoke_meta.to_dict(), - } - - tool_responses.append(tool_response) - if tool_response["tool_response"] is not None: - self._current_thoughts.append( - ToolPromptMessage( - content=str(tool_response["tool_response"]), - tool_call_id=tool_call_id, - name=tool_call_name, - ) + direct_flag = bool((tool_invoke_meta.extra or {}).get("return_direct", False)) + tool_response = self._create_tool_response( + tool_call_id, + tool_call_name, + tool_call_args, + tool_invoke_response, + tool_invoke_meta, + direct_flag, ) + tool_responses.append(tool_response) if len(tool_responses) > 0: + all_direct = all(tr.get("direct_flag") is True for tr in tool_responses) + if all_direct: + llm_final_usage = llm_usage.get("usage") or LLMUsage.empty_usage() + yield from self._handle_direct_return( + agent_thought_id, + tool_responses, + response or "", + message_file_ids, + prompt_messages, + llm_final_usage, + ) + return + + for tr in tool_responses: + if tr["tool_response"] is not None: + self._current_thoughts.append( + ToolPromptMessage( + content=str(tr["tool_response"]), + tool_call_id=tr["tool_call_id"], + name=tr["tool_call_name"], + ) + ) # save agent thought self.save_agent_thought( agent_thought_id=agent_thought_id, @@ -301,18 +322,10 @@ def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMU iteration_step += 1 - # publish end event - self.queue_manager.publish( - QueueMessageEndEvent( - llm_result=LLMResult( - model=model_instance.model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage(content=final_answer), - usage=llm_usage["usage"] or LLMUsage.empty_usage(), - system_fingerprint="", - ) - ), - PublishFrom.APPLICATION_MANAGER, + yield from self._yield_final_answer( + prompt_messages, + final_answer, + llm_usage.get("usage") or LLMUsage.empty_usage(), ) def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool: @@ -377,6 +390,100 @@ def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, return tool_calls + def _yield_final_answer( + self, + prompt_messages: list[PromptMessage], + final_answer: str, + usage: LLMUsage, + ) -> Generator[LLMResultChunk, None, None]: + yield LLMResultChunk( + model=self.model_instance.model, + prompt_messages=prompt_messages, + system_fingerprint="", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + ), + ) + + self.queue_manager.publish( + QueueMessageEndEvent( + llm_result=LLMResult( + model=self.model_instance.model, + prompt_messages=prompt_messages, + message=AssistantPromptMessage(content=final_answer), + usage=usage, + system_fingerprint="", + ) + ), + PublishFrom.APPLICATION_MANAGER, + ) + + def _create_tool_response( + self, + tool_call_id: str, + tool_call_name: str, + tool_call_args: dict[str, Any], + tool_invoke_response: str | None, + tool_invoke_meta: ToolInvokeMeta, + direct_flag: bool, + ) -> dict[str, Any]: + return { + "tool_call_id": tool_call_id, + "tool_call_name": tool_call_name, + "tool_response": tool_invoke_response, + "tool_call_args": tool_call_args, + "meta": tool_invoke_meta.to_dict(), + "direct_flag": direct_flag, + } + + @staticmethod + def _flatten(agg_dict: dict[str, list[Any]]) -> dict[str, Any]: + """Flattens a dictionary of lists, keeping single-item lists as values.""" + return {k: (v[0] if len(v) == 1 else v) for k, v in agg_dict.items()} + + def _handle_direct_return( + self, + agent_thought_id: str, + tool_responses: list[dict[str, Any]], + thought: str, + message_file_ids: list[str], + prompt_messages: list[PromptMessage], + usage: LLMUsage, + ) -> Generator[LLMResultChunk, None, None]: + final_answer_parts = [] + tool_invoke_meta_agg: dict[str, list[Any]] = {} + observation_agg: dict[str, list[Any]] = {} + tool_input_agg: dict[str, list[Any]] = {} + + for tr in tool_responses: + if tr.get("tool_response") is not None: + final_answer_parts.append(str(tr["tool_response"])) + tool_invoke_meta_agg.setdefault(tr["tool_call_name"], []).append(tr["meta"]) + observation_agg.setdefault(tr["tool_call_name"], []).append(tr["tool_response"]) + tool_input_agg.setdefault(tr["tool_call_name"], []).append(tr.get("tool_call_args", {})) + + final_answer = "\n".join(final_answer_parts) + tool_invoke_meta = self._flatten(tool_invoke_meta_agg) + observation = self._flatten(observation_agg) + tool_input = self._flatten(tool_input_agg) + tool_name = ";".join(sorted({tr["tool_call_name"] for tr in tool_responses})) + self.save_agent_thought( + agent_thought_id=agent_thought_id, + tool_name=tool_name, + tool_input=tool_input, + thought=thought, + tool_invoke_meta=tool_invoke_meta, + observation=observation, + answer=final_answer, + messages_ids=message_file_ids, + ) + self.queue_manager.publish( + QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER + ) + yield from self._yield_final_answer(prompt_messages, final_answer, usage) + def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]: """ Initialize system message diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 353f3a646a9542..23e6c0c5e11140 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -427,6 +427,7 @@ class ToolInvokeMeta(BaseModel): time_cost: float = Field(..., description="The time cost of the tool invoke") error: str | None = None tool_config: dict | None = None + extra: dict | None = None @classmethod def empty(cls) -> "ToolInvokeMeta": @@ -447,6 +448,7 @@ def to_dict(self): "time_cost": self.time_cost, "error": self.error, "tool_config": self.tool_config, + "extra": self.extra, } diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 13fd579e20aa01..d73feb0cc0dffe 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -107,9 +107,20 @@ def message_callback( tool_messages=binary_files, agent_message=message, invoke_from=invoke_from, user_id=user_id ) + # detect return_direct signal from variable messages (strict boolean short-circuit) + return_direct = any( + m.type == ToolInvokeMessage.MessageType.VARIABLE + and m.message is not None + and (variable := cast(ToolInvokeMessage.VariableMessage, m.message)) + and variable.variable_name == "return_direct" + and variable.variable_value is True + for m in message_list + ) + plain_text = ToolEngine._convert_tool_response_to_str(message_list) meta = invocation_meta_dict["meta"] + meta.extra = {"return_direct": return_direct} # hit the callback handler agent_tool_callback.on_tool_end( @@ -254,6 +265,9 @@ def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str ensure_ascii=False, ) ) + elif response.type == ToolInvokeMessage.MessageType.VARIABLE: + # internal variable messages should not be surfaced into plain text + continue else: parts.append(str(response.message)) diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 5703c19c880221..9d350b42cdb7be 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -114,10 +114,24 @@ def _invoke( for file in files: yield self.create_file_message(file) # type: ignore + return_direct_flag = isinstance(outputs, dict) and outputs.pop("return_direct", None) is True + self._latest_usage = self._derive_usage_from_result(data) - yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) + direct_text = None + if return_direct_flag: + string_values = [v for k, v in sorted(outputs.items()) if isinstance(v, str)] + if string_values: + direct_text = "\n".join(string_values) + + if direct_text is not None: + yield self.create_text_message(direct_text) + else: + yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) + yield self.create_json_message(outputs, suppress_output=True) + if return_direct_flag: + yield self.create_variable_message("return_direct", True) @property def latest_usage(self) -> LLMUsage: