From 092303d94b6f7c955304d7dcbaa4ca151a200e70 Mon Sep 17 00:00:00 2001 From: Yufeng He <40085740+universeplayer@users.noreply.github.com> Date: Thu, 19 Mar 2026 10:54:58 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=B8=8A=E4=B8=8B=E6=96=87=20token=20?= =?UTF-8?q?=E8=AE=A1=E6=95=B0=E6=94=AF=E6=8C=81=E5=A4=9A=E6=A8=A1=E6=80=81?= =?UTF-8?q?=E5=86=85=E5=AE=B9=EF=BC=88=E5=9B=BE=E7=89=87/=E9=9F=B3?= =?UTF-8?q?=E9=A2=91/=E6=80=9D=E8=80=83=E9=93=BE=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit EstimateTokenCounter 之前只计算 TextPart,完全忽略 ImageURLPart、 AudioURLPart 和 ThinkPart。多模态对话中图片占 500-2000 token, 不被计入会导致 context 压缩触发过晚,API 先报 context_length_exceeded。 改动: - ImageURLPart 按 765 token 估算(OpenAI vision 低/高分辨率中位数) - AudioURLPart 按 500 token 估算 - ThinkPart 的文本内容正常计算 - 10 个新测试覆盖各类型单独和混合场景 --- astrbot/core/agent/context/token_counter.py | 20 +++- tests/agent/test_token_counter.py | 103 ++++++++++++++++++++ 2 files changed, 120 insertions(+), 3 deletions(-) create mode 100644 tests/agent/test_token_counter.py diff --git a/astrbot/core/agent/context/token_counter.py b/astrbot/core/agent/context/token_counter.py index 1d4efbe8d5..7c60cb23ec 100644 --- a/astrbot/core/agent/context/token_counter.py +++ b/astrbot/core/agent/context/token_counter.py @@ -1,7 +1,7 @@ import json from typing import Protocol, runtime_checkable -from ..message import Message, TextPart +from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart @runtime_checkable @@ -28,9 +28,19 @@ def count_tokens( ... +# 图片/音频 token 开销估算值,参考 OpenAI vision pricing: +# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。 +# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。 +IMAGE_TOKEN_ESTIMATE = 765 +AUDIO_TOKEN_ESTIMATE = 500 + + class EstimateTokenCounter: """Estimate token counter implementation. Provides a simple estimation of token count based on character types. + + Supports multimodal content: images, audio, and thinking parts + are all counted so that the context compressor can trigger in time. """ def count_tokens( @@ -45,12 +55,16 @@ def count_tokens( if isinstance(content, str): total += self._estimate_tokens(content) elif isinstance(content, list): - # 处理多模态内容 for part in content: if isinstance(part, TextPart): total += self._estimate_tokens(part.text) + elif isinstance(part, ThinkPart): + total += self._estimate_tokens(part.think) + elif isinstance(part, ImageURLPart): + total += IMAGE_TOKEN_ESTIMATE + elif isinstance(part, AudioURLPart): + total += AUDIO_TOKEN_ESTIMATE - # 处理 Tool Calls if msg.tool_calls: for tc in msg.tool_calls: tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump()) diff --git a/tests/agent/test_token_counter.py b/tests/agent/test_token_counter.py new file mode 100644 index 0000000000..c68b056e66 --- /dev/null +++ b/tests/agent/test_token_counter.py @@ -0,0 +1,103 @@ +"""Tests for EstimateTokenCounter multimodal support.""" + +from astrbot.core.agent.context.token_counter import ( + AUDIO_TOKEN_ESTIMATE, + IMAGE_TOKEN_ESTIMATE, + EstimateTokenCounter, +) +from astrbot.core.agent.message import ( + AudioURLPart, + ImageURLPart, + Message, + TextPart, + ThinkPart, +) + + +counter = EstimateTokenCounter() + + +def _msg(role: str, content) -> Message: + return Message(role=role, content=content) + + +class TestTextCounting: + def test_plain_string(self): + tokens = counter.count_tokens([_msg("user", "hello world")]) + assert tokens > 0 + + def test_chinese(self): + # 中文字符权重更高 + en = counter.count_tokens([_msg("user", "abc")]) + zh = counter.count_tokens([_msg("user", "你好啊")]) + assert zh > en + + def test_text_part(self): + msg = _msg("user", [TextPart(text="hello")]) + assert counter.count_tokens([msg]) > 0 + + +class TestMultimodalCounting: + def test_image_counted(self): + msg = _msg("user", [ + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,abc")), + ]) + tokens = counter.count_tokens([msg]) + assert tokens == IMAGE_TOKEN_ESTIMATE + + def test_audio_counted(self): + msg = _msg("user", [ + AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3")), + ]) + tokens = counter.count_tokens([msg]) + assert tokens == AUDIO_TOKEN_ESTIMATE + + def test_think_counted(self): + msg = _msg("assistant", [ThinkPart(think="let me think about this")]) + tokens = counter.count_tokens([msg]) + assert tokens > 0 + + def test_mixed_content(self): + """文本 + 图片的多模态消息,token 数 = 文本 token + 图片估算。""" + text_only = _msg("user", [TextPart(text="describe this image")]) + mixed = _msg("user", [ + TextPart(text="describe this image"), + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")), + ]) + text_tokens = counter.count_tokens([text_only]) + mixed_tokens = counter.count_tokens([mixed]) + assert mixed_tokens == text_tokens + IMAGE_TOKEN_ESTIMATE + + def test_multiple_images(self): + """多张图片应该各自计算。""" + msg = _msg("user", [ + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,a")), + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,b")), + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,c")), + ]) + tokens = counter.count_tokens([msg]) + assert tokens == IMAGE_TOKEN_ESTIMATE * 3 + + +class TestTrustedUsage: + def test_trusted_overrides(self): + """如果 API 返回了 token 数,直接用它不做估算。""" + msg = _msg("user", [ + TextPart(text="hello"), + ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")), + ]) + tokens = counter.count_tokens([msg], trusted_token_usage=42) + assert tokens == 42 + + +class TestToolCalls: + def test_tool_calls_counted(self): + msg = Message( + role="assistant", + content="calling tool", + tool_calls=[{"type": "function", "id": "1", "function": {"name": "get_weather", "arguments": '{"city": "Beijing"}'}}], + ) + tokens = counter.count_tokens([msg]) + # 文本 + tool call JSON 都应被计算 + text_only = counter.count_tokens([_msg("assistant", "calling tool")]) + assert tokens > text_only