-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
feat: context token counting support for multimodal content (images, audio, and chain-of-thought) #6596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: context token counting support for multimodal content (images, audio, and chain-of-thought) #6596
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| """Tests for EstimateTokenCounter multimodal support.""" | ||
|
|
||
| from astrbot.core.agent.context.token_counter import ( | ||
| AUDIO_TOKEN_ESTIMATE, | ||
| IMAGE_TOKEN_ESTIMATE, | ||
| EstimateTokenCounter, | ||
| ) | ||
| from astrbot.core.agent.message import ( | ||
| AudioURLPart, | ||
| ImageURLPart, | ||
| Message, | ||
| TextPart, | ||
| ThinkPart, | ||
| ) | ||
|
|
||
|
|
||
| counter = EstimateTokenCounter() | ||
|
|
||
|
|
||
| def _msg(role: str, content) -> Message: | ||
| return Message(role=role, content=content) | ||
|
|
||
|
|
||
| class TestTextCounting: | ||
| def test_plain_string(self): | ||
| tokens = counter.count_tokens([_msg("user", "hello world")]) | ||
| assert tokens > 0 | ||
|
|
||
| def test_chinese(self): | ||
| # 中文字符权重更高 | ||
| en = counter.count_tokens([_msg("user", "abc")]) | ||
| zh = counter.count_tokens([_msg("user", "你好啊")]) | ||
| assert zh > en | ||
|
|
||
| def test_text_part(self): | ||
| msg = _msg("user", [TextPart(text="hello")]) | ||
| assert counter.count_tokens([msg]) > 0 | ||
|
|
||
|
|
||
| class TestMultimodalCounting: | ||
| def test_image_counted(self): | ||
| msg = _msg("user", [ | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,abc")), | ||
| ]) | ||
| tokens = counter.count_tokens([msg]) | ||
| assert tokens == IMAGE_TOKEN_ESTIMATE | ||
|
|
||
| def test_audio_counted(self): | ||
| msg = _msg("user", [ | ||
| AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3")), | ||
| ]) | ||
| tokens = counter.count_tokens([msg]) | ||
| assert tokens == AUDIO_TOKEN_ESTIMATE | ||
|
|
||
| def test_think_counted(self): | ||
| msg = _msg("assistant", [ThinkPart(think="let me think about this")]) | ||
| tokens = counter.count_tokens([msg]) | ||
| assert tokens > 0 | ||
|
|
||
| def test_mixed_content(self): | ||
| """文本 + 图片的多模态消息,token 数 = 文本 token + 图片估算。""" | ||
| text_only = _msg("user", [TextPart(text="describe this image")]) | ||
| mixed = _msg("user", [ | ||
| TextPart(text="describe this image"), | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")), | ||
| ]) | ||
| text_tokens = counter.count_tokens([text_only]) | ||
| mixed_tokens = counter.count_tokens([mixed]) | ||
| assert mixed_tokens == text_tokens + IMAGE_TOKEN_ESTIMATE | ||
|
|
||
| def test_multiple_images(self): | ||
| """多张图片应该各自计算。""" | ||
| msg = _msg("user", [ | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,a")), | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,b")), | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,c")), | ||
| ]) | ||
| tokens = counter.count_tokens([msg]) | ||
| assert tokens == IMAGE_TOKEN_ESTIMATE * 3 | ||
|
|
||
|
|
||
| class TestTrustedUsage: | ||
| def test_trusted_overrides(self): | ||
| """如果 API 返回了 token 数,直接用它不做估算。""" | ||
| msg = _msg("user", [ | ||
| TextPart(text="hello"), | ||
| ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")), | ||
| ]) | ||
| tokens = counter.count_tokens([msg], trusted_token_usage=42) | ||
| assert tokens == 42 | ||
|
Comment on lines
+82
to
+90
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (testing): Consider adding a trusted_token_usage test that includes tool_calls to verify the override in a more complex message Since Suggested implementation: class TestTrustedUsage:
def test_trusted_overrides(self):
"""如果 API 返回了 token 数,直接用它不做估算。"""
msg = _msg("user", [
TextPart(text="hello"),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")),
])
tokens = counter.count_tokens([msg], trusted_token_usage=42)
assert tokens == 42
def test_trusted_overrides_with_tool_calls(self):
"""如果 API 返回了 token 数,即使包含 tool_calls 也应直接使用该值。"""
msg = _msg(
"assistant",
[
TextPart(text="调用工具中……"),
ImageURLPart(
image_url=ImageURLPart.ImageURL(
url="data:image/png;base64,y"
)
),
],
tool_calls=[
{
"id": "call_1",
"type": "function",
"function": {
"name": "dummy_tool",
"arguments": '{"foo": "bar"}',
},
},
],
)
tokens = counter.count_tokens([msg], trusted_token_usage=42)
assert tokens == 42
|
||
|
|
||
|
|
||
| class TestToolCalls: | ||
| def test_tool_calls_counted(self): | ||
| msg = Message( | ||
| role="assistant", | ||
| content="calling tool", | ||
| tool_calls=[{"type": "function", "id": "1", "function": {"name": "get_weather", "arguments": '{"city": "Beijing"}'}}], | ||
| ) | ||
| tokens = counter.count_tokens([msg]) | ||
| # 文本 + tool call JSON 都应被计算 | ||
| text_only = counter.count_tokens([_msg("assistant", "calling tool")]) | ||
| assert tokens > text_only | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分使用
if/elif链来处理不同类型的ContentPart是可行的,但当未来需要支持更多类型的内容时,需要不断修改这个函数。为了提高代码的可扩展性和遵循开闭原则,可以考虑使用更动态的分发机制,例如使用字典将内容类型映射到处理逻辑。这样可以使
count_tokens方法更稳定,添加新类型时只需更新映射字典,而无需修改核心循环逻辑。例如,可以在类级别定义一个映射:
由于这只是一个改进建议,当前实现也是完全可以接受的。