diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 8bf92ef4d5..04397b1822 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -40,25 +40,46 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: async def get_embedding(self, text: str) -> list[float]: """获取文本的嵌入""" + kwargs = self._embedding_kwargs() embedding = await self.client.embeddings.create( input=text, model=self.model, - dimensions=self.get_dim(), + **kwargs, ) return embedding.data[0].embedding async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" + kwargs = self._embedding_kwargs() embeddings = await self.client.embeddings.create( input=text, model=self.model, - dimensions=self.get_dim(), + **kwargs, ) return [item.embedding for item in embeddings.data] + def _embedding_kwargs(self) -> dict: + """构建嵌入请求的可选参数""" + kwargs = {} + if "embedding_dimensions" in self.provider_config: + try: + kwargs["dimensions"] = int(self.provider_config["embedding_dimensions"]) + except (ValueError, TypeError): + logger.warning( + f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + ) + return kwargs + def get_dim(self) -> int: """获取向量的维度""" - return int(self.provider_config.get("embedding_dimensions", 1024)) + if "embedding_dimensions" in self.provider_config: + try: + return int(self.provider_config["embedding_dimensions"]) + except (ValueError, TypeError): + logger.warning( + f"embedding_dimensions in embedding configs is not a valid integer: '{self.provider_config['embedding_dimensions']}', ignored." + ) + return 0 async def terminate(self): if self.client: