From 8df36a5051866fa7cf0aeafccfbb72195afd81ba Mon Sep 17 00:00:00 2001 From: Weikjssss <1007840024@qq.com> Date: Mon, 16 Mar 2026 15:43:26 +0800 Subject: [PATCH] fix: tool call streaming output compatibility # Conflicts: # astrbot/core/provider/sources/openai_source.py resolved by [CherryPick] version --- .../core/provider/sources/openai_source.py | 168 +++++++++++++++++- 1 file changed, 160 insertions(+), 8 deletions(-) diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index c40234ed47..9b538ecf9d 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -305,13 +305,37 @@ async def _query_stream( llm_response = LLMResponse("assistant", is_chunk=True) state = ChatCompletionStreamState() + state_ok = True + state_error_logged = False + + # Fallback buffers are filled from raw chunks so we can still finalize + # a usable response even if SDK stream state aggregation fails. + fallback_text_parts: list[str] = [] + fallback_reasoning_parts: list[str] = [] + fallback_tool_calls: dict[str, dict[str, Any]] = {} + fallback_tool_call_idx_to_id: dict[int, str] = {} + fallback_usage: TokenUsage | None = None + fallback_id: str | None = None async for chunk in stream: + if chunk.id: + fallback_id = chunk.id try: state.handle_chunk(chunk) except Exception as e: - logger.warning("Saving chunk state error: " + str(e)) - if len(chunk.choices) == 0: + # Do not stop streaming on state aggregation failure. We still + # keep yielding chunk deltas and rely on fallback finalization. + state_ok = False + if not state_error_logged: + logger.warning( + f"Saving chunk state error for model {self.get_model()}: {e!r}", + ) + state_error_logged = True + else: + logger.debug(f"Saving chunk state error (suppressed): {e!r}") + if not chunk.choices: + if chunk.usage: + fallback_usage = self._extract_usage(chunk.usage) continue delta = chunk.choices[0].delta # logger.debug(f"chunk delta: {delta}") @@ -321,23 +345,151 @@ async def _query_stream( llm_response.id = chunk.id if reasoning: llm_response.reasoning_content = reasoning + fallback_reasoning_parts.append(reasoning) _y = True - if delta.content: + self._collect_stream_tool_calls( + delta=delta, + fallback_tool_calls=fallback_tool_calls, + fallback_tool_call_idx_to_id=fallback_tool_call_idx_to_id, + ) + if delta and delta.content: # Don't strip streaming chunks to preserve spaces between words completion_text = self._normalize_content(delta.content, strip=False) + fallback_text_parts.append(completion_text) llm_response.result_chain = MessageChain( chain=[Comp.Plain(completion_text)], ) _y = True if chunk.usage: llm_response.usage = self._extract_usage(chunk.usage) + fallback_usage = llm_response.usage if _y: yield llm_response - final_completion = state.get_final_completion() - llm_response = await self._parse_openai_completion(final_completion, tools) + final_response: LLMResponse | None = None + if state_ok: + try: + final_completion = state.get_final_completion() + final_response = await self._parse_openai_completion( + final_completion, tools + ) + except Exception as e: + logger.warning( + f"Parsing final streaming completion failed for model {self.get_model()}, using fallback: {e!r}", + ) + + if final_response is None: + final_response = self._build_stream_fallback_response( + fallback_id=fallback_id, + fallback_usage=fallback_usage, + fallback_text_parts=fallback_text_parts, + fallback_reasoning_parts=fallback_reasoning_parts, + fallback_tool_calls=fallback_tool_calls, + ) + if state_ok: + # state_ok=True but parse failed/get_final failed: fallback is still valid. + logger.warning( + f"Using fallback final response for model {self.get_model()} despite healthy stream state.", + ) + + yield final_response + + @staticmethod + def _collect_stream_tool_calls( + delta: Any, + fallback_tool_calls: dict[str, dict[str, Any]], + fallback_tool_call_idx_to_id: dict[int, str], + ) -> None: + tool_calls = getattr(delta, "tool_calls", None) + if not tool_calls: + return + for tool_call in tool_calls: + idx = getattr(tool_call, "index", None) + tool_call_id = getattr(tool_call, "id", None) + if isinstance(idx, int): + # Some providers only send id in the first fragment, then only + # send index for later argument fragments. + if tool_call_id: + fallback_tool_call_idx_to_id[idx] = str(tool_call_id) + else: + tool_call_id = fallback_tool_call_idx_to_id.get(idx) + if not tool_call_id: + if isinstance(idx, int): + tool_call_id = f"stream_tool_call_{idx}" + fallback_tool_call_idx_to_id[idx] = tool_call_id + else: + continue + entry = fallback_tool_calls.setdefault( + str(tool_call_id), + {"name": "", "arguments_raw": "", "extra_content": None}, + ) + function = getattr(tool_call, "function", None) + if function is not None: + name = getattr(function, "name", None) + if isinstance(name, str) and name: + entry["name"] = name + arguments = getattr(function, "arguments", None) + if arguments is not None: + entry["arguments_raw"] += str(arguments) + extra_content = getattr(tool_call, "extra_content", None) + if extra_content is not None: + entry["extra_content"] = extra_content + + def _build_stream_fallback_response( + self, + fallback_id: str | None, + fallback_usage: TokenUsage | None, + fallback_text_parts: list[str], + fallback_reasoning_parts: list[str], + fallback_tool_calls: dict[str, dict[str, Any]], + ) -> LLMResponse: + reasoning_text = "".join(fallback_reasoning_parts) + if fallback_tool_calls: + tools_call_args: list[dict[str, Any]] = [] + tools_call_name: list[str] = [] + tools_call_ids: list[str] = [] + tools_call_extra_content: dict[str, dict[str, Any]] = {} + for tool_call_id, info in fallback_tool_calls.items(): + raw_args = str(info.get("arguments_raw", "")).strip() + if raw_args: + try: + parsed_args = json.loads(raw_args) + except Exception: + # Keep raw payload to avoid dropping tool calls when + # provider emits non-JSON or incomplete argument chunks. + parsed_args = {"_raw_arguments": raw_args} + else: + parsed_args = {} + tools_call_ids.append(tool_call_id) + tools_call_name.append(str(info.get("name") or "unknown_tool")) + tools_call_args.append(parsed_args) + extra_content = info.get("extra_content") + if extra_content is not None and isinstance(extra_content, dict): + tools_call_extra_content[tool_call_id] = extra_content + return LLMResponse( + role="tool", + tools_call_args=tools_call_args, + tools_call_name=tools_call_name, + tools_call_ids=tools_call_ids, + tools_call_extra_content=tools_call_extra_content, + reasoning_content=reasoning_text, + id=fallback_id, + usage=fallback_usage, + ) - yield llm_response + completion_text = "".join(fallback_text_parts) + result_chain = ( + MessageChain().message(completion_text) + if completion_text + else MessageChain() + ) + return LLMResponse( + role="assistant", + result_chain=result_chain, + reasoning_content=reasoning_text, + id=fallback_id, + usage=fallback_usage, + ) def _extract_reasoning_content( self, @@ -345,7 +497,7 @@ def _extract_reasoning_content( ) -> str: """Extract reasoning content from OpenAI ChatCompletion if available.""" reasoning_text = "" - if len(completion.choices) == 0: + if not completion.choices: return reasoning_text if isinstance(completion, ChatCompletion): choice = completion.choices[0] @@ -468,7 +620,7 @@ async def _parse_openai_completion( """Parse OpenAI ChatCompletion into LLMResponse""" llm_response = LLMResponse("assistant") - if len(completion.choices) == 0: + if not completion.choices: raise Exception("API 返回的 completion 为空。") choice = completion.choices[0]