diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5b4ea7686a..ee9447c47d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1630,6 +1630,7 @@ class ChatProviderTemplate(TypedDict): "enable": True, "embedding_api_key": "", "embedding_api_base": "", + "embedding_api_base_mode": "auto", "embedding_model": "", "embedding_dimensions": 1024, "timeout": 20, @@ -1937,6 +1938,12 @@ class ChatProviderTemplate(TypedDict): "description": "API Base URL", "type": "string", }, + "embedding_api_base_mode": { + "description": "API Base URL Mode", + "type": "string", + "enum": ["auto", "force_v1", "as_is"], + "hint": "控制 API Base URL 的 /v1 后缀处理方式。auto: 自动判断(默认),如果 URL 不以 /v1 结尾则添加。force_v1: 强制添加 /v1 后缀。as_is: 保持原样(注意:只提供 base URL,不要包含 /embeddings 路径,SDK 会自动添加)。", + }, "volcengine_cluster": { "type": "string", "description": "火山引擎集群", diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 04397b1822..007332f034 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -6,6 +6,7 @@ from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter +from ..utils import resolve_openai_compatible_base_url @register_provider_adapter( @@ -23,13 +24,11 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if proxy: logger.info(f"[OpenAI Embedding] 使用代理: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) - api_base = provider_config.get("embedding_api_base", "").strip() - if not api_base: - api_base = "https://api.openai.com/v1" - else: - api_base = api_base.removesuffix("/") - if not api_base.endswith("/v1"): - api_base = f"{api_base}/v1" + api_base_mode = provider_config.get("embedding_api_base_mode", "auto") + api_base = resolve_openai_compatible_base_url( + provider_config.get("embedding_api_base", ""), + mode=api_base_mode, + ) self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=api_base, diff --git a/astrbot/core/provider/utils.py b/astrbot/core/provider/utils.py new file mode 100644 index 0000000000..86a8d0a1d2 --- /dev/null +++ b/astrbot/core/provider/utils.py @@ -0,0 +1,38 @@ +from typing import Literal + + +def resolve_openai_compatible_base_url( + api_base: str, + mode: Literal["auto", "force_v1", "as_is"] = "auto", + default_base: str = "https://api.openai.com/v1", +) -> str: + """Resolve OpenAI-compatible API base URL with configurable /v1 suffix handling. + + Args: + api_base: The user-provided API base URL. + mode: How to handle the /v1 suffix: + - "auto": Add /v1 if not present (default). + - "force_v1": Always add /v1 suffix. + - "as_is": Keep the URL unchanged (including trailing slashes). + Note: With as_is mode, provide only the base URL without /embeddings path, + as the OpenAI SDK will append /embeddings automatically. + default_base: Default base URL to use if api_base is empty. + + Returns: + The resolved API base URL. + """ + api_base = api_base.strip() + if not api_base: + return default_base + + if mode == "as_is": + # Return URL unchanged to preserve exact configuration + # Note: In this mode, users should provide base URL without /embeddings path + # as the OpenAI SDK will append /embeddings automatically + return api_base + + # Both "auto" and "force_v1" modes ensure URL ends with /v1 + api_base = api_base.removesuffix("/") + if not api_base.endswith("/v1"): + api_base = f"{api_base}/v1" + return api_base diff --git a/tests/test_openai_embedding_source.py b/tests/test_openai_embedding_source.py new file mode 100644 index 0000000000..2b71ad82cd --- /dev/null +++ b/tests/test_openai_embedding_source.py @@ -0,0 +1,169 @@ +import pytest + +from astrbot.core.provider.sources.openai_embedding_source import ( + OpenAIEmbeddingProvider, +) +from astrbot.core.provider.utils import resolve_openai_compatible_base_url + + +def _make_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider: + provider_config = { + "id": "test-openai-embedding", + "type": "openai_embedding", + "embedding_api_key": "test-key", + "embedding_model": "text-embedding-3-small", + } + if overrides: + provider_config.update(overrides) + return OpenAIEmbeddingProvider( + provider_config=provider_config, + provider_settings={}, + ) + + +class TestResolveOpenAICompatibleBaseUrl: + """Test the resolve_openai_compatible_base_url helper function.""" + + def test_auto_mode_adds_v1_when_missing(self): + """Test that auto mode adds /v1 suffix when not present.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com", + mode="auto", + ) + assert result == "https://api.example.com/v1" + + def test_auto_mode_keeps_v1_when_present(self): + """Test that auto mode keeps /v1 suffix when already present.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/v1", + mode="auto", + ) + assert result == "https://api.example.com/v1" + + def test_force_v1_mode_always_adds_v1(self): + """Test that force_v1 mode always adds /v1 suffix.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/v1", + mode="force_v1", + ) + assert result == "https://api.example.com/v1" + + def test_force_v1_mode_adds_v1_when_missing(self): + """Test that force_v1 mode adds /v1 suffix when missing.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com", + mode="force_v1", + ) + assert result == "https://api.example.com/v1" + + def test_as_is_mode_keeps_url_unchanged(self): + """Test that as_is mode keeps URL unchanged.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/custom/path", + mode="as_is", + ) + assert result == "https://api.example.com/custom/path" + + def test_as_is_mode_removes_trailing_slash(self): + """Test that as_is mode removes trailing slash.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/", + mode="as_is", + ) + assert result == "https://api.example.com" + + def test_empty_url_returns_default(self): + """Test that empty URL returns the default base URL.""" + result = resolve_openai_compatible_base_url( + "", + mode="auto", + ) + assert result == "https://api.openai.com/v1" + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_auto_mode(): + """Test OpenAI Embedding provider with auto mode (default).""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com", "embedding_api_base_mode": "auto"} + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_force_v1_mode(): + """Test OpenAI Embedding provider with force_v1 mode.""" + provider = _make_provider( + { + "embedding_api_base": "https://api.example.com", + "embedding_api_base_mode": "force_v1", + } + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_as_is_mode(): + """Test OpenAI Embedding provider with as_is mode.""" + provider = _make_provider( + { + "embedding_api_base": "https://api.example.com/v2/embeddings", + "embedding_api_base_mode": "as_is", + } + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v2/embeddings/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_default_base_when_empty(): + """Test OpenAI Embedding provider with empty base URL uses default.""" + provider = _make_provider({"embedding_api_base": ""}) + try: + assert str(provider.client.base_url) == "https://api.openai.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_with_v1_already_present(): + """Test OpenAI Embedding provider when URL already has /v1.""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com/v1", "embedding_api_base_mode": "auto"} + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_with_trailing_slash(): + """Test OpenAI Embedding provider keeps URL unchanged in as_is mode.""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com/", "embedding_api_base_mode": "as_is"} + ) + try: + # The provider returns URL unchanged, OpenAI client adds trailing slash + assert str(provider.client.base_url) == "https://api.example.com/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_auto_mode_default(): + """Test OpenAI Embedding provider auto mode is default when not specified.""" + provider = _make_provider({"embedding_api_base": "https://api.example.com"}) + try: + # Should default to auto mode + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate()