Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def lang_to_cmd(lang: str) -> str:
elif shutil.which("powershell") is not None:
return "powershell"
else:
raise ValueError(f"Powershell or pwsh is not installed. Please install one of them.")
raise ValueError("Powershell or pwsh is not installed. Please install one of them.")
else:
raise ValueError(f"Unsupported language: {lang}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,86 @@ def create_vector_search(
finally:
_allow_private_constructor.reset(token)

async def _get_embedding(self, query: str) -> List[float]:
"""Generate embedding vector for the query text.

This method handles generating embeddings for vector search functionality.
The embedding provider and model should be specified in the tool configuration.

Args:
query (str): The text to generate embeddings for.

Returns:
List[float]: The embedding vector as a list of floats.

Raises:
ValueError: If the embedding configuration is missing or invalid.
"""
embedding_provider = getattr(self.search_config, "embedding_provider", None)
embedding_model = getattr(self.search_config, "embedding_model", None)

if not embedding_provider or not embedding_model:
raise ValueError(
"To use vector search, you must provide embedding_provider and embedding_model in the configuration."
) from None

if embedding_provider.lower() == "azure_openai":
try:
from azure.identity import DefaultAzureCredential
from openai import AsyncAzureOpenAI
except ImportError:
raise ImportError(
"Azure OpenAI SDK is required for embedding generation. "
"Please install it with: uv add openai azure-identity"
) from None

api_key = None
if hasattr(self.search_config, "openai_api_key"):
api_key = self.search_config.openai_api_key

api_version = getattr(self.search_config, "openai_api_version", "2023-05-15")
endpoint = getattr(self.search_config, "openai_endpoint", None)

if not endpoint:
raise ValueError("OpenAI endpoint must be provided for Azure OpenAI embeddings") from None

if api_key:
azure_client = AsyncAzureOpenAI(api_key=api_key, api_version=api_version, azure_endpoint=endpoint)
else:

def get_token() -> str:
credential = DefaultAzureCredential()
return credential.get_token("https://cognitiveservices.azure.com/.default").token

azure_client = AsyncAzureOpenAI(
azure_ad_token_provider=get_token, api_version=api_version, azure_endpoint=endpoint
)

response = await azure_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding

elif embedding_provider.lower() == "openai":
try:
from openai import AsyncOpenAI
except ImportError:
raise ImportError(
"OpenAI SDK is required for embedding generation. " "Please install it with: uv add openai"
) from None

api_key = None
if hasattr(self.search_config, "openai_api_key"):
api_key = self.search_config.openai_api_key

openai_client = AsyncOpenAI(api_key=api_key)

response = await openai_client.embeddings.create(model=embedding_model, input=query)
return response.data[0].embedding
else:
raise ValueError(
f"Unsupported embedding provider: {embedding_provider}. "
"Currently supported providers are 'azure_openai' and 'openai'."
) from None

@classmethod
def create_hybrid_search(
cls,
Expand Down
Loading
Loading