diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 99ebd54ebb1..ac146cdb90b 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -97,6 +97,9 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +oci = [ + "oci>=2.168.0", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 873c1b7dd75..d0ca4c15cb1 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -319,6 +319,7 @@ def writable(self) -> bool: "hosted_vllm", "cerebras", "dashscope", + "oci", ] @@ -407,6 +408,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "hosted_vllm": "hosted_vllm", "cerebras": "cerebras", "dashscope": "dashscope", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -526,6 +528,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: # OpenRouter uses org/model format but accepts anything return True + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model_lower + return False @classmethod @@ -561,6 +566,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -642,6 +650,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return OpenAICompatibleCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None @model_validator(mode="before") diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 00000000000..0c397558bd3 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,5 @@ +from crewai.llms.providers.oci.completion import OCICompletion + +__all__ = [ + "OCICompletion", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 00000000000..1147a1c2ad8 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,726 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from contextlib import contextmanager +import json +import logging +import os +import threading +from typing import TYPE_CHECKING, Any, Literal, cast +import uuid + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI. + + Supports basic text completions for generic (Meta, Google, OpenAI, xAI) + and Cohere model families hosted on the OCI Generative AI service. + """ + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + stream: bool = False, + oci_provider: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.stream = stream + self.oci_provider = oci_provider or self._infer_provider(model) + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 + self.last_response_metadata = None + + # ------------------------------------------------------------------ + # Provider inference + # ------------------------------------------------------------------ + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + # ------------------------------------------------------------------ + # Message helpers + # ------------------------------------------------------------------ + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects.""" + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed: list[Any] = [] + for item in content: + if isinstance(item, str): + processed.append(models.TextContent(text=item)) + elif isinstance(item, Mapping) and item.get("type") == "text": + processed.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + else: + processed.append( + models.TextContent(text=self._coerce_text(item) or ".") + ) + return processed or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + oci_messages.append( + message_cls( + content=self._build_generic_content(message.get("content", "")), + ) + ) + + return oci_messages + + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], str]: + """Translate CrewAI messages into Cohere's split history + message shape.""" + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + + for message in messages[:-1]: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + message_text = self._coerce_text(last_message.get("content", "")) + return chat_history, message_text + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_chat_request( + self, + messages: list[LLMMessage], + *, + is_stream: bool = False, + ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + chat_history, message_text = self._build_cohere_chat_history(messages) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + if is_stream: + request_kwargs["is_stream"] = True + request_kwargs["stream_options"] = models.StreamOptions( + is_include_usage=True + ) + + if self.oci_provider == "cohere": + return models.CohereChatRequest(**request_kwargs) + return models.GenericChatRequest(**request_kwargs) + + # ------------------------------------------------------------------ + # Response extraction + # ------------------------------------------------------------------ + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + # ------------------------------------------------------------------ + # Streaming extraction + # ------------------------------------------------------------------ + + def _parse_stream_event(self, event: Any) -> dict[str, Any]: + """Convert OCI SSE event payloads into plain dicts.""" + event_data = getattr(event, "data", None) + if not event_data: + return {} + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + if isinstance(parsed, Mapping): + return dict(parsed) + return {} + except json.JSONDecodeError: + logging.debug("Skipping invalid OCI SSE payload: %s", event_data) + return {} + if isinstance(event_data, Mapping): + return dict(event_data) + return {} + + def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: + if self.oci_provider == "cohere": + if "text" in event_data: + return str(event_data.get("text", "")) + message = event_data.get("message", {}) + if isinstance(message, Mapping): + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) + ) + return "" + + message = event_data.get("message", {}) + if not isinstance(message, Mapping): + return "" + content = message.get("content", []) + if not isinstance(content, list): + return "" + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) and part.get("text") + ) + + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: + usage = event_data.get("usage") + if not isinstance(usage, Mapping): + return {} + return { + "prompt_tokens": int(usage.get("promptTokens", 0) or 0), + "completion_tokens": int(usage.get("completionTokens", 0) or 0), + "total_tokens": int(usage.get("totalTokens", 0) or 0), + } + + def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, Any]: + metadata: dict[str, Any] = {} + finish_reason = event_data.get("finishReason") + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + usage = self._extract_usage_from_stream_event(event_data) + if usage: + metadata["usage"] = usage + return metadata + + # ------------------------------------------------------------------ + # Call paths + # ------------------------------------------------------------------ + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def _stream_call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + """Handle OCI streaming while reconstructing final text state.""" + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + full_response = "" + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + response_id = uuid.uuid4().hex + + for event in self._stream_chat_events(chat_details): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.LLM_CALL, + response_id=response_id, + ) + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + return self._finalize_text_response( + content=full_response, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def iter_stream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Yield raw text chunks from OCI without triggering tool recursion.""" + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Expose the sync OCI SSE stream through an async generator facade.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + + thread.join() + if error_holder: + raise error_holder[0] + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + if self.stream: + return self._stream_call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + return self._call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + return await asyncio.to_thread( + self.call, + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + # ------------------------------------------------------------------ + # Client serialization + # ------------------------------------------------------------------ + + def _chat(self, chat_details: Any) -> Any: + with self._ordered_client_access(): + return self.client.chat(chat_details) + + def _stream_chat_events(self, chat_details: Any) -> Any: + """Yield streaming events while holding the shared OCI client lock.""" + with self._ordered_client_access(): + response = self.client.chat(chat_details) + yield from response.data.events() + + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + + # ------------------------------------------------------------------ + # Capability declarations + # ------------------------------------------------------------------ + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 00000000000..7935530690c --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/tests/llms/oci/__init__.py b/lib/crewai/tests/llms/oci/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 00000000000..164f53060f6 --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,189 @@ +"""Fixtures for OCI provider unit and integration tests.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fake OCI SDK module (replaces `import oci` in unit tests) +# --------------------------------------------------------------------------- + + +def _make_fake_oci_module() -> MagicMock: + """Build a lightweight mock of the OCI SDK surface used by OCICompletion.""" + oci = MagicMock() + + # Models namespace + models = oci.generative_ai_inference.models + + # Serving modes + models.OnDemandServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.DedicatedServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Content types + models.TextContent = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Message types + for cls_name in ( + "UserMessage", + "AssistantMessage", + "SystemMessage", + "CohereUserMessage", + "CohereSystemMessage", + "CohereChatBotMessage", + ): + setattr(models, cls_name, MagicMock(side_effect=lambda **kw: MagicMock(**kw))) + + # Request types + models.GenericChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.CohereChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.BaseChatRequest = MagicMock() + models.BaseChatRequest.API_FORMAT_GENERIC = "GENERIC" + models.BaseChatRequest.API_FORMAT_COHERE = "COHERE" + + # ChatDetails + models.ChatDetails = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Auth helpers + oci.config.from_file = MagicMock(return_value={"key_file": "/tmp/k", "security_token_file": "/tmp/t"}) + oci.signer.load_private_key_from_file = MagicMock(return_value="pk") + oci.auth.signers.SecurityTokenSigner = MagicMock() + oci.auth.signers.InstancePrincipalsSecurityTokenSigner = MagicMock() + oci.auth.signers.get_resource_principals_signer = MagicMock() + oci.retry.DEFAULT_RETRY_STRATEGY = "default_retry" + + # Client constructor + oci.generative_ai_inference.GenerativeAiInferenceClient = MagicMock() + + return oci + + +def _make_fake_chat_response(text: str = "Hello from OCI") -> MagicMock: + """Build a minimal OCI chat response for generic models.""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + usage.total_tokens = 15 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_fake_cohere_chat_response(text: str = "Hello from Cohere") -> MagicMock: + """Build a minimal OCI chat response for Cohere models.""" + chat_response = MagicMock() + chat_response.text = text + chat_response.finish_reason = "COMPLETE" + chat_response.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 8 + usage.completion_tokens = 4 + usage.total_tokens = 12 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +@pytest.fixture() +def oci_fake_module() -> MagicMock: + return _make_fake_oci_module() + + +@pytest.fixture() +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: MagicMock) -> MagicMock: + """Patch the OCI module import so no real SDK is needed.""" + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture() +def oci_response_factories() -> dict[str, Any]: + return { + "chat": _make_fake_chat_response, + "cohere_chat": _make_fake_cohere_chat_response, + } + + +# --------------------------------------------------------------------------- +# Unit test defaults +# --------------------------------------------------------------------------- + +@pytest.fixture() +def oci_unit_values() -> dict[str, str]: + return { + "compartment_id": "ocid1.compartment.oc1..test", + "model": "meta.llama-3.3-70b-instruct", + "cohere_model": "cohere.command-r-plus-08-2024", + } + + +# --------------------------------------------------------------------------- +# Integration test fixtures (live OCI calls) +# --------------------------------------------------------------------------- + +def _env_models(env_var: str, fallback_var: str, default: str) -> list[str]: + """Read model list from env, supporting comma-separated values.""" + raw = os.getenv(env_var) or os.getenv(fallback_var) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live_config() -> dict[str, str]: + """Return live config dict or skip the test.""" + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set — skipping live test") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT for live tests") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + if os.getenv("OCI_AUTH_FILE_LOCATION"): + config["auth_file_location"] = os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config") + return config + + +@pytest.fixture( + params=_env_models("OCI_TEST_MODELS", "OCI_TEST_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_chat_model(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture() +def oci_live_config() -> dict[str, str]: + return _skip_unless_live_config() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 00000000000..a5f7efb1b66 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,269 @@ +"""Unit tests for the OCI Generative AI provider (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider routing +# --------------------------------------------------------------------------- + + +def test_oci_completion_is_used_when_oci_provider(patch_oci_module): + """LLM(model='oci/...') should resolve to OCICompletion.""" + from crewai.llm import LLM + + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..test", + ) + from crewai.llms.providers.oci.completion import OCICompletion + + # LLM.__new__ returns the native provider instance directly + assert isinstance(llm, OCICompletion) + + +@pytest.mark.parametrize( + "model_id, expected_provider", + [ + ("meta.llama-3.3-70b-instruct", "generic"), + ("google.gemini-2.5-flash", "generic"), + ("openai.gpt-4o", "generic"), + ("xai.grok-3", "generic"), + ("cohere.command-r-plus-08-2024", "cohere"), + ], +) +def test_oci_completion_infers_provider_family( + patch_oci_module, oci_unit_values, model_id, expected_provider +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=model_id, + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.oci_provider == expected_provider + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +def test_oci_completion_initialization_parameters(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + temperature=0.7, + max_tokens=512, + top_p=0.9, + top_k=40, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 512 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.compartment_id == oci_unit_values["compartment_id"] + + +def test_oci_completion_uses_region_to_build_endpoint(patch_oci_module, oci_unit_values, monkeypatch): + from crewai.llms.providers.oci.completion import OCICompletion + + monkeypatch.delenv("OCI_SERVICE_ENDPOINT", raising=False) + monkeypatch.setenv("OCI_REGION", "us-ashburn-1") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert "us-ashburn-1" in llm.service_endpoint + + +# --------------------------------------------------------------------------- +# Basic call +# --------------------------------------------------------------------------- + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("test response") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "test response" in result + fake_client.chat.assert_called_once() + + +def test_oci_completion_cohere_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_chat"]("cohere reply") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Hi"}]) + + assert "cohere reply" in result + fake_client.chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._coerce_text(None) == "" + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, oci_response_factories, oci_unit_values +): + """Ensure normalize is not called twice when _call_impl receives a list.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + call_count = 0 + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + original_normalize = llm._normalize_messages + + def counting_normalize(msgs): + nonlocal call_count + call_count += 1 + return original_normalize(msgs) + + llm._normalize_messages = counting_normalize + + llm.call(messages=[{"role": "user", "content": "hi"}]) + # call() normalizes once, _call_impl should not normalize again + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAI model quirks +# --------------------------------------------------------------------------- + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-4o", + compartment_id=oci_unit_values["compartment_id"], + max_tokens=1024, + ) + request = llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args + assert call_kwargs is not None + kwargs = call_kwargs[1] if call_kwargs[1] else {} + assert kwargs.get("max_completion_tokens") == 1024 + assert "max_tokens" not in kwargs + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-5", + compartment_id=oci_unit_values["compartment_id"], + temperature=0.5, + ) + llm.stop = ["END"] + llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args[1] + assert "temperature" not in call_kwargs + assert "stop" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("async result") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = await llm.acall(messages=[{"role": "user", "content": "async test"}]) + + assert "async result" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 00000000000..adeb0f2ccef --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI basic text completion. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024,google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_basic.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_basic_call(oci_chat_model: str, oci_live_config: dict): + """Synchronous text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call(messages=[{"role": "user", "content": "Say 'hello world' in one sentence."}]) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_async_call(oci_chat_model: str, oci_live_config: dict): + """Async text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = await llm.acall(messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}]) + + assert isinstance(result, str) + assert len(result) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py new file mode 100644 index 00000000000..db41b1f3a97 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py @@ -0,0 +1,40 @@ +"""Live integration tests for OCI Generative AI streaming. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024" \ + uv run pytest tests/llms/oci/test_oci_integration_streaming.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_streaming_call(oci_chat_model: str, oci_live_config: dict): + """Streaming text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, stream=True, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "Count from 1 to 5, one per line."}] + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_astream(oci_chat_model: str, oci_live_config: dict): + """Async streaming should yield text chunks from a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Say hello in three words."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full_text = "".join(chunks) + assert len(full_text) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_streaming.py b/lib/crewai/tests/llms/oci/test_oci_streaming.py new file mode 100644 index 00000000000..721fa3f671a --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_streaming.py @@ -0,0 +1,150 @@ +"""Unit tests for OCI provider streaming (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def _make_fake_stream_event(text: str = "", finish_reason: str | None = None, usage: dict | None = None) -> MagicMock: + """Build a single SSE event with optional text, finish, and usage.""" + payload: dict = {} + if text: + payload["message"] = {"content": [{"text": text}]} + if finish_reason: + payload["finishReason"] = finish_reason + if usage: + payload["usage"] = usage + + import json + event = MagicMock() + event.data = json.dumps(payload) + return event + + +def _make_fake_stream_response(*events: MagicMock) -> MagicMock: + """Wrap events into a response.data.events() iterable.""" + response = MagicMock() + response.data.events.return_value = iter(events) + return response + + +def test_oci_completion_streams_generic_responses( + patch_oci_module, oci_unit_values +): + """Streaming call should accumulate text chunks and return full response.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="Hello "), + _make_fake_stream_event(text="world"), + _make_fake_stream_event( + finish_reason="stop", + usage={"promptTokens": 5, "completionTokens": 2, "totalTokens": 7}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # StreamOptions mock + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + stream=True, + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "Hello " in result + assert "world" in result + assert llm.last_response_metadata is not None + assert llm.last_response_metadata.get("finish_reason") == "stop" + + +def test_oci_iter_stream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """iter_stream should yield individual text chunks.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="chunk1"), + _make_fake_stream_event(text="chunk2"), + _make_fake_stream_event( + usage={"promptTokens": 3, "completionTokens": 2, "totalTokens": 5}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = list(llm.iter_stream(messages=[{"role": "user", "content": "test"}])) + + assert chunks == ["chunk1", "chunk2"] + assert llm.last_response_metadata is not None + assert llm.last_response_metadata["usage"]["total_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_oci_astream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """astream should yield chunks via async generator.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="async1"), + _make_fake_stream_event(text="async2"), + _make_fake_stream_event(), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = [] + async for chunk in llm.astream(messages=[{"role": "user", "content": "test"}]): + chunks.append(chunk) + + assert chunks == ["async1", "async2"] + + +def test_oci_stream_chat_events_holds_client_lock( + patch_oci_module, oci_unit_values +): + """_stream_chat_events should hold the client lock for the full iteration.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [_make_fake_stream_event(text="a"), _make_fake_stream_event(text="b")] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + # Before streaming, ticket should be 0 + assert llm._active_client_ticket == 0 + chat_details = MagicMock() + list(llm._stream_chat_events(chat_details)) + # After streaming completes, ticket should have advanced + assert llm._active_client_ticket == 1