Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions astrbot/core/agent/runners/tool_loop_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,9 @@ def _append_tool_call_result(tool_call_id: str, content: str) -> None:

valid_params = {} # 参数过滤:只传递函数实际需要的参数

if func_tool_args is None:
func_tool_args = {}

# 获取实际的 handler 函数
if func_tool.handler:
logger.debug(
Expand Down
82 changes: 81 additions & 1 deletion astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,78 @@
from ..register import register_provider_adapter


class _StreamToolCallFallback:
"""Fallback accumulator for streaming tool_call data.

Some OpenAI-compatible proxies (e.g. Gemini) return tool_call chunks
without the required ``index`` field, causing the openai SDK's
``ChatCompletionStreamState`` to reject them. This class manually
collects tool_call data from such rejected chunks so the arguments
can be restored after the stream completes.
"""

def __init__(self) -> None:
self._calls: dict[int, dict[str, str]] = {}

def collect_from_chunk(self, chunk: ChatCompletionChunk) -> None:
if not chunk.choices:
return
delta = chunk.choices[0].delta
if not delta or not hasattr(delta, "tool_calls") or not delta.tool_calls:
return
for tc in delta.tool_calls:
idx = getattr(tc, "index", None) or 0
if idx not in self._calls:
self._calls[idx] = {"id": "", "name": "", "arguments": ""}
if getattr(tc, "id", None):
self._calls[idx]["id"] = tc.id
if hasattr(tc, "function") and tc.function:
if getattr(tc.function, "name", None):
self._calls[idx]["name"] = tc.function.name
if getattr(tc.function, "arguments", None):
self._calls[idx]["arguments"] += tc.function.arguments

def apply_to(self, response: "LLMResponse") -> None:
if not self._calls:
return

# Case 1: tool_calls were parsed but args are empty — fill them in
if response.tools_call_args:
for i, args in enumerate(response.tools_call_args):
if (args is None or args == {}) and i in self._calls:
parsed = self._parse_args(self._calls[i]["arguments"])
if parsed is not None:
response.tools_call_args[i] = parsed
logger.info(
f"Stream fallback: restored args for {self._calls[i]['name']}"
)
return

# Case 2: no tool_calls were parsed at all — rebuild from fallback
for idx in sorted(self._calls):
fb = self._calls[idx]
if not fb["name"] or not fb["arguments"]:
continue
parsed = self._parse_args(fb["arguments"])
if parsed is None:
parsed = {}
response.tools_call_args = response.tools_call_args or []
response.tools_call_name = response.tools_call_name or []
response.tools_call_ids = response.tools_call_ids or []
response.tools_call_args.append(parsed)
response.tools_call_name.append(fb["name"])
response.tools_call_ids.append(fb["id"] or f"fallback_{idx}")
response.role = "tool"
logger.info(f"Stream fallback: rebuilt tool_call {fb['name']}({parsed})")

@staticmethod
def _parse_args(raw: str) -> dict | None:
try:
return json.loads(raw)
except (json.JSONDecodeError, TypeError, ValueError):
return None


@register_provider_adapter(
"openai_chat_completion",
"OpenAI API Chat Completion 提供商适配器",
Expand Down Expand Up @@ -305,12 +377,14 @@ async def _query_stream(
llm_response = LLMResponse("assistant", is_chunk=True)

state = ChatCompletionStreamState()
fallback_tc = _StreamToolCallFallback()

async for chunk in stream:
try:
state.handle_chunk(chunk)
except Exception as e:
logger.warning("Saving chunk state error: " + str(e))
fallback_tc.collect_from_chunk(chunk)
if not chunk.choices:
continue
choice = chunk.choices[0]
Expand Down Expand Up @@ -342,6 +416,7 @@ async def _query_stream(

final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
fallback_tc.apply_to(llm_response)

yield llm_response

Expand Down Expand Up @@ -517,9 +592,14 @@ async def _parse_openai_completion(
):
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
try:
args = json.loads(tool_call.function.arguments)
except (json.JSONDecodeError, TypeError):
args = {}
else:
args = tool_call.function.arguments
if args is None:
args = {}
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
Expand Down