Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 156 additions & 4 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,37 @@ async def _query_stream(
llm_response = LLMResponse("assistant", is_chunk=True)

state = ChatCompletionStreamState()
state_ok = True
state_error_logged = False

# Fallback buffers are filled from raw chunks so we can still finalize
# a usable response even if SDK stream state aggregation fails.
fallback_text_parts: list[str] = []
fallback_reasoning_parts: list[str] = []
fallback_tool_calls: dict[str, dict[str, Any]] = {}
fallback_tool_call_idx_to_id: dict[int, str] = {}
fallback_usage: TokenUsage | None = None
fallback_id: str | None = None

async for chunk in stream:
if chunk.id:
fallback_id = chunk.id
try:
state.handle_chunk(chunk)
except Exception as e:
logger.warning("Saving chunk state error: " + str(e))
# Do not stop streaming on state aggregation failure. We still
# keep yielding chunk deltas and rely on fallback finalization.
state_ok = False
if not state_error_logged:
logger.warning(
f"Saving chunk state error for model {self.get_model()}: {e!r}",
)
Comment on lines 323 to +332
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Swallowing JSON parse errors for tool arguments without logging may hinder debugging of malformed tool calls.

When json.loads(raw_args) fails, the exception is swallowed and replaced with { "_raw_arguments": raw_args } without any logging. This preserves functionality but obscures upstream issues like malformed JSON or truncation. Please add at least a throttled debug or warning log with the model name and a truncated raw_args so operational issues with degraded tool call payloads are detectable.

Suggested implementation:

        async for chunk in stream:
            if chunk.id:
                fallback_id = chunk.id
            try:
                state.handle_chunk(chunk)
            except Exception as e:
                # Do not stop streaming on state aggregation failure. We still
                # keep yielding chunk deltas and rely on fallback finalization.
                state_ok = False
                if not state_error_logged:
                    logger.warning(
                        f"Saving chunk state error for model {self.get_model()}: {e!r}",
                    )
                    state_error_logged = True
        try:
            arguments = json.loads(raw_args)
        except Exception as e:
            # Log JSON parse failures for tool arguments in a throttled manner so that
            # malformed or truncated tool payloads are observable without flooding logs.
            if _should_log_tool_args_parse_error():
                logger.warning(
                    "Failed to parse tool call arguments for model %s: %r. "
                    "Raw arguments (truncated to 256 chars): %r",
                    self.get_model(),
                    e,
                    raw_args[:256],
                )
            arguments = {"_raw_arguments": raw_args}
logger = logging.getLogger(__name__)

# Throttle logging for noisy tool-argument JSON parse errors. This keeps
# operational visibility without flooding logs if a model starts emitting
# systematically malformed tool calls.
_TOOL_ARGS_PARSE_ERROR_COUNT = 0
_TOOL_ARGS_PARSE_ERROR_LOG_EVERY_N = 100


def _should_log_tool_args_parse_error() -> bool:
    """
    Return True when we should emit a log entry for a tool-argument JSON parse error.

    This uses a simple counter-based throttle (log every N errors). It is intentionally
    lightweight and process-local; if tighter guarantees are needed across workers,
    this can be replaced with a shared-rate-limiter implementation.
    """
    global _ TOOL_ARGS_PARSE_ERROR_COUNT  # type: ignore[no-redef]
    _ TOOL_ARGS_PARSE_ERROR_COUNT += 1
    return _ TOOL_ARGS_PARSE_ERROR_COUNT % _ TOOL_ARGS_PARSE_ERROR_LOG_EVERY_N == 1
  1. Ensure that the try/except block for json.loads(raw_args) exists exactly as in the SEARCH block; if its structure or local variable names differ (raw_args vs raw_arguments, etc.), adjust the SEARCH/REPLACE snippet to match the actual code.
  2. The helper _should_log_tool_args_parse_error assumes logging is already imported and logger is defined as shown; if your file defines logger differently or in another location, insert the helper and the _TOOL_ARGS_PARSE_ERROR_* globals near that definition instead.
  3. If your codebase has an existing throttled-logging utility (e.g. log_throttled, RateLimitedLogger, etc.), you should replace the simple counter-based _should_log_tool_args_parse_error implementation with that shared utility to stay consistent with existing conventions.

state_error_logged = True
else:
logger.debug(f"Saving chunk state error (suppressed): {e!r}")
if not chunk.choices:
if chunk.usage:
fallback_usage = self._extract_usage(chunk.usage)
continue
delta = chunk.choices[0].delta
# logger.debug(f"chunk delta: {delta}")
Expand All @@ -321,23 +345,151 @@ async def _query_stream(
llm_response.id = chunk.id
if reasoning:
llm_response.reasoning_content = reasoning
fallback_reasoning_parts.append(reasoning)
_y = True
self._collect_stream_tool_calls(
delta=delta,
fallback_tool_calls=fallback_tool_calls,
fallback_tool_call_idx_to_id=fallback_tool_call_idx_to_id,
)
if delta and delta.content:
# Don't strip streaming chunks to preserve spaces between words
completion_text = self._normalize_content(delta.content, strip=False)
fallback_text_parts.append(completion_text)
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
_y = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
fallback_usage = llm_response.usage
if _y:
yield llm_response

final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
final_response: LLMResponse | None = None
if state_ok:
try:
final_completion = state.get_final_completion()
final_response = await self._parse_openai_completion(
final_completion, tools
)
except Exception as e:
logger.warning(
f"Parsing final streaming completion failed for model {self.get_model()}, using fallback: {e!r}",
)

if final_response is None:
final_response = self._build_stream_fallback_response(
fallback_id=fallback_id,
fallback_usage=fallback_usage,
fallback_text_parts=fallback_text_parts,
fallback_reasoning_parts=fallback_reasoning_parts,
fallback_tool_calls=fallback_tool_calls,
)
if state_ok:
# state_ok=True but parse failed/get_final failed: fallback is still valid.
logger.warning(
f"Using fallback final response for model {self.get_model()} despite healthy stream state.",
)

yield final_response

@staticmethod
def _collect_stream_tool_calls(
delta: Any,
fallback_tool_calls: dict[str, dict[str, Any]],
fallback_tool_call_idx_to_id: dict[int, str],
) -> None:
tool_calls = getattr(delta, "tool_calls", None)
if not tool_calls:
return
for tool_call in tool_calls:
idx = getattr(tool_call, "index", None)
tool_call_id = getattr(tool_call, "id", None)
if isinstance(idx, int):
# Some providers only send id in the first fragment, then only
# send index for later argument fragments.
if tool_call_id:
fallback_tool_call_idx_to_id[idx] = str(tool_call_id)
else:
tool_call_id = fallback_tool_call_idx_to_id.get(idx)
if not tool_call_id:
if isinstance(idx, int):
tool_call_id = f"stream_tool_call_{idx}"
fallback_tool_call_idx_to_id[idx] = tool_call_id
else:
continue
entry = fallback_tool_calls.setdefault(
str(tool_call_id),
{"name": "", "arguments_raw": "", "extra_content": None},
)
function = getattr(tool_call, "function", None)
if function is not None:
name = getattr(function, "name", None)
if isinstance(name, str) and name:
entry["name"] = name
arguments = getattr(function, "arguments", None)
if arguments is not None:
entry["arguments_raw"] += str(arguments)
extra_content = getattr(tool_call, "extra_content", None)
if extra_content is not None:
entry["extra_content"] = extra_content

def _build_stream_fallback_response(
self,
fallback_id: str | None,
fallback_usage: TokenUsage | None,
fallback_text_parts: list[str],
fallback_reasoning_parts: list[str],
fallback_tool_calls: dict[str, dict[str, Any]],
) -> LLMResponse:
reasoning_text = "".join(fallback_reasoning_parts)
if fallback_tool_calls:
tools_call_args: list[dict[str, Any]] = []
tools_call_name: list[str] = []
tools_call_ids: list[str] = []
tools_call_extra_content: dict[str, dict[str, Any]] = {}
for tool_call_id, info in fallback_tool_calls.items():
raw_args = str(info.get("arguments_raw", "")).strip()
if raw_args:
try:
parsed_args = json.loads(raw_args)
except Exception:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

为了提高代码的精确性和可维护性,建议将宽泛的 except Exception: 替换为更具体的 except json.JSONDecodeError:。这能更清晰地表明代码意图是处理JSON解析失败的情况,同时避免意外捕获其他类型的运行时错误,使错误处理逻辑更加稳健和易于调试。

Suggested change
except Exception:
except json.JSONDecodeError:

# Keep raw payload to avoid dropping tool calls when
# provider emits non-JSON or incomplete argument chunks.
parsed_args = {"_raw_arguments": raw_args}
else:
parsed_args = {}
tools_call_ids.append(tool_call_id)
tools_call_name.append(str(info.get("name") or "unknown_tool"))
tools_call_args.append(parsed_args)
extra_content = info.get("extra_content")
if extra_content is not None and isinstance(extra_content, dict):
tools_call_extra_content[tool_call_id] = extra_content
return LLMResponse(
role="tool",
tools_call_args=tools_call_args,
tools_call_name=tools_call_name,
tools_call_ids=tools_call_ids,
tools_call_extra_content=tools_call_extra_content,
reasoning_content=reasoning_text,
id=fallback_id,
usage=fallback_usage,
)

yield llm_response
completion_text = "".join(fallback_text_parts)
result_chain = (
MessageChain().message(completion_text)
if completion_text
else MessageChain()
)
return LLMResponse(
role="assistant",
result_chain=result_chain,
reasoning_content=reasoning_text,
id=fallback_id,
usage=fallback_usage,
)

def _extract_reasoning_content(
self,
Expand Down