From 829cc936ddf4df10bdefbd10385297f9f50afaa6 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:41:13 -0400 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20add=20OCI=20Generative=20AI=20provi?= =?UTF-8?q?der=20=E2=80=94=20basic=20text=20completion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add native OCI Generative AI support to CrewAI with basic text completion for generic (Meta, Google, OpenAI, xAI) and Cohere model families. This is the first in a series of PRs to incrementally build out full OCI support (streaming, tool calling, structured output, embeddings, and multimodal in follow-up PRs). Tracking issue: #4944 Supersedes: #4885 --- lib/crewai/pyproject.toml | 3 + lib/crewai/src/crewai/llm.py | 13 + .../src/crewai/llms/providers/oci/__init__.py | 5 + .../crewai/llms/providers/oci/completion.py | 505 ++++++++++++++++++ lib/crewai/src/crewai/utilities/oci.py | 72 +++ lib/crewai/tests/llms/oci/__init__.py | 0 lib/crewai/tests/llms/oci/conftest.py | 189 +++++++ lib/crewai/tests/llms/oci/test_oci.py | 269 ++++++++++ .../llms/oci/test_oci_integration_basic.py | 33 ++ 9 files changed, 1089 insertions(+) create mode 100644 lib/crewai/src/crewai/llms/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/llms/providers/oci/completion.py create mode 100644 lib/crewai/src/crewai/utilities/oci.py create mode 100644 lib/crewai/tests/llms/oci/__init__.py create mode 100644 lib/crewai/tests/llms/oci/conftest.py create mode 100644 lib/crewai/tests/llms/oci/test_oci.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_basic.py diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index a40484f0487..e3fce3fb46d 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -97,6 +97,9 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +oci = [ + "oci>=2.168.0", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 75b1f654689..fcd0f7fa926 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -317,6 +317,7 @@ def writable(self) -> bool: "hosted_vllm", "cerebras", "dashscope", + "oci", ] @@ -384,6 +385,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "hosted_vllm": "hosted_vllm", "cerebras": "cerebras", "dashscope": "dashscope", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -506,6 +508,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: # OpenRouter uses org/model format but accepts anything return True + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return False @classmethod @@ -541,6 +546,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -622,6 +630,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return OpenAICompatibleCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 00000000000..0c397558bd3 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,5 @@ +from crewai.llms.providers.oci.completion import OCICompletion + +__all__ = [ + "OCICompletion", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 00000000000..8c05e8caae9 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from contextlib import contextmanager +import json +import logging +import os +import re +import threading +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI. + + Supports basic text completions for generic (Meta, Google, OpenAI, xAI) + and Cohere model families hosted on the OCI Generative AI service. + """ + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + oci_provider: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.oci_provider = oci_provider or self._infer_provider(model) + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 + self.last_response_metadata = None + + # ------------------------------------------------------------------ + # Provider inference + # ------------------------------------------------------------------ + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + # ------------------------------------------------------------------ + # Message helpers + # ------------------------------------------------------------------ + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects.""" + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed: list[Any] = [] + for item in content: + if isinstance(item, str): + processed.append(models.TextContent(text=item)) + elif isinstance(item, Mapping) and item.get("type") == "text": + processed.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + else: + processed.append( + models.TextContent(text=self._coerce_text(item) or ".") + ) + return processed or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + oci_messages.append( + message_cls( + content=self._build_generic_content(message.get("content", "")), + ) + ) + + return oci_messages + + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], str]: + """Translate CrewAI messages into Cohere's split history + message shape.""" + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + + for message in messages[:-1]: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + message_text = self._coerce_text(last_message.get("content", "")) + return chat_history, message_text + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_chat_request( + self, + messages: list[LLMMessage], + ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + chat_history, message_text = self._build_cohere_chat_history(messages) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + if self.oci_provider == "cohere": + return models.CohereChatRequest(**request_kwargs) + return models.GenericChatRequest(**request_kwargs) + + # ------------------------------------------------------------------ + # Response extraction + # ------------------------------------------------------------------ + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + # ------------------------------------------------------------------ + # Call paths + # ------------------------------------------------------------------ + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + return self._call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + return await asyncio.to_thread( + self.call, + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + # ------------------------------------------------------------------ + # Client serialization + # ------------------------------------------------------------------ + + def _chat(self, chat_details: Any) -> Any: + with self._ordered_client_access(): + return self.client.chat(chat_details) + + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + + # ------------------------------------------------------------------ + # Capability declarations + # ------------------------------------------------------------------ + + def supports_function_calling(self) -> bool: + return True + + def supports_stop_words(self) -> bool: + return True + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 00000000000..7935530690c --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/tests/llms/oci/__init__.py b/lib/crewai/tests/llms/oci/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 00000000000..164f53060f6 --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,189 @@ +"""Fixtures for OCI provider unit and integration tests.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fake OCI SDK module (replaces `import oci` in unit tests) +# --------------------------------------------------------------------------- + + +def _make_fake_oci_module() -> MagicMock: + """Build a lightweight mock of the OCI SDK surface used by OCICompletion.""" + oci = MagicMock() + + # Models namespace + models = oci.generative_ai_inference.models + + # Serving modes + models.OnDemandServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.DedicatedServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Content types + models.TextContent = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Message types + for cls_name in ( + "UserMessage", + "AssistantMessage", + "SystemMessage", + "CohereUserMessage", + "CohereSystemMessage", + "CohereChatBotMessage", + ): + setattr(models, cls_name, MagicMock(side_effect=lambda **kw: MagicMock(**kw))) + + # Request types + models.GenericChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.CohereChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.BaseChatRequest = MagicMock() + models.BaseChatRequest.API_FORMAT_GENERIC = "GENERIC" + models.BaseChatRequest.API_FORMAT_COHERE = "COHERE" + + # ChatDetails + models.ChatDetails = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Auth helpers + oci.config.from_file = MagicMock(return_value={"key_file": "/tmp/k", "security_token_file": "/tmp/t"}) + oci.signer.load_private_key_from_file = MagicMock(return_value="pk") + oci.auth.signers.SecurityTokenSigner = MagicMock() + oci.auth.signers.InstancePrincipalsSecurityTokenSigner = MagicMock() + oci.auth.signers.get_resource_principals_signer = MagicMock() + oci.retry.DEFAULT_RETRY_STRATEGY = "default_retry" + + # Client constructor + oci.generative_ai_inference.GenerativeAiInferenceClient = MagicMock() + + return oci + + +def _make_fake_chat_response(text: str = "Hello from OCI") -> MagicMock: + """Build a minimal OCI chat response for generic models.""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + usage.total_tokens = 15 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_fake_cohere_chat_response(text: str = "Hello from Cohere") -> MagicMock: + """Build a minimal OCI chat response for Cohere models.""" + chat_response = MagicMock() + chat_response.text = text + chat_response.finish_reason = "COMPLETE" + chat_response.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 8 + usage.completion_tokens = 4 + usage.total_tokens = 12 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +@pytest.fixture() +def oci_fake_module() -> MagicMock: + return _make_fake_oci_module() + + +@pytest.fixture() +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: MagicMock) -> MagicMock: + """Patch the OCI module import so no real SDK is needed.""" + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture() +def oci_response_factories() -> dict[str, Any]: + return { + "chat": _make_fake_chat_response, + "cohere_chat": _make_fake_cohere_chat_response, + } + + +# --------------------------------------------------------------------------- +# Unit test defaults +# --------------------------------------------------------------------------- + +@pytest.fixture() +def oci_unit_values() -> dict[str, str]: + return { + "compartment_id": "ocid1.compartment.oc1..test", + "model": "meta.llama-3.3-70b-instruct", + "cohere_model": "cohere.command-r-plus-08-2024", + } + + +# --------------------------------------------------------------------------- +# Integration test fixtures (live OCI calls) +# --------------------------------------------------------------------------- + +def _env_models(env_var: str, fallback_var: str, default: str) -> list[str]: + """Read model list from env, supporting comma-separated values.""" + raw = os.getenv(env_var) or os.getenv(fallback_var) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live_config() -> dict[str, str]: + """Return live config dict or skip the test.""" + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set — skipping live test") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT for live tests") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + if os.getenv("OCI_AUTH_FILE_LOCATION"): + config["auth_file_location"] = os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config") + return config + + +@pytest.fixture( + params=_env_models("OCI_TEST_MODELS", "OCI_TEST_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_chat_model(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture() +def oci_live_config() -> dict[str, str]: + return _skip_unless_live_config() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 00000000000..a5f7efb1b66 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,269 @@ +"""Unit tests for the OCI Generative AI provider (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider routing +# --------------------------------------------------------------------------- + + +def test_oci_completion_is_used_when_oci_provider(patch_oci_module): + """LLM(model='oci/...') should resolve to OCICompletion.""" + from crewai.llm import LLM + + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..test", + ) + from crewai.llms.providers.oci.completion import OCICompletion + + # LLM.__new__ returns the native provider instance directly + assert isinstance(llm, OCICompletion) + + +@pytest.mark.parametrize( + "model_id, expected_provider", + [ + ("meta.llama-3.3-70b-instruct", "generic"), + ("google.gemini-2.5-flash", "generic"), + ("openai.gpt-4o", "generic"), + ("xai.grok-3", "generic"), + ("cohere.command-r-plus-08-2024", "cohere"), + ], +) +def test_oci_completion_infers_provider_family( + patch_oci_module, oci_unit_values, model_id, expected_provider +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=model_id, + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.oci_provider == expected_provider + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +def test_oci_completion_initialization_parameters(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + temperature=0.7, + max_tokens=512, + top_p=0.9, + top_k=40, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 512 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.compartment_id == oci_unit_values["compartment_id"] + + +def test_oci_completion_uses_region_to_build_endpoint(patch_oci_module, oci_unit_values, monkeypatch): + from crewai.llms.providers.oci.completion import OCICompletion + + monkeypatch.delenv("OCI_SERVICE_ENDPOINT", raising=False) + monkeypatch.setenv("OCI_REGION", "us-ashburn-1") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert "us-ashburn-1" in llm.service_endpoint + + +# --------------------------------------------------------------------------- +# Basic call +# --------------------------------------------------------------------------- + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("test response") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "test response" in result + fake_client.chat.assert_called_once() + + +def test_oci_completion_cohere_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_chat"]("cohere reply") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Hi"}]) + + assert "cohere reply" in result + fake_client.chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._coerce_text(None) == "" + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, oci_response_factories, oci_unit_values +): + """Ensure normalize is not called twice when _call_impl receives a list.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + call_count = 0 + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + original_normalize = llm._normalize_messages + + def counting_normalize(msgs): + nonlocal call_count + call_count += 1 + return original_normalize(msgs) + + llm._normalize_messages = counting_normalize + + llm.call(messages=[{"role": "user", "content": "hi"}]) + # call() normalizes once, _call_impl should not normalize again + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAI model quirks +# --------------------------------------------------------------------------- + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-4o", + compartment_id=oci_unit_values["compartment_id"], + max_tokens=1024, + ) + request = llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args + assert call_kwargs is not None + kwargs = call_kwargs[1] if call_kwargs[1] else {} + assert kwargs.get("max_completion_tokens") == 1024 + assert "max_tokens" not in kwargs + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-5", + compartment_id=oci_unit_values["compartment_id"], + temperature=0.5, + ) + llm.stop = ["END"] + llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args[1] + assert "temperature" not in call_kwargs + assert "stop" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("async result") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = await llm.acall(messages=[{"role": "user", "content": "async test"}]) + + assert "async result" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 00000000000..adeb0f2ccef --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI basic text completion. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024,google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_basic.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_basic_call(oci_chat_model: str, oci_live_config: dict): + """Synchronous text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call(messages=[{"role": "user", "content": "Say 'hello world' in one sentence."}]) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_async_call(oci_chat_model: str, oci_live_config: dict): + """Async text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = await llm.acall(messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}]) + + assert isinstance(result, str) + assert len(result) > 0 From 621b79cbe3724b77f68cb86b50d21ac03afcff95 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:49:49 -0400 Subject: [PATCH 2/6] fix: return False from supports_function_calling until tool PR Tool calling is not implemented in this PR. Returning True would cause CrewAI to choose the native tools path, silently dropping tools from agents. Flagged by Cursor Bugbot review. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 8c05e8caae9..b03ffa7990d 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -487,7 +487,7 @@ def _ordered_client_access(self) -> Any: # ------------------------------------------------------------------ def supports_function_calling(self) -> bool: - return True + return False # Tool calling support will be added in a follow-up PR def supports_stop_words(self) -> bool: return True From cb1aa10183b99c995df3c32684016313ad1fc2be Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:56:05 -0400 Subject: [PATCH 3/6] refactor: remove supports_function_calling and supports_stop_words Both methods are unnecessary in this PR. The base class and callers already default correctly when the methods are absent: - supports_function_calling: callers use getattr with False default - supports_stop_words: base class already returns True These will be added back in the tool calling follow-up PR. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index b03ffa7990d..71abf20d516 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -486,12 +486,6 @@ def _ordered_client_access(self) -> Any: # Capability declarations # ------------------------------------------------------------------ - def supports_function_calling(self) -> bool: - return False # Tool calling support will be added in a follow-up PR - - def supports_stop_words(self) -> bool: - return True - def get_context_window_size(self) -> int: model_lower = self.model.lower() if model_lower.startswith("google.gemini"): From 459ecca06847e908fbf1a699bef675bd3b466805 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:00:30 -0400 Subject: [PATCH 4/6] cleanup: remove unused imports and dead code Remove json, re imports and _OCI_SCHEMA_NAME_PATTERN regex that are only needed for structured output (not in this PR scope). --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 71abf20d516..787f38d5995 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,10 +3,8 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager -import json import logging import os -import re import threading from typing import TYPE_CHECKING, Any, Literal, cast @@ -26,7 +24,6 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" DEFAULT_OCI_REGION = "us-chicago-1" -_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") def _get_oci_module() -> Any: From 778b952f39962519f37b1f8099741c170376f645 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:02:09 -0400 Subject: [PATCH 5/6] fix: use model_lower consistently in OCI pattern check Use model_lower instead of model in the dot check to match the convention used by all other providers in _matches_provider_pattern. Flagged by Cursor Bugbot. --- lib/crewai/src/crewai/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index fcd0f7fa926..986010a0e80 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -509,7 +509,7 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: return True if provider == "oci": - return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model_lower return False From 1562c651c5b9f16d61f2c9127d08ddcfcc1903e3 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:40:23 -0400 Subject: [PATCH 6/6] feat: add streaming support to OCI Generative AI provider Add streaming text completion via OCI SSE events: - stream=True in call() routes to _stream_call_impl with chunk events - iter_stream() yields raw text chunks (sync generator) - astream() wraps iter_stream via thread+queue for async callers - _stream_chat_events holds client lock for full stream duration - SSE event parsing handles both string and mapping payloads Tested live against meta.llama-3.3-70b-instruct, cohere.command-r-plus-08-2024, google.gemini-2.5-flash, and openai.gpt-5.2-chat-latest. Depends on: #4959 Tracking issue: #4944 --- .../crewai/llms/providers/oci/completion.py | 230 ++++++++++++++++++ .../oci/test_oci_integration_streaming.py | 40 +++ .../tests/llms/oci/test_oci_streaming.py | 150 ++++++++++++ 3 files changed, 420 insertions(+) create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_streaming.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_streaming.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 787f38d5995..1147a1c2ad8 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,10 +3,12 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager +import json import logging import os import threading from typing import TYPE_CHECKING, Any, Literal, cast +import uuid from pydantic import BaseModel @@ -57,6 +59,7 @@ def __init__( max_tokens: int | None = None, top_p: float | None = None, top_k: int | None = None, + stream: bool = False, oci_provider: str | None = None, client: Any | None = None, **kwargs: Any, @@ -94,6 +97,7 @@ def __init__( self.max_tokens = max_tokens self.top_p = top_p self.top_k = top_k + self.stream = stream self.oci_provider = oci_provider or self._infer_provider(model) self._oci = _get_oci_module() @@ -246,6 +250,8 @@ def _build_cohere_chat_history( def _build_chat_request( self, messages: list[LLMMessage], + *, + is_stream: bool = False, ) -> Any: """Build the provider-specific OCI chat request for the current model.""" models = self._oci.generative_ai_inference.models @@ -279,6 +285,12 @@ def _build_chat_request( stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" request_kwargs[stop_key] = list(self.stop) + if is_stream: + request_kwargs["is_stream"] = True + request_kwargs["stream_options"] = models.StreamOptions( + is_include_usage=True + ) + if self.oci_provider == "cohere": return models.CohereChatRequest(**request_kwargs) return models.GenericChatRequest(**request_kwargs) @@ -339,6 +351,75 @@ def _extract_response_metadata(self, response: Any) -> dict[str, Any]: return metadata + # ------------------------------------------------------------------ + # Streaming extraction + # ------------------------------------------------------------------ + + def _parse_stream_event(self, event: Any) -> dict[str, Any]: + """Convert OCI SSE event payloads into plain dicts.""" + event_data = getattr(event, "data", None) + if not event_data: + return {} + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + if isinstance(parsed, Mapping): + return dict(parsed) + return {} + except json.JSONDecodeError: + logging.debug("Skipping invalid OCI SSE payload: %s", event_data) + return {} + if isinstance(event_data, Mapping): + return dict(event_data) + return {} + + def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: + if self.oci_provider == "cohere": + if "text" in event_data: + return str(event_data.get("text", "")) + message = event_data.get("message", {}) + if isinstance(message, Mapping): + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) + ) + return "" + + message = event_data.get("message", {}) + if not isinstance(message, Mapping): + return "" + content = message.get("content", []) + if not isinstance(content, list): + return "" + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) and part.get("text") + ) + + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: + usage = event_data.get("usage") + if not isinstance(usage, Mapping): + return {} + return { + "prompt_tokens": int(usage.get("promptTokens", 0) or 0), + "completion_tokens": int(usage.get("completionTokens", 0) or 0), + "total_tokens": int(usage.get("totalTokens", 0) or 0), + } + + def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, Any]: + metadata: dict[str, Any] = {} + finish_reason = event_data.get("finishReason") + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + usage = self._extract_usage_from_stream_event(event_data) + if usage: + metadata["usage"] = usage + return metadata + # ------------------------------------------------------------------ # Call paths # ------------------------------------------------------------------ @@ -392,6 +473,142 @@ def _call_impl( from_agent=from_agent, ) + def _stream_call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + """Handle OCI streaming while reconstructing final text state.""" + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + full_response = "" + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + response_id = uuid.uuid4().hex + + for event in self._stream_chat_events(chat_details): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.LLM_CALL, + response_id=response_id, + ) + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + return self._finalize_text_response( + content=full_response, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def iter_stream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Yield raw text chunks from OCI without triggering tool recursion.""" + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Expose the sync OCI SSE stream through an async generator facade.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + + thread.join() + if error_holder: + raise error_holder[0] + def call( self, messages: str | list[LLMMessage], @@ -420,6 +637,13 @@ def call( ): raise ValueError("LLM call blocked by before_llm_call hook") + if self.stream: + return self._stream_call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + return self._call_impl( messages=normalized_messages, from_task=from_task, @@ -463,6 +687,12 @@ def _chat(self, chat_details: Any) -> Any: with self._ordered_client_access(): return self.client.chat(chat_details) + def _stream_chat_events(self, chat_details: Any) -> Any: + """Yield streaming events while holding the shared OCI client lock.""" + with self._ordered_client_access(): + response = self.client.chat(chat_details) + yield from response.data.events() + @contextmanager def _ordered_client_access(self) -> Any: """Serialize shared OCI client access in call-arrival order.""" diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py new file mode 100644 index 00000000000..db41b1f3a97 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py @@ -0,0 +1,40 @@ +"""Live integration tests for OCI Generative AI streaming. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024" \ + uv run pytest tests/llms/oci/test_oci_integration_streaming.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_streaming_call(oci_chat_model: str, oci_live_config: dict): + """Streaming text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, stream=True, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "Count from 1 to 5, one per line."}] + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_astream(oci_chat_model: str, oci_live_config: dict): + """Async streaming should yield text chunks from a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Say hello in three words."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full_text = "".join(chunks) + assert len(full_text) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_streaming.py b/lib/crewai/tests/llms/oci/test_oci_streaming.py new file mode 100644 index 00000000000..721fa3f671a --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_streaming.py @@ -0,0 +1,150 @@ +"""Unit tests for OCI provider streaming (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def _make_fake_stream_event(text: str = "", finish_reason: str | None = None, usage: dict | None = None) -> MagicMock: + """Build a single SSE event with optional text, finish, and usage.""" + payload: dict = {} + if text: + payload["message"] = {"content": [{"text": text}]} + if finish_reason: + payload["finishReason"] = finish_reason + if usage: + payload["usage"] = usage + + import json + event = MagicMock() + event.data = json.dumps(payload) + return event + + +def _make_fake_stream_response(*events: MagicMock) -> MagicMock: + """Wrap events into a response.data.events() iterable.""" + response = MagicMock() + response.data.events.return_value = iter(events) + return response + + +def test_oci_completion_streams_generic_responses( + patch_oci_module, oci_unit_values +): + """Streaming call should accumulate text chunks and return full response.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="Hello "), + _make_fake_stream_event(text="world"), + _make_fake_stream_event( + finish_reason="stop", + usage={"promptTokens": 5, "completionTokens": 2, "totalTokens": 7}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # StreamOptions mock + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + stream=True, + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "Hello " in result + assert "world" in result + assert llm.last_response_metadata is not None + assert llm.last_response_metadata.get("finish_reason") == "stop" + + +def test_oci_iter_stream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """iter_stream should yield individual text chunks.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="chunk1"), + _make_fake_stream_event(text="chunk2"), + _make_fake_stream_event( + usage={"promptTokens": 3, "completionTokens": 2, "totalTokens": 5}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = list(llm.iter_stream(messages=[{"role": "user", "content": "test"}])) + + assert chunks == ["chunk1", "chunk2"] + assert llm.last_response_metadata is not None + assert llm.last_response_metadata["usage"]["total_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_oci_astream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """astream should yield chunks via async generator.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="async1"), + _make_fake_stream_event(text="async2"), + _make_fake_stream_event(), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = [] + async for chunk in llm.astream(messages=[{"role": "user", "content": "test"}]): + chunks.append(chunk) + + assert chunks == ["async1", "async2"] + + +def test_oci_stream_chat_events_holds_client_lock( + patch_oci_module, oci_unit_values +): + """_stream_chat_events should hold the client lock for the full iteration.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [_make_fake_stream_event(text="a"), _make_fake_stream_event(text="b")] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + # Before streaming, ticket should be 0 + assert llm._active_client_ticket == 0 + chat_details = MagicMock() + list(llm._stream_chat_events(chat_details)) + # After streaming completes, ticket should have advanced + assert llm._active_client_ticket == 1