-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
fix: tool call streaming output compatibility #6439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Weikjssss
wants to merge
2
commits into
AstrBotDevs:dev
from
Weikjssss:fix/tool-call-stream-compatibility
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -305,13 +305,37 @@ async def _query_stream( | |
| llm_response = LLMResponse("assistant", is_chunk=True) | ||
|
|
||
| state = ChatCompletionStreamState() | ||
| state_ok = True | ||
| state_error_logged = False | ||
|
|
||
| # Fallback buffers are filled from raw chunks so we can still finalize | ||
| # a usable response even if SDK stream state aggregation fails. | ||
| fallback_text_parts: list[str] = [] | ||
| fallback_reasoning_parts: list[str] = [] | ||
| fallback_tool_calls: dict[str, dict[str, Any]] = {} | ||
| fallback_tool_call_idx_to_id: dict[int, str] = {} | ||
| fallback_usage: TokenUsage | None = None | ||
| fallback_id: str | None = None | ||
|
|
||
| async for chunk in stream: | ||
| if chunk.id: | ||
| fallback_id = chunk.id | ||
| try: | ||
| state.handle_chunk(chunk) | ||
| except Exception as e: | ||
| logger.warning("Saving chunk state error: " + str(e)) | ||
| # Do not stop streaming on state aggregation failure. We still | ||
| # keep yielding chunk deltas and rely on fallback finalization. | ||
| state_ok = False | ||
| if not state_error_logged: | ||
| logger.warning( | ||
| f"Saving chunk state error for model {self.get_model()}: {e!r}", | ||
| ) | ||
| state_error_logged = True | ||
| else: | ||
| logger.debug(f"Saving chunk state error (suppressed): {e!r}") | ||
| if not chunk.choices: | ||
| if chunk.usage: | ||
| fallback_usage = self._extract_usage(chunk.usage) | ||
| continue | ||
| delta = chunk.choices[0].delta | ||
| # logger.debug(f"chunk delta: {delta}") | ||
|
|
@@ -321,23 +345,151 @@ async def _query_stream( | |
| llm_response.id = chunk.id | ||
| if reasoning: | ||
| llm_response.reasoning_content = reasoning | ||
| fallback_reasoning_parts.append(reasoning) | ||
| _y = True | ||
| self._collect_stream_tool_calls( | ||
| delta=delta, | ||
| fallback_tool_calls=fallback_tool_calls, | ||
| fallback_tool_call_idx_to_id=fallback_tool_call_idx_to_id, | ||
| ) | ||
| if delta and delta.content: | ||
| # Don't strip streaming chunks to preserve spaces between words | ||
| completion_text = self._normalize_content(delta.content, strip=False) | ||
| fallback_text_parts.append(completion_text) | ||
| llm_response.result_chain = MessageChain( | ||
| chain=[Comp.Plain(completion_text)], | ||
| ) | ||
| _y = True | ||
| if chunk.usage: | ||
| llm_response.usage = self._extract_usage(chunk.usage) | ||
| fallback_usage = llm_response.usage | ||
| if _y: | ||
| yield llm_response | ||
|
|
||
| final_completion = state.get_final_completion() | ||
| llm_response = await self._parse_openai_completion(final_completion, tools) | ||
| final_response: LLMResponse | None = None | ||
| if state_ok: | ||
| try: | ||
| final_completion = state.get_final_completion() | ||
| final_response = await self._parse_openai_completion( | ||
| final_completion, tools | ||
| ) | ||
| except Exception as e: | ||
| logger.warning( | ||
| f"Parsing final streaming completion failed for model {self.get_model()}, using fallback: {e!r}", | ||
| ) | ||
|
|
||
| if final_response is None: | ||
| final_response = self._build_stream_fallback_response( | ||
| fallback_id=fallback_id, | ||
| fallback_usage=fallback_usage, | ||
| fallback_text_parts=fallback_text_parts, | ||
| fallback_reasoning_parts=fallback_reasoning_parts, | ||
| fallback_tool_calls=fallback_tool_calls, | ||
| ) | ||
| if state_ok: | ||
| # state_ok=True but parse failed/get_final failed: fallback is still valid. | ||
| logger.warning( | ||
| f"Using fallback final response for model {self.get_model()} despite healthy stream state.", | ||
| ) | ||
|
|
||
| yield final_response | ||
|
|
||
| @staticmethod | ||
| def _collect_stream_tool_calls( | ||
| delta: Any, | ||
| fallback_tool_calls: dict[str, dict[str, Any]], | ||
| fallback_tool_call_idx_to_id: dict[int, str], | ||
| ) -> None: | ||
| tool_calls = getattr(delta, "tool_calls", None) | ||
| if not tool_calls: | ||
| return | ||
| for tool_call in tool_calls: | ||
| idx = getattr(tool_call, "index", None) | ||
| tool_call_id = getattr(tool_call, "id", None) | ||
| if isinstance(idx, int): | ||
| # Some providers only send id in the first fragment, then only | ||
| # send index for later argument fragments. | ||
| if tool_call_id: | ||
| fallback_tool_call_idx_to_id[idx] = str(tool_call_id) | ||
| else: | ||
| tool_call_id = fallback_tool_call_idx_to_id.get(idx) | ||
| if not tool_call_id: | ||
| if isinstance(idx, int): | ||
| tool_call_id = f"stream_tool_call_{idx}" | ||
| fallback_tool_call_idx_to_id[idx] = tool_call_id | ||
| else: | ||
| continue | ||
| entry = fallback_tool_calls.setdefault( | ||
| str(tool_call_id), | ||
| {"name": "", "arguments_raw": "", "extra_content": None}, | ||
| ) | ||
| function = getattr(tool_call, "function", None) | ||
| if function is not None: | ||
| name = getattr(function, "name", None) | ||
| if isinstance(name, str) and name: | ||
| entry["name"] = name | ||
| arguments = getattr(function, "arguments", None) | ||
| if arguments is not None: | ||
| entry["arguments_raw"] += str(arguments) | ||
| extra_content = getattr(tool_call, "extra_content", None) | ||
| if extra_content is not None: | ||
| entry["extra_content"] = extra_content | ||
|
|
||
| def _build_stream_fallback_response( | ||
| self, | ||
| fallback_id: str | None, | ||
| fallback_usage: TokenUsage | None, | ||
| fallback_text_parts: list[str], | ||
| fallback_reasoning_parts: list[str], | ||
| fallback_tool_calls: dict[str, dict[str, Any]], | ||
| ) -> LLMResponse: | ||
| reasoning_text = "".join(fallback_reasoning_parts) | ||
| if fallback_tool_calls: | ||
| tools_call_args: list[dict[str, Any]] = [] | ||
| tools_call_name: list[str] = [] | ||
| tools_call_ids: list[str] = [] | ||
| tools_call_extra_content: dict[str, dict[str, Any]] = {} | ||
| for tool_call_id, info in fallback_tool_calls.items(): | ||
| raw_args = str(info.get("arguments_raw", "")).strip() | ||
| if raw_args: | ||
| try: | ||
| parsed_args = json.loads(raw_args) | ||
| except Exception: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # Keep raw payload to avoid dropping tool calls when | ||
| # provider emits non-JSON or incomplete argument chunks. | ||
| parsed_args = {"_raw_arguments": raw_args} | ||
| else: | ||
| parsed_args = {} | ||
| tools_call_ids.append(tool_call_id) | ||
| tools_call_name.append(str(info.get("name") or "unknown_tool")) | ||
| tools_call_args.append(parsed_args) | ||
| extra_content = info.get("extra_content") | ||
| if extra_content is not None and isinstance(extra_content, dict): | ||
| tools_call_extra_content[tool_call_id] = extra_content | ||
| return LLMResponse( | ||
| role="tool", | ||
| tools_call_args=tools_call_args, | ||
| tools_call_name=tools_call_name, | ||
| tools_call_ids=tools_call_ids, | ||
| tools_call_extra_content=tools_call_extra_content, | ||
| reasoning_content=reasoning_text, | ||
| id=fallback_id, | ||
| usage=fallback_usage, | ||
| ) | ||
|
|
||
| yield llm_response | ||
| completion_text = "".join(fallback_text_parts) | ||
| result_chain = ( | ||
| MessageChain().message(completion_text) | ||
| if completion_text | ||
| else MessageChain() | ||
| ) | ||
| return LLMResponse( | ||
| role="assistant", | ||
| result_chain=result_chain, | ||
| reasoning_content=reasoning_text, | ||
| id=fallback_id, | ||
| usage=fallback_usage, | ||
| ) | ||
|
|
||
| def _extract_reasoning_content( | ||
| self, | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion: Swallowing JSON parse errors for tool arguments without logging may hinder debugging of malformed tool calls.
When
json.loads(raw_args)fails, the exception is swallowed and replaced with{ "_raw_arguments": raw_args }without any logging. This preserves functionality but obscures upstream issues like malformed JSON or truncation. Please add at least a throttled debug or warning log with the model name and a truncatedraw_argsso operational issues with degraded tool call payloads are detectable.Suggested implementation:
try/exceptblock forjson.loads(raw_args)exists exactly as in the SEARCH block; if its structure or local variable names differ (raw_argsvsraw_arguments, etc.), adjust the SEARCH/REPLACE snippet to match the actual code._should_log_tool_args_parse_errorassumesloggingis already imported andloggeris defined as shown; if your file definesloggerdifferently or in another location, insert the helper and the_TOOL_ARGS_PARSE_ERROR_*globals near that definition instead.log_throttled,RateLimitedLogger, etc.), you should replace the simple counter-based_should_log_tool_args_parse_errorimplementation with that shared utility to stay consistent with existing conventions.