Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions astrbot/core/agent/context/token_counter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing import Protocol, runtime_checkable

from ..message import Message, TextPart
from ..message import AudioURLPart, ImageURLPart, Message, TextPart, ThinkPart


@runtime_checkable
Expand All @@ -28,9 +28,19 @@ def count_tokens(
...


# 图片/音频 token 开销估算值,参考 OpenAI vision pricing:
# low-res ~85 tokens, high-res ~170 per 512px tile, 通常几百到上千。
# 这里取一个保守中位数,宁可偏高触发压缩也不要偏低导致 API 报错。
IMAGE_TOKEN_ESTIMATE = 765
AUDIO_TOKEN_ESTIMATE = 500


class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.

Supports multimodal content: images, audio, and thinking parts
are all counted so that the context compressor can trigger in time.
"""

def count_tokens(
Expand All @@ -45,12 +55,16 @@ def count_tokens(
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
elif isinstance(part, ThinkPart):
total += self._estimate_tokens(part.think)
elif isinstance(part, ImageURLPart):
total += IMAGE_TOKEN_ESTIMATE
elif isinstance(part, AudioURLPart):
total += AUDIO_TOKEN_ESTIMATE
Comment on lines 58 to +66
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

这部分使用 if/elif 链来处理不同类型的 ContentPart 是可行的,但当未来需要支持更多类型的内容时,需要不断修改这个函数。为了提高代码的可扩展性和遵循开闭原则,可以考虑使用更动态的分发机制,例如使用字典将内容类型映射到处理逻辑。

这样可以使 count_tokens 方法更稳定,添加新类型时只需更新映射字典,而无需修改核心循环逻辑。

例如,可以在类级别定义一个映射:

class EstimateTokenCounter:
    _PART_HANDLERS = {
        TextPart: lambda self, p: self._estimate_tokens(p.text),
        ThinkPart: lambda self, p: self._estimate_tokens(p.think),
        ImageURLPart: lambda self, p: IMAGE_TOKEN_ESTIMATE,
        AudioURLPart: lambda self, p: AUDIO_TOKEN_ESTIMATE,
    }

    def count_tokens(self, messages: list[Message], trusted_token_usage: int = 0) -> int:
        # ...
        elif isinstance(content, list):
            for part in content:
                # Note: using type() might not work with inheritance, isinstance is safer.
                # This is a conceptual example.
                handler = self._PART_HANDLERS.get(type(part))
                if handler:
                    total += handler(self, part)
        # ...

由于这只是一个改进建议,当前实现也是完全可以接受的。


# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
Expand Down
103 changes: 103 additions & 0 deletions tests/agent/test_token_counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
"""Tests for EstimateTokenCounter multimodal support."""

from astrbot.core.agent.context.token_counter import (
AUDIO_TOKEN_ESTIMATE,
IMAGE_TOKEN_ESTIMATE,
EstimateTokenCounter,
)
from astrbot.core.agent.message import (
AudioURLPart,
ImageURLPart,
Message,
TextPart,
ThinkPart,
)


counter = EstimateTokenCounter()


def _msg(role: str, content) -> Message:
return Message(role=role, content=content)


class TestTextCounting:
def test_plain_string(self):
tokens = counter.count_tokens([_msg("user", "hello world")])
assert tokens > 0

def test_chinese(self):
# 中文字符权重更高
en = counter.count_tokens([_msg("user", "abc")])
zh = counter.count_tokens([_msg("user", "你好啊")])
assert zh > en

def test_text_part(self):
msg = _msg("user", [TextPart(text="hello")])
assert counter.count_tokens([msg]) > 0


class TestMultimodalCounting:
def test_image_counted(self):
msg = _msg("user", [
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,abc")),
])
tokens = counter.count_tokens([msg])
assert tokens == IMAGE_TOKEN_ESTIMATE

def test_audio_counted(self):
msg = _msg("user", [
AudioURLPart(audio_url=AudioURLPart.AudioURL(url="https://x.com/a.mp3")),
])
tokens = counter.count_tokens([msg])
assert tokens == AUDIO_TOKEN_ESTIMATE

def test_think_counted(self):
msg = _msg("assistant", [ThinkPart(think="let me think about this")])
tokens = counter.count_tokens([msg])
assert tokens > 0

def test_mixed_content(self):
"""文本 + 图片的多模态消息,token 数 = 文本 token + 图片估算。"""
text_only = _msg("user", [TextPart(text="describe this image")])
mixed = _msg("user", [
TextPart(text="describe this image"),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")),
])
text_tokens = counter.count_tokens([text_only])
mixed_tokens = counter.count_tokens([mixed])
assert mixed_tokens == text_tokens + IMAGE_TOKEN_ESTIMATE

def test_multiple_images(self):
"""多张图片应该各自计算。"""
msg = _msg("user", [
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,a")),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,b")),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,c")),
])
tokens = counter.count_tokens([msg])
assert tokens == IMAGE_TOKEN_ESTIMATE * 3


class TestTrustedUsage:
def test_trusted_overrides(self):
"""如果 API 返回了 token 数,直接用它不做估算。"""
msg = _msg("user", [
TextPart(text="hello"),
ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")),
])
tokens = counter.count_tokens([msg], trusted_token_usage=42)
assert tokens == 42
Comment on lines +82 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a trusted_token_usage test that includes tool_calls to verify the override in a more complex message

Since trusted_token_usage also affects tool_calls, please add a companion test where the message includes both multimodal parts and non-empty tool_calls, and assert count_tokens(..., trusted_token_usage=42) == 42. This will verify the override still applies when tool calls are present and guard against regressions that reintroduce tool_call token summation on top of the trusted value.

Suggested implementation:

class TestTrustedUsage:
    def test_trusted_overrides(self):
        """如果 API 返回了 token 数,直接用它不做估算。"""
        msg = _msg("user", [
            TextPart(text="hello"),
            ImageURLPart(image_url=ImageURLPart.ImageURL(url="data:image/png;base64,x")),
        ])
        tokens = counter.count_tokens([msg], trusted_token_usage=42)
        assert tokens == 42

    def test_trusted_overrides_with_tool_calls(self):
        """如果 API 返回了 token 数,即使包含 tool_calls 也应直接使用该值。"""
        msg = _msg(
            "assistant",
            [
                TextPart(text="调用工具中……"),
                ImageURLPart(
                    image_url=ImageURLPart.ImageURL(
                        url="data:image/png;base64,y"
                    )
                ),
            ],
            tool_calls=[
                {
                    "id": "call_1",
                    "type": "function",
                    "function": {
                        "name": "dummy_tool",
                        "arguments": '{"foo": "bar"}',
                    },
                },
            ],
        )
        tokens = counter.count_tokens([msg], trusted_token_usage=42)
        assert tokens == 42
  1. This change assumes the _msg helper accepts extra keyword arguments (like tool_calls=...) and forwards them into the Message constructor. If that is not the case, update _msg accordingly, or construct Message directly in this test with the correct tool_calls field.
  2. The exact structure of tool_calls (dict shape or specific classes) should match what astrbot.core.agent.message.Message expects elsewhere in your codebase. If you already have a canonical example of tool_calls in other tests, mirror that structure here for consistency.



class TestToolCalls:
def test_tool_calls_counted(self):
msg = Message(
role="assistant",
content="calling tool",
tool_calls=[{"type": "function", "id": "1", "function": {"name": "get_weather", "arguments": '{"city": "Beijing"}'}}],
)
tokens = counter.count_tokens([msg])
# 文本 + tool call JSON 都应被计算
text_only = counter.count_tokens([_msg("assistant", "calling tool")])
assert tokens > text_only
Loading