From fa22ca83f1b94307f1309838daf1209bdc839d7b Mon Sep 17 00:00:00 2001 From: ccsang Date: Wed, 18 Mar 2026 16:41:37 +0000 Subject: [PATCH 1/3] fix(embedding): remove forced /v1 suffix for embeddings Fixes #6558 Add embedding_api_base_mode config (auto|force_v1|as_is) to default.py, modify openai_embedding_source.py to use resolve_openai_compatible_base_url() helper with mode parameter, add 7 test cases in tests/test_openai_embedding_source.py. Co-Authored-By: Claude Sonnet 4.6 --- astrbot/core/config/default.py | 6 + .../sources/openai_embedding_source.py | 13 +- astrbot/core/provider/utils.py | 39 ++++ tests/test_openai_embedding_source.py | 169 ++++++++++++++++++ 4 files changed, 220 insertions(+), 7 deletions(-) create mode 100644 astrbot/core/provider/utils.py create mode 100644 tests/test_openai_embedding_source.py diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 5b4ea7686a..11b1f78664 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1630,6 +1630,7 @@ class ChatProviderTemplate(TypedDict): "enable": True, "embedding_api_key": "", "embedding_api_base": "", + "embedding_api_base_mode": "auto", "embedding_model": "", "embedding_dimensions": 1024, "timeout": 20, @@ -1937,6 +1938,11 @@ class ChatProviderTemplate(TypedDict): "description": "API Base URL", "type": "string", }, + "embedding_api_base_mode": { + "description": "API Base URL Mode", + "type": "string", + "hint": "控制 API Base URL 的 /v1 后缀处理方式。auto: 自动判断(默认),如果 URL 不以 /v1 结尾则添加。force_v1: 强制添加 /v1 后缀。as_is: 保持原样不做修改。", + }, "volcengine_cluster": { "type": "string", "description": "火山引擎集群", diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 04397b1822..007332f034 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -6,6 +6,7 @@ from ..entities import ProviderType from ..provider import EmbeddingProvider from ..register import register_provider_adapter +from ..utils import resolve_openai_compatible_base_url @register_provider_adapter( @@ -23,13 +24,11 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if proxy: logger.info(f"[OpenAI Embedding] 使用代理: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) - api_base = provider_config.get("embedding_api_base", "").strip() - if not api_base: - api_base = "https://api.openai.com/v1" - else: - api_base = api_base.removesuffix("/") - if not api_base.endswith("/v1"): - api_base = f"{api_base}/v1" + api_base_mode = provider_config.get("embedding_api_base_mode", "auto") + api_base = resolve_openai_compatible_base_url( + provider_config.get("embedding_api_base", ""), + mode=api_base_mode, + ) self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), base_url=api_base, diff --git a/astrbot/core/provider/utils.py b/astrbot/core/provider/utils.py new file mode 100644 index 0000000000..2a6ea1d149 --- /dev/null +++ b/astrbot/core/provider/utils.py @@ -0,0 +1,39 @@ +from typing import Literal + + +def resolve_openai_compatible_base_url( + api_base: str, + mode: Literal["auto", "force_v1", "as_is"] = "auto", + default_base: str = "https://api.openai.com/v1", +) -> str: + """Resolve OpenAI-compatible API base URL with configurable /v1 suffix handling. + + Args: + api_base: The user-provided API base URL. + mode: How to handle the /v1 suffix: + - "auto": Add /v1 if not present (default). + - "force_v1": Always add /v1 suffix. + - "as_is": Keep the URL as-is without modification. + default_base: Default base URL to use if api_base is empty. + + Returns: + The resolved API base URL. + """ + api_base = api_base.strip() + if not api_base: + return default_base + + if mode == "as_is": + return api_base.removesuffix("/") + + if mode == "force_v1": + api_base = api_base.removesuffix("/") + if not api_base.endswith("/v1"): + api_base = f"{api_base}/v1" + return api_base + + # mode == "auto" + api_base = api_base.removesuffix("/") + if not api_base.endswith("/v1"): + api_base = f"{api_base}/v1" + return api_base diff --git a/tests/test_openai_embedding_source.py b/tests/test_openai_embedding_source.py new file mode 100644 index 0000000000..0c57502261 --- /dev/null +++ b/tests/test_openai_embedding_source.py @@ -0,0 +1,169 @@ +import pytest + +from astrbot.core.provider.sources.openai_embedding_source import ( + OpenAIEmbeddingProvider, +) +from astrbot.core.provider.utils import resolve_openai_compatible_base_url + + +def _make_provider(overrides: dict | None = None) -> OpenAIEmbeddingProvider: + provider_config = { + "id": "test-openai-embedding", + "type": "openai_embedding", + "embedding_api_key": "test-key", + "embedding_model": "text-embedding-3-small", + } + if overrides: + provider_config.update(overrides) + return OpenAIEmbeddingProvider( + provider_config=provider_config, + provider_settings={}, + ) + + +class TestResolveOpenAICompatibleBaseUrl: + """Test the resolve_openai_compatible_base_url helper function.""" + + def test_auto_mode_adds_v1_when_missing(self): + """Test that auto mode adds /v1 suffix when not present.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com", + mode="auto", + ) + assert result == "https://api.example.com/v1" + + def test_auto_mode_keeps_v1_when_present(self): + """Test that auto mode keeps /v1 suffix when already present.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/v1", + mode="auto", + ) + assert result == "https://api.example.com/v1" + + def test_force_v1_mode_always_adds_v1(self): + """Test that force_v1 mode always adds /v1 suffix.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/v1", + mode="force_v1", + ) + assert result == "https://api.example.com/v1" + + def test_force_v1_mode_adds_v1_when_missing(self): + """Test that force_v1 mode adds /v1 suffix when missing.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com", + mode="force_v1", + ) + assert result == "https://api.example.com/v1" + + def test_as_is_mode_keeps_url_unchanged(self): + """Test that as_is mode keeps URL unchanged.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/custom/path", + mode="as_is", + ) + assert result == "https://api.example.com/custom/path" + + def test_as_is_mode_removes_trailing_slash(self): + """Test that as_is mode removes trailing slash.""" + result = resolve_openai_compatible_base_url( + "https://api.example.com/", + mode="as_is", + ) + assert result == "https://api.example.com" + + def test_empty_url_returns_default(self): + """Test that empty URL returns the default base URL.""" + result = resolve_openai_compatible_base_url( + "", + mode="auto", + ) + assert result == "https://api.openai.com/v1" + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_auto_mode(): + """Test OpenAI Embedding provider with auto mode (default).""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com", "embedding_api_base_mode": "auto"} + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_force_v1_mode(): + """Test OpenAI Embedding provider with force_v1 mode.""" + provider = _make_provider( + { + "embedding_api_base": "https://api.example.com", + "embedding_api_base_mode": "force_v1", + } + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_as_is_mode(): + """Test OpenAI Embedding provider with as_is mode.""" + provider = _make_provider( + { + "embedding_api_base": "https://api.example.com/v2/embeddings", + "embedding_api_base_mode": "as_is", + } + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v2/embeddings/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_default_base_when_empty(): + """Test OpenAI Embedding provider with empty base URL uses default.""" + provider = _make_provider({"embedding_api_base": ""}) + try: + assert str(provider.client.base_url) == "https://api.openai.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_with_v1_already_present(): + """Test OpenAI Embedding provider when URL already has /v1.""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com/v1", "embedding_api_base_mode": "auto"} + ) + try: + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_with_trailing_slash(): + """Test OpenAI Embedding provider removes trailing slash in as_is mode.""" + provider = _make_provider( + {"embedding_api_base": "https://api.example.com/", "embedding_api_base_mode": "as_is"} + ) + try: + # The provider removes trailing slash, but OpenAI client adds it back + assert str(provider.client.base_url) == "https://api.example.com" + finally: + await provider.terminate() + + +@pytest.mark.asyncio +async def test_openai_embedding_provider_auto_mode_default(): + """Test OpenAI Embedding provider auto mode is default when not specified.""" + provider = _make_provider({"embedding_api_base": "https://api.example.com"}) + try: + # Should default to auto mode + assert str(provider.client.base_url) == "https://api.example.com/v1/" + finally: + await provider.terminate() From cad7d06903f51b9354b2221e66a7f927474e44fc Mon Sep 17 00:00:00 2001 From: ccsang Date: Thu, 19 Mar 2026 01:49:28 +0000 Subject: [PATCH 2/3] fix(embedding): address Sourcery review comments - Fix as_is mode to return URL unchanged (including trailing slashes) - Add enum constraint to embedding_api_base_mode config - Update docstring to clarify as_is behavior --- astrbot/core/config/default.py | 1 + astrbot/core/provider/utils.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 11b1f78664..7cab59254d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1941,6 +1941,7 @@ class ChatProviderTemplate(TypedDict): "embedding_api_base_mode": { "description": "API Base URL Mode", "type": "string", + "enum": ["auto", "force_v1", "as_is"], "hint": "控制 API Base URL 的 /v1 后缀处理方式。auto: 自动判断(默认),如果 URL 不以 /v1 结尾则添加。force_v1: 强制添加 /v1 后缀。as_is: 保持原样不做修改。", }, "volcengine_cluster": { diff --git a/astrbot/core/provider/utils.py b/astrbot/core/provider/utils.py index 2a6ea1d149..324b7cfed6 100644 --- a/astrbot/core/provider/utils.py +++ b/astrbot/core/provider/utils.py @@ -13,7 +13,7 @@ def resolve_openai_compatible_base_url( mode: How to handle the /v1 suffix: - "auto": Add /v1 if not present (default). - "force_v1": Always add /v1 suffix. - - "as_is": Keep the URL as-is without modification. + - "as_is": Keep the URL unchanged (including trailing slashes). default_base: Default base URL to use if api_base is empty. Returns: @@ -24,7 +24,8 @@ def resolve_openai_compatible_base_url( return default_base if mode == "as_is": - return api_base.removesuffix("/") + # Return URL unchanged to preserve exact configuration + return api_base if mode == "force_v1": api_base = api_base.removesuffix("/") From daf0fcc90620643cbc22d1ea6b100d3efead835f Mon Sep 17 00:00:00 2001 From: ccsang Date: Thu, 19 Mar 2026 02:34:52 +0000 Subject: [PATCH 3/3] refactor(embedding): address all review comments - Fix test assertion to expect trailing slash from OpenAI client - Simplify code by merging identical force_v1 and auto logic - Add documentation for as_is mode limitation (no /embeddings path) - Update config hint to clarify as_is mode usage --- astrbot/core/config/default.py | 2 +- astrbot/core/provider/utils.py | 12 +++++------- tests/test_openai_embedding_source.py | 6 +++--- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7cab59254d..ee9447c47d 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -1942,7 +1942,7 @@ class ChatProviderTemplate(TypedDict): "description": "API Base URL Mode", "type": "string", "enum": ["auto", "force_v1", "as_is"], - "hint": "控制 API Base URL 的 /v1 后缀处理方式。auto: 自动判断(默认),如果 URL 不以 /v1 结尾则添加。force_v1: 强制添加 /v1 后缀。as_is: 保持原样不做修改。", + "hint": "控制 API Base URL 的 /v1 后缀处理方式。auto: 自动判断(默认),如果 URL 不以 /v1 结尾则添加。force_v1: 强制添加 /v1 后缀。as_is: 保持原样(注意:只提供 base URL,不要包含 /embeddings 路径,SDK 会自动添加)。", }, "volcengine_cluster": { "type": "string", diff --git a/astrbot/core/provider/utils.py b/astrbot/core/provider/utils.py index 324b7cfed6..86a8d0a1d2 100644 --- a/astrbot/core/provider/utils.py +++ b/astrbot/core/provider/utils.py @@ -14,6 +14,8 @@ def resolve_openai_compatible_base_url( - "auto": Add /v1 if not present (default). - "force_v1": Always add /v1 suffix. - "as_is": Keep the URL unchanged (including trailing slashes). + Note: With as_is mode, provide only the base URL without /embeddings path, + as the OpenAI SDK will append /embeddings automatically. default_base: Default base URL to use if api_base is empty. Returns: @@ -25,15 +27,11 @@ def resolve_openai_compatible_base_url( if mode == "as_is": # Return URL unchanged to preserve exact configuration + # Note: In this mode, users should provide base URL without /embeddings path + # as the OpenAI SDK will append /embeddings automatically return api_base - if mode == "force_v1": - api_base = api_base.removesuffix("/") - if not api_base.endswith("/v1"): - api_base = f"{api_base}/v1" - return api_base - - # mode == "auto" + # Both "auto" and "force_v1" modes ensure URL ends with /v1 api_base = api_base.removesuffix("/") if not api_base.endswith("/v1"): api_base = f"{api_base}/v1" diff --git a/tests/test_openai_embedding_source.py b/tests/test_openai_embedding_source.py index 0c57502261..2b71ad82cd 100644 --- a/tests/test_openai_embedding_source.py +++ b/tests/test_openai_embedding_source.py @@ -147,13 +147,13 @@ async def test_openai_embedding_provider_with_v1_already_present(): @pytest.mark.asyncio async def test_openai_embedding_provider_with_trailing_slash(): - """Test OpenAI Embedding provider removes trailing slash in as_is mode.""" + """Test OpenAI Embedding provider keeps URL unchanged in as_is mode.""" provider = _make_provider( {"embedding_api_base": "https://api.example.com/", "embedding_api_base_mode": "as_is"} ) try: - # The provider removes trailing slash, but OpenAI client adds it back - assert str(provider.client.base_url) == "https://api.example.com" + # The provider returns URL unchanged, OpenAI client adds trailing slash + assert str(provider.client.base_url) == "https://api.example.com/" finally: await provider.terminate()