From 7cfb5f7c6040dd971206aa2d81e165c3c5ac190e Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 14 Mar 2026 23:03:06 -0400 Subject: [PATCH 1/9] Add OCI LLM and embedding support to crewai --- docs/en/concepts/llms.mdx | 221 +++ lib/crewai/pyproject.toml | 3 + lib/crewai/src/crewai/llm.py | 13 + lib/crewai/src/crewai/llms/base_llm.py | 43 +- .../src/crewai/llms/providers/oci/__init__.py | 20 + .../crewai/llms/providers/oci/completion.py | 1456 +++++++++++++++++ .../src/crewai/llms/providers/oci/vision.py | 57 + .../src/crewai/rag/embeddings/factory.py | 9 + .../rag/embeddings/providers/oci/__init__.py | 18 + .../providers/oci/embedding_callable.py | 255 +++ .../embeddings/providers/oci/oci_provider.py | 88 + .../rag/embeddings/providers/oci/types.py | 30 + lib/crewai/src/crewai/rag/embeddings/types.py | 3 + lib/crewai/tests/llms/oci/conftest.py | 447 +++++ .../tests/llms/oci/profile_oci_agent.py | 160 ++ lib/crewai/tests/llms/oci/test_oci.py | 601 +++++++ .../oci/test_oci_integration_async_batch.py | 55 + .../llms/oci/test_oci_integration_basic.py | 41 + .../oci/test_oci_integration_multimodal.py | 165 ++ .../oci/test_oci_integration_streaming.py | 21 + .../oci/test_oci_integration_structured.py | 31 + .../llms/oci/test_oci_integration_tools.py | 108 ++ .../tests/llms/oci/test_oci_sdk_surface.py | 155 ++ lib/crewai/tests/rag/embeddings/conftest.py | 73 + .../tests/rag/embeddings/test_factory_oci.py | 211 +++ .../test_oci_embedding_integration.py | 23 + .../test_oci_image_embedding_integration.py | 45 + 27 files changed, 4351 insertions(+), 1 deletion(-) create mode 100644 lib/crewai/src/crewai/llms/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/llms/providers/oci/completion.py create mode 100644 lib/crewai/src/crewai/llms/providers/oci/vision.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py create mode 100644 lib/crewai/tests/llms/oci/conftest.py create mode 100644 lib/crewai/tests/llms/oci/profile_oci_agent.py create mode 100644 lib/crewai/tests/llms/oci/test_oci.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_async_batch.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_basic.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_streaming.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_structured.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_tools.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_sdk_surface.py create mode 100644 lib/crewai/tests/rag/embeddings/conftest.py create mode 100644 lib/crewai/tests/rag/embeddings/test_factory_oci.py create mode 100644 lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py create mode 100644 lib/crewai/tests/rag/embeddings/test_oci_image_embedding_integration.py diff --git a/docs/en/concepts/llms.mdx b/docs/en/concepts/llms.mdx index 98bfbeb2349..40efc86605a 100644 --- a/docs/en/concepts/llms.mdx +++ b/docs/en/concepts/llms.mdx @@ -772,6 +772,227 @@ In this section, you'll find detailed examples that help you select, configure, ``` + + CrewAI provides native integration with OCI Generative AI for generic chat models, OpenAI-hosted OCI models, Google-hosted OCI models, Meta-hosted OCI models, and dedicated inference endpoints. + + **Recommended Cohere Models (verified in OCI on March 10, 2026):** + - `cohere.command-a-reasoning` for reasoning-heavy text workflows + - `cohere.command-a-03-2025` for general text generation and streaming + - `cohere.command-a-vision` for the top Cohere multimodal tier once OCI Cohere vision formatting is enabled in CrewAI + + **Recommended Regions for Cohere in OCI:** + - `eu-frankfurt-1` + - `us-ashburn-1` + - `eu-paris-1` + - `uk-london-1` + - `ap-mumbai-1` + + ```toml Code + # Required + OCI_COMPARTMENT_ID=ocid1.compartment.oc1..exampleuniqueID + + # Optional when not passing service_endpoint directly + OCI_REGION=eu-frankfurt-1 + + # Authentication options + OCI_AUTH_TYPE=API_KEY + OCI_AUTH_PROFILE=DEFAULT + OCI_AUTH_FILE_LOCATION=~/.oci/config + + # Optional explicit endpoint override + OCI_SERVICE_ENDPOINT=https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com + ``` + + **Basic Usage:** + ```python Code + from crewai import LLM + + llm = LLM( + model="oci/cohere.command-a-reasoning", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + auth_type="API_KEY", + auth_profile="DEFAULT", + temperature=0, + max_tokens=512, + ) + ``` + + **Provider Routing Examples:** + ```python Code + from crewai import LLM + + meta_llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + + gemini_llm = LLM( + model="oci/google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + + openai_llm = LLM( + model="oci/openai.gpt-4o-mini", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + ``` + + **Async Usage:** + ```python Code + import asyncio + from crewai import LLM + + llm = LLM( + model="oci/cohere.command-a-03-2025", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + + async def main(): + result = await llm.acall("Summarize Oracle Cloud in one sentence.") + print(result) + + asyncio.run(main()) + ``` + + **Streaming Usage:** + ```python Code + from crewai import LLM + + llm = LLM( + model="oci/cohere.command-a-03-2025", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + stream=True, + temperature=0, + ) + + result = llm.call("Reply with exactly three words about Oracle Cloud.") + print(result) + ``` + + **Multimodal Usage:** + ```python Code + from crewai import LLM + from crewai_files import ImageFile + + llm = LLM( + model="oci/google.gemini-2.5-flash", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + + response = llm.call( + [ + { + "role": "user", + "content": "Describe this architecture diagram.", + "files": { + "diagram": ImageFile(source="./architecture.png"), + }, + } + ] + ) + + print(response) + ``` + + **Structured Outputs:** + ```python Code + from pydantic import BaseModel + from crewai import LLM + + class OCIAnswer(BaseModel): + topic: str + summary: str + + llm = LLM( + model="oci/cohere.command-a-reasoning", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + temperature=0, + ) + + response = llm.call( + "Return a short JSON summary about Oracle Cloud.", + response_model=OCIAnswer, + ) + print(response.summary) + ``` + + **Agent + Tool Usage:** + ```python Code + from crewai import Agent, LLM + from crewai.tools import tool + + @tool + def add_numbers(a: int, b: int) -> int: + """Add two numbers and return the sum.""" + return a + b + + agent = Agent( + role="OCI Calculator", + goal="Use tools to solve arithmetic problems", + backstory="You are a precise calculator assistant.", + llm=LLM( + model="oci/cohere.command-a-03-2025", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + temperature=0, + ), + tools=[add_numbers], + verbose=True, + ) + + result = agent.kickoff( + "Use the add_numbers tool to calculate 15 + 27. Return only the final result." + ) + print(result.raw) + ``` + + **Dedicated Endpoint Usage:** + ```python Code + from crewai import LLM + + llm = LLM( + model="oci/ocid1.generativeaiendpoint.oc1.eu-frankfurt-1.exampleuniqueID", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + ) + ``` + + **Supported Authentication Types:** + - `API_KEY` + - `SECURITY_TOKEN` + - `INSTANCE_PRINCIPAL` + - `RESOURCE_PRINCIPAL` + + **Features:** + - Native function calling support + - Structured outputs with JSON schema response formats + - Async `acall()` support + - Streaming responses for OCI chat models + - Multimodal generic chat inputs for text, images, documents, video, and audio + - Stop sequences support + - Dedicated endpoint routing + - Generic chat support for OCI-hosted OpenAI, Google, and Meta models + - Cohere-specific OCI chat formatting for text models + - OCI embeddings in CrewAI RAG via the `oci` embedding provider + - CrewAI-managed OCI retrieval via `/en/tools/cloud-storage/ociknowledgebasetool` + + **Current Limitations:** + - Cohere multimodal and vision-specific formatting is not implemented yet + + **`langchain-oci` Sample Coverage in CrewAI:** + - `01-getting-started`: native OCI auth, provider routing, and streaming examples above + - `02-vision-and-multimodal`: supported through CrewAI file attachments on OCI multimodal models + - `03-building-ai-agents`: maps directly to CrewAI `Agent` usage with an OCI-backed `LLM` + - `04-tool-calling-mastery`: supported through CrewAI tools and OCI native function calling + - `05-structured-output`: supported through `response_model` + - `07-async-for-production`: supported through `acall()` + - `09-provider-deep-dive`: covered by provider-prefixed model routing and dedicated endpoint support + - `10-embeddings`: covered by the OCI embedding provider for text and image embeddings, plus [`OCIKnowledgeBaseTool`](/en/tools/cloud-storage/ociknowledgebasetool) for CrewAI-managed retrieval + + **Install:** + ```bash + uv add "crewai[oci]" + ``` + + ```toml Code AWS_ACCESS_KEY_ID= diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index b0d70f3880c..8c1704ab99f 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -87,6 +87,9 @@ litellm = [ bedrock = [ "boto3~=1.40.45", ] +oci = [ + "oci>=2.161.0", +] google-genai = [ "google-genai~=1.65.0", ] diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 8a4ac2edde5..563230d1034 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -325,6 +325,7 @@ def writable(self) -> bool: "gemini", "bedrock", "aws", + "oci", ] @@ -384,6 +385,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "gemini": "gemini", "bedrock": "bedrock", "aws": "bedrock", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -483,6 +485,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"] ) + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return False @classmethod @@ -514,6 +519,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: if provider == "bedrock" and model in BEDROCK_MODELS: return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + if provider == "azure": # azure does not provide a list of available models, determine a better way to handle this return True @@ -582,6 +590,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return BedrockCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index 1ab7107066d..c0499476320 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -7,7 +7,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from collections.abc import Generator +import asyncio +from collections.abc import AsyncGenerator, Generator from contextlib import contextmanager import contextvars from datetime import datetime @@ -151,6 +152,7 @@ def __init__( "successful_requests": 0, "cached_prompt_tokens": 0, } + self.last_response_metadata: dict[str, Any] | None = None @property def provider(self) -> str: @@ -238,6 +240,45 @@ async def acall( """ raise NotImplementedError + async def abatch( + self, + messages_batch: list[str | list[LLMMessage]], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> list[str | Any]: + """Execute multiple async LLM calls concurrently.""" + return await asyncio.gather( + *[ + self.acall( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + for messages in messages_batch + ] + ) + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> AsyncGenerator[str, None]: + """Optional async chunked streaming interface for native providers.""" + raise NotImplementedError + def _convert_tools_for_interference( self, tools: list[dict[str, BaseTool]] ) -> list[dict[str, BaseTool]]: diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 00000000000..1d2c170969c --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,20 @@ +from crewai.llms.providers.oci.completion import OCICompletion +from crewai.llms.providers.oci.vision import ( + IMAGE_EMBEDDING_MODELS, + VISION_MODELS, + encode_image, + is_vision_model, + load_image, + to_data_uri, +) + + +__all__ = [ + "IMAGE_EMBEDDING_MODELS", + "VISION_MODELS", + "OCICompletion", + "encode_image", + "is_vision_model", + "load_image", + "to_data_uri", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 00000000000..f563e931023 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,1456 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +import json +import logging +import os +import re +import threading +from typing import TYPE_CHECKING, Any, Literal, cast +import uuid + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.pydantic_schema_utils import generate_model_description +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "eu-frankfurt-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") +_OCI_TOOL_RESULT_GUIDANCE = ( + "You have received tool results above. Respond to the user with a helpful, " + "natural language answer that incorporates the tool results. Do not output " + "raw JSON or tool call syntax. If you need additional information, you may " + "call another tool." +) + + +def _get_oci_module() -> Any: + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI native provider not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, +) -> dict[str, Any]: + """Create authenticated OCI client kwargs.""" + oci = _get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": (10, 240), + } + + auth_type_upper = auth_type.upper() + + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI.""" + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + stream: bool = False, + oci_provider: str | None = None, + max_sequential_tool_calls: int = 8, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.stream = stream + self.oci_provider = oci_provider or self._infer_provider(model) + self.max_sequential_tool_calls = max_sequential_tool_calls + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_lock = threading.Lock() + self.last_response_metadata = None + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _message_has_multimodal_content(self, content: Any) -> bool: + if not isinstance(content, list): + return False + for item in content: + if isinstance(item, Mapping) and item.get("type") not in (None, "text"): + return True + return False + + def _build_generic_content(self, content: Any) -> list[Any]: + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed_content: list[Any] = [] + for item in content: + if isinstance(item, str): + processed_content.append(models.TextContent(text=item)) + continue + if not isinstance(item, Mapping): + raise ValueError( + f"OCI message content items must be strings or dictionaries, got: {type(item)}" + ) + + content_type = item.get("type") + if content_type == "text": + processed_content.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + elif content_type == "image_url": + image_url = item.get("image_url", {}) + url = image_url.get("url") if isinstance(image_url, Mapping) else None + if not url: + raise ValueError("OCI image_url content requires image_url.url") + processed_content.append( + models.ImageContent(image_url=models.ImageUrl(url=url)) + ) + elif content_type in ("document_url", "document", "file"): + document_data = ( + item.get("document_url") or item.get("document") or item.get("file") + ) + url = ( + document_data.get("url") + if isinstance(document_data, Mapping) + else item.get("url") + ) + if not url: + raise ValueError("OCI document content requires a url") + processed_content.append( + models.DocumentContent(document_url=models.DocumentUrl(url=url)) + ) + elif content_type in ("video_url", "video"): + video_data = item.get("video_url") or item.get("video") + url = ( + video_data.get("url") + if isinstance(video_data, Mapping) + else item.get("url") + ) + if not url: + raise ValueError("OCI video content requires a url") + processed_content.append( + models.VideoContent(video_url=models.VideoUrl(url=url)) + ) + elif content_type in ("audio_url", "audio"): + audio_data = item.get("audio_url") or item.get("audio") + url = ( + audio_data.get("url") + if isinstance(audio_data, Mapping) + else item.get("url") + ) + if not url: + raise ValueError("OCI audio content requires a url") + processed_content.append( + models.AudioContent(audio_url=models.AudioUrl(url=url)) + ) + else: + raise ValueError(f"Unsupported OCI content type: {content_type}") + + return processed_content or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + if role == "tool": + tool_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if message.get("tool_call_id"): + tool_kwargs["tool_call_id"] = message["tool_call_id"] + oci_messages.append(models.ToolMessage(**tool_kwargs)) + continue + + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + + message_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if role == "assistant" and message.get("tool_calls"): + message_kwargs["tool_calls"] = [ + models.FunctionCall( + id=tool_call.get("id"), + name=tool_call.get("function", {}).get("name"), + arguments=tool_call.get("function", {}).get("arguments", "{}"), + ) + for tool_call in message.get("tool_calls", []) + if tool_call.get("function", {}).get("name") + ] + if not message_kwargs["content"]: + message_kwargs["content"] = [models.TextContent(text=".")] + + oci_messages.append(message_cls(**message_kwargs)) + + if ( + self._tool_result_guidance_enabled() + and any(str(message.get("role", "")).lower() == "tool" for message in messages) + ): + oci_messages.append( + models.SystemMessage( + content=[models.TextContent(text=_OCI_TOOL_RESULT_GUIDANCE)] + ) + ) + + return oci_messages + + def _build_cohere_chat_history(self, messages: list[LLMMessage]) -> tuple[list[Any], list[Any] | None, str]: + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + + for message in messages[:-1]: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + if self._message_has_multimodal_content(content): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + tool_calls = None + if message.get("tool_calls"): + tool_calls = [] + for tool_call in message.get("tool_calls", []): + function_info = tool_call.get("function", {}) + function_name = function_info.get("name") + if not function_name: + continue + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + tool_calls.append( + models.CohereToolCall( + name=function_name, + parameters=parameters, + ) + ) + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + tool_calls=tool_calls, + ) + ) + elif role == "tool": + tool_result_kwargs: dict[str, Any] = { + "outputs": [{"output": self._coerce_text(content)}] + } + tool_name = message.get("name") or "tool" + tool_result_kwargs["call"] = models.CohereToolCall( + name=tool_name, + parameters={}, + ) + chat_history.append( + models.CohereToolMessage( + tool_results=[models.CohereToolResult(**tool_result_kwargs)] + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + tool_results: list[Any] = [] + if str(last_message.get("role", "user")).lower() == "tool": + previous_tool_calls: dict[str, dict[str, Any]] = {} + for message in messages: + if str(message.get("role", "")).lower() != "assistant": + continue + for tool_call in message.get("tool_calls", []): + tool_call_id = tool_call.get("id") + if not tool_call_id: + continue + function_info = tool_call.get("function", {}) + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + previous_tool_calls[tool_call_id] = { + "name": function_info.get("name", "tool"), + "parameters": parameters, + } + + for message in messages: + if str(message.get("role", "")).lower() != "tool": + continue + tool_call_id = message.get("tool_call_id") + if not isinstance(tool_call_id, str): + continue + previous_call = previous_tool_calls.get(tool_call_id, {}) + tool_results.append( + models.CohereToolResult( + call=models.CohereToolCall( + name=previous_call.get("name", message.get("name", "tool")), + parameters=previous_call.get("parameters", {}), + ), + outputs=[ + { + "output": self._coerce_text( + message.get("content", "") + ) + } + ], + ) + ) + + message_text = self._coerce_text(last_message.get("content", "")) + if tool_results: + message_text = "" + + return chat_history, tool_results or None, message_text + + def _format_tools(self, tools: list[dict[str, Any]] | None) -> list[Any]: + if not tools: + return [] + + models = self._oci.generative_ai_inference.models + formatted_tools: list[Any] = [] + for tool in tools: + if not isinstance(tool, Mapping): + continue + function_spec = tool.get("function", {}) + if not isinstance(function_spec, Mapping): + continue + name = function_spec.get("name") + if not name: + continue + + parameters = function_spec.get("parameters", {}) + if not isinstance(parameters, Mapping): + parameters = {} + + if self.oci_provider == "cohere": + parameter_definitions = {} + required = set(parameters.get("required", [])) + for param_name, param_schema in parameters.get("properties", {}).items(): + if not isinstance(param_schema, Mapping): + continue + parameter_definitions[param_name] = models.CohereParameterDefinition( + description=param_schema.get("description", ""), + type=param_schema.get("type", "object"), + is_required=param_name in required, + ) + formatted_tools.append( + models.CohereTool( + name=name, + description=function_spec.get("description", name), + parameter_definitions=parameter_definitions, + ) + ) + else: + formatted_tools.append( + models.FunctionDefinition( + name=name, + description=function_spec.get("description", name), + parameters={ + "type": parameters.get("type", "object"), + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + ) + ) + return formatted_tools + + def _build_response_format( + self, response_model: type[BaseModel] | None + ) -> Any | None: + if response_model is None: + return None + + models = self._oci.generative_ai_inference.models + schema_description = generate_model_description(response_model)["json_schema"] + schema_name = _OCI_SCHEMA_NAME_PATTERN.sub("_", schema_description["name"]) + json_schema = models.ResponseJsonSchema( + name=schema_name, + description=(response_model.__doc__ or "").strip() or schema_name, + schema=schema_description["schema"], + is_strict=schema_description["strict"], + ) + + if self.oci_provider == "cohere": + return models.CohereResponseJsonFormat(schema=json_schema.schema) + + return models.JsonSchemaResponseFormat(json_schema=json_schema) + + def _tool_result_guidance_enabled(self) -> bool: + return bool(self.additional_params.get("tool_result_guidance")) + + def _parallel_tool_calls_enabled(self) -> bool: + return bool(self.additional_params.get("parallel_tool_calls")) + + def _build_tool_choice(self) -> Any | None: + tool_choice = self.additional_params.get("tool_choice") + if tool_choice is None: + return None + + models = self._oci.generative_ai_inference.models + if isinstance(tool_choice, str): + if tool_choice == "auto": + return models.ToolChoiceAuto() + if tool_choice == "none": + return models.ToolChoiceNone() + if tool_choice in ("any", "required"): + return models.ToolChoiceRequired() + return models.ToolChoiceFunction(name=tool_choice) + + if isinstance(tool_choice, bool): + return models.ToolChoiceRequired() if tool_choice else models.ToolChoiceNone() + + if isinstance(tool_choice, Mapping): + function_info = tool_choice.get("function") + if isinstance(function_info, Mapping): + function_name = function_info.get("name") + if function_name: + return models.ToolChoiceFunction(name=str(function_name)) + return models.ToolChoiceAuto() + + raise ValueError( + "Unrecognized OCI tool_choice. Expected str, bool, or function mapping." + ) + + def _build_chat_request( + self, + messages: list[LLMMessage], + tools: list[dict[str, Any]] | None = None, + response_model: type[BaseModel] | None = None, + *, + is_stream: bool = False, + ) -> Any: + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + if any(self._message_has_multimodal_content(msg.get("content")) for msg in messages): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) + + chat_history, tool_results, message_text = self._build_cohere_chat_history( + messages + ) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + if tool_results: + request_kwargs["tool_results"] = tool_results + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + formatted_tools = self._format_tools(tools) + if formatted_tools: + request_kwargs["tools"] = formatted_tools + if self.oci_provider == "cohere": + if self._parallel_tool_calls_enabled(): + raise ValueError( + "OCI Cohere models do not support parallel_tool_calls." + ) + request_kwargs.setdefault("is_force_single_step", False) + else: + tool_choice = self._build_tool_choice() + if tool_choice is not None: + request_kwargs["tool_choice"] = tool_choice + if self._parallel_tool_calls_enabled(): + request_kwargs["is_parallel_tool_calls"] = True + + response_format = self._build_response_format(response_model) + if response_format is not None: + request_kwargs["response_format"] = response_format + + if is_stream: + request_kwargs["is_stream"] = True + request_kwargs["stream_options"] = models.StreamOptions( + is_include_usage=True + ) + + passthrough_params = dict(self.additional_params) + passthrough_params.pop("tool_choice", None) + passthrough_params.pop("parallel_tool_calls", None) + passthrough_params.pop("tool_result_guidance", None) + request_kwargs.update(passthrough_params) + + if self.oci_provider == "cohere": + return models.CohereChatRequest(**request_kwargs) + return models.GenericChatRequest(**request_kwargs) + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + chat_response = response.data.chat_response + raw_tool_calls: list[Any] = [] + if self.oci_provider == "cohere": + raw_tool_calls = getattr(chat_response, "tool_calls", None) or [] + else: + choices = getattr(chat_response, "choices", None) or [] + if choices: + message = getattr(choices[0], "message", None) + raw_tool_calls = getattr(message, "tool_calls", None) or [] + + if self.oci_provider == "cohere": + formatted: list[dict[str, Any]] = [] + for tool_call in raw_tool_calls: + parameters = getattr(tool_call, "parameters", {}) + formatted.append( + { + "id": uuid.uuid4().hex, + "type": "function", + "function": { + "name": getattr(tool_call, "name", ""), + "arguments": json.dumps(parameters or {}), + }, + } + ) + return formatted + + return [ + { + "id": getattr(tool_call, "id", None), + "type": "function", + "function": { + "name": getattr(tool_call, "name", ""), + "arguments": getattr(tool_call, "arguments", "{}"), + }, + } + for tool_call in raw_tool_calls + ] + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + message = getattr(choices[0], "message", None) + if message is not None: + reasoning = getattr(message, "reasoning_content", None) + if reasoning: + metadata["reasoning_content"] = reasoning + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + for field_name in ("documents", "citations", "search_queries", "is_search_required"): + value = getattr(chat_response, field_name, None) + if value: + metadata[field_name] = value + + message = getattr(chat_response, "message", None) + if message is not None: + citations = getattr(message, "citations", None) + if citations: + metadata["citations"] = citations + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + def _parse_stream_event(self, event: Any) -> dict[str, Any]: + event_data = getattr(event, "data", None) + if not event_data: + return {} + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + if isinstance(parsed, Mapping): + return dict(parsed) + return {} + except json.JSONDecodeError: + logging.debug("Skipping invalid OCI SSE payload: %s", event_data) + return {} + if isinstance(event_data, Mapping): + return dict(event_data) + return {} + + def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: + if self.oci_provider == "cohere": + if "text" in event_data: + return str(event_data.get("text", "")) + message = event_data.get("message", {}) + if isinstance(message, Mapping): + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) + ) + return "" + + message = event_data.get("message", {}) + if not isinstance(message, Mapping): + return "" + content = message.get("content", []) + if not isinstance(content, list): + return "" + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) and part.get("text") + ) + + def _extract_tool_calls_from_stream_event( + self, event_data: dict[str, Any] + ) -> list[dict[str, Any]]: + message = event_data.get("message", {}) + if self.oci_provider == "cohere": + raw_tool_calls = event_data.get("toolCalls", []) + else: + raw_tool_calls = ( + message.get("toolCalls", []) if isinstance(message, Mapping) else [] + ) + + if not isinstance(raw_tool_calls, list): + return [] + + if self.oci_provider == "cohere": + return [ + { + "id": None, + "type": "function", + "function": { + "name": str(tool_call.get("name", "")), + "arguments": json.dumps(tool_call.get("parameters", {})), + }, + } + for tool_call in raw_tool_calls + if isinstance(tool_call, Mapping) + ] + + return [ + { + "id": tool_call.get("id"), + "type": "function", + "function": { + "name": tool_call.get("name"), + "arguments": tool_call.get("arguments"), + }, + } + for tool_call in raw_tool_calls + if isinstance(tool_call, Mapping) + ] + + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: + usage = event_data.get("usage") + if not isinstance(usage, Mapping): + return {} + return { + "prompt_tokens": int(usage.get("promptTokens", 0) or 0), + "completion_tokens": int(usage.get("completionTokens", 0) or 0), + "total_tokens": int(usage.get("totalTokens", 0) or 0), + } + + def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, Any]: + metadata: dict[str, Any] = {} + finish_reason = event_data.get("finishReason") + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + for field_name in ("documents", "citations", "searchQueries", "isSearchRequired"): + value = event_data.get(field_name) + if value is not None: + normalized_name = { + "searchQueries": "search_queries", + "isSearchRequired": "is_search_required", + }.get(field_name, field_name) + metadata[normalized_name] = value + + usage = self._extract_usage_from_stream_event(event_data) + if usage: + metadata["usage"] = usage + return metadata + + def _parse_structured_response( + self, + *, + content: str, + response_model: type[BaseModel], + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> BaseModel: + try: + structured_response = self._validate_structured_output( + content, response_model + ) + except Exception as error: + error_message = ( + f"Failed to validate OCI structured response with model " + f"{response_model.__name__}: {error}" + ) + raise ValueError(error_message) from error + + if not isinstance(structured_response, BaseModel): + raise ValueError( + f"OCI structured response parsing returned unexpected type: " + f"{type(structured_response)}" + ) + + self._emit_call_completed_event( + response=structured_response.model_dump_json(), + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return structured_response + + def _handle_tool_calls( + self, + *, + normalized_messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None, + callbacks: list[Any] | None, + available_functions: dict[str, Any] | None, + from_task: Task | None, + from_agent: Agent | None, + tool_depth: int, + response_model: type[BaseModel] | None, + tool_calls: list[dict[str, Any]], + ) -> str | BaseModel | list[dict[str, Any]]: + if tool_calls and not available_functions: + self._emit_call_completed_event( + response=tool_calls, + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=normalized_messages, + ) + return tool_calls + + if tool_depth >= self.max_sequential_tool_calls: + raise RuntimeError( + "OCI native provider exceeded max_sequential_tool_calls while executing tools." + ) + + next_messages = list(normalized_messages) + next_messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": tool_calls, + } + ) + + for tool_call in tool_calls: + function_info = tool_call.get("function", {}) + function_name = function_info.get("name", "") + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + function_args = json.loads(raw_arguments) + except json.JSONDecodeError: + function_args = {} + elif isinstance(raw_arguments, Mapping): + function_args = dict(raw_arguments) + else: + function_args = {} + + tool_result = self._handle_tool_execution( + function_name=function_name, + function_args=function_args, + available_functions=available_functions or {}, + from_task=from_task, + from_agent=from_agent, + ) + if tool_result is None: + continue + + next_messages.append( + { + "role": "tool", + "tool_call_id": str(tool_call.get("id") or uuid.uuid4().hex), + "name": function_name, + "content": str(tool_result), + } + ) + + if self.stream: + return self._stream_call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + response_model=response_model, + ) + + return self._call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + response_model=response_model, + ) + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None, + callbacks: list[Any] | None, + available_functions: dict[str, Any] | None, + from_task: Task | None, + from_agent: Agent | None, + tool_depth: int, + response_model: type[BaseModel] | None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request( + normalized_messages, + tools=tools, + response_model=response_model, + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + tool_calls = self._extract_tool_calls(response) + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + response_model=response_model, + tool_calls=tool_calls, + ) + + if response_model is not None: + return self._parse_structured_response( + content=content, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def _stream_call_impl( + self, + *, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None, + callbacks: list[Any] | None, + available_functions: dict[str, Any] | None, + from_task: Task | None, + from_agent: Agent | None, + tool_depth: int, + response_model: type[BaseModel] | None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request( + normalized_messages, + tools=tools, + response_model=response_model, + is_stream=True, + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + + full_response = "" + tool_calls_by_index: dict[int, dict[str, Any]] = {} + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + response_id = uuid.uuid4().hex + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.LLM_CALL, + response_id=response_id, + ) + + stream_tool_calls = self._extract_tool_calls_from_stream_event(event_data) + for index, tool_call in enumerate(stream_tool_calls): + tool_state = tool_calls_by_index.setdefault( + index, + { + "id": None, + "type": "function", + "function": {"name": None, "arguments": ""}, + }, + ) + if tool_call.get("id"): + tool_state["id"] = tool_call["id"] + function_info = tool_call.get("function", {}) + if function_info.get("name"): + tool_state["function"]["name"] = function_info["name"] + chunk_arguments = function_info.get("arguments") + if chunk_arguments: + tool_state["function"]["arguments"] += str(chunk_arguments) + + self._emit_stream_chunk_event( + chunk=str(chunk_arguments or ""), + tool_call={ + "id": tool_state["id"], + "type": "function", + "function": { + "name": tool_state["function"]["name"], + "arguments": str(chunk_arguments or ""), + }, + }, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.TOOL_CALL, + response_id=response_id, + ) + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + if usage_data: + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + tool_calls = [ + { + "id": tool_call.get("id") or uuid.uuid4().hex, + "type": "function", + "function": { + "name": tool_call["function"].get("name", "") or "", + "arguments": tool_call["function"].get("arguments", "") or "", + }, + } + for _, tool_call in sorted(tool_calls_by_index.items()) + if tool_call["function"].get("name") + ] + + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + response_model=response_model, + tool_calls=tool_calls, + ) + + if response_model is not None: + return self._parse_structured_response( + content=full_response, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + return self._finalize_text_response( + content=full_response, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def iter_stream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> Any: + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request( + normalized_messages, + tools=tools, + response_model=response_model, + is_stream=True, + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> Any: + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + + thread.join() + if error_holder: + raise error_holder[0] + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + if self.stream: + return self._stream_call_impl( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=0, + response_model=response_model, + ) + + return self._call_impl( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=0, + response_model=response_model, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + return await asyncio.to_thread( + self.call, + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + async def abatch( + self, + messages_batch: list[str | list[LLMMessage]], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> list[str | Any]: + return [ + await self.acall( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + for messages in messages_batch + ] + + def _chat(self, chat_details: Any) -> Any: + # The OCI SDK client is shared across sync + thread-offloaded async calls. + # Serialize access so sync/async calls cannot race on the same client. + with self._client_lock: + return self.client.chat(chat_details) + + def supports_function_calling(self) -> bool: + return True + + def supports_stop_words(self) -> bool: + return True + + def supports_multimodal(self) -> bool: + return True + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/llms/providers/oci/vision.py b/lib/crewai/src/crewai/llms/providers/oci/vision.py new file mode 100644 index 00000000000..d0049f2632f --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/vision.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import base64 +import mimetypes +from pathlib import Path + + +VISION_MODELS: list[str] = [ + "meta.llama-3.2-90b-vision-instruct", + "meta.llama-3.2-11b-vision-instruct", + "meta.llama-4-scout-17b-16e-instruct", + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "google.gemini-2.5-flash", + "google.gemini-2.5-pro", + "google.gemini-2.5-flash-lite", + "xai.grok-4", + "xai.grok-4-1-fast-reasoning", + "xai.grok-4-1-fast-non-reasoning", + "xai.grok-4-fast-reasoning", + "xai.grok-4-fast-non-reasoning", + "cohere.command-a-vision", +] + +IMAGE_EMBEDDING_MODELS: list[str] = [ + "cohere.embed-v4.0", + "cohere.embed-multilingual-image-v3.0", +] + + +def to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + """Convert bytes, file paths, or data URIs into a data URI.""" + if isinstance(image, bytes): + encoded = base64.standard_b64encode(image).decode("utf-8") + return f"data:{mime_type};base64,{encoded}" + + image_str = str(image) + if image_str.startswith("data:"): + return image_str + + path = Path(image_str) + detected_mime = mimetypes.guess_type(str(path))[0] or mime_type + encoded = base64.standard_b64encode(path.read_bytes()).decode("utf-8") + return f"data:{detected_mime};base64,{encoded}" + + +def load_image(file_path: str | Path) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(file_path)}} + + +def encode_image( + image_bytes: bytes, mime_type: str = "image/png" +) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(image_bytes, mime_type)}} + + +def is_vision_model(model_id: str) -> bool: + return model_id in VISION_MODELS diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index 8027793200c..0caa2e82ed9 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -70,6 +70,10 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec + from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, + ) + from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -100,6 +104,7 @@ "jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider", "ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider", "onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider", + "oci": "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider", "openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider", "openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider", "roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider", @@ -216,6 +221,10 @@ def build_embedder_from_dict( def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ... +@overload +def build_embedder_from_dict(spec: OCIProviderSpec) -> OCIEmbeddingFunction: ... + + @overload def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ... diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py new file mode 100644 index 00000000000..a3c44e2b030 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py @@ -0,0 +1,18 @@ +"""OCI embedding provider exports.""" + +from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, +) +from crewai.rag.embeddings.providers.oci.oci_provider import OCIProvider +from crewai.rag.embeddings.providers.oci.types import ( + OCIProviderConfig, + OCIProviderSpec, +) + + +__all__ = [ + "OCIEmbeddingFunction", + "OCIProvider", + "OCIProviderConfig", + "OCIProviderSpec", +] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py new file mode 100644 index 00000000000..e3b4c052a2c --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -0,0 +1,255 @@ +"""OCI embedding function implementation.""" + +from __future__ import annotations + +import base64 +from collections.abc import Iterator, Sequence +import mimetypes +import os +from pathlib import Path +from typing import Any, cast + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from typing_extensions import Unpack + +from crewai.rag.embeddings.providers.oci.types import OCIProviderConfig + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "eu-frankfurt-1" + + +def _get_oci_module() -> Any: + try: + import oci # type: ignore[import-untyped] + except ImportError as e: + raise ImportError( + "oci is required for OCI embeddings. Install it with: uv add 'crewai[oci]'" + ) from e + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], +) -> dict[str, Any]: + """Create authenticated OCI client kwargs.""" + oci = _get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + private_key = oci.signer.load_private_key_from_file(config["key_file"], None) + with open(config["security_token_file"], encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs + + +class OCIEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for OCI Generative AI embedding models.""" + + def __init__(self, **kwargs: Unpack[OCIProviderConfig]) -> None: + self._config = kwargs + self._client: Any = kwargs.get("client") + if self._client is None: + service_endpoint = kwargs.get("service_endpoint") + region = kwargs.get("region") or os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + if service_endpoint is None: + service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + client_kwargs = create_oci_client_kwargs( + auth_type=kwargs.get("auth_type", "API_KEY"), + service_endpoint=service_endpoint, + auth_file_location=kwargs.get("auth_file_location", "~/.oci/config"), + auth_profile=kwargs.get("auth_profile", "DEFAULT"), + timeout=kwargs.get("timeout", (10, 120)), + ) + self._client = ( + _get_oci_module().generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + ) + + def _require_client(self) -> Any: + if self._client is None: + raise ValueError("OCI embedding client is not initialized.") + return self._client + + @staticmethod + def name() -> str: + """Return the name of the embedding function for ChromaDB compatibility.""" + return "oci" + + @staticmethod + def build_from_config(config: dict[str, Any]) -> OCIEmbeddingFunction: + """Build an OCI embedding function from a serialized config.""" + timeout = config.get("timeout") + if isinstance(timeout, list): + config = dict(config) + config["timeout"] = tuple(timeout) + + return OCIEmbeddingFunction(**config) + + def get_config(self) -> dict[str, Any]: + """Return a serializable config for ChromaDB compatibility.""" + config = dict(self._config) + config.pop("client", None) + + timeout = config.get("timeout") + if isinstance(timeout, tuple): + config["timeout"] = list(timeout) + + return config + + def _get_serving_mode(self) -> Any: + oci = _get_oci_module() + model_name = self._config.get("model_name") + if not model_name: + raise ValueError("OCI embeddings require model_name") + + if model_name.startswith(CUSTOM_ENDPOINT_PREFIX): + return oci.generative_ai_inference.models.DedicatedServingMode( + endpoint_id=model_name + ) + + return oci.generative_ai_inference.models.OnDemandServingMode( + model_id=model_name + ) + + def _build_request( + self, inputs: list[str], *, input_type: str | None = None + ) -> Any: + oci = _get_oci_module() + compartment_id = self._config.get("compartment_id") or os.getenv( + "OCI_COMPARTMENT_ID" + ) + if not compartment_id: + raise ValueError( + "OCI embeddings require compartment_id. Set it explicitly or use OCI_COMPARTMENT_ID." + ) + + request_kwargs: dict[str, Any] = { + "serving_mode": self._get_serving_mode(), + "compartment_id": compartment_id, + "truncate": self._config.get("truncate", "END"), + "inputs": inputs, + } + + resolved_input_type = input_type or self._config.get("input_type") + if resolved_input_type: + request_kwargs["input_type"] = resolved_input_type + + output_dimensions = self._config.get("output_dimensions") + if output_dimensions is not None: + embed_text_details = oci.generative_ai_inference.models.EmbedTextDetails + if hasattr(embed_text_details, "output_dimensions"): + request_kwargs["output_dimensions"] = output_dimensions + else: + raise ValueError( + "output_dimensions requires a newer OCI SDK. Upgrade the oci package." + ) + + return oci.generative_ai_inference.models.EmbedTextDetails(**request_kwargs) + + def _batch_inputs(self, input: list[str]) -> Iterator[list[str]]: + batch_size = self._config.get("batch_size", 96) + for index in range(0, len(input), batch_size): + yield input[index : index + batch_size] + + @staticmethod + def _to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + if isinstance(image, Path): + resolved_mime = mimetypes.guess_type(image.name)[0] or mime_type + data = image.read_bytes() + return ( + f"data:{resolved_mime};base64," + f"{base64.b64encode(data).decode('ascii')}" + ) + + if isinstance(image, bytes): + return f"data:{mime_type};base64,{base64.b64encode(image).decode('ascii')}" + + if image.startswith("data:"): + return image + + path = Path(image) + if path.exists(): + return OCIEmbeddingFunction._to_data_uri(path, mime_type=mime_type) + + raise ValueError( + "OCI image embeddings require a file path, raw bytes, or a data URI." + ) + + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + + embeddings: Embeddings = [] + for chunk in self._batch_inputs(input): + response = self._require_client().embed_text(self._build_request(chunk)) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings + + def embed_image( + self, + image: str | bytes | Path, + *, + mime_type: str = "image/png", + ) -> list[float]: + return [float(value) for value in self.embed_image_batch([image], mime_type=mime_type)[0]] + + def embed_image_batch( + self, + images: Sequence[str | bytes | Path], + *, + mime_type: str = "image/png", + ) -> Embeddings: + embeddings: Embeddings = [] + for image in images: + data_uri = self._to_data_uri(image, mime_type=mime_type) + response = self._require_client().embed_text( + self._build_request([data_uri], input_type="IMAGE") + ) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py new file mode 100644 index 00000000000..92a28db90d5 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py @@ -0,0 +1,88 @@ +"""OCI embeddings provider.""" + +from typing import Any + +from pydantic import AliasChoices, Field + +from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class OCIProvider(BaseEmbeddingsProvider[OCIEmbeddingFunction]): + """OCI Generative AI embeddings provider.""" + + embedding_callable: type[OCIEmbeddingFunction] = Field( + default=OCIEmbeddingFunction, + description="OCI embedding function class", + ) + model_name: str = Field( + default="cohere.embed-english-v3.0", + description="Model name to use for embeddings", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_MODEL_NAME", + "OCI_EMBED_MODEL", + "model", + "model_name", + ), + ) + compartment_id: str = Field( + description="OCI compartment ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_COMPARTMENT_ID", + "OCI_COMPARTMENT_ID", + "compartment_id", + ), + ) + service_endpoint: str | None = Field( + default=None, + description="OCI Generative AI inference endpoint", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_SERVICE_ENDPOINT", + "OCI_SERVICE_ENDPOINT", + "service_endpoint", + ), + ) + region: str | None = Field( + default=None, + description="OCI region used to derive the inference endpoint when service_endpoint is not provided", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_REGION", + "OCI_REGION", + "region", + ), + ) + auth_type: str = Field( + default="API_KEY", + description="OCI SDK auth type", + validation_alias=AliasChoices("EMBEDDINGS_OCI_AUTH_TYPE", "OCI_AUTH_TYPE"), + ) + auth_profile: str = Field( + default="DEFAULT", + description="OCI config profile name", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_AUTH_PROFILE", + "OCI_AUTH_PROFILE", + ), + ) + auth_file_location: str = Field( + default="~/.oci/config", + description="OCI config file location", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_AUTH_FILE_LOCATION", + "OCI_AUTH_FILE_LOCATION", + ), + ) + truncate: str = Field(default="END", description="OCI embedding truncate policy") + input_type: str | None = Field( + default=None, + description="Optional OCI embedding input type such as SEARCH_DOCUMENT or SEARCH_QUERY", + ) + output_dimensions: int | None = Field( + default=None, + description="Optional output dimensions for compatible OCI embedding models", + ) + batch_size: int = Field(default=96, description="OCI embedding batch size") + timeout: tuple[int, int] = Field( + default=(10, 120), description="OCI SDK connect/read timeout" + ) + client: Any | None = Field(default=None, description="Injected OCI client") diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py new file mode 100644 index 00000000000..757d0be6cef --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py @@ -0,0 +1,30 @@ +"""Type definitions for OCI embedding providers.""" + +from typing import Annotated, Any, Literal + +from typing_extensions import Required, TypedDict + + +class OCIProviderConfig(TypedDict, total=False): + """Configuration for OCI embedding provider.""" + + model_name: Annotated[str, "cohere.embed-english-v3.0"] + compartment_id: str + service_endpoint: str + region: str + auth_type: str + auth_profile: str + auth_file_location: str + truncate: str + input_type: str + output_dimensions: int + batch_size: int + timeout: tuple[int, int] + client: Any + + +class OCIProviderSpec(TypedDict, total=False): + """OCI provider specification.""" + + provider: Required[Literal["oci"]] + config: OCIProviderConfig diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py index 794f4c6f9a7..23d2b850a36 100644 --- a/lib/crewai/src/crewai/rag/embeddings/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/types.py @@ -17,6 +17,7 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec +from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -40,6 +41,7 @@ | JinaProviderSpec | OllamaProviderSpec | ONNXProviderSpec + | OCIProviderSpec | OpenAIProviderSpec | OpenCLIPProviderSpec | RoboflowProviderSpec @@ -62,6 +64,7 @@ "jina", "ollama", "onnx", + "oci", "openai", "openclip", "roboflow", diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 00000000000..58559ce1ba9 --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,447 @@ +from __future__ import annotations + +import json +import os +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from crewai.llm import LLM + + +def _simple_init_class(name: str): + class _Simple: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + _Simple.__name__ = name + return _Simple + + +class FakeOCI: + def __init__(self) -> None: + self.retry = SimpleNamespace(DEFAULT_RETRY_STRATEGY="retry") + self.config = SimpleNamespace( + from_file=lambda file_location, profile_name: { + "file_location": file_location, + "profile_name": profile_name, + } + ) + self.signer = SimpleNamespace( + load_private_key_from_file=lambda *_args, **_kwargs: "private-key" + ) + self.auth = SimpleNamespace( + signers=SimpleNamespace( + SecurityTokenSigner=lambda token, key: (token, key), + InstancePrincipalsSecurityTokenSigner=lambda: "instance-principal", + get_resource_principals_signer=lambda: "resource-principal", + ) + ) + self.generative_ai_inference = SimpleNamespace( + GenerativeAiInferenceClient=MagicMock(), + models=SimpleNamespace( + BaseChatRequest=SimpleNamespace( + API_FORMAT_GENERIC="GENERIC", + API_FORMAT_COHERE="COHERE", + ), + GenericChatRequest=_simple_init_class("GenericChatRequest"), + ChatDetails=_simple_init_class("ChatDetails"), + OnDemandServingMode=_simple_init_class("OnDemandServingMode"), + DedicatedServingMode=_simple_init_class("DedicatedServingMode"), + UserMessage=_simple_init_class("UserMessage"), + AssistantMessage=_simple_init_class("AssistantMessage"), + SystemMessage=_simple_init_class("SystemMessage"), + ToolMessage=_simple_init_class("ToolMessage"), + TextContent=_simple_init_class("TextContent"), + ImageContent=_simple_init_class("ImageContent"), + ImageUrl=_simple_init_class("ImageUrl"), + DocumentContent=_simple_init_class("DocumentContent"), + DocumentUrl=_simple_init_class("DocumentUrl"), + VideoContent=_simple_init_class("VideoContent"), + VideoUrl=_simple_init_class("VideoUrl"), + AudioContent=_simple_init_class("AudioContent"), + AudioUrl=_simple_init_class("AudioUrl"), + FunctionCall=_simple_init_class("FunctionCall"), + FunctionDefinition=_simple_init_class("FunctionDefinition"), + ToolChoiceAuto=_simple_init_class("ToolChoiceAuto"), + ToolChoiceFunction=_simple_init_class("ToolChoiceFunction"), + ToolChoiceNone=_simple_init_class("ToolChoiceNone"), + ToolChoiceRequired=_simple_init_class("ToolChoiceRequired"), + StreamOptions=_simple_init_class("StreamOptions"), + CohereChatRequest=_simple_init_class("CohereChatRequest"), + CohereUserMessage=_simple_init_class("CohereUserMessage"), + CohereChatBotMessage=_simple_init_class("CohereChatBotMessage"), + CohereSystemMessage=_simple_init_class("CohereSystemMessage"), + CohereToolMessage=_simple_init_class("CohereToolMessage"), + CohereTool=_simple_init_class("CohereTool"), + CohereParameterDefinition=_simple_init_class( + "CohereParameterDefinition" + ), + CohereToolCall=_simple_init_class("CohereToolCall"), + CohereToolResult=_simple_init_class("CohereToolResult"), + CohereResponseJsonFormat=_simple_init_class( + "CohereResponseJsonFormat" + ), + ResponseJsonSchema=_simple_init_class("ResponseJsonSchema"), + JsonSchemaResponseFormat=_simple_init_class("JsonSchemaResponseFormat"), + ), + ) + + +def fake_chat_response(text: str): + return SimpleNamespace( + data=SimpleNamespace( + chat_response=SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content=[SimpleNamespace(text=text)]) + ) + ] + ) + ) + ) + + +def fake_tool_call_response(name: str, arguments: str): + return SimpleNamespace( + data=SimpleNamespace( + chat_response=SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace( + content=[], + tool_calls=[ + SimpleNamespace( + id="call_123", + name=name, + arguments=arguments, + ) + ], + ) + ) + ] + ) + ) + ) + + +def fake_cohere_tool_call_response(name: str, parameters: dict[str, str]): + return SimpleNamespace( + data=SimpleNamespace( + chat_response=SimpleNamespace( + text="", + tool_calls=[SimpleNamespace(name=name, parameters=parameters)], + ) + ) + ) + + +def fake_stream_response(*payloads: dict[str, object]): + return SimpleNamespace( + data=SimpleNamespace( + events=lambda: [ + SimpleNamespace(data=json.dumps(payload)) for payload in payloads + ] + ) + ) + + +def _env(name: str, default: str) -> str: + return os.getenv(name, default) + + +def _unit_model_defaults() -> dict[str, str]: + return { + "generic_model": _env("OCI_UNIT_GENERIC_MODEL", "meta.test-generic-model"), + "generic_tool_model": _env("OCI_UNIT_OPENAI_MODEL", "openai.test-chat-model"), + "generic_structured_model": _env( + "OCI_UNIT_STRUCTURED_MODEL", "openai.test-structured-model" + ), + "gpt4_control_model": _env("OCI_UNIT_GPT4_MODEL", "openai.test-gpt4-model"), + "gpt5_model": _env("OCI_UNIT_GPT5_MODEL", "openai.gpt-5"), + "cohere_model": _env("OCI_UNIT_COHERE_MODEL", "cohere.test-chat-model"), + "cohere_chat_model": _env( + "OCI_UNIT_COHERE_CHAT_MODEL", "cohere.test-chat-model" + ), + "gemini_model": _env("OCI_UNIT_GEMINI_MODEL", "google.test-chat-model"), + "llama_model": _env("OCI_UNIT_LLAMA_MODEL", "meta.test-llama-model"), + "grok_model": _env("OCI_UNIT_GROK_MODEL", "xai.test-chat-model"), + } + + +def _provider_family_cases() -> list[tuple[str, str]]: + models = _unit_model_defaults() + return [ + (models["gpt5_model"], "generic"), + (models["gemini_model"], "generic"), + (models["llama_model"], "generic"), + (models["grok_model"], "generic"), + (models["cohere_chat_model"], "cohere"), + ] + + +def _unit_test_values() -> dict[str, object]: + models = _unit_model_defaults() + return { + "compartment_id": _env("OCI_UNIT_COMPARTMENT_ID", "ocid1.compartment.oc1..test"), + "service_endpoint": _env( + "OCI_UNIT_SERVICE_ENDPOINT", + "https://inference.generativeai.test-region-1.oci.oraclecloud.com", + ), + "region": _env("OCI_UNIT_REGION", "test-region-1"), + "generic_model": models["generic_model"], + "generic_tool_model": models["generic_tool_model"], + "generic_structured_model": models["generic_structured_model"], + "gpt4_control_model": models["gpt4_control_model"], + "gpt5_model": models["gpt5_model"], + "cohere_model": models["cohere_model"], + "cohere_chat_model": models["cohere_chat_model"], + "gemini_model": models["gemini_model"], + "llama_model": models["llama_model"], + "grok_model": models["grok_model"], + "prefixed_model": f"oci/{models['generic_model']}", + "chat_prompt": _env("OCI_UNIT_CHAT_PROMPT", "Tell me something about Oracle Cloud."), + "hello_prompt": _env("OCI_UNIT_HELLO_PROMPT", "Say hello"), + "json_prompt": _env("OCI_UNIT_JSON_PROMPT", "Summarize OCI in JSON."), + "search_prompt": _env("OCI_UNIT_SEARCH_PROMPT", "Search Oracle Cloud docs"), + "docs_prompt": _env("OCI_UNIT_DOCS_PROMPT", "Find docs about Oracle Cloud"), + "weather_prompt": _env("OCI_UNIT_WEATHER_PROMPT", "What is the weather in Paris?"), + "multimodal_prompt": _env("OCI_UNIT_MULTIMODAL_PROMPT", "Summarize these files"), + "provider_family_cases": _provider_family_cases(), + } + + +def _prompt_defaults() -> dict[str, str]: + return { + "basic": _env( + "OCI_TEST_BASIC_PROMPT", "Reply with exactly two words about Oracle Cloud." + ), + "stream": _env( + "OCI_TEST_STREAM_PROMPT", "Reply with exactly three words about Oracle Cloud." + ), + "async": _env( + "OCI_TEST_ASYNC_PROMPT", "Reply with exactly two words about Oracle Cloud." + ), + "structured": _env( + "OCI_TEST_STRUCTURED_PROMPT", + "Return a short JSON summary about Oracle Cloud.", + ), + "tool": _env( + "OCI_TEST_TOOL_PROMPT", + "Use the add_numbers tool to calculate 15 + 27. Return only the final result.", + ), + "tool_structured": _env( + "OCI_TEST_TOOL_STRUCTURED_PROMPT", + "Calculate 15 + 27 using your add_numbers tool. Report the result.", + ), + } + +OCI_ALLOWED_HOSTS = [r".*\.oci\.oraclecloud\.com"] + + +def _resolve_model_matrix(single_var: str, list_var: str) -> list[str | None]: + single_model = os.getenv(single_var) + if single_model: + return [single_model] + + models_env = os.getenv(list_var) + if models_env: + return [model.strip() for model in models_env.split(",") if model.strip()] + + return [None] + + +def _has_oci_sdk() -> bool: + try: + import oci # noqa: F401 + except ImportError: + return False + return True + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + if "oci_provider_family_case" in metafunc.fixturenames: + provider_family_cases = _provider_family_cases() + metafunc.parametrize( + "oci_provider_family_case", + provider_family_cases, + ids=[case[0] for case in provider_family_cases], + ) + + if "oci_chat_model" in metafunc.fixturenames: + models = _resolve_model_matrix("OCI_TEST_MODEL", "OCI_TEST_MODELS") + metafunc.parametrize( + "oci_chat_model", + models, + indirect=True, + ids=[model or "unconfigured-chat-model" for model in models], + ) + + if "oci_tool_model" in metafunc.fixturenames: + models = _resolve_model_matrix("OCI_TEST_TOOL_MODEL", "OCI_TEST_TOOL_MODELS") + metafunc.parametrize( + "oci_tool_model", + models, + indirect=True, + ids=[model or "unconfigured-tool-model" for model in models], + ) + + if "oci_multimodal_model" in metafunc.fixturenames: + models = _resolve_model_matrix( + "OCI_TEST_MULTIMODAL_MODEL", "OCI_TEST_MULTIMODAL_MODELS" + ) + metafunc.parametrize( + "oci_multimodal_model", + models, + indirect=True, + ids=[model or "unconfigured-multimodal-model" for model in models], + ) + + +@pytest.fixture +def vcr_config() -> dict[str, list[str]]: + return {"allowed_hosts": OCI_ALLOWED_HOSTS} + + +@pytest.fixture +def allowed_hosts() -> list[str]: + return [r".*"] + + +@pytest.fixture +def oci_unit_values() -> dict[str, object]: + return dict(_unit_test_values()) + + +@pytest.fixture +def oci_prompts() -> dict[str, str]: + return dict(_prompt_defaults()) + + +@pytest.fixture +def oci_fake_module() -> FakeOCI: + return FakeOCI() + + +@pytest.fixture +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: FakeOCI) -> FakeOCI: + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture +def oci_response_factories(): + return { + "chat": fake_chat_response, + "tool_call": fake_tool_call_response, + "cohere_tool_call": fake_cohere_tool_call_response, + "stream": fake_stream_response, + } + + +@pytest.fixture +def oci_live_config() -> dict[str, str | None]: + if not _has_oci_sdk(): + pytest.skip("Requires OCI SDK") + + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + region = os.getenv("OCI_TEST_REGION") or os.getenv("OCI_REGION") + service_endpoint = os.getenv("OCI_TEST_SERVICE_ENDPOINT") or os.getenv( + "OCI_SERVICE_ENDPOINT" + ) + if not compartment_id or not (region or service_endpoint): + pytest.skip( + "Requires OCI_COMPARTMENT_ID plus OCI_REGION/OCI_TEST_REGION or OCI_SERVICE_ENDPOINT/OCI_TEST_SERVICE_ENDPOINT" + ) + + return { + "compartment_id": compartment_id, + "service_endpoint": service_endpoint, + "auth_type": os.getenv("OCI_AUTH_TYPE", "API_KEY"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), + "auth_file_location": os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + } + + +@pytest.fixture +def oci_live_llm_factory(oci_live_config: dict[str, str | None]): + def _factory(model: str, **kwargs: object) -> LLM: + return LLM( + model=f"oci/{model}", + compartment_id=oci_live_config["compartment_id"], + service_endpoint=oci_live_config["service_endpoint"], + auth_type=oci_live_config["auth_type"], + auth_profile=oci_live_config["auth_profile"], + auth_file_location=oci_live_config["auth_file_location"], + **kwargs, + ) + + return _factory + + +@pytest.fixture +def oci_chat_model(request: pytest.FixtureRequest) -> str: + model = request.param + if not model: + pytest.skip("Configure OCI_TEST_MODEL or OCI_TEST_MODELS for live chat tests") + return model + + +@pytest.fixture +def oci_tool_model(request: pytest.FixtureRequest) -> str: + model = request.param + if not model: + pytest.skip( + "Configure OCI_TEST_TOOL_MODEL or OCI_TEST_TOOL_MODELS for live tool tests" + ) + return model + + +@pytest.fixture +def oci_multimodal_model(request: pytest.FixtureRequest) -> str: + model = request.param + if not model: + pytest.skip( + "Configure OCI_TEST_MULTIMODAL_MODEL or OCI_TEST_MULTIMODAL_MODELS for live multimodal tests" + ) + return model + + +@pytest.fixture +def oci_temperature_for_model(): + def _temperature(model: str) -> float | None: + if model.startswith("openai.gpt-5"): + return None + return 0 + + return _temperature + + +@pytest.fixture +def oci_token_budget(): + def _budget(model: str, scenario: str) -> int: + if scenario == "structured" and model.startswith("openai.gpt-5"): + return 2048 + if scenario == "agent" and model.startswith("openai.gpt-5"): + return 1536 + if scenario == "stream" and model.startswith("openai.gpt-5"): + return 1536 + if scenario in {"basic", "async"} and model.startswith("openai.gpt-5"): + return 1024 + if scenario == "agent" and model.startswith("google.gemini"): + return 384 + if scenario in {"basic", "async", "structured"} and model.startswith( + "google.gemini" + ): + return 256 + if scenario == "stream": + return 64 + if scenario == "agent": + return 256 + return 128 + + return _budget diff --git a/lib/crewai/tests/llms/oci/profile_oci_agent.py b/lib/crewai/tests/llms/oci/profile_oci_agent.py new file mode 100644 index 00000000000..e660a207d29 --- /dev/null +++ b/lib/crewai/tests/llms/oci/profile_oci_agent.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import argparse +import cProfile +import os +import pstats +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +from crewai import Agent, LLM +from crewai.tools import tool + +os.environ.setdefault("OTEL_SDK_DISABLED", "true") + + +@tool +def add_numbers(a: int, b: int) -> int: + """Add two numbers and return the sum.""" + return a + b + + +@dataclass +class ScenarioResult: + name: str + elapsed_seconds: float + responses: list[str] + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Profile the OCI tool-calling agent flow used by the OCI integration tests." + ) + parser.add_argument( + "--model", + default=os.getenv("OCI_TEST_TOOL_MODEL", "openai.gpt-5.2-chat-latest"), + help="OCI model name without the oci/ prefix.", + ) + parser.add_argument( + "--scenario", + choices=("single", "multi", "both"), + default="both", + help="Which scenario to run.", + ) + parser.add_argument( + "--top", + type=int, + default=25, + help="Number of cProfile rows to print when profiling is enabled.", + ) + parser.add_argument( + "--profile-output", + type=Path, + help="Optional path to dump cProfile stats.", + ) + return parser.parse_args() + + +def _require_env(name: str) -> str: + value = os.getenv(name) + if not value: + raise SystemExit(f"Missing required environment variable: {name}") + return value + + +def _temperature_for_model(model: str) -> float | None: + if model.startswith("openai.gpt-5"): + return None + return 0 + + +def _build_llm(model: str) -> LLM: + service_endpoint = os.getenv("OCI_TEST_SERVICE_ENDPOINT") or os.getenv( + "OCI_SERVICE_ENDPOINT" + ) + region = os.getenv("OCI_TEST_REGION") or os.getenv("OCI_REGION") + if not service_endpoint and not region: + raise SystemExit( + "Set OCI_TEST_SERVICE_ENDPOINT/OCI_SERVICE_ENDPOINT or OCI_TEST_REGION/OCI_REGION." + ) + + kwargs: dict[str, Any] = { + "model": f"oci/{model}", + "compartment_id": _require_env("OCI_COMPARTMENT_ID"), + "auth_type": os.getenv("OCI_AUTH_TYPE", "API_KEY"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), + "auth_file_location": os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + "max_tokens": 1536, + "temperature": _temperature_for_model(model), + } + if service_endpoint: + kwargs["service_endpoint"] = service_endpoint + return LLM(**kwargs) + + +def _build_agent(model: str) -> Agent: + return Agent( + role="Calculator", + goal="Use tools to solve arithmetic problems accurately and consistently.", + backstory="You are a precise calculator that must use the available tool.", + llm=_build_llm(model), + tools=[add_numbers], + verbose=False, + ) + + +def _run_single(agent: Agent) -> ScenarioResult: + started = time.perf_counter() + result = agent.kickoff("Use add_numbers to calculate 20 + 22. Return only the final result.") + elapsed_seconds = time.perf_counter() - started + return ScenarioResult("single", elapsed_seconds, [result.raw]) + + +def _run_multi(agent: Agent) -> ScenarioResult: + prompts = [ + "Use add_numbers to calculate 2 + 5. Return only the final result.", + "Use add_numbers to calculate 10 + 11. Return only the final result.", + "Use add_numbers to calculate 20 + 22. Return only the final result.", + ] + started = time.perf_counter() + responses = [agent.kickoff(prompt).raw for prompt in prompts] + elapsed_seconds = time.perf_counter() - started + return ScenarioResult("multi", elapsed_seconds, responses) + + +def _run_scenarios(model: str, scenario: str) -> list[ScenarioResult]: + results: list[ScenarioResult] = [] + if scenario in {"single", "both"}: + results.append(_run_single(_build_agent(model))) + if scenario in {"multi", "both"}: + results.append(_run_multi(_build_agent(model))) + return results + + +def _print_results(results: list[ScenarioResult]) -> None: + for result in results: + print(f"[{result.name}] elapsed={result.elapsed_seconds:.3f}s") + for index, response in enumerate(result.responses, start=1): + print(f" response_{index}: {response}") + + +def main() -> None: + args = _parse_args() + profiler = cProfile.Profile() + profiler.enable() + results = _run_scenarios(args.model, args.scenario) + profiler.disable() + + _print_results(results) + + stats = pstats.Stats(profiler).sort_stats("cumulative") + stats.print_stats(args.top) + if args.profile_output: + profiler.dump_stats(str(args.profile_output)) + print(f"cProfile stats written to {args.profile_output}") + + +if __name__ == "__main__": + main() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 00000000000..066616127a1 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,601 @@ +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + +from crewai.llm import LLM +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_completion_is_used_when_oci_provider( + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value.chat.return_value = oci_response_factories[ + "chat" + ]("test") + + llm = LLM( + model=str(oci_unit_values["prefixed_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + assert isinstance(llm, OCICompletion) + assert llm.provider == "oci" + assert llm.model == str(oci_unit_values["generic_model"]) + + +def test_oci_completion_infers_provider_family( + patch_oci_module, + oci_response_factories, + oci_provider_family_case: tuple[str, str], + oci_unit_values: dict[str, object], +): + model, oci_provider = oci_provider_family_case + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value.chat.return_value = oci_response_factories[ + "chat" + ]("test") + + llm = OCICompletion( + model=model, + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + assert llm.oci_provider == oci_provider + + +def test_oci_completion_initialization_parameters( + patch_oci_module, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = LLM( + model=str(oci_unit_values["prefixed_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + service_endpoint=str(oci_unit_values["service_endpoint"]), + temperature=0.2, + max_tokens=256, + top_p=0.9, + top_k=20, + ) + + assert isinstance(llm, OCICompletion) + assert llm.temperature == 0.2 + assert llm.max_tokens == 256 + assert llm.top_p == 0.9 + assert llm.top_k == 20 + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.assert_called_once() + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("Hello from OCI") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["chat_prompt"])}] + ) + + assert result == "Hello from OCI" + request = fake_client.chat.call_args.args[0] + assert request.compartment_id == str(oci_unit_values["compartment_id"]) + assert request.serving_mode.model_id == str(oci_unit_values["generic_model"]) + assert request.chat_request.messages[0].content[0].text == str( + oci_unit_values["chat_prompt"] + ) + + +def test_oci_completion_uses_region_to_build_endpoint( + monkeypatch: pytest.MonkeyPatch, + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value.chat.return_value = oci_response_factories[ + "chat" + ]("test") + monkeypatch.setenv("OCI_REGION", str(oci_unit_values["region"])) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + assert ( + llm.service_endpoint + == f"https://inference.generativeai.{oci_unit_values['region']}.oci.oraclecloud.com" + ) + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value.chat.return_value = oci_response_factories[ + "chat" + ]("test") + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + max_tokens=77, + ) + + request = llm._build_chat_request([{"role": "user", "content": "hello"}]) + + assert not hasattr(request, "max_tokens") + assert request.max_completion_tokens == 77 + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value.chat.return_value = oci_response_factories[ + "chat" + ]("test") + + llm = OCICompletion( + model=str(oci_unit_values["gpt5_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + temperature=0, + stop=["Observation:"], + ) + + request = llm._build_chat_request([{"role": "user", "content": "hello"}]) + + assert not hasattr(request, "temperature") + assert not hasattr(request, "stop") + + +def test_oci_completion_supports_structured_output( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + class OracleCloudSummary(BaseModel): + summary: str + confidence: int + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]( + '{"summary":"OCI is scalable","confidence":92}' + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_structured_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["json_prompt"])}], + response_model=OracleCloudSummary, + ) + + assert isinstance(result, OracleCloudSummary) + assert result.summary == "OCI is scalable" + assert result.confidence == 92 + request = fake_client.chat.call_args.args[0] + assert request.chat_request.response_format.json_schema.name == "OracleCloudSummary" + assert request.chat_request.response_format.json_schema.is_strict is True + + +def test_oci_completion_extracts_fenced_structured_output( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + class OracleCloudSummary(BaseModel): + summary: str + topic: str + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]( + '```json\n{"summary":"OCI is scalable","topic":"Oracle Cloud"}\n```' + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["gpt4_control_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["json_prompt"])}], + response_model=OracleCloudSummary, + ) + + assert isinstance(result, OracleCloudSummary) + assert result.summary == "OCI is scalable" + assert result.topic == "Oracle Cloud" + + +def test_oci_completion_streams_generic_responses( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["stream"]( + {"message": {"content": [{"text": "Hello"}]}}, + {"message": {"content": [{"text": " from OCI"}]}}, + {"finishReason": "stop"}, + {"usage": {"promptTokens": 11, "completionTokens": 4, "totalTokens": 15}}, + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + stream=True, + ) + llm._emit_stream_chunk_event = MagicMock() + + result = llm.call([{"role": "user", "content": str(oci_unit_values["hello_prompt"])}]) + + assert result == "Hello from OCI" + request = fake_client.chat.call_args.args[0] + assert request.chat_request.is_stream is True + assert request.chat_request.stream_options.is_include_usage is True + assert llm._emit_stream_chunk_event.call_count == 2 + + +def test_oci_completion_builds_multimodal_generic_messages( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + request = llm._build_chat_request( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": str(oci_unit_values["multimodal_prompt"])}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAA"}}, + { + "type": "document_url", + "document_url": {"url": "data:application/pdf;base64,BBB"}, + }, + {"type": "video_url", "video_url": {"url": "data:video/mp4;base64,CCC"}}, + {"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,DDD"}}, + ], + } + ] + ) + + content = request.messages[0].content + assert content[0].text == str(oci_unit_values["multimodal_prompt"]) + assert content[1].image_url.url == "data:image/png;base64,AAA" + assert content[2].document_url.url == "data:application/pdf;base64,BBB" + assert content[3].video_url.url == "data:video/mp4;base64,CCC" + assert content[4].audio_url.url == "data:audio/wav;base64,DDD" + + +def test_oci_cohere_models_use_cohere_request_format( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["cohere_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + request = llm._build_chat_request( + [{"role": "user", "content": str(oci_unit_values["search_prompt"])}], + tools=[ + { + "type": "function", + "function": { + "name": "search_docs", + "description": "Search documentation", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ], + ) + + assert request.api_format == "COHERE" + assert request.message == str(oci_unit_values["search_prompt"]) + assert request.tools[0].name == "search_docs" + assert request.tools[0].parameter_definitions["query"].is_required is True + + +def test_oci_cohere_completion_formats_tool_calls( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_tool_call"]( + "search_docs", {"query": "Oracle Cloud"} + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["cohere_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["docs_prompt"])}], + tools=[ + { + "type": "function", + "function": { + "name": "search_docs", + "description": "Search documentation", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ], + ) + + assert isinstance(result, list) + assert result[0]["function"]["name"] == "search_docs" + assert json.loads(result[0]["function"]["arguments"]) == { + "query": "Oracle Cloud" + } + + +def test_oci_completion_returns_tool_calls_for_executor( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["tool_call"]( + "get_weather", '{"city":"Paris"}' + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["weather_prompt"])}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + ) + + assert isinstance(result, list) + assert result[0]["function"]["name"] == "get_weather" + assert result[0]["function"]["arguments"] == '{"city":"Paris"}' + + +def test_oci_completion_supports_generic_tool_controls( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["llama_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + tool_choice="get_weather", + parallel_tool_calls=True, + tool_result_guidance=True, + ) + + request = llm._build_chat_request( + [ + {"role": "assistant", "content": None, "tool_calls": []}, + { + "role": "tool", + "tool_call_id": "call_123", + "name": "get_weather", + "content": "Weather for Paris: sunny", + }, + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + ) + + assert request.is_parallel_tool_calls is True + assert request.tool_choice.name == "get_weather" + assert request.messages[-1].content[0].text.startswith( + "You have received tool results above." + ) + + +@pytest.mark.parametrize( + ("tool_choice", "expected_class_name", "expected_name"), + [ + ("auto", "ToolChoiceAuto", None), + ("none", "ToolChoiceNone", None), + ("required", "ToolChoiceRequired", None), + (True, "ToolChoiceRequired", None), + (False, "ToolChoiceNone", None), + ({"type": "function", "function": {"name": "search_docs"}}, "ToolChoiceFunction", "search_docs"), + ], +) +def test_oci_completion_formats_tool_choice_variants( + patch_oci_module, + oci_unit_values: dict[str, object], + tool_choice: object, + expected_class_name: str, + expected_name: str | None, +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + tool_choice=tool_choice, + ) + + formatted = llm._build_tool_choice() + + assert formatted.__class__.__name__ == expected_class_name + if expected_name is not None: + assert formatted.name == expected_name + + +def test_oci_completion_rejects_parallel_tools_for_cohere( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["cohere_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + parallel_tool_calls=True, + ) + + with pytest.raises(ValueError, match="do not support parallel_tool_calls"): + llm._build_chat_request( + [{"role": "user", "content": str(oci_unit_values["search_prompt"])}], + tools=[ + { + "type": "function", + "function": { + "name": "search_docs", + "description": "Search documentation", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + } + ], + ) + + +def test_oci_completion_executes_tool_calls_recursively( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.side_effect = [ + oci_response_factories["tool_call"]("get_weather", '{"city":"Paris"}'), + oci_response_factories["chat"]("Sunny in Paris"), + ] + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["weather_prompt"])}], + tools=[ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ], + available_functions={ + "get_weather": lambda city: f"Weather for {city}: sunny", + }, + ) + + assert result == "Sunny in Paris" + assert fake_client.chat.call_count == 2 + second_request = fake_client.chat.call_args_list[1].args[0] + assert second_request.chat_request.messages[1].tool_calls[0].name == "get_weather" + assert second_request.chat_request.messages[2].tool_call_id == "call_123" + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]( + "Hello from OCI async" + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + result = await llm.acall( + [{"role": "user", "content": str(oci_unit_values["hello_prompt"])}] + ) + + assert result == "Hello from OCI async" + assert fake_client.chat.call_count == 1 diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_async_batch.py b/lib/crewai/tests/llms/oci/test_oci_integration_async_batch.py new file mode 100644 index 00000000000..6c763cbc909 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_async_batch.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import asyncio + + +async def _collect_stream(llm, prompt: str) -> str: + chunks: list[str] = [] + async for chunk in llm.astream(prompt): + chunks.append(chunk) + return "".join(chunks) + + +def test_oci_live_abatch( + oci_chat_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "async"), + temperature=oci_temperature_for_model(oci_chat_model), + ) + + results = asyncio.run( + llm.abatch( + [ + "Reply with one short sentence about Oracle Cloud.", + "Reply with one short sentence about databases.", + ] + ) + ) + + assert len(results) == 2 + assert all(isinstance(result, str) and result.strip() for result in results) + + +def test_oci_live_astream( + oci_chat_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "stream"), + temperature=oci_temperature_for_model(oci_chat_model), + ) + + result = asyncio.run( + _collect_stream(llm, "Reply with a short sentence about Oracle Cloud.") + ) + + assert isinstance(result, str) + assert result.strip() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 00000000000..53ddcddbfb8 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import asyncio + + +def test_oci_live_basic_call( + oci_chat_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "basic"), + temperature=oci_temperature_for_model(oci_chat_model), + ) + + result = llm.call(oci_prompts["basic"]) + + assert isinstance(result, str) + assert result.strip() + + +def test_oci_live_async_call( + oci_chat_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "async"), + temperature=oci_temperature_for_model(oci_chat_model), + ) + + result = asyncio.run(llm.acall(oci_prompts["async"])) + + assert isinstance(result, str) + assert result.strip() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py new file mode 100644 index 00000000000..e913f3d82ff --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +import base64 +import io +import os +import wave + +from PIL import Image +import pytest + + +def _png_data_uri() -> str: + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color=(255, 255, 255)).save(buffer, format="PNG") + return ( + "data:image/png;base64," + f"{base64.b64encode(buffer.getvalue()).decode('ascii')}" + ) + + +def _pdf_data_uri() -> str: + buffer = io.BytesIO() + Image.new("RGB", (8, 8), color=(255, 255, 255)).save(buffer, format="PDF") + return ( + "data:application/pdf;base64," + f"{base64.b64encode(buffer.getvalue()).decode('ascii')}" + ) + + +def _wav_data_uri() -> str: + buffer = io.BytesIO() + with wave.open(buffer, "wb") as wav_file: + wav_file.setnchannels(1) + wav_file.setsampwidth(2) + wav_file.setframerate(8000) + wav_file.writeframes(b"\x00\x00" * 800) + return ( + "data:audio/wav;base64," + f"{base64.b64encode(buffer.getvalue()).decode('ascii')}" + ) + + +def test_oci_live_image_input( + oci_multimodal_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_multimodal_model, + max_tokens=oci_token_budget(oci_multimodal_model, "basic"), + temperature=oci_temperature_for_model(oci_multimodal_model), + ) + + result = llm.call( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Reply with a short sentence about this image."}, + {"type": "image_url", "image_url": {"url": _png_data_uri()}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert result.strip() + + +def test_oci_live_pdf_input( + oci_multimodal_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + if not oci_multimodal_model.startswith("google.gemini"): + pytest.skip("PDF multimodal coverage currently requires a Gemini OCI model") + + llm = oci_live_llm_factory( + oci_multimodal_model, + max_tokens=oci_token_budget(oci_multimodal_model, "basic"), + temperature=oci_temperature_for_model(oci_multimodal_model), + ) + + result = llm.call( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Reply briefly after inspecting this PDF."}, + {"type": "document_url", "document_url": {"url": _pdf_data_uri()}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert result.strip() + + +def test_oci_live_audio_input( + oci_multimodal_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + if not oci_multimodal_model.startswith("google.gemini"): + pytest.skip("Audio multimodal coverage currently requires a Gemini OCI model") + + llm = oci_live_llm_factory( + oci_multimodal_model, + max_tokens=oci_token_budget(oci_multimodal_model, "basic"), + temperature=oci_temperature_for_model(oci_multimodal_model), + ) + + result = llm.call( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Reply briefly after inspecting this audio."}, + {"type": "audio_url", "audio_url": {"url": _wav_data_uri()}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert result.strip() + + +def test_oci_live_video_input( + oci_multimodal_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + if not oci_multimodal_model.startswith("google.gemini"): + pytest.skip("Video multimodal coverage currently requires a Gemini OCI model") + + video_data_uri = os.getenv("OCI_TEST_VIDEO_DATA_URI") + if not video_data_uri: + pytest.skip("Configure OCI_TEST_VIDEO_DATA_URI for OCI live video tests") + + llm = oci_live_llm_factory( + oci_multimodal_model, + max_tokens=oci_token_budget(oci_multimodal_model, "basic"), + temperature=oci_temperature_for_model(oci_multimodal_model), + ) + + result = llm.call( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Reply briefly after inspecting this video."}, + {"type": "video_url", "video_url": {"url": video_data_uri}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert result.strip() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py new file mode 100644 index 00000000000..c760980a800 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py @@ -0,0 +1,21 @@ +from __future__ import annotations + + +def test_oci_live_streaming_call( + oci_chat_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "stream"), + temperature=oci_temperature_for_model(oci_chat_model), + stream=True, + ) + + result = llm.call(oci_prompts["stream"]) + + assert isinstance(result, str) + assert result.strip() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_structured.py b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py new file mode 100644 index 00000000000..6fd7be5fa60 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from pydantic import BaseModel + + +class OCIStructuredResponse(BaseModel): + summary: str + topic: str + + +def test_oci_live_structured_output( + oci_chat_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + llm = oci_live_llm_factory( + oci_chat_model, + max_tokens=oci_token_budget(oci_chat_model, "structured"), + temperature=oci_temperature_for_model(oci_chat_model), + ) + + result = llm.call( + oci_prompts["structured"], + response_model=OCIStructuredResponse, + ) + + assert isinstance(result, OCIStructuredResponse) + assert result.summary.strip() + assert result.topic.strip() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_tools.py b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py new file mode 100644 index 00000000000..ccbcc4ea4c7 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from pydantic import BaseModel + +from crewai import Agent +from crewai.tools import tool + + +class CalculationResult(BaseModel): + operation: str + result: int + explanation: str + + +@tool +def add_numbers(a: int, b: int) -> int: + """Add two numbers and return the sum.""" + return a + b + + +def test_oci_agent_uses_tool( + oci_tool_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + agent = Agent( + role="Calculator", + goal="Use tools to solve arithmetic problems", + backstory="You are a precise calculator that must use the available tools.", + llm=oci_live_llm_factory( + oci_tool_model, + max_tokens=oci_token_budget(oci_tool_model, "agent"), + temperature=oci_temperature_for_model(oci_tool_model), + ), + tools=[add_numbers], + verbose=True, + ) + + result = agent.kickoff(oci_prompts["tool"]) + + assert "42" in result.raw + assert add_numbers.current_usage_count >= 1 + + +def test_oci_agent_kickoff_structured_output_with_tools( + oci_tool_model: str, + oci_live_llm_factory, + oci_prompts: dict[str, str], + oci_temperature_for_model, + oci_token_budget, +): + agent = Agent( + role="Calculator", + goal="Perform calculations using available tools", + backstory="You are a calculator assistant that uses tools to compute results.", + llm=oci_live_llm_factory( + oci_tool_model, + max_tokens=oci_token_budget(oci_tool_model, "agent"), + temperature=oci_temperature_for_model(oci_tool_model), + ), + tools=[add_numbers], + verbose=True, + ) + + result = agent.kickoff( + messages=oci_prompts["tool_structured"], + response_format=CalculationResult, + ) + + assert result.pydantic is not None + assert isinstance(result.pydantic, CalculationResult) + assert result.pydantic.result == 42 + assert result.pydantic.operation + assert result.pydantic.explanation + + +def test_oci_agent_handles_multiple_tool_asks_in_sequence( + oci_tool_model: str, + oci_live_llm_factory, + oci_temperature_for_model, + oci_token_budget, +): + agent = Agent( + role="Calculator", + goal="Use tools to solve arithmetic problems accurately across repeated asks", + backstory="You are a calculator assistant that must use the available tool every time.", + llm=oci_live_llm_factory( + oci_tool_model, + max_tokens=oci_token_budget(oci_tool_model, "agent"), + temperature=oci_temperature_for_model(oci_tool_model), + ), + tools=[add_numbers], + verbose=True, + ) + + prompts = [ + "Use add_numbers to calculate 2 + 5. Return only the final result.", + "Use add_numbers to calculate 10 + 11. Return only the final result.", + "Use add_numbers to calculate 20 + 22. Return only the final result.", + ] + + results = [agent.kickoff(prompt) for prompt in prompts] + + assert "7" in results[0].raw + assert "21" in results[1].raw + assert "42" in results[2].raw diff --git a/lib/crewai/tests/llms/oci/test_oci_sdk_surface.py b/lib/crewai/tests/llms/oci/test_oci_sdk_surface.py new file mode 100644 index 00000000000..9509cc3c9cb --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_sdk_surface.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion +from crewai.llms.providers.oci.vision import ( + IMAGE_EMBEDDING_MODELS, + VISION_MODELS, + encode_image, + is_vision_model, + load_image, + to_data_uri, +) + + +def test_oci_iter_stream_yields_text_chunks_and_metadata( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["stream"]( + {"message": {"content": [{"text": "Hello"}]}}, + {"message": {"content": [{"text": " world"}]}}, + {"finishReason": "stop"}, + {"usage": {"promptTokens": 3, "completionTokens": 2, "totalTokens": 5}}, + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + chunks = list( + llm.iter_stream([{"role": "user", "content": str(oci_unit_values["hello_prompt"])}]) + ) + + assert chunks == ["Hello", " world"] + assert llm.last_response_metadata == { + "finish_reason": "stop", + "usage": {"prompt_tokens": 3, "completion_tokens": 2, "total_tokens": 5}, + } + + +@pytest.mark.asyncio +async def test_oci_astream_yields_text_chunks( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["stream"]( + {"message": {"content": [{"text": "Async"}]}}, + {"message": {"content": [{"text": " stream"}]}}, + {"finishReason": "stop"}, + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + chunks = [] + async for chunk in llm.astream( + [{"role": "user", "content": str(oci_unit_values["hello_prompt"])}] + ): + chunks.append(chunk) + + assert chunks == ["Async", " stream"] + + +@pytest.mark.asyncio +async def test_oci_abatch_runs_multiple_calls( + patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.side_effect = [ + oci_response_factories["chat"]("first"), + oci_response_factories["chat"]("second"), + ] + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + results = await llm.abatch(["prompt one", "prompt two"]) + + assert results == ["first", "second"] + assert fake_client.chat.call_count == 2 + + +def test_oci_extracts_response_metadata( + patch_oci_module, oci_unit_values: dict[str, object] +): + fake_client = MagicMock() + fake_client.chat.return_value = SimpleNamespace( + data=SimpleNamespace( + chat_response=SimpleNamespace( + finish_reason="stop", + citations=[{"start": 0, "end": 1}], + documents=[{"id": "doc_1"}], + search_queries=["oracle cloud"], + is_search_required=True, + usage=SimpleNamespace( + prompt_tokens=10, completion_tokens=5, total_tokens=15 + ), + ) + ) + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["cohere_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + llm.call("hello") + + assert llm.last_response_metadata == { + "finish_reason": "stop", + "documents": [{"id": "doc_1"}], + "citations": [{"start": 0, "end": 1}], + "search_queries": ["oracle cloud"], + "is_search_required": True, + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + +def test_oci_vision_utilities(tmp_path: Path): + image_path = tmp_path / "sample.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + + assert "meta.llama-3.2-90b-vision-instruct" in VISION_MODELS + assert "cohere.embed-v4.0" in IMAGE_EMBEDDING_MODELS + assert is_vision_model("google.gemini-2.5-flash") is True + assert is_vision_model("meta.llama-3.3-70b-instruct") is False + assert to_data_uri(b"hello", mime_type="image/png").startswith("data:image/png;base64,") + assert to_data_uri("data:image/png;base64,abc") == "data:image/png;base64,abc" + assert load_image(image_path)["image_url"]["url"].startswith("data:image/png;base64,") + assert encode_image(b"hello", "image/png")["image_url"]["url"].startswith( + "data:image/png;base64," + ) diff --git a/lib/crewai/tests/rag/embeddings/conftest.py b/lib/crewai/tests/rag/embeddings/conftest.py new file mode 100644 index 00000000000..fc6b1a5c282 --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/conftest.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import io +import os + +from PIL import Image +import pytest + + +def _valid_png_bytes() -> bytes: + buffer = io.BytesIO() + Image.new("RGB", (4, 4), color=(255, 255, 255)).save(buffer, format="PNG") + return buffer.getvalue() + + +def _has_oci_test_config() -> bool: + return bool( + os.getenv("OCI_COMPARTMENT_ID") + and (os.getenv("OCI_SERVICE_ENDPOINT") or os.getenv("OCI_REGION")) + ) + + +def _has_oci_sdk() -> bool: + try: + import oci # noqa: F401 + except ImportError: + return False + return True + + +def _embedding_provider_config(model_env_var: str, default_model: str) -> dict[str, object]: + config: dict[str, object] = { + "model_name": os.getenv(model_env_var, default_model), + "compartment_id": os.getenv("OCI_COMPARTMENT_ID"), + "auth_type": os.getenv("OCI_AUTH_TYPE", "API_KEY"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), + "auth_file_location": os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + } + if os.getenv("OCI_REGION"): + config["region"] = os.getenv("OCI_REGION") + if os.getenv("OCI_SERVICE_ENDPOINT"): + config["service_endpoint"] = os.getenv("OCI_SERVICE_ENDPOINT") + return config + + +@pytest.fixture +def oci_embeddings_live_config() -> dict[str, object]: + if not _has_oci_sdk() or not _has_oci_test_config(): + pytest.skip( + "Requires OCI SDK plus OCI_COMPARTMENT_ID and OCI endpoint configuration" + ) + return { + "text_model_env": "OCI_EMBED_TEST_MODEL", + "text_model_default": "cohere.embed-english-v3.0", + "image_model_env": "OCI_IMAGE_EMBED_TEST_MODEL", + "image_model_default": "cohere.embed-v4.0", + "text_inputs": [ + "Oracle Cloud Infrastructure provides cloud services.", + "Autonomous Database is an Oracle managed database service.", + ], + "image_query": "OCI architecture diagram", + "image_bytes": _valid_png_bytes(), + } + + +@pytest.fixture +def oci_embedding_provider_config(): + return _embedding_provider_config + + +@pytest.fixture +def allowed_hosts() -> list[str]: + return [r".*"] diff --git a/lib/crewai/tests/rag/embeddings/test_factory_oci.py b/lib/crewai/tests/rag/embeddings/test_factory_oci.py new file mode 100644 index 00000000000..a47fab3f3bd --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_factory_oci.py @@ -0,0 +1,211 @@ +"""Tests for OCI embedding provider wiring.""" + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.rag.embeddings.factory import build_embedder +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class _FakeOCI: + def __init__(self) -> None: + self.retry = SimpleNamespace(DEFAULT_RETRY_STRATEGY="retry") + self.config = SimpleNamespace( + from_file=lambda file_location, profile_name: { + "file_location": file_location, + "profile_name": profile_name, + } + ) + self.signer = SimpleNamespace( + load_private_key_from_file=lambda *_args, **_kwargs: "private-key" + ) + self.auth = SimpleNamespace( + signers=SimpleNamespace( + SecurityTokenSigner=lambda token, key: (token, key), + InstancePrincipalsSecurityTokenSigner=lambda: "instance-principal", + get_resource_principals_signer=lambda: "resource-principal", + ) + ) + self.generative_ai_inference = SimpleNamespace( + GenerativeAiInferenceClient=MagicMock(), + models=SimpleNamespace( + EmbedTextDetails=_simple_init_class("EmbedTextDetails"), + OnDemandServingMode=_simple_init_class("OnDemandServingMode"), + DedicatedServingMode=_simple_init_class("DedicatedServingMode"), + ), + ) + + +def _simple_init_class(name: str): + class _Simple: + output_dimensions = None + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + _Simple.__name__ = name + return _Simple + + +@patch("crewai.rag.embeddings.factory.import_and_validate_definition") +def test_build_embedder_oci(mock_import): + """Test building OCI embedder.""" + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_embedding_function = MagicMock() + + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + mock_provider_instance.embedding_callable.return_value = mock_embedding_function + + config = { + "provider": "oci", + "config": { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "region": "eu-frankfurt-1", + "auth_profile": "DEFAULT", + }, + } + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider" + ) + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["model_name"] == "cohere.embed-english-v3.0" + assert call_kwargs["compartment_id"] == "ocid1.compartment.oc1..test" + assert call_kwargs["region"] == "eu-frankfurt-1" + + +def test_oci_embedding_function_batches_requests(monkeypatch): + """Test OCI embedding batching and request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.side_effect = [ + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.1, 0.2], [0.3, 0.4]])), + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.5, 0.6]])), + ] + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + region="eu-frankfurt-1", + batch_size=2, + ) + + result = embedder(["a", "b", "c"]) + + result_rows = [embedding.tolist() for embedding in result] + expected_rows = [ + [0.1, 0.2], + [0.3, 0.4], + [0.5, 0.6], + ] + assert len(result_rows) == len(expected_rows) + for actual, expected in zip(result_rows, expected_rows, strict=True): + assert actual == pytest.approx(expected) + assert fake_client.embed_text.call_count == 2 + first_request = fake_client.embed_text.call_args_list[0].args[0] + assert first_request.compartment_id == "ocid1.compartment.oc1..test" + assert first_request.serving_mode.model_id == "cohere.embed-english-v3.0" + + +def test_oci_embedding_function_supports_output_dimensions(monkeypatch): + """Test OCI output_dimensions mapping.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.1, 0.2]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + output_dimensions=512, + ) + + embedder(["hello"]) + + request = fake_client.embed_text.call_args.args[0] + assert request.output_dimensions == 512 + + +def test_oci_embedding_function_exposes_serializable_config(monkeypatch): + """Test OCI embedding config serialization for ChromaDB compatibility.""" + fake_oci = _FakeOCI() + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + timeout=(5, 30), + ) + + assert embedder.get_config() == { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "timeout": [5, 30], + } + + rebuilt = OCIEmbeddingFunction.build_from_config(embedder.get_config()) + assert rebuilt.get_config() == embedder.get_config() + + +def test_oci_embedding_function_supports_image_embeddings(monkeypatch, tmp_path: Path): + """Test OCI image embedding request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.7, 0.8, 0.9]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + image_path = tmp_path / "diagram.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + ) + + result = embedder.embed_image(image_path) + + assert result == pytest.approx([0.7, 0.8, 0.9]) + request = fake_client.embed_text.call_args.args[0] + assert request.input_type == "IMAGE" + assert request.inputs[0].startswith("data:image/png;base64,") diff --git a/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py new file mode 100644 index 00000000000..2c3c7e9a37f --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from crewai.rag.embeddings.factory import build_embedder + + +def test_oci_live_embedding_call( + oci_embeddings_live_config: dict[str, object], + oci_embedding_provider_config, +) -> None: + embedder = build_embedder( + { + "provider": "oci", + "config": oci_embedding_provider_config( + str(oci_embeddings_live_config["text_model_env"]), + str(oci_embeddings_live_config["text_model_default"]), + ), + } + ) + + result = embedder(list(oci_embeddings_live_config["text_inputs"])) + + assert len(result) == 2 + assert all(len(embedding) > 0 for embedding in result) diff --git a/lib/crewai/tests/rag/embeddings/test_oci_image_embedding_integration.py b/lib/crewai/tests/rag/embeddings/test_oci_image_embedding_integration.py new file mode 100644 index 00000000000..b57cef4f410 --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_oci_image_embedding_integration.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +import pytest + +from crewai.rag.embeddings.factory import build_embedder + + +def test_oci_live_image_embedding_call( + oci_embeddings_live_config: dict[str, object], + oci_embedding_provider_config, +) -> None: + text_embedder = build_embedder( + { + "provider": "oci", + "config": oci_embedding_provider_config( + str(oci_embeddings_live_config["text_model_env"]), + str(oci_embeddings_live_config["text_model_default"]), + ), + } + ) + image_embedder = build_embedder( + { + "provider": "oci", + "config": oci_embedding_provider_config( + str(oci_embeddings_live_config["image_model_env"]), + str(oci_embeddings_live_config["image_model_default"]), + ), + } + ) + + text_vector = text_embedder([str(oci_embeddings_live_config["image_query"])])[0] + try: + image_vector = image_embedder.embed_image( + bytes(oci_embeddings_live_config["image_bytes"]), + mime_type="image/png", + ) + except Exception as exc: + if "Entity with key" in str(exc) and "not found" in str(exc): + pytest.skip( + "OCI image embedding model is listed in this tenancy but not invokable via embedText." + ) + raise + + assert len(text_vector) > 0 + assert len(image_vector) > 0 From d6ae8399fc8fd76574b28d739237823d4a625c42 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 14 Mar 2026 23:03:16 -0400 Subject: [PATCH 2/9] Add OCI and Oracle DB tools integrations --- .../cloud-storage/ociknowledgebasetool.mdx | 111 +++++ .../ociobjectstoragereadertool.mdx | 38 ++ .../ociobjectstoragewritertool.mdx | 36 ++ docs/en/tools/cloud-storage/overview.mdx | 27 +- .../tools/integration/ociinvokeagenttool.mdx | 43 ++ docs/en/tools/integration/overview.mdx | 17 +- docs/en/tools/overview.mdx | 6 + docs/en/tools/tool-integrations/overview.mdx | 10 +- lib/crewai-tools/pyproject.toml | 6 + lib/crewai-tools/src/crewai_tools/__init__.py | 16 + .../src/crewai_tools/oci/__init__.py | 12 + .../src/crewai_tools/oci/agents/README.md | 19 + .../src/crewai_tools/oci/agents/__init__.py | 4 + .../oci/agents/invoke_agent_tool.py | 131 +++++ .../src/crewai_tools/oci/common.py | 99 ++++ .../oci/knowledge_base/__init__.py | 4 + .../oci/knowledge_base/retriever_tool.py | 108 +++++ .../crewai_tools/oci/object_storage/README.md | 23 + .../oci/object_storage/__init__.py | 8 + .../oci/object_storage/reader_tool.py | 104 ++++ .../oci/object_storage/writer_tool.py | 93 ++++ .../src/crewai_tools/oracle_db/__init__.py | 12 + .../src/crewai_tools/oracle_db/common.py | 85 ++++ .../oracle_db/knowledge_base/__init__.py | 12 + .../knowledge_base/retriever_tool.py | 384 +++++++++++++++ lib/crewai-tools/tests/tools/conftest.py | 1 + .../tests/tools/oci_tools_test.py | 88 ++++ .../tests/tools/oracle_db/conftest.py | 450 ++++++++++++++++++ .../tests/tools/oracle_db_tools_test.py | 140 ++++++ ...est_oci_knowledge_base_tool_integration.py | 57 +++ ...t_oracle_hybrid_search_tool_integration.py | 35 ++ ...est_oracle_text_search_tool_integration.py | 33 ++ ...t_oracle_vector_search_tool_integration.py | 34 ++ uv.lock | 74 ++- 34 files changed, 2311 insertions(+), 9 deletions(-) create mode 100644 docs/en/tools/cloud-storage/ociknowledgebasetool.mdx create mode 100644 docs/en/tools/cloud-storage/ociobjectstoragereadertool.mdx create mode 100644 docs/en/tools/cloud-storage/ociobjectstoragewritertool.mdx create mode 100644 docs/en/tools/integration/ociinvokeagenttool.mdx create mode 100644 lib/crewai-tools/src/crewai_tools/oci/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/agents/README.md create mode 100644 lib/crewai-tools/src/crewai_tools/oci/agents/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/agents/invoke_agent_tool.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/common.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/knowledge_base/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/object_storage/README.md create mode 100644 lib/crewai-tools/src/crewai_tools/oci/object_storage/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/object_storage/reader_tool.py create mode 100644 lib/crewai-tools/src/crewai_tools/oci/object_storage/writer_tool.py create mode 100644 lib/crewai-tools/src/crewai_tools/oracle_db/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oracle_db/common.py create mode 100644 lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/__init__.py create mode 100644 lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py create mode 100644 lib/crewai-tools/tests/tools/conftest.py create mode 100644 lib/crewai-tools/tests/tools/oci_tools_test.py create mode 100644 lib/crewai-tools/tests/tools/oracle_db/conftest.py create mode 100644 lib/crewai-tools/tests/tools/oracle_db_tools_test.py create mode 100644 lib/crewai-tools/tests/tools/test_oci_knowledge_base_tool_integration.py create mode 100644 lib/crewai-tools/tests/tools/test_oracle_hybrid_search_tool_integration.py create mode 100644 lib/crewai-tools/tests/tools/test_oracle_text_search_tool_integration.py create mode 100644 lib/crewai-tools/tests/tools/test_oracle_vector_search_tool_integration.py diff --git a/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx b/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx new file mode 100644 index 00000000000..ca378b4afda --- /dev/null +++ b/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx @@ -0,0 +1,111 @@ +--- +title: "OCI Knowledge Base Tool" +description: "Build and query a CrewAI-managed knowledge base powered by OCI embeddings" +icon: "database" +mode: "wide" +--- + +# `OCIKnowledgeBaseTool` + +The `OCIKnowledgeBaseTool` gives your agents a CrewAI-managed retrieval tool powered by OCI embedding models. It uses CrewAI's native `RagTool` stack and defaults to OCI embeddings, so you can load documents, directories, or URLs and query them semantically from your crews. + +This is the closest CrewAI-native OCI equivalent to the Bedrock knowledge-base workflow. Unlike Amazon Bedrock Knowledge Bases, the index is managed inside CrewAI's RAG system rather than an OCI managed retrieval service. + +## Installation + +```bash +uv add "crewai[oci]" +uv add "crewai-tools[oci]" +``` + +## Example + +```python +from crewai import Agent +from crewai_tools import OCIKnowledgeBaseTool + +kb_tool = OCIKnowledgeBaseTool( + knowledge_source="./oracle-architecture.pdf", + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + region="eu-frankfurt-1", +) + +agent = Agent( + role="OCI Research Analyst", + goal="Answer architecture questions with the indexed OCI knowledge base", + tools=[kb_tool], + verbose=True, +) +``` + +## Add Sources Dynamically + +```python +from crewai_tools import OCIKnowledgeBaseTool + +kb_tool = OCIKnowledgeBaseTool( + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + region="eu-frankfurt-1", +) + +kb_tool.add("./runbooks/networking.md") +kb_tool.add("./runbooks/security/") +kb_tool.add("https://docs.oracle.com/en-us/iaas/Content/home.htm") +``` + +## Configuration + +The tool defaults to this embedding configuration: + +```python +{ + "embedding_model": { + "provider": "oci", + "config": { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..exampleuniqueID", + "region": "eu-frankfurt-1", + "auth_type": "API_KEY", + "auth_profile": "DEFAULT", + "auth_file_location": "~/.oci/config", + } + } +} +``` + +You can override the vector database layer with standard `RagTool` config: + +```python +from crewai_tools import OCIKnowledgeBaseTool + +kb_tool = OCIKnowledgeBaseTool( + compartment_id="ocid1.compartment.oc1..exampleuniqueID", + config={ + "vectordb": { + "provider": "qdrant", + "config": { + "url": "http://localhost:6333", + "api_key": "qdrant-key", + }, + } + }, +) +``` + +## Environment Variables + +```bash +OCI_COMPARTMENT_ID=ocid1.compartment.oc1..exampleuniqueID +OCI_REGION=eu-frankfurt-1 +OCI_AUTH_TYPE=API_KEY +OCI_AUTH_PROFILE=DEFAULT +OCI_AUTH_FILE_LOCATION=~/.oci/config +OCI_EMBED_MODEL=cohere.embed-english-v3.0 +``` + +## Notes + +- Uses CrewAI's native RAG stack, not an OCI managed knowledge-base service +- Supports any source type that `RagTool` can ingest +- Defaults to OCI embeddings, but you can still override the `config` field for advanced vector store configuration +- For direct OCI text or image embedding workflows with `cohere.embed-v4.0`, use the CrewAI OCI embedding provider outside of `OCIKnowledgeBaseTool` diff --git a/docs/en/tools/cloud-storage/ociobjectstoragereadertool.mdx b/docs/en/tools/cloud-storage/ociobjectstoragereadertool.mdx new file mode 100644 index 00000000000..76d6f6659f7 --- /dev/null +++ b/docs/en/tools/cloud-storage/ociobjectstoragereadertool.mdx @@ -0,0 +1,38 @@ +--- +title: "OCI Object Storage Reader Tool" +description: "Read files from Oracle Cloud Infrastructure Object Storage" +icon: "cloud" +mode: "wide" +--- + +# `OCIObjectStorageReaderTool` + +Use `OCIObjectStorageReaderTool` to read text files from Oracle Cloud Infrastructure Object Storage inside a CrewAI workflow. + +## Installation + +```bash +uv pip install 'crewai-tools[oci]' +``` + +## Usage + +```python +from crewai import Agent +from crewai_tools import OCIObjectStorageReaderTool + +oci_reader = OCIObjectStorageReaderTool(namespace_name="my-namespace") + +agent = Agent( + role="Cloud Reader", + goal="Fetch cloud-hosted files", + tools=[oci_reader], +) +``` + +## Path Formats + +- `oci://bucket/path/to/file.txt` +- `oci://namespace@bucket/path/to/file.txt` + +If the namespace is omitted, the tool will use `namespace_name`, `OCI_OBJECT_STORAGE_NAMESPACE`, or fetch the namespace from OCI automatically. diff --git a/docs/en/tools/cloud-storage/ociobjectstoragewritertool.mdx b/docs/en/tools/cloud-storage/ociobjectstoragewritertool.mdx new file mode 100644 index 00000000000..5eeb2e11b9a --- /dev/null +++ b/docs/en/tools/cloud-storage/ociobjectstoragewritertool.mdx @@ -0,0 +1,36 @@ +--- +title: "OCI Object Storage Writer Tool" +description: "Write files to Oracle Cloud Infrastructure Object Storage" +icon: "cloud-arrow-up" +mode: "wide" +--- + +# `OCIObjectStorageWriterTool` + +Use `OCIObjectStorageWriterTool` to upload text content to Oracle Cloud Infrastructure Object Storage from a CrewAI workflow. + +## Installation + +```bash +uv pip install 'crewai-tools[oci]' +``` + +## Usage + +```python +from crewai import Agent +from crewai_tools import OCIObjectStorageWriterTool + +oci_writer = OCIObjectStorageWriterTool(namespace_name="my-namespace") + +agent = Agent( + role="Cloud Writer", + goal="Persist generated reports in OCI", + tools=[oci_writer], +) +``` + +## Path Formats + +- `oci://bucket/path/to/file.txt` +- `oci://namespace@bucket/path/to/file.txt` diff --git a/docs/en/tools/cloud-storage/overview.mdx b/docs/en/tools/cloud-storage/overview.mdx index 87e23ca390e..414066b18fd 100644 --- a/docs/en/tools/cloud-storage/overview.mdx +++ b/docs/en/tools/cloud-storage/overview.mdx @@ -18,6 +18,14 @@ These tools enable your agents to interact with cloud services, access cloud sto Write and upload files to Amazon S3 storage. + + Read files and data from Oracle Cloud Infrastructure Object Storage. + + + + Write files to Oracle Cloud Infrastructure Object Storage. + + Invoke Amazon Bedrock agents for AI-powered tasks. @@ -25,6 +33,10 @@ These tools enable your agents to interact with cloud services, access cloud sto Retrieve information from Amazon Bedrock knowledge bases. + + + Query a CrewAI-managed knowledge base powered by OCI embeddings. + ## **Common Use Cases** @@ -36,16 +48,27 @@ These tools enable your agents to interact with cloud services, access cloud sto - **Scalable Operations**: Leverage cloud infrastructure for processing ```python -from crewai_tools import S3ReaderTool, S3WriterTool, BedrockInvokeAgentTool +from crewai_tools import ( + BedrockInvokeAgentTool, + OCIKnowledgeBaseTool, + OCIObjectStorageReaderTool, + OCIObjectStorageWriterTool, + S3ReaderTool, + S3WriterTool, +) # Create cloud tools s3_reader = S3ReaderTool() s3_writer = S3WriterTool() +oci_reader = OCIObjectStorageReaderTool(namespace_name="my-namespace") +oci_writer = OCIObjectStorageWriterTool(namespace_name="my-namespace") +oci_kb = OCIKnowledgeBaseTool(knowledge_source="./oracle-architecture.pdf") bedrock_agent = BedrockInvokeAgentTool() # Add to your agent agent = Agent( role="Cloud Operations Specialist", - tools=[s3_reader, s3_writer, bedrock_agent], + tools=[s3_reader, s3_writer, oci_reader, oci_writer, oci_kb, bedrock_agent], goal="Manage cloud resources and AI services" ) +``` diff --git a/docs/en/tools/integration/ociinvokeagenttool.mdx b/docs/en/tools/integration/ociinvokeagenttool.mdx new file mode 100644 index 00000000000..a58b3fd0675 --- /dev/null +++ b/docs/en/tools/integration/ociinvokeagenttool.mdx @@ -0,0 +1,43 @@ +--- +title: "OCI Generative AI Agent Tool" +description: "Invoke Oracle Cloud Infrastructure Generative AI agent endpoints from CrewAI" +icon: "cloud" +mode: "wide" +--- + +# `OCIGenAIInvokeAgentTool` + +`OCIGenAIInvokeAgentTool` lets CrewAI agents call Oracle Cloud Infrastructure Generative AI agent endpoints directly. + +## Installation + +```bash +uv pip install 'crewai-tools[oci]' +``` + +## Usage + +```python +from crewai import Agent +from crewai_tools import OCIGenAIInvokeAgentTool + +oci_agent_tool = OCIGenAIInvokeAgentTool( + agent_endpoint_id="ocid1.genaiagentendpoint.oc1..exampleuniqueID" +) + +agent = Agent( + role="Oracle AI Specialist", + goal="Delegate managed Oracle AI tasks", + tools=[oci_agent_tool], +) +``` + +## Environment Variables + +```bash +OCI_AGENT_ENDPOINT_ID=ocid1.genaiagentendpoint.oc1..exampleuniqueID +OCI_AUTH_TYPE=API_KEY +OCI_AUTH_PROFILE=DEFAULT +OCI_AUTH_FILE_LOCATION=~/.oci/config +OCI_AGENT_RUNTIME_ENDPOINT=https://agent-runtime.generativeai.eu-frankfurt-1.oci.oraclecloud.com +``` diff --git a/docs/en/tools/integration/overview.mdx b/docs/en/tools/integration/overview.mdx index 001a07967bc..17d4eaba532 100644 --- a/docs/en/tools/integration/overview.mdx +++ b/docs/en/tools/integration/overview.mdx @@ -21,12 +21,16 @@ Integration tools let your agents hand off work to other automation platforms an Call Amazon Bedrock Agents from your crews, reuse AWS guardrails, and stream responses back into the workflow. + + + Call Oracle Cloud Infrastructure Generative AI agent endpoints directly from your crews. + ## **Common Use Cases** - **Chain automations**: Kick off an existing CrewAI deployment from within another crew or flow -- **Enterprise hand-off**: Route tasks to Bedrock Agents that already encapsulate company logic and guardrails +- **Enterprise hand-off**: Route tasks to Bedrock or OCI managed agents that already encapsulate company logic and guardrails - **Hybrid workflows**: Combine CrewAI reasoning with downstream systems that expose their own agent APIs - **Long-running jobs**: Poll external automations and merge the final results back into the current run @@ -34,7 +38,7 @@ Integration tools let your agents hand off work to other automation platforms an ```python from crewai import Agent, Task, Crew -from crewai_tools import InvokeCrewAIAutomationTool +from crewai_tools import InvokeCrewAIAutomationTool, OCIGenAIInvokeAgentTool from crewai_tools.aws.bedrock.agents.invoke_agent_tool import BedrockInvokeAgentTool # External automation @@ -51,11 +55,16 @@ knowledge_router = BedrockInvokeAgentTool( agent_alias_id="prod", ) +# Managed agent on OCI +oracle_router = OCIGenAIInvokeAgentTool( + agent_endpoint_id="ocid1.genaiagentendpoint.oc1..exampleuniqueID", +) + automation_strategist = Agent( role="Automation Strategist", goal="Orchestrate external automations and summarise their output", backstory="You coordinate enterprise workflows and know when to delegate tasks to specialised services.", - tools=[analysis_automation, knowledge_router], + tools=[analysis_automation, knowledge_router, oracle_router], verbose=True, ) @@ -71,6 +80,6 @@ Crew(agents=[automation_strategist], tasks=[execute_playbook]).kickoff() - **Secure credentials**: Store API keys and bearer tokens in environment variables or a secrets manager - **Plan for latency**: External automations may take longer—set appropriate polling intervals and timeouts -- **Reuse sessions**: Bedrock Agents support session IDs so you can maintain context across multiple tool calls +- **Reuse sessions**: Bedrock and OCI agent tools support session IDs so you can maintain context across multiple tool calls - **Validate responses**: Normalise remote output (JSON, text, status codes) before forwarding it to downstream tasks - **Monitor usage**: Track audit logs in CrewAI Platform or AWS CloudWatch to stay ahead of quota limits and failures diff --git a/docs/en/tools/overview.mdx b/docs/en/tools/overview.mdx index fb9926f0b5d..e742c949dcd 100644 --- a/docs/en/tools/overview.mdx +++ b/docs/en/tools/overview.mdx @@ -106,6 +106,12 @@ Need a specific tool? Here are some popular choices: Access AWS S3 files + + Access OCI Object Storage files + + + Build OCI-powered semantic retrieval + ## **Getting Started** diff --git a/docs/en/tools/tool-integrations/overview.mdx b/docs/en/tools/tool-integrations/overview.mdx index a40afecc8c7..634f6047a0f 100644 --- a/docs/en/tools/tool-integrations/overview.mdx +++ b/docs/en/tools/tool-integrations/overview.mdx @@ -25,7 +25,15 @@ mode: "wide" > Automate deployment and operations by integrating CrewAI with external platforms and workflows. + + + Invoke Oracle Cloud Infrastructure Generative AI agent endpoints from CrewAI. + Use these integrations to connect CrewAI with your infrastructure and workflows. - diff --git a/lib/crewai-tools/pyproject.toml b/lib/crewai-tools/pyproject.toml index 17b7c71b5a7..acae1785007 100644 --- a/lib/crewai-tools/pyproject.toml +++ b/lib/crewai-tools/pyproject.toml @@ -140,6 +140,12 @@ contextual = [ "contextual-client>=0.1.0", "nest-asyncio>=1.6.0", ] +oci = [ + "oci>=2.161.0", +] +oracle = [ + "oracledb>=2.5.1", +] [build-system] diff --git a/lib/crewai-tools/src/crewai_tools/__init__.py b/lib/crewai-tools/src/crewai_tools/__init__.py index aab05fed6d6..bf53fffa501 100644 --- a/lib/crewai-tools/src/crewai_tools/__init__.py +++ b/lib/crewai-tools/src/crewai_tools/__init__.py @@ -7,6 +7,15 @@ ) from crewai_tools.aws.s3.reader_tool import S3ReaderTool from crewai_tools.aws.s3.writer_tool import S3WriterTool +from crewai_tools.oci.agents.invoke_agent_tool import OCIGenAIInvokeAgentTool +from crewai_tools.oci.knowledge_base.retriever_tool import OCIKnowledgeBaseTool +from crewai_tools.oci.object_storage.reader_tool import OCIObjectStorageReaderTool +from crewai_tools.oci.object_storage.writer_tool import OCIObjectStorageWriterTool +from crewai_tools.oracle_db.knowledge_base.retriever_tool import ( + OracleHybridSearchTool, + OracleTextSearchTool, + OracleVectorSearchTool, +) from crewai_tools.tools.ai_mind_tool.ai_mind_tool import AIMindTool from crewai_tools.tools.apify_actors_tool.apify_actors_tool import ApifyActorsTool from crewai_tools.tools.arxiv_paper_tool.arxiv_paper_tool import ArxivPaperTool @@ -262,7 +271,14 @@ "MultiOnTool", "MySQLSearchTool", "NL2SQLTool", + "OCIGenAIInvokeAgentTool", + "OCIKnowledgeBaseTool", + "OCIObjectStorageReaderTool", + "OCIObjectStorageWriterTool", "OCRTool", + "OracleHybridSearchTool", + "OracleTextSearchTool", + "OracleVectorSearchTool", "OxylabsAmazonProductScraperTool", "OxylabsAmazonSearchScraperTool", "OxylabsGoogleSearchScraperTool", diff --git a/lib/crewai-tools/src/crewai_tools/oci/__init__.py b/lib/crewai-tools/src/crewai_tools/oci/__init__.py new file mode 100644 index 00000000000..1226f621d1e --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/__init__.py @@ -0,0 +1,12 @@ +from crewai_tools.oci.agents.invoke_agent_tool import OCIGenAIInvokeAgentTool +from crewai_tools.oci.knowledge_base.retriever_tool import OCIKnowledgeBaseTool +from crewai_tools.oci.object_storage.reader_tool import OCIObjectStorageReaderTool +from crewai_tools.oci.object_storage.writer_tool import OCIObjectStorageWriterTool + + +__all__ = [ + "OCIGenAIInvokeAgentTool", + "OCIKnowledgeBaseTool", + "OCIObjectStorageReaderTool", + "OCIObjectStorageWriterTool", +] diff --git a/lib/crewai-tools/src/crewai_tools/oci/agents/README.md b/lib/crewai-tools/src/crewai_tools/oci/agents/README.md new file mode 100644 index 00000000000..db946bcd14c --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/agents/README.md @@ -0,0 +1,19 @@ +# OCI Generative AI Agent Tool + +This tool lets CrewAI agents invoke an Oracle Cloud Infrastructure Generative AI agent endpoint. + +## Installation + +```bash +uv pip install 'crewai-tools[oci]' +``` + +## Usage + +```python +from crewai_tools import OCIGenAIInvokeAgentTool + +agent_tool = OCIGenAIInvokeAgentTool( + agent_endpoint_id="ocid1.genaiagentendpoint.oc1..exampleuniqueID" +) +``` diff --git a/lib/crewai-tools/src/crewai_tools/oci/agents/__init__.py b/lib/crewai-tools/src/crewai_tools/oci/agents/__init__.py new file mode 100644 index 00000000000..06013e876ea --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/agents/__init__.py @@ -0,0 +1,4 @@ +from crewai_tools.oci.agents.invoke_agent_tool import OCIGenAIInvokeAgentTool + + +__all__ = ["OCIGenAIInvokeAgentTool"] diff --git a/lib/crewai-tools/src/crewai_tools/oci/agents/invoke_agent_tool.py b/lib/crewai-tools/src/crewai_tools/oci/agents/invoke_agent_tool.py new file mode 100644 index 00000000000..774a8c97597 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/agents/invoke_agent_tool.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +import os +import time +from typing import Any, cast + +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +from crewai_tools.oci.common import create_oci_client_kwargs, get_oci_module + + +class OCIGenAIInvokeAgentToolInput(BaseModel): + """Input schema for OCIGenAIInvokeAgentTool.""" + + query: str = Field(..., description="The query to send to the OCI Generative AI agent") + + +class OCIGenAIInvokeAgentTool(BaseTool): + name: str = "OCI Generative AI Agent Invoke Tool" + description: str = ( + "Invokes an Oracle Cloud Infrastructure Generative AI agent endpoint." + ) + args_schema: type[BaseModel] = OCIGenAIInvokeAgentToolInput + package_dependencies: list[str] = Field(default_factory=lambda: ["oci"]) + agent_endpoint_id: str | None = None + session_id: str | None = None + create_session_if_missing: bool = True + client: Any | None = None + + def __init__( + self, + agent_endpoint_id: str | None = None, + session_id: str | None = None, + create_session_if_missing: bool = True, + description: str | None = None, + *, + auth_type: str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + service_endpoint: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.agent_endpoint_id = agent_endpoint_id or os.getenv( + "OCI_AGENT_ENDPOINT_ID" + ) + self.session_id = session_id + self.create_session_if_missing = create_session_if_missing + self.client = client + + if description: + self.description = description + + if not self.agent_endpoint_id: + raise ValueError( + "agent_endpoint_id is required. Set it explicitly or use OCI_AGENT_ENDPOINT_ID." + ) + + if self.client is None: + oci = get_oci_module() + resolved_auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + resolved_auth_file_location = cast( + str, + auth_file_location or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + client_kwargs = create_oci_client_kwargs( + auth_type=auth_type, + auth_profile=resolved_auth_profile, + auth_file_location=resolved_auth_file_location, + service_endpoint=service_endpoint + or os.getenv("OCI_AGENT_RUNTIME_ENDPOINT"), + timeout=(10, 180), + ) + self.client = oci.generative_ai_agent_runtime.GenerativeAiAgentRuntimeClient( + **client_kwargs + ) + + def _require_client(self) -> Any: + if self.client is None: + raise ValueError("OCI Generative AI agent client is not initialized.") + return self.client + + def _ensure_session_id(self) -> str | None: + if self.session_id or not self.create_session_if_missing: + return self.session_id + + oci = get_oci_module() + response = self._require_client().create_session( + agent_endpoint_id=self.agent_endpoint_id, + create_session_details=oci.generative_ai_agent_runtime.models.CreateSessionDetails( + display_name=f"crewai-tools-{int(time.time())}", + description="Created by CrewAI OCI agent tool", + ), + ) + self.session_id = str(response.data.id) + return self.session_id + + def _extract_text(self, response: Any) -> str: + message = getattr(response.data, "message", None) + content = getattr(message, "content", None) + if content is not None and getattr(content, "text", None): + return str(content.text) + + required_actions = getattr(response.data, "required_actions", None) or [] + if required_actions: + return ( + "OCI agent requires follow-up actions before it can complete the response." + ) + + return str(response.data) + + def _run(self, query: str) -> str: + try: + oci = get_oci_module() + session_id = self._ensure_session_id() + chat_details = oci.generative_ai_agent_runtime.models.ChatDetails( + user_message=query, + session_id=session_id, + should_stream=False, + ) + response = self._require_client().chat( + agent_endpoint_id=self.agent_endpoint_id, + chat_details=chat_details, + ) + return self._extract_text(response) + except Exception as error: + return f"Error invoking OCI Generative AI agent: {error!s}" diff --git a/lib/crewai-tools/src/crewai_tools/oci/common.py b/lib/crewai-tools/src/crewai_tools/oci/common.py new file mode 100644 index 00000000000..156e30cbe0f --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/common.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import os +from typing import Any + + +DEFAULT_OCI_REGION = "eu-frankfurt-1" + + +def get_oci_module() -> Any: + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + "`oci` package not found, please install the optional dependency with " + "`uv add 'crewai-tools[oci]'`" + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + auth_profile: str, + auth_file_location: str, + service_endpoint: str | None = None, + timeout: tuple[int, int] = (10, 120), +) -> dict[str, Any]: + """Create standard OCI client kwargs for CrewAI tools.""" + oci = get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + private_key = oci.signer.load_private_key_from_file(config["key_file"], None) + with open(config["security_token_file"], encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs + + +def parse_object_storage_path(file_path: str) -> tuple[str | None, str, str]: + """Parse an OCI Object Storage path. + + Supported formats: + - `oci://bucket/path/to/object.txt` + - `oci://namespace@bucket/path/to/object.txt` + """ + normalized = file_path.removeprefix("oci://") + bucket_part, _, object_name = normalized.partition("/") + if not bucket_part or not object_name: + raise ValueError( + "OCI Object Storage paths must be in the form " + "`oci://bucket/path` or `oci://namespace@bucket/path`." + ) + + if "@" in bucket_part: + namespace_name, bucket_name = bucket_part.split("@", 1) + return namespace_name, bucket_name, object_name + + return None, bucket_part, object_name + + +def get_region() -> str: + return os.getenv("OCI_REGION", DEFAULT_OCI_REGION) diff --git a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/__init__.py b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/__init__.py new file mode 100644 index 00000000000..63f3786398d --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/__init__.py @@ -0,0 +1,4 @@ +from crewai_tools.oci.knowledge_base.retriever_tool import OCIKnowledgeBaseTool + + +__all__ = ["OCIKnowledgeBaseTool"] diff --git a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py new file mode 100644 index 00000000000..305dc72434b --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +from pydantic import BaseModel, Field + +from crewai_tools.tools.rag.rag_tool import RagTool + + +class FixedOCIKnowledgeBaseToolSchema(BaseModel): + """Input for OCIKnowledgeBaseTool with a preconfigured source.""" + + query: str = Field( + ..., description="Mandatory query you want to use to search the knowledge base" + ) + + +class OCIKnowledgeBaseToolSchema(FixedOCIKnowledgeBaseToolSchema): + """Input for OCIKnowledgeBaseTool.""" + + knowledge_source: str = Field( + ..., + description=( + "File path, directory path, URL, or other CrewAI-supported source " + "to add to the OCI-backed knowledge base before querying" + ), + ) + + +class OCIKnowledgeBaseTool(RagTool): + name: str = "OCI Knowledge Base Tool" + description: str = ( + "A CrewAI-managed knowledge base tool powered by OCI embeddings." + ) + args_schema: type[BaseModel] = OCIKnowledgeBaseToolSchema + + def __init__( + self, + knowledge_source: str | None = None, + *, + model_name: str | None = None, + compartment_id: str | None = None, + service_endpoint: str | None = None, + region: str | None = None, + auth_type: str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + config: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: + oci_embedding_config: dict[str, str] = { + "model_name": cast( + str, + model_name or os.getenv("OCI_EMBED_MODEL", "cohere.embed-english-v3.0"), + ), + "compartment_id": cast( + str, + compartment_id or os.getenv("OCI_COMPARTMENT_ID", ""), + ), + "region": cast(str, region or os.getenv("OCI_REGION", "eu-frankfurt-1")), + "auth_type": auth_type, + "auth_profile": cast( + str, + auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), + ), + "auth_file_location": cast( + str, + auth_file_location or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ), + } + if service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT"): + oci_embedding_config["service_endpoint"] = ( + service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") or "" + ) + + merged_config = dict(config or {}) + merged_config.setdefault( + "embedding_model", + { + "provider": "oci", + "config": oci_embedding_config, + }, + ) + + super().__init__(config=merged_config, **kwargs) + + if knowledge_source is not None: + self.add(knowledge_source) + self.description = ( + "A CrewAI-managed knowledge base tool powered by OCI embeddings " + f"and preloaded with {knowledge_source}." + ) + self.args_schema = FixedOCIKnowledgeBaseToolSchema + self._generate_description() + + def _run( # type: ignore[override] + self, + query: str, + knowledge_source: str | None = None, + similarity_threshold: float | None = None, + limit: int | None = None, + ) -> str: + if knowledge_source is not None: + self.add(knowledge_source) + return super()._run( + query=query, similarity_threshold=similarity_threshold, limit=limit + ) diff --git a/lib/crewai-tools/src/crewai_tools/oci/object_storage/README.md b/lib/crewai-tools/src/crewai_tools/oci/object_storage/README.md new file mode 100644 index 00000000000..cf4365f4899 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/object_storage/README.md @@ -0,0 +1,23 @@ +# OCI Object Storage Tools + +These tools let CrewAI agents read from and write to Oracle Cloud Infrastructure Object Storage. + +## Installation + +```bash +uv pip install 'crewai-tools[oci]' +``` + +## Usage + +```python +from crewai_tools import OCIObjectStorageReaderTool, OCIObjectStorageWriterTool + +reader = OCIObjectStorageReaderTool(namespace_name="my-namespace") +writer = OCIObjectStorageWriterTool(namespace_name="my-namespace") +``` + +Supported path formats: + +- `oci://bucket/path/to/file.txt` +- `oci://namespace@bucket/path/to/file.txt` diff --git a/lib/crewai-tools/src/crewai_tools/oci/object_storage/__init__.py b/lib/crewai-tools/src/crewai_tools/oci/object_storage/__init__.py new file mode 100644 index 00000000000..fdd34ccd445 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/object_storage/__init__.py @@ -0,0 +1,8 @@ +from crewai_tools.oci.object_storage.reader_tool import OCIObjectStorageReaderTool +from crewai_tools.oci.object_storage.writer_tool import OCIObjectStorageWriterTool + + +__all__ = [ + "OCIObjectStorageReaderTool", + "OCIObjectStorageWriterTool", +] diff --git a/lib/crewai-tools/src/crewai_tools/oci/object_storage/reader_tool.py b/lib/crewai-tools/src/crewai_tools/oci/object_storage/reader_tool.py new file mode 100644 index 00000000000..701a2134737 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/object_storage/reader_tool.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +from crewai_tools.oci.common import ( + create_oci_client_kwargs, + get_oci_module, + parse_object_storage_path, +) + + +class OCIObjectStorageReaderToolInput(BaseModel): + """Input schema for OCIObjectStorageReaderTool.""" + + file_path: str = Field( + ..., + description=( + "OCI Object Storage path in the form " + "`oci://bucket/path` or `oci://namespace@bucket/path`." + ), + ) + + +class OCIObjectStorageReaderTool(BaseTool): + name: str = "OCI Object Storage Reader Tool" + description: str = ( + "Reads a text file from Oracle Cloud Infrastructure Object Storage." + ) + args_schema: type[BaseModel] = OCIObjectStorageReaderToolInput + package_dependencies: list[str] = Field(default_factory=lambda: ["oci"]) + namespace_name: str | None = None + client: Any | None = None + + def __init__( + self, + namespace_name: str | None = None, + *, + auth_type: str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.namespace_name = namespace_name or os.getenv("OCI_OBJECT_STORAGE_NAMESPACE") + self.client = client + + if self.client is None: + oci = get_oci_module() + resolved_auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + resolved_auth_file_location = cast( + str, + auth_file_location or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + client_kwargs = create_oci_client_kwargs( + auth_type=auth_type, + auth_profile=resolved_auth_profile, + auth_file_location=resolved_auth_file_location, + ) + self.client = oci.object_storage.ObjectStorageClient(**client_kwargs) + + def _require_client(self) -> Any: + if self.client is None: + raise ValueError("OCI Object Storage client is not initialized.") + return self.client + + def _resolve_namespace(self, path_namespace: str | None) -> str: + if path_namespace: + return path_namespace + if self.namespace_name: + return self.namespace_name + return str(self._require_client().get_namespace().data) + + def _run(self, file_path: str) -> str: + try: + path_namespace, bucket_name, object_name = parse_object_storage_path( + file_path + ) + namespace_name = self._resolve_namespace(path_namespace) + response = self._require_client().get_object( + namespace_name, bucket_name, object_name + ) + data = response.data + + if hasattr(data, "content"): + content = data.content + elif hasattr(data, "text"): + return str(data.text) + elif hasattr(data, "raw") and hasattr(data.raw, "data"): + content = data.raw.data + else: + content = data + + if isinstance(content, bytes): + return content.decode("utf-8") + return str(content) + except Exception as error: + return f"Error reading file from OCI Object Storage: {error!s}" diff --git a/lib/crewai-tools/src/crewai_tools/oci/object_storage/writer_tool.py b/lib/crewai-tools/src/crewai_tools/oci/object_storage/writer_tool.py new file mode 100644 index 00000000000..c24f7479836 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oci/object_storage/writer_tool.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import os +from typing import Any, cast + +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + +from crewai_tools.oci.common import ( + create_oci_client_kwargs, + get_oci_module, + parse_object_storage_path, +) + + +class OCIObjectStorageWriterToolInput(BaseModel): + """Input schema for OCIObjectStorageWriterTool.""" + + file_path: str = Field( + ..., + description=( + "OCI Object Storage path in the form " + "`oci://bucket/path` or `oci://namespace@bucket/path`." + ), + ) + content: str = Field(..., description="Content to write to the object") + + +class OCIObjectStorageWriterTool(BaseTool): + name: str = "OCI Object Storage Writer Tool" + description: str = "Writes a text file to Oracle Cloud Infrastructure Object Storage." + args_schema: type[BaseModel] = OCIObjectStorageWriterToolInput + package_dependencies: list[str] = Field(default_factory=lambda: ["oci"]) + namespace_name: str | None = None + client: Any | None = None + + def __init__( + self, + namespace_name: str | None = None, + *, + auth_type: str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.namespace_name = namespace_name or os.getenv("OCI_OBJECT_STORAGE_NAMESPACE") + self.client = client + + if self.client is None: + oci = get_oci_module() + resolved_auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + resolved_auth_file_location = cast( + str, + auth_file_location or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + client_kwargs = create_oci_client_kwargs( + auth_type=auth_type, + auth_profile=resolved_auth_profile, + auth_file_location=resolved_auth_file_location, + ) + self.client = oci.object_storage.ObjectStorageClient(**client_kwargs) + + def _require_client(self) -> Any: + if self.client is None: + raise ValueError("OCI Object Storage client is not initialized.") + return self.client + + def _resolve_namespace(self, path_namespace: str | None) -> str: + if path_namespace: + return path_namespace + if self.namespace_name: + return self.namespace_name + return str(self._require_client().get_namespace().data) + + def _run(self, file_path: str, content: str) -> str: + try: + path_namespace, bucket_name, object_name = parse_object_storage_path( + file_path + ) + namespace_name = self._resolve_namespace(path_namespace) + self._require_client().put_object( + namespace_name, + bucket_name, + object_name, + content.encode("utf-8"), + ) + return f"Successfully wrote content to {file_path}" + except Exception as error: + return f"Error writing file to OCI Object Storage: {error!s}" diff --git a/lib/crewai-tools/src/crewai_tools/oracle_db/__init__.py b/lib/crewai-tools/src/crewai_tools/oracle_db/__init__.py new file mode 100644 index 00000000000..44a1b348277 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oracle_db/__init__.py @@ -0,0 +1,12 @@ +from crewai_tools.oracle_db.knowledge_base.retriever_tool import ( + OracleHybridSearchTool, + OracleTextSearchTool, + OracleVectorSearchTool, +) + + +__all__ = [ + "OracleHybridSearchTool", + "OracleTextSearchTool", + "OracleVectorSearchTool", +] diff --git a/lib/crewai-tools/src/crewai_tools/oracle_db/common.py b/lib/crewai-tools/src/crewai_tools/oracle_db/common.py new file mode 100644 index 00000000000..11986e3ccd7 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oracle_db/common.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +import os +import re +from typing import Any + + +_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z][A-Za-z0-9_$#]*$") + + +def get_oracledb_module() -> Any: + try: + import oracledb + except ImportError: + raise ImportError( + "`oracledb` package not found, please install the optional dependency with " + "`uv add 'crewai-tools[oracle]'`" + ) from None + return oracledb + + +def validate_identifier(identifier: str, *, field_name: str) -> str: + if not identifier or not _IDENTIFIER_PATTERN.match(identifier): + raise ValueError( + f"{field_name} must be a simple Oracle identifier starting with a letter." + ) + return identifier + + +def get_oracle_connection_kwargs( + *, + user: str | None, + password: str | None, + dsn: str | None, + config_dir: str | None = None, + wallet_location: str | None = None, + wallet_password: str | None = None, +) -> dict[str, Any]: + resolved_user = user or os.getenv("ORACLE_DB_USER") + resolved_password = password or os.getenv("ORACLE_DB_PASSWORD") + resolved_dsn = dsn or os.getenv("ORACLE_DB_DSN") + + if not resolved_user or not resolved_password or not resolved_dsn: + raise ValueError( + "Oracle DB connection requires user, password, and dsn. " + "Set them explicitly or via ORACLE_DB_USER, ORACLE_DB_PASSWORD, and " + "ORACLE_DB_DSN." + ) + + kwargs: dict[str, Any] = { + "user": resolved_user, + "password": resolved_password, + "dsn": resolved_dsn, + } + + resolved_config_dir = config_dir or os.getenv("ORACLE_DB_CONFIG_DIR") + resolved_wallet_location = wallet_location or os.getenv("ORACLE_DB_WALLET_LOCATION") + resolved_wallet_password = wallet_password or os.getenv("ORACLE_DB_WALLET_PASSWORD") + + if resolved_config_dir: + kwargs["config_dir"] = resolved_config_dir + if resolved_wallet_location: + kwargs["wallet_location"] = resolved_wallet_location + if resolved_wallet_password: + kwargs["wallet_password"] = resolved_wallet_password + + return kwargs + + +@contextmanager +def oracle_connection_context( + client: Any = None, **connect_kwargs: Any +) -> Iterator[Any]: + if client is not None: + yield client + return + + oracledb = get_oracledb_module() + connection = oracledb.connect(**connect_kwargs) + try: + yield connection + finally: + connection.close() diff --git a/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/__init__.py b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/__init__.py new file mode 100644 index 00000000000..44a1b348277 --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/__init__.py @@ -0,0 +1,12 @@ +from crewai_tools.oracle_db.knowledge_base.retriever_tool import ( + OracleHybridSearchTool, + OracleTextSearchTool, + OracleVectorSearchTool, +) + + +__all__ = [ + "OracleHybridSearchTool", + "OracleTextSearchTool", + "OracleVectorSearchTool", +] diff --git a/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py new file mode 100644 index 00000000000..3b9f17f273c --- /dev/null +++ b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py @@ -0,0 +1,384 @@ +from __future__ import annotations + +import array +import json +import re +from typing import Any, Literal + +from crewai.rag.embeddings.factory import build_embedder +from crewai.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field + +from crewai_tools.oracle_db.common import ( + get_oracle_connection_kwargs, + oracle_connection_context, + validate_identifier, +) + + +def _generate_accum_query(query: str, fuzzy: bool = False) -> str: + words = re.split(r"\W+", query) + tokens = [word for word in words if word] + if fuzzy: + return " ACCUM ".join(f'fuzzy("{token}")' for token in tokens) + return " ACCUM ".join(f'"{token}"' for token in tokens) + + +class OracleSearchToolInput(BaseModel): + query: str = Field(..., description="The query to retrieve information from Oracle.") + + +class OracleToolBase(BaseTool): + model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) + + client: Any | None = Field(default=None, exclude=True) + user: str | None = Field(default=None, description="Oracle DB username") + password: str | None = Field(default=None, description="Oracle DB password") + dsn: str | None = Field(default=None, description="Oracle DB DSN") + config_dir: str | None = Field(default=None, description="Optional Oracle config dir") + wallet_location: str | None = Field( + default=None, description="Optional wallet directory" + ) + wallet_password: str | None = Field( + default=None, description="Optional wallet password" + ) + package_dependencies: list[str] = Field(default_factory=lambda: ["oracledb"]) + + def _connection_kwargs(self) -> dict[str, Any]: + if self.client is not None: + return {} + return get_oracle_connection_kwargs( + user=self.user, + password=self.password, + dsn=self.dsn, + config_dir=self.config_dir, + wallet_location=self.wallet_location, + wallet_password=self.wallet_password, + ) + + def _result_json(self, results: list[dict[str, Any]]) -> str: + if results: + return json.dumps({"results": results}, indent=2) + return json.dumps({"message": "No results found for the given query."}, indent=2) + + +class OracleVectorSearchTool(OracleToolBase): + name: str = "Oracle Vector Search Tool" + description: str = ( + "Retrieves information from Oracle Database vector columns using " + "VECTOR_DISTANCE against an externally generated embedding." + ) + args_schema: type[BaseModel] = OracleSearchToolInput + table_name: str = Field(..., description="Oracle table that stores the documents") + text_column: str = Field(default="text", description="Text column to return") + embedding_column: str = Field( + default="embedding", description="Vector column to compare against" + ) + metadata_column: str | None = Field( + default="metadata", + description="Optional JSON metadata column to merge into each result", + ) + metadata_columns: list[str] = Field( + default_factory=list, + description="Additional scalar columns to return as metadata", + ) + number_of_results: int = Field(default=5, description="Maximum results to return") + distance_metric: Literal["COSINE", "EUCLIDEAN", "DOT"] = Field( + default="COSINE", description="Oracle VECTOR_DISTANCE metric" + ) + embedding_model: dict[str, Any] | None = Field( + default=None, + exclude=True, + description="Optional CrewAI embedder specification used to build query embeddings", + ) + embedder: Any | None = Field( + default=None, + exclude=True, + description="Optional prebuilt embedding callable", + ) + + def model_post_init(self, __context: Any) -> None: + super().model_post_init(__context) + if self.embedder is None and self.embedding_model is not None: + self.embedder = build_embedder(self.embedding_model) + + def _embed_query(self, query: str) -> array.array[float]: + if self.embedder is None: + raise ValueError( + "OracleVectorSearchTool requires either embedder or embedding_model." + ) + + embedding_response = self.embedder([query]) + if not embedding_response: + raise ValueError("Embedding model returned no vectors.") + + embedding = embedding_response[0] + return array.array("f", [float(value) for value in embedding]) + + def _run(self, query: str) -> str: + table_name = validate_identifier(self.table_name, field_name="table_name") + text_column = validate_identifier(self.text_column, field_name="text_column") + embedding_column = validate_identifier( + self.embedding_column, field_name="embedding_column" + ) + metadata_column = None + if self.metadata_column: + metadata_column = validate_identifier( + self.metadata_column, field_name="metadata_column" + ) + metadata_columns = [ + validate_identifier(column, field_name="metadata_columns") + for column in self.metadata_columns + if column.lower() != text_column.lower() + and column.lower() != embedding_column.lower() + and (metadata_column is None or column.lower() != metadata_column.lower()) + ] + + fetch_columns = [text_column] + if metadata_column: + fetch_columns.append(metadata_column) + fetch_columns.extend(metadata_columns) + fetch_columns_sql = ", ".join(fetch_columns) + number_of_results = max(1, self.number_of_results) + metric = self.distance_metric.upper() + sql = ( + f"SELECT {fetch_columns_sql}, " # noqa: S608 + f"VECTOR_DISTANCE({embedding_column}, :query_embedding, {metric}) distance " + f"FROM {table_name} ORDER BY distance ASC " + f"FETCH FIRST {number_of_results} ROWS ONLY" + ) + + with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with connection.cursor() as cursor: + cursor.execute(sql, query_embedding=self._embed_query(query)) + columns = [column[0].lower() for column in cursor.description] + results = [] + for row in cursor.fetchall(): + row_dict = dict(zip(columns, row, strict=False)) + metadata: dict[str, Any] = {} + if metadata_column: + metadata_value = row_dict.get(metadata_column.lower()) + if isinstance(metadata_value, dict): + metadata.update(metadata_value) + elif isinstance(metadata_value, str): + try: + parsed_metadata = json.loads(metadata_value) + except json.JSONDecodeError: + parsed_metadata = None + if isinstance(parsed_metadata, dict): + metadata.update(parsed_metadata) + metadata.update( + { + column: row_dict.get(column.lower()) + for column in metadata_columns + } + ) + results.append( + { + "content": row_dict[text_column.lower()], + "metadata": metadata, + "distance": row_dict.get("distance"), + } + ) + + return self._result_json(results) + + +class OracleTextSearchTool(OracleToolBase): + name: str = "Oracle Text Search Tool" + description: str = ( + "Retrieves information from Oracle Database using Oracle Text CONTAINS search." + ) + args_schema: type[BaseModel] = OracleSearchToolInput + table_name: str = Field(..., description="Oracle table that stores the documents") + text_column: str = Field(default="text", description="Text column to search") + metadata_columns: list[str] = Field( + default_factory=list, + description="Additional columns to return as metadata alongside the text", + ) + number_of_results: int = Field(default=5, description="Maximum results to return") + operator_search: bool = Field( + default=False, + description="Treat the query as a raw Oracle Text expression instead of ACCUM tokens", + ) + fuzzy: bool = Field( + default=False, + description="Apply Oracle Text FUZZY matching when operator_search is false", + ) + return_scores: bool = Field( + default=True, description="Include Oracle Text SCORE(1) in each result" + ) + + def _run(self, query: str) -> str: + table_name = validate_identifier(self.table_name, field_name="table_name") + text_column = validate_identifier(self.text_column, field_name="text_column") + metadata_columns = [ + validate_identifier(column, field_name="metadata_columns") + for column in self.metadata_columns + if column.lower() != text_column.lower() + ] + number_of_results = max(1, self.number_of_results) + + search_text = query if self.operator_search else _generate_accum_query(query, self.fuzzy) + if not search_text: + return self._result_json([]) + + select_columns = [text_column, *metadata_columns] + select_columns_sql = ", ".join(select_columns) + sql = ( + f"SELECT SCORE(1) score, {select_columns_sql} FROM {table_name} " # noqa: S608 + f"WHERE CONTAINS({text_column}, :query, 1) > 0 " + f"ORDER BY score DESC FETCH FIRST {number_of_results} ROWS ONLY" + ) + + with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with connection.cursor() as cursor: + cursor.execute(sql, query=search_text) + columns = [column[0].lower() for column in cursor.description] + results = [] + for row in cursor.fetchall(): + row_dict = dict(zip(columns, row, strict=False)) + result = { + "content": row_dict[text_column.lower()], + "metadata": { + column: row_dict.get(column.lower()) for column in metadata_columns + }, + } + if self.return_scores: + result["score"] = row_dict.get("score") + results.append(result) + + return self._result_json(results) + + +class OracleHybridSearchTool(OracleToolBase): + name: str = "Oracle Hybrid Search Tool" + description: str = ( + "Retrieves information from Oracle Database hybrid vector indexes using " + "DBMS_HYBRID_VECTOR.SEARCH." + ) + args_schema: type[BaseModel] = OracleSearchToolInput + hybrid_index_name: str = Field(..., description="Hybrid index name to query") + table_name: str = Field(..., description="Oracle table that stores the documents") + text_column: str = Field(default="text", description="Text column to return") + metadata_column: str | None = Field( + default="metadata", + description="Optional JSON metadata column to merge into each result", + ) + metadata_columns: list[str] = Field( + default_factory=list, + description="Additional scalar columns to return as metadata", + ) + number_of_results: int = Field(default=5, description="Maximum results to return") + search_mode: Literal["keyword", "hybrid", "semantic"] = Field( + default="hybrid", description="Oracle hybrid search mode" + ) + return_scores: bool = Field( + default=True, description="Include hybrid, vector, and text scores" + ) + params: dict[str, Any] | None = Field( + default=None, description="Additional DBMS_HYBRID_VECTOR.SEARCH parameters" + ) + + def _build_search_params(self, query: str) -> dict[str, Any]: + search_params = dict(self.params or {}) + search_params["hybrid_index_name"] = validate_identifier( + self.hybrid_index_name, field_name="hybrid_index_name" + ) + + if "return" in search_params or "search_text" in search_params: + raise ValueError("Reserved hybrid search params cannot be supplied directly.") + + if self.search_mode in {"hybrid", "semantic"}: + search_params["vector"] = dict(search_params.get("vector") or {}) + if "search_text" in search_params["vector"] or "search_vector" in search_params["vector"]: + raise ValueError("vector.search_text and vector.search_vector are managed internally.") + search_params["vector"]["search_text"] = query + + if self.search_mode in {"hybrid", "keyword"}: + search_params["text"] = dict(search_params.get("text") or {}) + if ( + "search_text" in search_params["text"] + or "search_vector" in search_params["text"] + or "contains" in search_params["text"] + ): + raise ValueError("text.search_text, text.search_vector, and text.contains are managed internally.") + search_params["text"]["search_text"] = query + + search_params["return"] = { + "topN": max(1, self.number_of_results), + "values": ["rowid", "score", "vector_score", "text_score"], + "format": "JSON", + } + return search_params + + def _run(self, query: str) -> str: + table_name = validate_identifier(self.table_name, field_name="table_name") + text_column = validate_identifier(self.text_column, field_name="text_column") + metadata_column = None + if self.metadata_column: + metadata_column = validate_identifier( + self.metadata_column, field_name="metadata_column" + ) + metadata_columns = [ + validate_identifier(column, field_name="metadata_columns") + for column in self.metadata_columns + if column.lower() != text_column.lower() + and (metadata_column is None or column.lower() != metadata_column.lower()) + ] + search_params = self._build_search_params(query) + + fetch_columns = [text_column] + if metadata_column: + fetch_columns.append(metadata_column) + fetch_columns.extend(metadata_columns) + fetch_columns_sql = ", ".join(fetch_columns) + row_sql = f"SELECT {fetch_columns_sql} FROM {table_name} WHERE rowid = :1" # noqa: S608 + + with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with connection.cursor() as cursor: + cursor.execute( + "SELECT DBMS_HYBRID_VECTOR.SEARCH(json(:search_params))", + search_params=json.dumps(search_params), + ) + raw = cursor.fetchall() + if not raw: + return self._result_json([]) + raw_payload = raw[0][0] + if hasattr(raw_payload, "read"): + raw_payload = raw_payload.read() + rowids = json.loads(raw_payload) + + results = [] + for item in rowids: + cursor.execute(row_sql, [item["rowid"]]) + row = cursor.fetchone() + if row is None: + continue + + row_index = 0 + content = row[row_index] + row_index += 1 + metadata: dict[str, Any] = {} + if metadata_column: + metadata_value = row[row_index] + row_index += 1 + if isinstance(metadata_value, dict): + metadata.update(metadata_value) + elif metadata_value is not None: + metadata[metadata_column] = metadata_value + for column in metadata_columns: + metadata[column] = row[row_index] + row_index += 1 + + result = { + "content": content, + "metadata": metadata, + } + if self.return_scores: + result["score"] = item.get("score") + result["vector_score"] = item.get("vector_score") + result["text_score"] = item.get("text_score") + results.append(result) + + return self._result_json(results) diff --git a/lib/crewai-tools/tests/tools/conftest.py b/lib/crewai-tools/tests/tools/conftest.py new file mode 100644 index 00000000000..a26de8ce728 --- /dev/null +++ b/lib/crewai-tools/tests/tools/conftest.py @@ -0,0 +1 @@ +pytest_plugins = ("tests.tools.oracle_db.conftest",) diff --git a/lib/crewai-tools/tests/tools/oci_tools_test.py b/lib/crewai-tools/tests/tools/oci_tools_test.py new file mode 100644 index 00000000000..0bdb82eb236 --- /dev/null +++ b/lib/crewai-tools/tests/tools/oci_tools_test.py @@ -0,0 +1,88 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock +from unittest.mock import patch + +from crewai_tools import ( + OCIGenAIInvokeAgentTool, + OCIKnowledgeBaseTool, + OCIObjectStorageReaderTool, + OCIObjectStorageWriterTool, +) + + +def test_oci_object_storage_reader_tool_reads_text(): + client = MagicMock() + client.get_namespace.return_value = SimpleNamespace(data="testns") + client.get_object.return_value = SimpleNamespace( + data=SimpleNamespace(content=b"oracle cloud") + ) + + tool = OCIObjectStorageReaderTool(client=client) + + result = tool._run("oci://my-bucket/docs/intro.txt") + + assert result == "oracle cloud" + client.get_object.assert_called_once_with("testns", "my-bucket", "docs/intro.txt") + + +def test_oci_object_storage_writer_tool_writes_bytes(): + client = MagicMock() + client.get_namespace.return_value = SimpleNamespace(data="testns") + + tool = OCIObjectStorageWriterTool(client=client) + + result = tool._run("oci://testns@my-bucket/reports/out.txt", "hello oci") + + assert result == "Successfully wrote content to oci://testns@my-bucket/reports/out.txt" + client.put_object.assert_called_once_with( + "testns", + "my-bucket", + "reports/out.txt", + b"hello oci", + ) + + +def test_oci_genai_invoke_agent_tool_creates_session_and_chats(): + client = MagicMock() + client.create_session.return_value = SimpleNamespace( + data=SimpleNamespace(id="session-123") + ) + client.chat.return_value = SimpleNamespace( + data=SimpleNamespace( + message=SimpleNamespace( + content=SimpleNamespace(text="OCI agent response") + ) + ) + ) + + tool = OCIGenAIInvokeAgentTool( + agent_endpoint_id="ocid1.genaiagentendpoint.oc1..example", + client=client, + ) + + result = tool._run("Summarize Oracle Cloud Infrastructure") + + assert result == "OCI agent response" + client.create_session.assert_called_once() + client.chat.assert_called_once() + assert client.chat.call_args.kwargs["agent_endpoint_id"] == ( + "ocid1.genaiagentendpoint.oc1..example" + ) + + +@patch("crewai_tools.tools.rag.rag_tool.build_embedder", return_value=MagicMock()) +@patch("crewai_tools.adapters.crewai_rag_adapter.create_client") +def test_oci_knowledge_base_tool_uses_oci_embedding_config( + mock_create_client, _mock_build_embedder +): + mock_client = MagicMock() + mock_client.get_or_create_collection = MagicMock(return_value=None) + mock_create_client.return_value = mock_client + + tool = OCIKnowledgeBaseTool(config={"vectordb": {"provider": "chromadb", "config": {}}}) + + assert tool.config["embedding_model"]["provider"] == "oci" + assert ( + tool.config["embedding_model"]["config"]["model_name"] + == "cohere.embed-english-v3.0" + ) diff --git a/lib/crewai-tools/tests/tools/oracle_db/conftest.py b/lib/crewai-tools/tests/tools/oracle_db/conftest.py new file mode 100644 index 00000000000..8d5ad1f828f --- /dev/null +++ b/lib/crewai-tools/tests/tools/oracle_db/conftest.py @@ -0,0 +1,450 @@ +from __future__ import annotations + +import array +import json +import os +import signal +from typing import Any +from unittest.mock import MagicMock +import uuid + +import pytest + +from crewai.rag.embeddings.factory import build_embedder + + +class CursorStub: + def __init__(self, execute_side_effect=None): + self.execute_side_effect = execute_side_effect + self.description = [] + self._fetchall_result = [] + self._fetchone_result = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, sql: str, *args, **kwargs): + if self.execute_side_effect: + return self.execute_side_effect(self, sql, *args, **kwargs) + return None + + def fetchall(self): + return self._fetchall_result + + def fetchone(self): + return self._fetchone_result + + +@pytest.fixture +def oracle_connection_mock() -> MagicMock: + connection = MagicMock() + connection.__enter__.return_value = connection + connection.__exit__.return_value = False + return connection + + +@pytest.fixture +def oracle_live_config() -> dict[str, Any]: + default_wallet_dir = None + for candidate in ( + os.path.expanduser("~/.oracle-wallet/deepresearch"), + os.path.expanduser("~/.langchain-oracle-wallet"), + os.path.expanduser("~/.oracle-wallet"), + ): + if os.path.exists(os.path.join(candidate, "tnsnames.ora")): + default_wallet_dir = candidate + break + + default_dsn = None + if default_wallet_dir: + wallet_name = os.path.basename(default_wallet_dir) + if wallet_name == "deepresearch": + default_dsn = "deepresearch_high" + elif wallet_name == ".langchain-oracle-wallet": + default_dsn = "deepresearch_high" + elif wallet_name == ".oracle-wallet": + default_dsn = "locuscheck_high" + + password = os.getenv( + "ORACLE_DB_PASSWORD", + os.getenv("ORACLE_PASSWORD"), + ) + + return { + "user": os.getenv("ORACLE_DB_USER", os.getenv("ORACLE_USER", "ADMIN")), + "password": password, + "dsn": os.getenv("ORACLE_DB_DSN", os.getenv("ORACLE_DSN", default_dsn)), + "config_dir": os.getenv("ORACLE_DB_CONFIG_DIR", default_wallet_dir), + "wallet_location": os.getenv("ORACLE_DB_WALLET_LOCATION", default_wallet_dir), + "wallet_password": os.getenv( + "ORACLE_DB_WALLET_PASSWORD", + os.getenv("ORACLE_WALLET_PASSWORD"), + ) + or password, + } + + +@pytest.fixture +def oracle_live_text_tool_kwargs(oracle_live_config: dict[str, Any]) -> dict[str, Any]: + return { + **oracle_live_config, + "table_name": os.getenv("ORACLE_DB_TEXT_TABLE", ""), + "text_column": os.getenv("ORACLE_DB_TEXT_COLUMN", "text"), + "metadata_columns": [ + column.strip() + for column in os.getenv("ORACLE_DB_TEXT_METADATA_COLUMNS", "").split(",") + if column.strip() + ], + } + + +@pytest.fixture +def oracle_live_hybrid_tool_kwargs( + oracle_live_config: dict[str, Any], +) -> dict[str, Any]: + return { + **oracle_live_config, + "hybrid_index_name": os.getenv("ORACLE_DB_HYBRID_INDEX_NAME", ""), + "table_name": os.getenv("ORACLE_DB_HYBRID_TABLE", ""), + "text_column": os.getenv("ORACLE_DB_HYBRID_TEXT_COLUMN", "text"), + "metadata_column": os.getenv("ORACLE_DB_HYBRID_METADATA_COLUMN", "metadata"), + "metadata_columns": [ + column.strip() + for column in os.getenv("ORACLE_DB_HYBRID_METADATA_COLUMNS", "").split(",") + if column.strip() + ], + "search_mode": os.getenv("ORACLE_DB_HYBRID_SEARCH_MODE", "hybrid"), + } + + +def has_oracle_text_test_config() -> bool: + return bool( + oracle_live_defaults_available() + ) + + +def has_oracle_hybrid_test_config() -> bool: + return bool( + oracle_live_defaults_available() + ) + + +def has_oracle_vector_test_config() -> bool: + return bool(oracle_live_defaults_available()) and bool( + os.getenv("OCI_COMPARTMENT_ID") + and (os.getenv("OCI_REGION") or os.getenv("OCI_SERVICE_ENDPOINT")) + ) + + +def oracle_live_defaults_available() -> bool: + wallet_candidates = ( + os.path.expanduser("~/.oracle-wallet"), + os.path.expanduser("~/.langchain-oracle-wallet"), + ) + wallet_available = any( + os.path.exists(os.path.join(candidate, "tnsnames.ora")) + for candidate in wallet_candidates + ) + return wallet_available + + +@pytest.fixture +def oracle_live_vector_tool_kwargs(oracle_live_config: dict[str, Any]) -> dict[str, Any]: + embedding_model: dict[str, Any] = { + "provider": "oci", + "config": { + "model_name": os.getenv("OCI_EMBED_MODEL_NAME", "cohere.embed-v4.0"), + "compartment_id": os.getenv("OCI_COMPARTMENT_ID"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "API_KEY_AUTH"), + "auth_file_location": os.getenv( + "OCI_AUTH_FILE_LOCATION", os.path.expanduser("~/.oci/config") + ), + "output_dimensions": int(os.getenv("OCI_EMBED_OUTPUT_DIMENSIONS", "1536")), + }, + } + if os.getenv("OCI_REGION"): + embedding_model["config"]["region"] = os.getenv("OCI_REGION") + if os.getenv("OCI_SERVICE_ENDPOINT"): + embedding_model["config"]["service_endpoint"] = os.getenv( + "OCI_SERVICE_ENDPOINT" + ) + + return { + **oracle_live_config, + "text_column": "text", + "embedding_column": "embedding", + "metadata_column": "metadata", + "metadata_columns": ["category"], + "embedding_model": embedding_model, + } + + +class _ConnectTimeout(Exception): + pass + + +def _timeout_handler(signum, frame): + raise _ConnectTimeout("Oracle DB connect timed out") + + +@pytest.fixture +def oracle_live_connection(oracle_live_config: dict[str, Any]): + try: + import oracledb + except ImportError: + pytest.skip("oracledb is not installed") + + previous_handler = signal.getsignal(signal.SIGALRM) + signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(20) + try: + connection = oracledb.connect(**oracle_live_config) + except _ConnectTimeout as exc: + pytest.skip(str(exc)) + except Exception as exc: + pytest.skip(f"Oracle DB connection failed: {type(exc).__name__}: {exc}") + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, previous_handler) + + try: + yield connection + finally: + try: + connection.close() + except Exception: + pass + + +@pytest.fixture +def oracle_text_live_resources(oracle_live_connection): + suffix = uuid.uuid4().hex[:8].upper() + table_name = f"CT_ORTXT_{suffix}" + index_name = f"CT_OIDX_{suffix}" + + with oracle_live_connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {table_name} (text VARCHAR2(4000), category VARCHAR2(100))" + ) + cursor.execute( + f"INSERT INTO {table_name}(text, category) VALUES (:1, :2)", + [ + "Our refund policy for premium plans allows refunds within 30 days.", + "billing", + ], + ) + cursor.execute( + f"INSERT INTO {table_name}(text, category) VALUES (:1, :2)", + [ + "Autonomous Database is Oracle managed database infrastructure.", + "database", + ], + ) + cursor.execute(f"CREATE SEARCH INDEX {index_name} ON {table_name}(text)") + oracle_live_connection.commit() + + try: + yield { + "table_name": table_name, + "text_column": "text", + "metadata_columns": ["category"], + } + finally: + with oracle_live_connection.cursor() as cursor: + try: + cursor.execute(f"DROP INDEX {index_name}") + except Exception: + pass + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + except Exception: + pass + oracle_live_connection.commit() + + +@pytest.fixture +def oracle_hybrid_live_resources(oracle_live_connection): + suffix = uuid.uuid4().hex[:8].upper() + table_name = f"CT_ORHY_{suffix}" + index_name = f"CT_OHYIDX_{suffix}" + preference_name = f"CT_OPREF_{suffix}" + + with oracle_live_connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {table_name} (text VARCHAR2(4000), metadata JSON, category VARCHAR2(100))" + ) + cursor.execute( + f"INSERT INTO {table_name}(text, metadata, category) VALUES (:1, :2, :3)", + [ + "Oracle Autonomous Database provides managed database capabilities.", + json.dumps({"source": "oracle-docs"}), + "database", + ], + ) + cursor.execute( + f"INSERT INTO {table_name}(text, metadata, category) VALUES (:1, :2, :3)", + [ + "Premium refund policy allows refunds within 30 days.", + json.dumps({"source": "policy-docs"}), + "billing", + ], + ) + cursor.execute( + """ + begin + DBMS_VECTOR_CHAIN.CREATE_PREFERENCE( + :pref_name, + dbms_vector_chain.vectorizer, + json(:pref_json) + ); + end; + """, + { + "pref_name": preference_name, + "pref_json": '{"model":"allminilm"}', + }, + ) + try: + # Oracle Hybrid Vector Index creation uses the database-side vectorizer + # registered in DBMS_VECTOR_CHAIN. That is distinct from OCI GenAI + # embedding models available via API_KEY_AUTH (for example + # cohere.embed-v4.0 in us-chicago-1). + # + # On Autonomous Database 26ai, this hybrid-index path depends on a + # supported in-database embedding/vectorizer model being installed in + # the database itself. If the requested model is unavailable, Oracle + # raises ORA-40284 during index creation. That is a database capability + # gap for this tenancy/database, not a CrewAI tool failure. + # + # References: + # - https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/create-hybrid-vector-index.html + # - https://docs.oracle.com/en/database/oracle/oracle-database/23/vecse/guidelines-and-restrictions-hybrid-vector-indexes.html + # - https://docs.oracle.com/en/database/oracle/oracle-database/26/arpls/dbms_hybrid_vector1.html + cursor.execute( + f"CREATE HYBRID VECTOR INDEX {index_name} ON {table_name}(text) " + f"PARAMETERS ('vectorizer {preference_name}')" + ) + except Exception as exc: + oracle_live_connection.rollback() + if "ORA-40284" in str(exc): + pytest.skip( + "Oracle DB hybrid vectorizer model is not available in this 26ai database; " + "hybrid indexes here require a supported DB-local vectorizer model rather " + "than OCI GenAI embedding models" + ) + raise + oracle_live_connection.commit() + + try: + yield { + "hybrid_index_name": index_name, + "table_name": table_name, + "text_column": "text", + "metadata_column": "metadata", + "metadata_columns": ["category"], + "search_mode": "hybrid", + } + finally: + with oracle_live_connection.cursor() as cursor: + try: + cursor.execute(f"DROP INDEX {index_name}") + except Exception: + pass + try: + cursor.execute( + """ + begin + DBMS_VECTOR_CHAIN.DROP_PREFERENCE(:pref_name); + end; + """, + {"pref_name": preference_name}, + ) + except Exception: + pass + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + except Exception: + pass + oracle_live_connection.commit() + + +@pytest.fixture +def oracle_vector_live_resources( + oracle_live_connection, + oracle_live_vector_tool_kwargs: dict[str, Any], +): + suffix = uuid.uuid4().hex[:8].upper() + table_name = f"CT_ORVEC_{suffix}" + output_dimensions = oracle_live_vector_tool_kwargs["embedding_model"]["config"][ + "output_dimensions" + ] + + try: + embedder = build_embedder(oracle_live_vector_tool_kwargs["embedding_model"]) + documents = [ + ( + "Premium refund policy allows refunds within 30 days of purchase.", + {"source": "policy-docs", "topic": "billing"}, + "billing", + ), + ( + "Autonomous Database is Oracle managed database infrastructure.", + {"source": "oracle-docs", "topic": "database"}, + "database", + ), + ( + "Vector search compares embeddings using similarity distance.", + {"source": "vector-docs", "topic": "ai"}, + "ai", + ), + ] + embeddings = embedder([document[0] for document in documents]) + except Exception as exc: + pytest.skip(f"OCI embedding setup failed: {type(exc).__name__}: {exc}") + + with oracle_live_connection.cursor() as cursor: + cursor.execute( + f"CREATE TABLE {table_name} (" + f"id VARCHAR2(64), " + f"text VARCHAR2(4000), " + f"metadata JSON, " + f"category VARCHAR2(100), " + f"embedding VECTOR({output_dimensions}, FLOAT32)" + f")" + ) + for index, ((text, metadata, category), embedding) in enumerate( + zip(documents, embeddings, strict=True), + start=1, + ): + cursor.execute( + f"INSERT INTO {table_name}(id, text, metadata, category, embedding) " + f"VALUES (:1, :2, :3, :4, :5)", + [ + f"doc-{index}", + text, + json.dumps(metadata), + category, + array.array("f", [float(value) for value in embedding]), + ], + ) + oracle_live_connection.commit() + + try: + yield { + "table_name": table_name, + "text_column": "text", + "embedding_column": "embedding", + "metadata_column": "metadata", + "metadata_columns": ["category"], + "embedder": embedder, + } + finally: + with oracle_live_connection.cursor() as cursor: + try: + cursor.execute(f"DROP TABLE {table_name} PURGE") + except Exception: + pass + oracle_live_connection.commit() diff --git a/lib/crewai-tools/tests/tools/oracle_db_tools_test.py b/lib/crewai-tools/tests/tools/oracle_db_tools_test.py new file mode 100644 index 00000000000..cc33d2ac896 --- /dev/null +++ b/lib/crewai-tools/tests/tools/oracle_db_tools_test.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import json + +import pytest + +from crewai_tools import ( + OracleHybridSearchTool, + OracleTextSearchTool, + OracleVectorSearchTool, +) +from tests.tools.oracle_db.conftest import CursorStub + +pytest_plugins = ("tests.tools.oracle_db.conftest",) + + +def test_oracle_text_search_tool_formats_contains_query(oracle_connection_mock): + cursor = CursorStub() + + def execute_side_effect(cursor_obj, sql, *args, **kwargs): + assert "CONTAINS(text, :query, 1)" in sql + assert 'fuzzy("oracle") ACCUM fuzzy("database")' == kwargs["query"] + cursor_obj.description = [("SCORE",), ("TEXT",), ("CATEGORY",)] + cursor_obj._fetchall_result = [(92, "Oracle Database text result", "docs")] + + cursor.execute_side_effect = execute_side_effect + oracle_connection_mock.cursor.return_value = cursor + + tool = OracleTextSearchTool( + client=oracle_connection_mock, + table_name="docs_table", + text_column="text", + metadata_columns=["category"], + fuzzy=True, + ) + + result = json.loads(tool._run("oracle database")) + + assert result["results"][0]["content"] == "Oracle Database text result" + assert result["results"][0]["metadata"]["category"] == "docs" + assert result["results"][0]["score"] == 92 + + +def test_oracle_text_search_tool_validates_identifiers(): + tool = OracleTextSearchTool( + client=object(), + table_name="docs-table", + text_column="text", + ) + + with pytest.raises(ValueError, match="table_name"): + tool._run("oracle") + + +def test_oracle_hybrid_search_tool_builds_search_params_and_fetches_rows( + oracle_connection_mock, +): + cursor = CursorStub() + state = {"search_called": False} + + def execute_side_effect(cursor_obj, sql, *args, **kwargs): + if "DBMS_HYBRID_VECTOR.SEARCH" in sql: + state["search_called"] = True + search_params = json.loads(kwargs["search_params"]) + assert search_params["hybrid_index_name"] == "docsidx" + assert search_params["vector"]["search_text"] == "autonomous database" + assert search_params["text"]["search_text"] == "autonomous database" + cursor_obj._fetchall_result = [ + ( + '[{"rowid":"AAABBB","score":0.97,"vector_score":0.95,"text_score":0.88}]', + ) + ] + return None + + assert state["search_called"] + assert "WHERE rowid = :1" in sql + cursor_obj._fetchone_result = ( + "Autonomous Database is managed.", + {"source": "doc-1"}, + "oracle", + ) + + cursor.execute_side_effect = execute_side_effect + oracle_connection_mock.cursor.return_value = cursor + + tool = OracleHybridSearchTool( + client=oracle_connection_mock, + hybrid_index_name="docsidx", + table_name="docs_table", + metadata_columns=["category"], + ) + + result = json.loads(tool._run("autonomous database")) + + assert result["results"][0]["content"] == "Autonomous Database is managed." + assert result["results"][0]["metadata"]["source"] == "doc-1" + assert result["results"][0]["metadata"]["category"] == "oracle" + assert result["results"][0]["score"] == 0.97 + + +def test_oracle_hybrid_search_tool_rejects_reserved_params(): + tool = OracleHybridSearchTool( + client=object(), + hybrid_index_name="docsidx", + table_name="docs_table", + params={"return": {"topN": 1}}, + ) + + with pytest.raises(ValueError, match="Reserved hybrid search params"): + tool._build_search_params("oracle") + + +def test_oracle_vector_search_tool_embeds_query_and_fetches_rows(oracle_connection_mock): + cursor = CursorStub() + + def execute_side_effect(cursor_obj, sql, *args, **kwargs): + assert "VECTOR_DISTANCE(embedding, :query_embedding, COSINE)" in sql + query_embedding = kwargs["query_embedding"] + assert list(query_embedding) == pytest.approx([0.1, 0.2, 0.3]) + cursor_obj.description = [("TEXT",), ("METADATA",), ("CATEGORY",), ("DISTANCE",)] + cursor_obj._fetchall_result = [ + ("Oracle vector result", {"source": "doc-1"}, "docs", 0.0123) + ] + + cursor.execute_side_effect = execute_side_effect + oracle_connection_mock.cursor.return_value = cursor + + tool = OracleVectorSearchTool( + client=oracle_connection_mock, + table_name="docs_table", + metadata_columns=["category"], + embedder=lambda inputs: [[0.1, 0.2, 0.3] for _ in inputs], + ) + + result = json.loads(tool._run("oracle vector database")) + + assert result["results"][0]["content"] == "Oracle vector result" + assert result["results"][0]["metadata"]["source"] == "doc-1" + assert result["results"][0]["metadata"]["category"] == "docs" + assert result["results"][0]["distance"] == pytest.approx(0.0123) diff --git a/lib/crewai-tools/tests/tools/test_oci_knowledge_base_tool_integration.py b/lib/crewai-tools/tests/tools/test_oci_knowledge_base_tool_integration.py new file mode 100644 index 00000000000..22f0b57b99d --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_oci_knowledge_base_tool_integration.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import os +import uuid + +import pytest + +from crewai_tools import OCIKnowledgeBaseTool + + +OCI_SDK_AVAILABLE = True +try: + import oci # noqa: F401 +except ImportError: + OCI_SDK_AVAILABLE = False + + +def _has_oci_test_config() -> bool: + return bool( + os.getenv("OCI_COMPARTMENT_ID") + and (os.getenv("OCI_SERVICE_ENDPOINT") or os.getenv("OCI_REGION")) + ) + + +@pytest.mark.skipif( + not OCI_SDK_AVAILABLE or not _has_oci_test_config(), + reason="Requires OCI SDK plus OCI_COMPARTMENT_ID and OCI endpoint configuration", +) +@pytest.mark.block_network(allowed_hosts=[r".*"]) +def test_oci_knowledge_base_tool_live_query(tmp_path) -> None: + document_path = tmp_path / "oracle_notes.txt" + document_path.write_text( + ( + "Oracle Cloud Infrastructure includes Autonomous Database. " + "Autonomous Database is a managed Oracle database service." + ), + encoding="utf-8", + ) + + tool_kwargs = { + "knowledge_source": str(document_path), + "collection_name": f"oci_kb_live_test_{uuid.uuid4().hex[:8]}", + "compartment_id": os.getenv("OCI_COMPARTMENT_ID"), + "auth_type": os.getenv("OCI_AUTH_TYPE", "API_KEY"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), + "auth_file_location": os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + } + if os.getenv("OCI_REGION"): + tool_kwargs["region"] = os.getenv("OCI_REGION") + if os.getenv("OCI_SERVICE_ENDPOINT"): + tool_kwargs["service_endpoint"] = os.getenv("OCI_SERVICE_ENDPOINT") + + tool = OCIKnowledgeBaseTool(**tool_kwargs) + + result = tool._run("Which Oracle service is described in the document?") + + assert "Autonomous Database" in result diff --git a/lib/crewai-tools/tests/tools/test_oracle_hybrid_search_tool_integration.py b/lib/crewai-tools/tests/tools/test_oracle_hybrid_search_tool_integration.py new file mode 100644 index 00000000000..fd940defe6a --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_oracle_hybrid_search_tool_integration.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import json + +import pytest + +from crewai_tools import OracleHybridSearchTool +from tests.tools.oracle_db.conftest import has_oracle_hybrid_test_config + + +ORACLE_DB_AVAILABLE = True +try: + import oracledb # noqa: F401 +except ImportError: + ORACLE_DB_AVAILABLE = False + + +@pytest.mark.skipif( + not ORACLE_DB_AVAILABLE or not has_oracle_hybrid_test_config(), + reason="Requires oracledb plus a local Oracle wallet-backed DB configuration", +) +@pytest.mark.block_network(allowed_hosts=[r".*"]) +def test_oracle_hybrid_search_tool_live_query( + oracle_live_hybrid_tool_kwargs, + oracle_hybrid_live_resources, +) -> None: + tool = OracleHybridSearchTool( + **(oracle_live_hybrid_tool_kwargs | oracle_hybrid_live_resources) + ) + + result = json.loads(tool._run("managed database")) + + assert "results" in result + assert result["results"] + assert "database" in result["results"][0]["content"].lower() diff --git a/lib/crewai-tools/tests/tools/test_oracle_text_search_tool_integration.py b/lib/crewai-tools/tests/tools/test_oracle_text_search_tool_integration.py new file mode 100644 index 00000000000..abf745fa629 --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_oracle_text_search_tool_integration.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import json + +import pytest + +from crewai_tools import OracleTextSearchTool +from tests.tools.oracle_db.conftest import has_oracle_text_test_config + + +ORACLE_DB_AVAILABLE = True +try: + import oracledb # noqa: F401 +except ImportError: + ORACLE_DB_AVAILABLE = False + + +@pytest.mark.skipif( + not ORACLE_DB_AVAILABLE or not has_oracle_text_test_config(), + reason="Requires oracledb plus a local Oracle wallet-backed DB configuration", +) +@pytest.mark.block_network(allowed_hosts=[r".*"]) +def test_oracle_text_search_tool_live_query( + oracle_live_text_tool_kwargs, + oracle_text_live_resources, +) -> None: + tool = OracleTextSearchTool(**(oracle_live_text_tool_kwargs | oracle_text_live_resources)) + + result = json.loads(tool._run("refund policy")) + + assert "results" in result + assert result["results"] + assert "refund" in result["results"][0]["content"].lower() diff --git a/lib/crewai-tools/tests/tools/test_oracle_vector_search_tool_integration.py b/lib/crewai-tools/tests/tools/test_oracle_vector_search_tool_integration.py new file mode 100644 index 00000000000..b579b8d7043 --- /dev/null +++ b/lib/crewai-tools/tests/tools/test_oracle_vector_search_tool_integration.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import json + +import pytest + +from crewai_tools import OracleVectorSearchTool +from tests.tools.oracle_db.conftest import has_oracle_vector_test_config + +pytestmark = pytest.mark.filterwarnings( + "ignore:datetime.datetime.utcnow\\(\\) is deprecated.*:DeprecationWarning:oci.base_client" +) + + +@pytest.mark.skipif( + not has_oracle_vector_test_config(), + reason="Oracle DB wallet or OCI live embedding config is not available", +) +def test_oracle_vector_search_tool_live( + oracle_live_vector_tool_kwargs, + oracle_vector_live_resources, +): + tool = OracleVectorSearchTool( + **(oracle_live_vector_tool_kwargs | oracle_vector_live_resources) + ) + + result = json.loads(tool._run("What is the refund policy?")) + + assert result["results"] + top_result = result["results"][0] + assert "refund policy" in top_result["content"].lower() + assert top_result["metadata"]["category"] == "billing" + assert top_result["metadata"]["topic"] == "billing" + assert top_result["distance"] >= 0 diff --git a/uv.lock b/uv.lock index 8fc9e56f5df..4ee3d366773 100644 --- a/uv.lock +++ b/uv.lock @@ -899,6 +899,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/6e/956e62975305a4e31daf6114a73b3b0683a8f36f8d70b20aabd466770edb/chromadb-1.1.1-cp39-abi3-win_amd64.whl", hash = "sha256:a77aa026a73a18181fd89bbbdb86191c9a82fd42aa0b549ff18d8cae56394c8b", size = 19844042, upload-time = "2025-10-05T02:49:16.925Z" }, ] +[[package]] +name = "circuitbreaker" +version = "2.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/df/ac/de7a92c4ed39cba31fe5ad9203b76a25ca67c530797f6bb420fff5f65ccb/circuitbreaker-2.1.3.tar.gz", hash = "sha256:1a4baee510f7bea3c91b194dcce7c07805fe96c4423ed5594b75af438531d084", size = 10787, upload-time = "2025-03-31T08:12:08.963Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/34/15f08edd4628f65217de1fc3c1a27c82e46fe357d60c217fc9881e12ebcc/circuitbreaker-2.1.3-py3-none-any.whl", hash = "sha256:87ba6a3ed03fdc7032bc175561c2b04d52ade9d5faf94ca2b035fbdc5e6b1dd1", size = 7737, upload-time = "2025-03-31T08:12:07.802Z" }, +] + [[package]] name = "click" version = "8.1.8" @@ -1161,6 +1170,9 @@ litellm = [ mem0 = [ { name = "mem0ai" }, ] +oci = [ + { name = "oci" }, +] openpyxl = [ { name = "openpyxl" }, ] @@ -1209,6 +1221,7 @@ requires-dist = [ { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.74.9,<3" }, { name = "mcp", specifier = "~=1.26.0" }, { name = "mem0ai", marker = "extra == 'mem0'", specifier = "~=0.1.94" }, + { name = "oci", marker = "extra == 'oci'", specifier = ">=2.161.0" }, { name = "openai", specifier = ">=1.83.0,<3" }, { name = "openpyxl", specifier = "~=3.1.5" }, { name = "openpyxl", marker = "extra == 'openpyxl'", specifier = "~=3.1.5" }, @@ -1232,7 +1245,7 @@ requires-dist = [ { name = "uv", specifier = "~=0.9.13" }, { name = "voyageai", marker = "extra == 'voyageai'", specifier = "~=0.3.5" }, ] -provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"] +provides-extras = ["a2a", "anthropic", "aws", "azure-ai-inference", "bedrock", "docling", "embeddings", "file-processing", "google-genai", "litellm", "mem0", "oci", "openpyxl", "pandas", "qdrant", "tools", "voyageai", "watson"] [[package]] name = "crewai-devtools" @@ -1353,6 +1366,12 @@ multion = [ mysql = [ { name = "pymysql" }, ] +oci = [ + { name = "oci" }, +] +oracle = [ + { name = "oracledb" }, +] oxylabs = [ { name = "oxylabs" }, ] @@ -1436,6 +1455,8 @@ requires-dist = [ { name = "multion", marker = "extra == 'multion'", specifier = ">=1.1.0" }, { name = "nest-asyncio", marker = "extra == 'bedrock'", specifier = ">=1.6.0" }, { name = "nest-asyncio", marker = "extra == 'contextual'", specifier = ">=1.6.0" }, + { name = "oci", marker = "extra == 'oci'", specifier = ">=2.161.0" }, + { name = "oracledb", marker = "extra == 'oracle'", specifier = ">=2.5.1" }, { name = "oxylabs", marker = "extra == 'oxylabs'", specifier = "==2.0.0" }, { name = "patronus", marker = "extra == 'patronus'", specifier = ">=0.0.16" }, { name = "playwright", marker = "extra == 'bedrock'", specifier = ">=1.52.0" }, @@ -1466,7 +1487,7 @@ requires-dist = [ { name = "weaviate-client", marker = "extra == 'weaviate-client'", specifier = ">=4.10.2" }, { name = "youtube-transcript-api", specifier = "~=1.2.2" }, ] -provides-extras = ["apify", "beautifulsoup4", "bedrock", "browserbase", "composio-core", "contextual", "couchbase", "databricks-sdk", "exa-py", "firecrawl-py", "github", "hyperbrowser", "linkup-sdk", "mcp", "mongodb", "multion", "mysql", "oxylabs", "patronus", "postgresql", "qdrant-client", "rag", "scrapegraph-py", "scrapfly-sdk", "selenium", "serpapi", "singlestore", "snowflake", "spider-client", "sqlalchemy", "stagehand", "tavily-python", "weaviate-client", "xml"] +provides-extras = ["apify", "beautifulsoup4", "bedrock", "browserbase", "composio-core", "contextual", "couchbase", "databricks-sdk", "exa-py", "firecrawl-py", "github", "hyperbrowser", "linkup-sdk", "mcp", "mongodb", "multion", "mysql", "oci", "oracle", "oxylabs", "patronus", "postgresql", "qdrant-client", "rag", "scrapegraph-py", "scrapfly-sdk", "selenium", "serpapi", "singlestore", "snowflake", "spider-client", "sqlalchemy", "stagehand", "tavily-python", "weaviate-client", "xml"] [[package]] name = "cryptography" @@ -4481,6 +4502,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/be/9c/92789c596b8df838baa98fa71844d84283302f7604ed565dafe5a6b5041a/oauthlib-3.3.1-py3-none-any.whl", hash = "sha256:88119c938d2b8fb88561af5f6ee0eec8cc8d552b7bb1f712743136eb7523b7a1", size = 160065, upload-time = "2025-06-19T22:48:06.508Z" }, ] +[[package]] +name = "oci" +version = "2.168.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "circuitbreaker" }, + { name = "cryptography" }, + { name = "pyopenssl" }, + { name = "python-dateutil" }, + { name = "pytz" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/68/edf8ffbb42e97ad44d64fce85be00818d979b472dd4377dc948155f811e9/oci-2.168.1.tar.gz", hash = "sha256:b941674171b41e999b8e3adb38d4797d7b42d2bb5ff40d17c26e8ce2a7d4b605", size = 16751235, upload-time = "2026-03-10T10:50:16.244Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/3e/29e05b4f8bed3b4a89b52fc57e76ac86669fc43a59e128eb526e395eda7b/oci-2.168.1-py3-none-any.whl", hash = "sha256:d106cfffc9153b5c9de628877c967ed87bbbfbbc9d411c97feee0eba8f2e4eab", size = 34033119, upload-time = "2026-03-10T10:50:08.501Z" }, +] + [[package]] name = "ocrmac" version = "1.0.1" @@ -4752,6 +4790,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/89/267b0af1b1d0ba828f0e60642b6a5116ac1fd917cde7fc02821627029bd1/opentelemetry_semantic_conventions-0.55b1-py3-none-any.whl", hash = "sha256:5da81dfdf7d52e3d37f8fe88d5e771e191de924cfff5f550ab0b8f7b2409baed", size = 196223, upload-time = "2025-06-10T08:55:17.638Z" }, ] +[[package]] +name = "oracledb" +version = "3.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cryptography" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f7/02/70a872d1a4a739b4f7371ab8d3d5ed8c6e57e142e2503531aafcb220893c/oracledb-3.4.2.tar.gz", hash = "sha256:46e0f2278ff1fe83fbc33a3b93c72d429323ec7eed47bc9484e217776cd437e5", size = 855467, upload-time = "2026-01-28T17:25:39.91Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/5d/b8a0ca1c520fa43ae33260f6f8ca9bd468ade43da7986029bc214965df12/oracledb-3.4.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff3c89cecea62af8ca02aa33cab0f2edc0214c747eac7d3364ed6b2640cb55e4", size = 4243966, upload-time = "2026-01-28T17:25:45.05Z" }, + { url = "https://files.pythonhosted.org/packages/f6/43/26e2bbb2a6ee31392a339089e53cb2e386ca795ff4fbe2f673c167821bd6/oracledb-3.4.2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e068ef844a327877bfefbef1bc6fb7284c727bb87af80095f08d95bcaf7b8bb2", size = 2426056, upload-time = "2026-01-28T17:25:47.176Z" }, + { url = "https://files.pythonhosted.org/packages/09/ba/11ee1d044295465a04ff45c6e3023d35400bb3f67bc5fed9408f0f2dc04c/oracledb-3.4.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9f434a739405557bd57cb39b62238142bb27855a524a70dc6d397a2a8c576c9d", size = 2603062, upload-time = "2026-01-28T17:25:49.817Z" }, + { url = "https://files.pythonhosted.org/packages/c5/bc/292f2f5f7b65a667787871e300889ab8f4a3b9cfd88c5d78f828a40f6d31/oracledb-3.4.2-cp310-cp310-win32.whl", hash = "sha256:00c79448017f367bb7ab6900efe0706658a53768abea2b4519a4c9b2d5743890", size = 1496639, upload-time = "2026-01-28T17:25:51.298Z" }, + { url = "https://files.pythonhosted.org/packages/21/23/81931c16663e771937c0161bb90460668d2a5f7982b5030ab7bef3b3a4f9/oracledb-3.4.2-cp310-cp310-win_amd64.whl", hash = "sha256:574c8280d49cbbe21dbe03fc28356d9b9a5b9e300ebcde6c6d106e51453a7e65", size = 1837314, upload-time = "2026-01-28T17:25:52.718Z" }, + { url = "https://files.pythonhosted.org/packages/64/80/be263b668ba32b258d07c85f7bfb6967a9677e016c299207b28734f04c4b/oracledb-3.4.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b8e4b8a852251cef09038b75f30fce1227010835f4e19cfbd436027acba2697c", size = 4228552, upload-time = "2026-01-28T17:25:54.844Z" }, + { url = "https://files.pythonhosted.org/packages/91/bc/e832a649529da7c60409a81be41f3213b4c7ffda4fe424222b2145e8d43c/oracledb-3.4.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1617a1db020346883455af005efbefd51be2c4d797e43b1b38455a19f8526b48", size = 2421924, upload-time = "2026-01-28T17:25:56.984Z" }, + { url = "https://files.pythonhosted.org/packages/86/21/d867c37e493a63b5521bd248110ad5b97b18253d64a30703e3e8f3d9631e/oracledb-3.4.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5ed78d7e7079a778062744ccf42141ce4806818c3f4dd6463e4a7edd561c9f86", size = 2599301, upload-time = "2026-01-28T17:25:58.529Z" }, + { url = "https://files.pythonhosted.org/packages/2a/de/9b1843ea27f7791449652d7f340f042c3053336d2c11caf29e59bab86189/oracledb-3.4.2-cp311-cp311-win32.whl", hash = "sha256:0e16fe3d057e0c41a23ad2ae95bfa002401690773376d476be608f79ac74bf05", size = 1492890, upload-time = "2026-01-28T17:26:00.662Z" }, + { url = "https://files.pythonhosted.org/packages/d6/10/cbc8afa2db0cec80530858d3e4574f9734fae8c0b7f1df261398aa026c5f/oracledb-3.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:f93cae08e8ed20f2d5b777a8602a71f9418389c661d2c937e84d94863e7e7011", size = 1843355, upload-time = "2026-01-28T17:26:02.637Z" }, + { url = "https://files.pythonhosted.org/packages/8f/81/2e6154f34b71cd93b4946c73ea13b69d54b8d45a5f6bbffe271793240d21/oracledb-3.4.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a7396664e592881225ba66385ee83ce339d864f39003d6e4ca31a894a7e7c552", size = 4220806, upload-time = "2026-01-28T17:26:04.322Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a9/a1d59aaac77d8f727156ec6a3b03399917c90b7da4f02d057f92e5601f56/oracledb-3.4.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f04a2d62073407672f114d02529921de0677c6883ed7c64d8d1a3c04caa3238", size = 2233795, upload-time = "2026-01-28T17:26:05.877Z" }, + { url = "https://files.pythonhosted.org/packages/94/ec/8c4a38020cd251572bd406ddcbde98ca052ec94b5684f9aa9ef1ddfcc68c/oracledb-3.4.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8d75e4f879b908be66cce05ba6c05791a5dbb4a15e39abc01aa25c8a2492bd9", size = 2424756, upload-time = "2026-01-28T17:26:07.35Z" }, + { url = "https://files.pythonhosted.org/packages/fa/7d/c251c2a8567151ccfcfbe3467ea9a60fb5480dc4719342e2e6b7a9679e5d/oracledb-3.4.2-cp312-cp312-win32.whl", hash = "sha256:31b7ee83c23d0439778303de8a675717f805f7e8edb5556d48c4d8343bcf14f5", size = 1453486, upload-time = "2026-01-28T17:26:08.869Z" }, + { url = "https://files.pythonhosted.org/packages/4c/78/c939f3c16fb39400c4734d5a3340db5659ba4e9dce23032d7b33ccfd3fe5/oracledb-3.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:ac25a0448fc830fb7029ad50cd136cdbfcd06975d53967e269772cc5cb8c203a", size = 1794445, upload-time = "2026-01-28T17:26:10.66Z" }, + { url = "https://files.pythonhosted.org/packages/22/68/f7126f5d911c295b57720c6b1a0609a5a2667b4546946433552a4de46333/oracledb-3.4.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:643c25d301a289a371e37fcedb59e5fa5e54fb321708e5c12821c4b55bdd8a4d", size = 4205176, upload-time = "2026-01-28T17:26:12.463Z" }, + { url = "https://files.pythonhosted.org/packages/5d/93/2fced60f92dc82e66980a8a3ba5c1ea48110bf1dd81d030edb69d88f992e/oracledb-3.4.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:55397e7eb43bb7017c03a981c736c25724182f5210951181dfe3fab0e5d457fb", size = 2231298, upload-time = "2026-01-28T17:26:14.497Z" }, + { url = "https://files.pythonhosted.org/packages/75/a7/4dd286f3a6348d786fef9e6ab2e6c9b74ca9195d9a756f2a67e45743cdf0/oracledb-3.4.2-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b26a10f9c790bd141ffc8af68520803ed4a44a9258bf7d1eea9bfdd36bd6df7f", size = 2439430, upload-time = "2026-01-28T17:26:16.044Z" }, + { url = "https://files.pythonhosted.org/packages/19/28/94bc753e5e969c60ee5d9c914e2b4ef79999eaca8e91bcab2fbf0586b80b/oracledb-3.4.2-cp313-cp313-win32.whl", hash = "sha256:b974caec2c330c22bbe765705a5ac7d98ec3022811dec2042d561a3c65cb991b", size = 1458209, upload-time = "2026-01-28T17:26:17.652Z" }, + { url = "https://files.pythonhosted.org/packages/cb/2b/593a9b2d4c12c9de3289e67d84fe023336d99f36ba51442a5a0f5ce6acf7/oracledb-3.4.2-cp313-cp313-win_amd64.whl", hash = "sha256:3df8eee1410d25360599968b1625b000f10c5ae0e47274031a7842a9dc418890", size = 1793558, upload-time = "2026-01-28T17:26:19.914Z" }, +] + [[package]] name = "orjson" version = "3.11.7" From e1eaccf19611a681c7357d366a802985cebf28a2 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 14 Mar 2026 23:13:46 -0400 Subject: [PATCH 3/9] Improve OCI and Oracle integration code clarity --- .../src/crewai_tools/oci/common.py | 4 +- .../oci/knowledge_base/retriever_tool.py | 4 ++ .../knowledge_base/retriever_tool.py | 26 +++++++-- .../crewai/llms/providers/oci/completion.py | 53 ++++++++++++++++++- .../providers/oci/embedding_callable.py | 11 +++- 5 files changed, 90 insertions(+), 8 deletions(-) diff --git a/lib/crewai-tools/src/crewai_tools/oci/common.py b/lib/crewai-tools/src/crewai_tools/oci/common.py index 156e30cbe0f..c26b6afdd4a 100644 --- a/lib/crewai-tools/src/crewai_tools/oci/common.py +++ b/lib/crewai-tools/src/crewai_tools/oci/common.py @@ -8,6 +8,7 @@ def get_oci_module() -> Any: + """Import the OCI SDK lazily so optional dependencies stay optional.""" try: import oci # type: ignore[import-untyped] except ImportError: @@ -26,7 +27,7 @@ def create_oci_client_kwargs( service_endpoint: str | None = None, timeout: tuple[int, int] = (10, 120), ) -> dict[str, Any]: - """Create standard OCI client kwargs for CrewAI tools.""" + """Build standard OCI SDK client kwargs shared by the tool integrations.""" oci = get_oci_module() client_kwargs: dict[str, Any] = { "config": {}, @@ -96,4 +97,5 @@ def parse_object_storage_path(file_path: str) -> tuple[str | None, str, str]: def get_region() -> str: + """Return the default OCI region for tools that support region fallbacks.""" return os.getenv("OCI_REGION", DEFAULT_OCI_REGION) diff --git a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py index 305dc72434b..68f4d1f6015 100644 --- a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py +++ b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py @@ -29,6 +29,7 @@ class OCIKnowledgeBaseToolSchema(FixedOCIKnowledgeBaseToolSchema): class OCIKnowledgeBaseTool(RagTool): + """RAG tool preconfigured to use OCI embeddings as the backing embedder.""" name: str = "OCI Knowledge Base Tool" description: str = ( "A CrewAI-managed knowledge base tool powered by OCI embeddings." @@ -49,6 +50,8 @@ def __init__( config: dict[str, Any] | None = None, **kwargs: Any, ) -> None: + # Keep the OCI embedder config serializable so the underlying RagTool can + # build or rebuild the embedder through CrewAI's standard provider factory. oci_embedding_config: dict[str, str] = { "model_name": cast( str, @@ -101,6 +104,7 @@ def _run( # type: ignore[override] similarity_threshold: float | None = None, limit: int | None = None, ) -> str: + """Optionally add a source, then delegate retrieval to the base RagTool.""" if knowledge_source is not None: self.add(knowledge_source) return super()._run( diff --git a/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py index 3b9f17f273c..85e33cb83bf 100644 --- a/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py +++ b/lib/crewai-tools/src/crewai_tools/oracle_db/knowledge_base/retriever_tool.py @@ -17,6 +17,7 @@ def _generate_accum_query(query: str, fuzzy: bool = False) -> str: + """Translate plain text into a simple Oracle Text ACCUM expression.""" words = re.split(r"\W+", query) tokens = [word for word in words if word] if fuzzy: @@ -29,6 +30,7 @@ class OracleSearchToolInput(BaseModel): class OracleToolBase(BaseTool): + """Shared Oracle connection/result helpers for retrieval-style tools.""" model_config = ConfigDict(arbitrary_types_allowed=True, protected_namespaces=()) client: Any | None = Field(default=None, exclude=True) @@ -57,6 +59,7 @@ def _connection_kwargs(self) -> dict[str, Any]: ) def _result_json(self, results: list[dict[str, Any]]) -> str: + """Return a stable JSON envelope for tool responses.""" if results: return json.dumps({"results": results}, indent=2) return json.dumps({"message": "No results found for the given query."}, indent=2) @@ -103,6 +106,7 @@ def model_post_init(self, __context: Any) -> None: self.embedder = build_embedder(self.embedding_model) def _embed_query(self, query: str) -> array.array[float]: + """Build a float32 vector compatible with Oracle VECTOR columns.""" if self.embedder is None: raise ValueError( "OracleVectorSearchTool requires either embedder or embedding_model." @@ -116,6 +120,7 @@ def _embed_query(self, query: str) -> array.array[float]: return array.array("f", [float(value) for value in embedding]) def _run(self, query: str) -> str: + """Run plain vector similarity search against an Oracle VECTOR column.""" table_name = validate_identifier(self.table_name, field_name="table_name") text_column = validate_identifier(self.text_column, field_name="text_column") embedding_column = validate_identifier( @@ -148,7 +153,9 @@ def _run(self, query: str) -> str: f"FETCH FIRST {number_of_results} ROWS ONLY" ) - with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with oracle_connection_context( + self.client, **self._connection_kwargs() + ) as connection: with connection.cursor() as cursor: cursor.execute(sql, query_embedding=self._embed_query(query)) columns = [column[0].lower() for column in cursor.description] @@ -210,6 +217,7 @@ class OracleTextSearchTool(OracleToolBase): ) def _run(self, query: str) -> str: + """Run Oracle Text retrieval using CONTAINS/SCORE against the table.""" table_name = validate_identifier(self.table_name, field_name="table_name") text_column = validate_identifier(self.text_column, field_name="text_column") metadata_columns = [ @@ -219,7 +227,9 @@ def _run(self, query: str) -> str: ] number_of_results = max(1, self.number_of_results) - search_text = query if self.operator_search else _generate_accum_query(query, self.fuzzy) + search_text = ( + query if self.operator_search else _generate_accum_query(query, self.fuzzy) + ) if not search_text: return self._result_json([]) @@ -231,7 +241,9 @@ def _run(self, query: str) -> str: f"ORDER BY score DESC FETCH FIRST {number_of_results} ROWS ONLY" ) - with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with oracle_connection_context( + self.client, **self._connection_kwargs() + ) as connection: with connection.cursor() as cursor: cursor.execute(sql, query=search_text) columns = [column[0].lower() for column in cursor.description] @@ -281,6 +293,7 @@ class OracleHybridSearchTool(OracleToolBase): ) def _build_search_params(self, query: str) -> dict[str, Any]: + """Assemble the JSON payload consumed by DBMS_HYBRID_VECTOR.SEARCH.""" search_params = dict(self.params or {}) search_params["hybrid_index_name"] = validate_identifier( self.hybrid_index_name, field_name="hybrid_index_name" @@ -313,6 +326,7 @@ def _build_search_params(self, query: str) -> dict[str, Any]: return search_params def _run(self, query: str) -> str: + """Run hybrid retrieval through a prebuilt Oracle hybrid vector index.""" table_name = validate_identifier(self.table_name, field_name="table_name") text_column = validate_identifier(self.text_column, field_name="text_column") metadata_column = None @@ -335,8 +349,11 @@ def _run(self, query: str) -> str: fetch_columns_sql = ", ".join(fetch_columns) row_sql = f"SELECT {fetch_columns_sql} FROM {table_name} WHERE rowid = :1" # noqa: S608 - with oracle_connection_context(self.client, **self._connection_kwargs()) as connection: + with oracle_connection_context( + self.client, **self._connection_kwargs() + ) as connection: with connection.cursor() as cursor: + # Oracle expects a JSON string here rather than a bound Python dict. cursor.execute( "SELECT DBMS_HYBRID_VECTOR.SEARCH(json(:search_params))", search_params=json.dumps(search_params), @@ -346,6 +363,7 @@ def _run(self, query: str) -> str: return self._result_json([]) raw_payload = raw[0][0] if hasattr(raw_payload, "read"): + # Some Oracle drivers return the JSON document as a LOB. raw_payload = raw_payload.read() rowids = json.loads(raw_payload) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index f563e931023..9a6894afd61 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -52,7 +52,12 @@ def create_oci_client_kwargs( auth_file_location: str, auth_profile: str, ) -> dict[str, Any]: - """Create authenticated OCI client kwargs.""" + """Build OCI SDK client kwargs for the supported auth modes. + + The native provider is used from both sync and thread-offloaded async paths, + so we centralize client construction here instead of duplicating auth logic + across `call`, `acall`, and streaming code paths. + """ oci = _get_oci_module() client_kwargs: dict[str, Any] = { "config": {}, @@ -229,6 +234,12 @@ def _message_has_multimodal_content(self, content: Any) -> bool: return False def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects. + + CrewAI accepts OpenAI-style multimodal payloads. OCI expects strongly + typed SDK content objects, so this method is the normalization boundary + between the two representations. + """ models = self._oci.generative_ai_inference.models if isinstance(content, str): return [models.TextContent(text=content or ".")] @@ -303,6 +314,7 @@ def _build_generic_content(self, content: Any) -> list[Any]: return processed_content or [models.TextContent(text=".")] def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" models = self._oci.generative_ai_inference.models role_map = { "user": models.UserMessage, @@ -349,6 +361,9 @@ def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: self._tool_result_guidance_enabled() and any(str(message.get("role", "")).lower() == "tool" for message in messages) ): + # OCI generic models do not automatically know that the tool phase has + # ended. Appending a final system hint keeps the model focused on + # synthesizing the tool results instead of trying to emit more tool JSON. oci_messages.append( models.SystemMessage( content=[models.TextContent(text=_OCI_TOOL_RESULT_GUIDANCE)] @@ -357,7 +372,15 @@ def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: return oci_messages - def _build_cohere_chat_history(self, messages: list[LLMMessage]) -> tuple[list[Any], list[Any] | None, str]: + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], list[Any] | None, str]: + """Translate CrewAI messages into Cohere's split history/tool-results shape. + + OCI's Cohere API does not accept the same unified message structure as + OCI's generic chat API. Tool outputs for the latest turn are provided via + `tool_results`, while older turns remain in `chat_history`. + """ models = self._oci.generative_ai_inference.models chat_history: list[Any] = [] @@ -595,6 +618,7 @@ def _build_chat_request( *, is_stream: bool = False, ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" models = self._oci.generative_ai_inference.models if self.oci_provider == "cohere": @@ -694,6 +718,7 @@ def _extract_text(self, response: Any) -> str: return "".join(part.text for part in content if getattr(part, "text", None)) def _extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + """Normalize provider-specific tool calls back into CrewAI's shape.""" chat_response = response.data.chat_response raw_tool_calls: list[Any] = [] if self.oci_provider == "cohere": @@ -779,6 +804,12 @@ def _extract_response_metadata(self, response: Any) -> dict[str, Any]: return metadata def _parse_stream_event(self, event: Any) -> dict[str, Any]: + """Convert OCI SSE event payloads into plain dicts. + + The SDK surfaces event payloads as strings or mapping-like objects + depending on provider/model family, so the streaming parser works against + a single normalized representation. + """ event_data = getattr(event, "data", None) if not event_data: return {} @@ -941,6 +972,12 @@ def _handle_tool_calls( response_model: type[BaseModel] | None, tool_calls: list[dict[str, Any]], ) -> str | BaseModel | list[dict[str, Any]]: + """Execute one round of tool calls and recurse until the model finishes. + + OCI returns native tool-call payloads, but CrewAI owns the actual tool + execution loop. We append assistant/tool messages back into the transcript + so the next OCI call sees the full conversation state. + """ if tool_calls and not available_functions: self._emit_call_completed_event( response=tool_calls, @@ -1112,6 +1149,11 @@ def _stream_call_impl( tool_depth: int, response_model: type[BaseModel] | None, ) -> str | BaseModel | list[dict[str, Any]]: + """Handle OCI streaming while reconstructing final text/tool state. + + OCI streams partial tool-call fragments, so we accumulate them by index + and only hand them to CrewAI once the stream completes. + """ normalized_messages = self._normalize_messages(messages) chat_request = self._build_chat_request( normalized_messages, @@ -1246,6 +1288,12 @@ def iter_stream( from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> Any: + """Yield raw text chunks from OCI without triggering tool recursion. + + This is the lowest-level public streaming primitive for the provider. + `astream()` wraps it for async callers, while `call(stream=True)` uses the + higher-level `_stream_call_impl()` path that also handles tool calls. + """ normalized_messages = self._normalize_messages(messages) chat_request = self._build_chat_request( normalized_messages, @@ -1291,6 +1339,7 @@ async def astream( from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> Any: + """Expose the sync OCI SSE stream through an async generator facade.""" loop = asyncio.get_running_loop() queue: asyncio.Queue[str | None] = asyncio.Queue() error_holder: list[BaseException] = [] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py index e3b4c052a2c..609e592126f 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -37,7 +37,7 @@ def create_oci_client_kwargs( auth_profile: str, timeout: tuple[int, int], ) -> dict[str, Any]: - """Create authenticated OCI client kwargs.""" + """Build OCI SDK client kwargs for embedding requests.""" oci = _get_oci_module() client_kwargs: dict[str, Any] = { "config": {}, @@ -143,6 +143,7 @@ def get_config(self) -> dict[str, Any]: return config def _get_serving_mode(self) -> Any: + """Resolve either an on-demand model id or a dedicated endpoint id.""" oci = _get_oci_module() model_name = self._config.get("model_name") if not model_name: @@ -160,6 +161,11 @@ def _get_serving_mode(self) -> Any: def _build_request( self, inputs: list[str], *, input_type: str | None = None ) -> Any: + """Build a single OCI embedding request payload. + + The same endpoint handles text and image embeddings. `input_type` lets + callers override the default text mode when building multimodal requests. + """ oci = _get_oci_module() compartment_id = self._config.get("compartment_id") or os.getenv( "OCI_COMPARTMENT_ID" @@ -193,12 +199,14 @@ def _build_request( return oci.generative_ai_inference.models.EmbedTextDetails(**request_kwargs) def _batch_inputs(self, input: list[str]) -> Iterator[list[str]]: + """Chunk large embedding jobs to stay within OCI request limits.""" batch_size = self._config.get("batch_size", 96) for index in range(0, len(input), batch_size): yield input[index : index + batch_size] @staticmethod def _to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + """Normalize image inputs into the data-URI form OCI expects.""" if isinstance(image, Path): resolved_mime = mimetypes.guess_type(image.name)[0] or mime_type data = image.read_bytes() @@ -245,6 +253,7 @@ def embed_image_batch( *, mime_type: str = "image/png", ) -> Embeddings: + """Embed one or more images through OCI's IMAGE input mode.""" embeddings: Embeddings = [] for image in images: data_uri = self._to_data_uri(image, mime_type=mime_type) From a33a7a4a41a7ed06e511e1ef21ef58fb1ee9f06e Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 14 Mar 2026 23:19:49 -0400 Subject: [PATCH 4/9] Clean OCI examples and test defaults --- docs/en/concepts/llms.mdx | 13 +++++-------- .../cloud-storage/ociknowledgebasetool.mdx | 8 ++++---- .../tools/integration/ociinvokeagenttool.mdx | 2 +- .../src/crewai_tools/oci/common.py | 2 +- .../oci/knowledge_base/retriever_tool.py | 2 +- .../tests/tools/oracle_db/conftest.py | 18 +++--------------- .../crewai/llms/providers/oci/completion.py | 2 +- .../providers/oci/embedding_callable.py | 2 +- .../tests/rag/embeddings/test_factory_oci.py | 6 +++--- 9 files changed, 20 insertions(+), 35 deletions(-) diff --git a/docs/en/concepts/llms.mdx b/docs/en/concepts/llms.mdx index 40efc86605a..58f98fd016c 100644 --- a/docs/en/concepts/llms.mdx +++ b/docs/en/concepts/llms.mdx @@ -781,18 +781,15 @@ In this section, you'll find detailed examples that help you select, configure, - `cohere.command-a-vision` for the top Cohere multimodal tier once OCI Cohere vision formatting is enabled in CrewAI **Recommended Regions for Cohere in OCI:** - - `eu-frankfurt-1` - - `us-ashburn-1` - - `eu-paris-1` - - `uk-london-1` - - `ap-mumbai-1` + - choose any subscribed OCI region where the target model is available + - common examples include `us-chicago-1`, `us-ashburn-1`, `uk-london-1`, and `eu-paris-1` ```toml Code # Required OCI_COMPARTMENT_ID=ocid1.compartment.oc1..exampleuniqueID # Optional when not passing service_endpoint directly - OCI_REGION=eu-frankfurt-1 + OCI_REGION= # Authentication options OCI_AUTH_TYPE=API_KEY @@ -800,7 +797,7 @@ In this section, you'll find detailed examples that help you select, configure, OCI_AUTH_FILE_LOCATION=~/.oci/config # Optional explicit endpoint override - OCI_SERVICE_ENDPOINT=https://inference.generativeai.eu-frankfurt-1.oci.oraclecloud.com + OCI_SERVICE_ENDPOINT=https://inference.generativeai..oci.oraclecloud.com ``` **Basic Usage:** @@ -950,7 +947,7 @@ In this section, you'll find detailed examples that help you select, configure, from crewai import LLM llm = LLM( - model="oci/ocid1.generativeaiendpoint.oc1.eu-frankfurt-1.exampleuniqueID", + model="oci/ocid1.generativeaiendpoint.oc1..exampleuniqueID", compartment_id="ocid1.compartment.oc1..exampleuniqueID", ) ``` diff --git a/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx b/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx index ca378b4afda..fe4f96abe9e 100644 --- a/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx +++ b/docs/en/tools/cloud-storage/ociknowledgebasetool.mdx @@ -27,7 +27,7 @@ from crewai_tools import OCIKnowledgeBaseTool kb_tool = OCIKnowledgeBaseTool( knowledge_source="./oracle-architecture.pdf", compartment_id="ocid1.compartment.oc1..exampleuniqueID", - region="eu-frankfurt-1", + region="", ) agent = Agent( @@ -45,7 +45,7 @@ from crewai_tools import OCIKnowledgeBaseTool kb_tool = OCIKnowledgeBaseTool( compartment_id="ocid1.compartment.oc1..exampleuniqueID", - region="eu-frankfurt-1", + region="", ) kb_tool.add("./runbooks/networking.md") @@ -64,7 +64,7 @@ The tool defaults to this embedding configuration: "config": { "model_name": "cohere.embed-english-v3.0", "compartment_id": "ocid1.compartment.oc1..exampleuniqueID", - "region": "eu-frankfurt-1", + "region": "", "auth_type": "API_KEY", "auth_profile": "DEFAULT", "auth_file_location": "~/.oci/config", @@ -96,7 +96,7 @@ kb_tool = OCIKnowledgeBaseTool( ```bash OCI_COMPARTMENT_ID=ocid1.compartment.oc1..exampleuniqueID -OCI_REGION=eu-frankfurt-1 +OCI_REGION= OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=DEFAULT OCI_AUTH_FILE_LOCATION=~/.oci/config diff --git a/docs/en/tools/integration/ociinvokeagenttool.mdx b/docs/en/tools/integration/ociinvokeagenttool.mdx index a58b3fd0675..a1ba9ee7560 100644 --- a/docs/en/tools/integration/ociinvokeagenttool.mdx +++ b/docs/en/tools/integration/ociinvokeagenttool.mdx @@ -39,5 +39,5 @@ OCI_AGENT_ENDPOINT_ID=ocid1.genaiagentendpoint.oc1..exampleuniqueID OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=DEFAULT OCI_AUTH_FILE_LOCATION=~/.oci/config -OCI_AGENT_RUNTIME_ENDPOINT=https://agent-runtime.generativeai.eu-frankfurt-1.oci.oraclecloud.com +OCI_AGENT_RUNTIME_ENDPOINT=https://agent-runtime.generativeai..oci.oraclecloud.com ``` diff --git a/lib/crewai-tools/src/crewai_tools/oci/common.py b/lib/crewai-tools/src/crewai_tools/oci/common.py index c26b6afdd4a..784086bdb9d 100644 --- a/lib/crewai-tools/src/crewai_tools/oci/common.py +++ b/lib/crewai-tools/src/crewai_tools/oci/common.py @@ -4,7 +4,7 @@ from typing import Any -DEFAULT_OCI_REGION = "eu-frankfurt-1" +DEFAULT_OCI_REGION = "us-chicago-1" def get_oci_module() -> Any: diff --git a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py index 68f4d1f6015..f22923608b4 100644 --- a/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py +++ b/lib/crewai-tools/src/crewai_tools/oci/knowledge_base/retriever_tool.py @@ -61,7 +61,7 @@ def __init__( str, compartment_id or os.getenv("OCI_COMPARTMENT_ID", ""), ), - "region": cast(str, region or os.getenv("OCI_REGION", "eu-frankfurt-1")), + "region": cast(str, region or os.getenv("OCI_REGION", "us-chicago-1")), "auth_type": auth_type, "auth_profile": cast( str, diff --git a/lib/crewai-tools/tests/tools/oracle_db/conftest.py b/lib/crewai-tools/tests/tools/oracle_db/conftest.py index 8d5ad1f828f..5025321574e 100644 --- a/lib/crewai-tools/tests/tools/oracle_db/conftest.py +++ b/lib/crewai-tools/tests/tools/oracle_db/conftest.py @@ -50,7 +50,6 @@ def oracle_connection_mock() -> MagicMock: def oracle_live_config() -> dict[str, Any]: default_wallet_dir = None for candidate in ( - os.path.expanduser("~/.oracle-wallet/deepresearch"), os.path.expanduser("~/.langchain-oracle-wallet"), os.path.expanduser("~/.oracle-wallet"), ): @@ -58,16 +57,6 @@ def oracle_live_config() -> dict[str, Any]: default_wallet_dir = candidate break - default_dsn = None - if default_wallet_dir: - wallet_name = os.path.basename(default_wallet_dir) - if wallet_name == "deepresearch": - default_dsn = "deepresearch_high" - elif wallet_name == ".langchain-oracle-wallet": - default_dsn = "deepresearch_high" - elif wallet_name == ".oracle-wallet": - default_dsn = "locuscheck_high" - password = os.getenv( "ORACLE_DB_PASSWORD", os.getenv("ORACLE_PASSWORD"), @@ -76,7 +65,7 @@ def oracle_live_config() -> dict[str, Any]: return { "user": os.getenv("ORACLE_DB_USER", os.getenv("ORACLE_USER", "ADMIN")), "password": password, - "dsn": os.getenv("ORACLE_DB_DSN", os.getenv("ORACLE_DSN", default_dsn)), + "dsn": os.getenv("ORACLE_DB_DSN", os.getenv("ORACLE_DSN")), "config_dir": os.getenv("ORACLE_DB_CONFIG_DIR", default_wallet_dir), "wallet_location": os.getenv("ORACLE_DB_WALLET_LOCATION", default_wallet_dir), "wallet_password": os.getenv( @@ -158,7 +147,7 @@ def oracle_live_vector_tool_kwargs(oracle_live_config: dict[str, Any]) -> dict[s "config": { "model_name": os.getenv("OCI_EMBED_MODEL_NAME", "cohere.embed-v4.0"), "compartment_id": os.getenv("OCI_COMPARTMENT_ID"), - "auth_profile": os.getenv("OCI_AUTH_PROFILE", "API_KEY_AUTH"), + "auth_profile": os.getenv("OCI_AUTH_PROFILE", "DEFAULT"), "auth_file_location": os.getenv( "OCI_AUTH_FILE_LOCATION", os.path.expanduser("~/.oci/config") ), @@ -310,8 +299,7 @@ def oracle_hybrid_live_resources(oracle_live_connection): try: # Oracle Hybrid Vector Index creation uses the database-side vectorizer # registered in DBMS_VECTOR_CHAIN. That is distinct from OCI GenAI - # embedding models available via API_KEY_AUTH (for example - # cohere.embed-v4.0 in us-chicago-1). + # embedding models available through the configured OCI profile. # # On Autonomous Database 26ai, this hybrid-index path depends on a # supported in-database embedding/vectorizer model being installed in diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 9a6894afd61..cef6620d409 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -25,7 +25,7 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" -DEFAULT_OCI_REGION = "eu-frankfurt-1" +DEFAULT_OCI_REGION = "us-chicago-1" _OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") _OCI_TOOL_RESULT_GUIDANCE = ( "You have received tool results above. Respond to the user with a helpful, " diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py index 609e592126f..623d5844c6c 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -16,7 +16,7 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" -DEFAULT_OCI_REGION = "eu-frankfurt-1" +DEFAULT_OCI_REGION = "us-chicago-1" def _get_oci_module() -> Any: diff --git a/lib/crewai/tests/rag/embeddings/test_factory_oci.py b/lib/crewai/tests/rag/embeddings/test_factory_oci.py index a47fab3f3bd..fe10f56084b 100644 --- a/lib/crewai/tests/rag/embeddings/test_factory_oci.py +++ b/lib/crewai/tests/rag/embeddings/test_factory_oci.py @@ -66,7 +66,7 @@ def test_build_embedder_oci(mock_import): "config": { "model_name": "cohere.embed-english-v3.0", "compartment_id": "ocid1.compartment.oc1..test", - "region": "eu-frankfurt-1", + "region": "us-chicago-1", "auth_profile": "DEFAULT", }, } @@ -79,7 +79,7 @@ def test_build_embedder_oci(mock_import): call_kwargs = mock_provider_class.call_args.kwargs assert call_kwargs["model_name"] == "cohere.embed-english-v3.0" assert call_kwargs["compartment_id"] == "ocid1.compartment.oc1..test" - assert call_kwargs["region"] == "eu-frankfurt-1" + assert call_kwargs["region"] == "us-chicago-1" def test_oci_embedding_function_batches_requests(monkeypatch): @@ -102,7 +102,7 @@ def test_oci_embedding_function_batches_requests(monkeypatch): embedder = OCIEmbeddingFunction( model_name="cohere.embed-english-v3.0", compartment_id="ocid1.compartment.oc1..test", - region="eu-frankfurt-1", + region="us-chicago-1", batch_size=2, ) From 072c276b8d7fac5e0f6496ce15efc0f4b2fc624d Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sat, 14 Mar 2026 23:28:42 -0400 Subject: [PATCH 5/9] Resolve OCI provider review feedback --- .../crewai/llms/providers/oci/completion.py | 113 +++++------------- .../providers/oci/embedding_callable.py | 71 ++--------- lib/crewai/src/crewai/utilities/oci.py | 72 +++++++++++ lib/crewai/tests/llms/oci/test_oci.py | 44 +++++++ 4 files changed, 154 insertions(+), 146 deletions(-) create mode 100644 lib/crewai/src/crewai/utilities/oci.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index cef6620d409..3c94feca6e5 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -14,6 +14,7 @@ from crewai.events.types.llm_events import LLMCallType from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage @@ -36,75 +37,8 @@ def _get_oci_module() -> Any: - try: - import oci # type: ignore[import-untyped] - except ImportError: - raise ImportError( - 'OCI native provider not available, to install: uv add "crewai[oci]"' - ) from None - return oci - - -def create_oci_client_kwargs( - *, - auth_type: str, - service_endpoint: str | None, - auth_file_location: str, - auth_profile: str, -) -> dict[str, Any]: - """Build OCI SDK client kwargs for the supported auth modes. - - The native provider is used from both sync and thread-offloaded async paths, - so we centralize client construction here instead of duplicating auth logic - across `call`, `acall`, and streaming code paths. - """ - oci = _get_oci_module() - client_kwargs: dict[str, Any] = { - "config": {}, - "service_endpoint": service_endpoint, - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": (10, 240), - } - - auth_type_upper = auth_type.upper() - - if auth_type_upper == "API_KEY": - client_kwargs["config"] = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - elif auth_type_upper == "SECURITY_TOKEN": - config = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - key_file = config["key_file"] - security_token_file = config["security_token_file"] - private_key = oci.signer.load_private_key_from_file(key_file, None) - with open(security_token_file, encoding="utf-8") as file: - security_token = file.read() - client_kwargs["config"] = config - client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( - security_token, private_key - ) - elif auth_type_upper == "INSTANCE_PRINCIPAL": - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif auth_type_upper == "RESOURCE_PRINCIPAL": - client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() - else: - valid_types = [ - "API_KEY", - "SECURITY_TOKEN", - "INSTANCE_PRINCIPAL", - "RESOURCE_PRINCIPAL", - ] - raise ValueError( - f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" - ) - - return client_kwargs + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() class OCICompletion(BaseLLM): @@ -181,6 +115,8 @@ def __init__( service_endpoint=self.service_endpoint, auth_file_location=self.auth_file_location, auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, ) self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs @@ -383,8 +319,17 @@ def _build_cohere_chat_history( """ models = self._oci.generative_ai_inference.models chat_history: list[Any] = [] + trailing_tool_count = 0 + for message in reversed(messages): + if str(message.get("role", "")).lower() != "tool": + break + trailing_tool_count += 1 - for message in messages[:-1]: + history_messages = ( + messages[:-trailing_tool_count] if trailing_tool_count else messages[:-1] + ) + + for message in history_messages: role = str(message.get("role", "user")).lower() content = message.get("content", "") if self._message_has_multimodal_content(content): @@ -472,7 +417,7 @@ def _build_cohere_chat_history( "parameters": parameters, } - for message in messages: + for message in messages[-trailing_tool_count:]: if str(message.get("role", "")).lower() != "tool": continue tool_call_id = message.get("tool_call_id") @@ -1464,18 +1409,20 @@ async def abatch( from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> list[str | Any]: - return [ - await self.acall( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) - for messages in messages_batch - ] + return await asyncio.gather( + *[ + self.acall( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + for messages in messages_batch + ] + ) def _chat(self, chat_details: Any) -> Any: # The OCI SDK client is shared across sync + thread-offloaded async calls. diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py index 623d5844c6c..8a348eb0224 100644 --- a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -13,6 +13,7 @@ from typing_extensions import Unpack from crewai.rag.embeddings.providers.oci.types import OCIProviderConfig +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" @@ -20,68 +21,8 @@ def _get_oci_module() -> Any: - try: - import oci # type: ignore[import-untyped] - except ImportError as e: - raise ImportError( - "oci is required for OCI embeddings. Install it with: uv add 'crewai[oci]'" - ) from e - return oci - - -def create_oci_client_kwargs( - *, - auth_type: str, - service_endpoint: str | None, - auth_file_location: str, - auth_profile: str, - timeout: tuple[int, int], -) -> dict[str, Any]: - """Build OCI SDK client kwargs for embedding requests.""" - oci = _get_oci_module() - client_kwargs: dict[str, Any] = { - "config": {}, - "service_endpoint": service_endpoint, - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": timeout, - } - - auth_type_upper = auth_type.upper() - if auth_type_upper == "API_KEY": - client_kwargs["config"] = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - elif auth_type_upper == "SECURITY_TOKEN": - config = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - private_key = oci.signer.load_private_key_from_file(config["key_file"], None) - with open(config["security_token_file"], encoding="utf-8") as file: - security_token = file.read() - client_kwargs["config"] = config - client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( - security_token, private_key - ) - elif auth_type_upper == "INSTANCE_PRINCIPAL": - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif auth_type_upper == "RESOURCE_PRINCIPAL": - client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() - else: - valid_types = [ - "API_KEY", - "SECURITY_TOKEN", - "INSTANCE_PRINCIPAL", - "RESOURCE_PRINCIPAL", - ] - raise ValueError( - f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" - ) - - return client_kwargs + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() class OCIEmbeddingFunction(EmbeddingFunction[Documents]): @@ -104,6 +45,7 @@ def __init__(self, **kwargs: Unpack[OCIProviderConfig]) -> None: auth_file_location=kwargs.get("auth_file_location", "~/.oci/config"), auth_profile=kwargs.get("auth_profile", "DEFAULT"), timeout=kwargs.get("timeout", (10, 120)), + oci_module=_get_oci_module(), ) self._client = ( _get_oci_module().generative_ai_inference.GenerativeAiInferenceClient( @@ -245,7 +187,10 @@ def embed_image( *, mime_type: str = "image/png", ) -> list[float]: - return [float(value) for value in self.embed_image_batch([image], mime_type=mime_type)[0]] + return [ + float(value) + for value in self.embed_image_batch([image], mime_type=mime_type)[0] + ] def embed_image_batch( self, diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 00000000000..7935530690c --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py index 066616127a1..7bbec09d741 100644 --- a/lib/crewai/tests/llms/oci/test_oci.py +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -376,6 +376,50 @@ def test_oci_cohere_completion_formats_tool_calls( } +def test_oci_cohere_request_excludes_trailing_tool_messages_from_chat_history( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["cohere_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + chat_history, tool_results, message_text = llm._build_cohere_chat_history( + [ + {"role": "user", "content": "First question"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "function": {"name": "lookup_a", "arguments": '{"query":"a"}'}, + }, + { + "id": "call_2", + "function": {"name": "lookup_b", "arguments": '{"query":"b"}'}, + }, + ], + }, + {"role": "tool", "tool_call_id": "call_1", "name": "lookup_a", "content": "A"}, + {"role": "tool", "tool_call_id": "call_2", "name": "lookup_b", "content": "B"}, + ] + ) + + assert len(chat_history) == 2 + assert chat_history[0].message == "First question" + assert len(chat_history[1].tool_calls) == 2 + assert tool_results is not None + assert len(tool_results) == 2 + assert tool_results[0].call.name == "lookup_a" + assert tool_results[1].call.name == "lookup_b" + assert message_text == "" + + def test_oci_completion_returns_tool_calls_for_executor( patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] ): From 3a61c6495a5980d98290d2c4b44671784f32d84d Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sun, 15 Mar 2026 00:36:07 -0400 Subject: [PATCH 6/9] Address remaining OCI provider review comments --- .../crewai/llms/providers/oci/completion.py | 35 +++---------- lib/crewai/tests/llms/oci/test_oci.py | 52 +++++++++++++++++++ 2 files changed, 60 insertions(+), 27 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 3c94feca6e5..6c87ec734eb 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -146,6 +146,8 @@ def _normalize_messages( return self._format_messages(messages) def _coerce_text(self, content: Any) -> str: + if content is None: + return "" if isinstance(content, str): return content if isinstance(content, list): @@ -1034,7 +1036,9 @@ def _call_impl( tool_depth: int, response_model: type[BaseModel] | None, ) -> str | BaseModel | list[dict[str, Any]]: - normalized_messages = self._normalize_messages(messages) + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) chat_request = self._build_chat_request( normalized_messages, tools=tools, @@ -1099,7 +1103,9 @@ def _stream_call_impl( OCI streams partial tool-call fragments, so we accumulate them by index and only hand them to CrewAI once the stream completes. """ - normalized_messages = self._normalize_messages(messages) + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) chat_request = self._build_chat_request( normalized_messages, tools=tools, @@ -1399,31 +1405,6 @@ async def acall( response_model=response_model, ) - async def abatch( - self, - messages_batch: list[str | list[LLMMessage]], - tools: list[dict[str, BaseTool]] | None = None, - callbacks: list[Any] | None = None, - available_functions: dict[str, Any] | None = None, - from_task: Task | None = None, - from_agent: Agent | None = None, - response_model: type[BaseModel] | None = None, - ) -> list[str | Any]: - return await asyncio.gather( - *[ - self.acall( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) - for messages in messages_batch - ] - ) - def _chat(self, chat_details: Any) -> Any: # The OCI SDK client is shared across sync + thread-offloaded async calls. # Serialize access so sync/async calls cannot race on the same client. diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py index 7bbec09d741..2377ba8e5c8 100644 --- a/lib/crewai/tests/llms/oci/test_oci.py +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -103,6 +103,58 @@ def test_oci_completion_call_uses_chat_api( ) +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + content = llm._build_generic_content(None) + + assert content[0].text == "." + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, + oci_response_factories, + oci_unit_values: dict[str, object], + monkeypatch: pytest.MonkeyPatch, +): + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("Hello from OCI") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + normalize_call_count = 0 + original_normalize_messages = llm._normalize_messages + + def counting_normalize(messages): + nonlocal normalize_call_count + normalize_call_count += 1 + return original_normalize_messages(messages) + + monkeypatch.setattr(llm, "_normalize_messages", counting_normalize) + + result = llm.call( + [{"role": "user", "content": str(oci_unit_values["chat_prompt"])}] + ) + + assert result == "Hello from OCI" + assert normalize_call_count == 1 + + def test_oci_completion_uses_region_to_build_endpoint( monkeypatch: pytest.MonkeyPatch, patch_oci_module, From ba205495253996b39ea73f7753ad01a9889aa707 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sun, 15 Mar 2026 01:00:45 -0400 Subject: [PATCH 7/9] Address additional OCI review comments --- .../src/crewai_tools/oci/common.py | 60 +++-------- .../crewai/llms/providers/oci/completion.py | 64 +++++++++-- lib/crewai/tests/llms/oci/test_oci.py | 100 ++++++++++++++++++ 3 files changed, 168 insertions(+), 56 deletions(-) diff --git a/lib/crewai-tools/src/crewai_tools/oci/common.py b/lib/crewai-tools/src/crewai_tools/oci/common.py index 784086bdb9d..e95bcd8166b 100644 --- a/lib/crewai-tools/src/crewai_tools/oci/common.py +++ b/lib/crewai-tools/src/crewai_tools/oci/common.py @@ -3,6 +3,11 @@ import os from typing import Any +from crewai.utilities.oci import ( + create_oci_client_kwargs as shared_create_oci_client_kwargs, + get_oci_module as shared_get_oci_module, +) + DEFAULT_OCI_REGION = "us-chicago-1" @@ -10,13 +15,12 @@ def get_oci_module() -> Any: """Import the OCI SDK lazily so optional dependencies stay optional.""" try: - import oci # type: ignore[import-untyped] + return shared_get_oci_module() except ImportError: raise ImportError( "`oci` package not found, please install the optional dependency with " "`uv add 'crewai-tools[oci]'`" ) from None - return oci def create_oci_client_kwargs( @@ -28,50 +32,14 @@ def create_oci_client_kwargs( timeout: tuple[int, int] = (10, 120), ) -> dict[str, Any]: """Build standard OCI SDK client kwargs shared by the tool integrations.""" - oci = get_oci_module() - client_kwargs: dict[str, Any] = { - "config": {}, - "service_endpoint": service_endpoint, - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": timeout, - } - - auth_type_upper = auth_type.upper() - if auth_type_upper == "API_KEY": - client_kwargs["config"] = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - elif auth_type_upper == "SECURITY_TOKEN": - config = oci.config.from_file( - file_location=auth_file_location, - profile_name=auth_profile, - ) - private_key = oci.signer.load_private_key_from_file(config["key_file"], None) - with open(config["security_token_file"], encoding="utf-8") as file: - security_token = file.read() - client_kwargs["config"] = config - client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( - security_token, private_key - ) - elif auth_type_upper == "INSTANCE_PRINCIPAL": - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif auth_type_upper == "RESOURCE_PRINCIPAL": - client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() - else: - valid_types = [ - "API_KEY", - "SECURITY_TOKEN", - "INSTANCE_PRINCIPAL", - "RESOURCE_PRINCIPAL", - ] - raise ValueError( - f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" - ) - - return client_kwargs + return shared_create_oci_client_kwargs( + auth_type=auth_type, + auth_profile=auth_profile, + auth_file_location=auth_file_location, + service_endpoint=service_endpoint, + timeout=timeout, + oci_module=get_oci_module(), + ) def parse_object_storage_path(file_path: str) -> tuple[str | None, str, str]: diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 6c87ec734eb..bfc77f999d9 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -2,6 +2,7 @@ import asyncio from collections.abc import Mapping +import inspect import json import logging import os @@ -34,6 +35,11 @@ "raw JSON or tool call syntax. If you need additional information, you may " "call another tool." ) +_OCI_RESERVED_REQUEST_KWARGS = { + "tool_choice", + "parallel_tool_calls", + "tool_result_guidance", +} def _get_oci_module() -> Any: @@ -557,6 +563,23 @@ def _build_tool_choice(self) -> Any | None: "Unrecognized OCI tool_choice. Expected str, bool, or function mapping." ) + def _allowed_passthrough_request_keys(self, request_cls: type[Any]) -> set[str]: + """Return request attributes that can safely be forwarded to the OCI SDK.""" + attribute_map = getattr(request_cls, "attribute_map", None) + if isinstance(attribute_map, Mapping): + return {str(key) for key in attribute_map} + + swagger_types = getattr(request_cls, "swagger_types", None) + if isinstance(swagger_types, Mapping): + return {str(key) for key in swagger_types} + + signature = inspect.signature(request_cls) + return { + name + for name, parameter in signature.parameters.items() + if name != "self" and parameter.kind is not inspect.Parameter.VAR_KEYWORD + } + def _build_chat_request( self, messages: list[LLMMessage], @@ -632,14 +655,29 @@ def _build_chat_request( is_include_usage=True ) - passthrough_params = dict(self.additional_params) - passthrough_params.pop("tool_choice", None) - passthrough_params.pop("parallel_tool_calls", None) - passthrough_params.pop("tool_result_guidance", None) - request_kwargs.update(passthrough_params) - if self.oci_provider == "cohere": + allowed_passthrough_keys = self._allowed_passthrough_request_keys( + models.CohereChatRequest + ) + passthrough_params = { + key: value + for key, value in self.additional_params.items() + if key not in _OCI_RESERVED_REQUEST_KWARGS + and key in allowed_passthrough_keys + } + request_kwargs.update(passthrough_params) return models.CohereChatRequest(**request_kwargs) + + allowed_passthrough_keys = self._allowed_passthrough_request_keys( + models.GenericChatRequest + ) + passthrough_params = { + key: value + for key, value in self.additional_params.items() + if key not in _OCI_RESERVED_REQUEST_KWARGS + and key in allowed_passthrough_keys + } + request_kwargs.update(passthrough_params) return models.GenericChatRequest(**request_kwargs) def _extract_text(self, response: Any) -> str: @@ -971,7 +1009,9 @@ def _handle_tool_calls( from_agent=from_agent, ) if tool_result is None: - continue + tool_result = ( + f"Tool '{function_name}' failed or returned no result." + ) next_messages.append( { @@ -1117,15 +1157,13 @@ def _stream_call_impl( serving_mode=self._build_serving_mode(), chat_request=chat_request, ) - response = self._chat(chat_details) - full_response = "" tool_calls_by_index: dict[int, dict[str, Any]] = {} usage_data: dict[str, int] = {} response_metadata: dict[str, Any] = {} response_id = uuid.uuid4().hex - for event in response.data.events(): + for event in self._stream_chat_events(chat_details): event_data = self._parse_stream_event(event) if not event_data: continue @@ -1411,6 +1449,12 @@ def _chat(self, chat_details: Any) -> Any: with self._client_lock: return self.client.chat(chat_details) + def _stream_chat_events(self, chat_details: Any) -> Any: + """Yield streaming events while holding the shared OCI client lock.""" + with self._client_lock: + response = self.client.chat(chat_details) + yield from response.data.events() + def supports_function_calling(self) -> bool: return True diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py index 2377ba8e5c8..8ece6394823 100644 --- a/lib/crewai/tests/llms/oci/test_oci.py +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -155,6 +155,24 @@ def counting_normalize(messages): assert normalize_call_count == 1 +def test_oci_completion_filters_unknown_passthrough_params( + patch_oci_module, oci_unit_values: dict[str, object] +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + unsupported_param="should-not-pass-through", + ) + + request = llm._build_chat_request([{"role": "user", "content": "hello"}]) + + assert not hasattr(request, "unsupported_param") + + def test_oci_completion_uses_region_to_build_endpoint( monkeypatch: pytest.MonkeyPatch, patch_oci_module, @@ -511,6 +529,54 @@ def test_oci_completion_returns_tool_calls_for_executor( assert result[0]["function"]["arguments"] == '{"city":"Paris"}' +def test_oci_completion_emits_tool_message_when_tool_execution_returns_none( + patch_oci_module, oci_unit_values: dict[str, object], monkeypatch: pytest.MonkeyPatch +): + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + + captured_messages: list[dict[str, object]] = [] + + def fake_call_impl(**kwargs): + captured_messages.extend(kwargs["messages"]) + return "done" + + monkeypatch.setattr(llm, "_call_impl", fake_call_impl) + monkeypatch.setattr(llm, "_handle_tool_execution", lambda **_kwargs: None) + + result = llm._handle_tool_calls( + normalized_messages=[{"role": "user", "content": "hi"}], + tools=None, + callbacks=None, + available_functions={"missing_tool": lambda: None}, + from_task=None, + from_agent=None, + tool_depth=0, + response_model=None, + tool_calls=[ + { + "id": "call_1", + "type": "function", + "function": {"name": "missing_tool", "arguments": "{}"}, + } + ], + ) + + assert result == "done" + assert captured_messages[-1] == { + "role": "tool", + "tool_call_id": "call_1", + "name": "missing_tool", + "content": "Tool 'missing_tool' failed or returned no result.", + } + + def test_oci_completion_supports_generic_tool_controls( patch_oci_module, oci_unit_values: dict[str, object] ): @@ -672,6 +738,40 @@ def test_oci_completion_executes_tool_calls_recursively( assert second_request.chat_request.messages[2].tool_call_id == "call_123" +def test_oci_stream_chat_events_holds_client_lock_while_iterating( + patch_oci_module, oci_unit_values: dict[str, object] +): + lock_states: list[bool] = [] + lock_ref: object | None = None + + def iter_events(): + assert lock_ref is not None + lock_states.append(lock_ref.locked()) + yield MagicMock() + lock_states.append(lock_ref.locked()) + yield MagicMock() + + fake_client = MagicMock() + fake_client.chat.return_value = MagicMock( + data=MagicMock(events=lambda: iter_events()) + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=str(oci_unit_values["generic_tool_model"]), + compartment_id=str(oci_unit_values["compartment_id"]), + ) + lock_ref = llm._client_lock + + events = list(llm._stream_chat_events(MagicMock())) + + assert len(events) == 2 + assert lock_states == [True, True] + assert llm._client_lock.locked() is False + + @pytest.mark.asyncio async def test_oci_completion_acall_delegates_to_call( patch_oci_module, oci_response_factories, oci_unit_values: dict[str, object] From f545ecca67661c6dd04dd7c64cca68ba9000e319 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sun, 15 Mar 2026 01:04:53 -0400 Subject: [PATCH 8/9] Preserve ordered OCI client access in batch and streams --- .../crewai/llms/providers/oci/completion.py | 25 +++++++++-- lib/crewai/tests/llms/oci/test_oci.py | 42 ++++++++++++++----- 2 files changed, 53 insertions(+), 14 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index bfc77f999d9..16678a8c6fb 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -2,6 +2,7 @@ import asyncio from collections.abc import Mapping +from contextlib import contextmanager import inspect import json import logging @@ -127,7 +128,9 @@ def __init__( self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs ) - self._client_lock = threading.Lock() + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 self.last_response_metadata = None def _infer_provider(self, model: str) -> str: @@ -1446,15 +1449,31 @@ async def acall( def _chat(self, chat_details: Any) -> Any: # The OCI SDK client is shared across sync + thread-offloaded async calls. # Serialize access so sync/async calls cannot race on the same client. - with self._client_lock: + with self._ordered_client_access(): return self.client.chat(chat_details) def _stream_chat_events(self, chat_details: Any) -> Any: """Yield streaming events while holding the shared OCI client lock.""" - with self._client_lock: + with self._ordered_client_access(): response = self.client.chat(chat_details) yield from response.data.events() + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + def supports_function_calling(self) -> bool: return True diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py index 8ece6394823..a3b44195830 100644 --- a/lib/crewai/tests/llms/oci/test_oci.py +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import threading from unittest.mock import MagicMock import pytest @@ -741,20 +742,35 @@ def test_oci_completion_executes_tool_calls_recursively( def test_oci_stream_chat_events_holds_client_lock_while_iterating( patch_oci_module, oci_unit_values: dict[str, object] ): - lock_states: list[bool] = [] - lock_ref: object | None = None + call_sequence: list[str] = [] + second_call_attempted = threading.Event() + second_call_thread: threading.Thread | None = None def iter_events(): - assert lock_ref is not None - lock_states.append(lock_ref.locked()) + nonlocal second_call_thread + call_sequence.append("stream-event-1") + + def run_second_call() -> None: + second_call_attempted.set() + llm._chat(MagicMock()) + call_sequence.append("second-chat-complete") + + second_call_thread = threading.Thread(target=run_second_call) + second_call_thread.start() + yield MagicMock() - lock_states.append(lock_ref.locked()) + + assert second_call_attempted.wait(timeout=1) + assert "second-chat-complete" not in call_sequence + + call_sequence.append("stream-event-2") yield MagicMock() fake_client = MagicMock() - fake_client.chat.return_value = MagicMock( - data=MagicMock(events=lambda: iter_events()) - ) + fake_client.chat.side_effect = [ + MagicMock(data=MagicMock(events=lambda: iter_events())), + MagicMock(), + ] patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( fake_client ) @@ -763,13 +779,17 @@ def iter_events(): model=str(oci_unit_values["generic_tool_model"]), compartment_id=str(oci_unit_values["compartment_id"]), ) - lock_ref = llm._client_lock events = list(llm._stream_chat_events(MagicMock())) + assert second_call_thread is not None + second_call_thread.join(timeout=1) assert len(events) == 2 - assert lock_states == [True, True] - assert llm._client_lock.locked() is False + assert call_sequence == [ + "stream-event-1", + "stream-event-2", + "second-chat-complete", + ] @pytest.mark.asyncio From 9be59b9cea940e02e13ea3322cfa61b6d60020f7 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Sun, 15 Mar 2026 12:46:18 -0400 Subject: [PATCH 9/9] Remove deprecated RAG factory import helper --- lib/crewai/src/crewai/rag/factory.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/lib/crewai/src/crewai/rag/factory.py b/lib/crewai/src/crewai/rag/factory.py index 47fc6cb62d9..5ff54e56477 100644 --- a/lib/crewai/src/crewai/rag/factory.py +++ b/lib/crewai/src/crewai/rag/factory.py @@ -1,5 +1,6 @@ """Factory functions for creating RAG clients from configuration.""" +import importlib from typing import cast from crewai.rag.config.optional_imports.protocols import ( @@ -8,7 +9,18 @@ ) from crewai.rag.config.types import RagConfigType from crewai.rag.core.base_client import BaseClient -from crewai.utilities.import_utils import require + + +def _import_rag_factory(module_path: str, purpose: str) -> object: + """Import an optional RAG factory module with a clear install hint.""" + try: + return importlib.import_module(module_path) + except ImportError as exc: + package_name = module_path.split(".")[0] + raise ImportError( + f"{purpose} requires the optional dependency '{module_path}'.\n" + f"Install it with: uv add {package_name}" + ) from exc def create_client(config: RagConfigType) -> BaseClient: @@ -27,7 +39,7 @@ def create_client(config: RagConfigType) -> BaseClient: if config.provider == "chromadb": chromadb_mod = cast( ChromaFactoryModule, - require( + _import_rag_factory( "crewai.rag.chromadb.factory", purpose="The 'chromadb' provider", ), @@ -37,7 +49,7 @@ def create_client(config: RagConfigType) -> BaseClient: if config.provider == "qdrant": qdrant_mod = cast( QdrantFactoryModule, - require( + _import_rag_factory( "crewai.rag.qdrant.factory", purpose="The 'qdrant' provider", ),