Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,42 @@ async def get_models(self):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")

@staticmethod
def _sanitize_assistant_messages(payloads: dict) -> None:
"""在请求发送前过滤/规范化空的 assistant 消息。

严格 API(Moonshot、DeepSeek Reasoner 等)会在 assistant 消息同时缺少
``content`` 和 ``tool_calls`` 时返回 400。把 ``""`` / ``None`` / ``[]``
都视作空内容:无 tool_calls 时整条过滤掉;有 tool_calls 时将 content
设为 ``None`` 以符合 OpenAI 规范。就地修改 ``payloads["messages"]``。
"""
messages = payloads.get("messages")
if not isinstance(messages, list):
return

def _is_empty(content: Any) -> bool:
return content is None or content == "" or content == []

cleaned: list[Any] = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or msg.get("role") != "assistant":
cleaned.append(msg)
continue

content = msg.get("content")
tool_calls = msg.get("tool_calls")

if _is_empty(content) and not tool_calls:
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
continue

if _is_empty(content) and tool_calls:
msg["content"] = None

cleaned.append(msg)
Comment on lines +540 to +554
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic inside this loop can be simplified by inverting the initial if condition. This removes one level of indentation and a continue statement, making the main logic path for assistant messages clearer and improving readability.

            if isinstance(msg, dict) and msg.get("role") == "assistant":
                content = msg.get("content")
                tool_calls = msg.get("tool_calls")

                if _is_empty(content) and not tool_calls:
                    logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
                    continue

                if _is_empty(content) and tool_calls:
                    msg["content"] = None

            cleaned.append(msg)


payloads["messages"] = cleaned

async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
Expand Down Expand Up @@ -548,26 +584,7 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:

model = payloads.get("model", "").lower()

if "messages" in payloads and isinstance(payloads["messages"], list):
cleaned_messages = []
for idx, msg in enumerate(payloads["messages"]):
# 过滤空的 assistant 消息,防止严格 API(如 Moonshot)返回 400 错误
if msg.get("role") == "assistant":
content = msg.get("content")
tool_calls = msg.get("tool_calls")

# 情况1: 空/null content 且无 tool_calls -> 过滤掉
if not tool_calls and (content == "" or content is None):
logger.warning(f"过滤第 {idx} 条空 assistant 消息 (无工具调用)")
continue

# 情况2: 空 content 但有 tool_calls -> 设为 None (符合 OpenAI 规范)
if content == "" and tool_calls:
msg["content"] = None

cleaned_messages.append(msg)

payloads["messages"] = cleaned_messages
self._sanitize_assistant_messages(payloads)

completion = await self.client.chat.completions.create(
**payloads,
Expand Down Expand Up @@ -619,6 +636,8 @@ async def _query_stream(
del payloads[key]
self._apply_provider_specific_extra_body_overrides(extra_body)

self._sanitize_assistant_messages(payloads)

stream = await self.client.chat.completions.create(
**payloads,
stream=True,
Expand Down
106 changes: 106 additions & 0 deletions tests/test_openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -1618,3 +1618,109 @@ async def fake_create(**kwargs):
assert messages[2] == {"role": "user", "content": "hello"}
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_query_stream_filters_empty_assistant_message(monkeypatch):
"""Regression for #7721: streaming path must also filter empty assistant messages.

Previously only ``_query`` sanitized the payload; ``_query_stream`` forwarded
the raw history and strict providers (e.g. DeepSeek Reasoner) returned 400 on
the next turn after a tool call whose assistant entry had reasoning only.
"""
provider = _make_provider()
try:
captured_kwargs = {}

async def fake_stream():
yield ChatCompletionChunk.model_validate(
{
"id": "chatcmpl-stream",
"object": "chat.completion.chunk",
"created": 0,
"model": "deepseek-reasoner",
"choices": [
{
"index": 0,
"delta": {"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}
],
}
)

async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return fake_stream()

monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)

payloads = {
"model": "deepseek-reasoner",
"messages": [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": ""}, # should be filtered
{"role": "user", "content": "world"},
],
}

async for _ in provider._query_stream(payloads=payloads, tools=None):
pass

messages = captured_kwargs["messages"]
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "hello"}
assert messages[1] == {"role": "user", "content": "world"}
finally:
await provider.terminate()


@pytest.mark.asyncio
async def test_query_filters_empty_list_content_assistant_message(monkeypatch):
"""Empty-list content (``content == []``) must also be filtered, not just ``""`` / ``None``."""
provider = _make_provider()
try:
captured_kwargs = {}

async def fake_create(**kwargs):
captured_kwargs.update(kwargs)
return ChatCompletion.model_validate(
{
"id": "chatcmpl-test",
"object": "chat.completion",
"created": 0,
"model": "gpt-4o-mini",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "ok"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
)

monkeypatch.setattr(provider.client.chat.completions, "create", fake_create)

payloads = {
"model": "gpt-4o-mini",
"messages": [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": []}, # should be filtered
{"role": "user", "content": "again"},
],
}

await provider._query(payloads=payloads, tools=None)

messages = captured_kwargs["messages"]
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "hi"}
assert messages[1] == {"role": "user", "content": "again"}
finally:
await provider.terminate()
Loading