From ff0c8d249086bd2b79051954b418bb7411d773b1 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:41:13 -0400 Subject: [PATCH 1/8] =?UTF-8?q?feat:=20add=20OCI=20Generative=20AI=20provi?= =?UTF-8?q?der=20=E2=80=94=20basic=20text=20completion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add native OCI Generative AI support to CrewAI with basic text completion for generic (Meta, Google, OpenAI, xAI) and Cohere model families. This is the first in a series of PRs to incrementally build out full OCI support (streaming, tool calling, structured output, embeddings, and multimodal in follow-up PRs). Tracking issue: #4944 Supersedes: #4885 --- lib/crewai/pyproject.toml | 3 + lib/crewai/src/crewai/llm.py | 13 + .../src/crewai/llms/providers/oci/__init__.py | 5 + .../crewai/llms/providers/oci/completion.py | 505 ++++++++++++++++++ lib/crewai/src/crewai/utilities/oci.py | 72 +++ lib/crewai/tests/llms/oci/__init__.py | 0 lib/crewai/tests/llms/oci/conftest.py | 189 +++++++ lib/crewai/tests/llms/oci/test_oci.py | 269 ++++++++++ .../llms/oci/test_oci_integration_basic.py | 33 ++ 9 files changed, 1089 insertions(+) create mode 100644 lib/crewai/src/crewai/llms/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/llms/providers/oci/completion.py create mode 100644 lib/crewai/src/crewai/utilities/oci.py create mode 100644 lib/crewai/tests/llms/oci/__init__.py create mode 100644 lib/crewai/tests/llms/oci/conftest.py create mode 100644 lib/crewai/tests/llms/oci/test_oci.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_basic.py diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index 528f4ea2c51..d5060e3a5be 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -98,6 +98,9 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +oci = [ + "oci>=2.168.0", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index e6f5cc68b06..61b033994c7 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -249,6 +249,7 @@ "hosted_vllm", "cerebras", "dashscope", + "oci", ] @@ -338,6 +339,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "hosted_vllm": "hosted_vllm", "cerebras": "cerebras", "dashscope": "dashscope", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -457,6 +459,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: # OpenRouter uses org/model format but accepts anything return True + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return False @classmethod @@ -492,6 +497,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -573,6 +581,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return OpenAICompatibleCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None @model_validator(mode="before") diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 00000000000..0c397558bd3 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,5 @@ +from crewai.llms.providers.oci.completion import OCICompletion + +__all__ = [ + "OCICompletion", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 00000000000..8c05e8caae9 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from contextlib import contextmanager +import json +import logging +import os +import re +import threading +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI. + + Supports basic text completions for generic (Meta, Google, OpenAI, xAI) + and Cohere model families hosted on the OCI Generative AI service. + """ + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + oci_provider: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.oci_provider = oci_provider or self._infer_provider(model) + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 + self.last_response_metadata = None + + # ------------------------------------------------------------------ + # Provider inference + # ------------------------------------------------------------------ + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + # ------------------------------------------------------------------ + # Message helpers + # ------------------------------------------------------------------ + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects.""" + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed: list[Any] = [] + for item in content: + if isinstance(item, str): + processed.append(models.TextContent(text=item)) + elif isinstance(item, Mapping) and item.get("type") == "text": + processed.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + else: + processed.append( + models.TextContent(text=self._coerce_text(item) or ".") + ) + return processed or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + oci_messages.append( + message_cls( + content=self._build_generic_content(message.get("content", "")), + ) + ) + + return oci_messages + + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], str]: + """Translate CrewAI messages into Cohere's split history + message shape.""" + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + + for message in messages[:-1]: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + message_text = self._coerce_text(last_message.get("content", "")) + return chat_history, message_text + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_chat_request( + self, + messages: list[LLMMessage], + ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + chat_history, message_text = self._build_cohere_chat_history(messages) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + if self.oci_provider == "cohere": + return models.CohereChatRequest(**request_kwargs) + return models.GenericChatRequest(**request_kwargs) + + # ------------------------------------------------------------------ + # Response extraction + # ------------------------------------------------------------------ + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + # ------------------------------------------------------------------ + # Call paths + # ------------------------------------------------------------------ + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + return self._call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + return await asyncio.to_thread( + self.call, + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + # ------------------------------------------------------------------ + # Client serialization + # ------------------------------------------------------------------ + + def _chat(self, chat_details: Any) -> Any: + with self._ordered_client_access(): + return self.client.chat(chat_details) + + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + + # ------------------------------------------------------------------ + # Capability declarations + # ------------------------------------------------------------------ + + def supports_function_calling(self) -> bool: + return True + + def supports_stop_words(self) -> bool: + return True + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 00000000000..7935530690c --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/tests/llms/oci/__init__.py b/lib/crewai/tests/llms/oci/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 00000000000..164f53060f6 --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,189 @@ +"""Fixtures for OCI provider unit and integration tests.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fake OCI SDK module (replaces `import oci` in unit tests) +# --------------------------------------------------------------------------- + + +def _make_fake_oci_module() -> MagicMock: + """Build a lightweight mock of the OCI SDK surface used by OCICompletion.""" + oci = MagicMock() + + # Models namespace + models = oci.generative_ai_inference.models + + # Serving modes + models.OnDemandServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.DedicatedServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Content types + models.TextContent = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Message types + for cls_name in ( + "UserMessage", + "AssistantMessage", + "SystemMessage", + "CohereUserMessage", + "CohereSystemMessage", + "CohereChatBotMessage", + ): + setattr(models, cls_name, MagicMock(side_effect=lambda **kw: MagicMock(**kw))) + + # Request types + models.GenericChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.CohereChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.BaseChatRequest = MagicMock() + models.BaseChatRequest.API_FORMAT_GENERIC = "GENERIC" + models.BaseChatRequest.API_FORMAT_COHERE = "COHERE" + + # ChatDetails + models.ChatDetails = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Auth helpers + oci.config.from_file = MagicMock(return_value={"key_file": "/tmp/k", "security_token_file": "/tmp/t"}) + oci.signer.load_private_key_from_file = MagicMock(return_value="pk") + oci.auth.signers.SecurityTokenSigner = MagicMock() + oci.auth.signers.InstancePrincipalsSecurityTokenSigner = MagicMock() + oci.auth.signers.get_resource_principals_signer = MagicMock() + oci.retry.DEFAULT_RETRY_STRATEGY = "default_retry" + + # Client constructor + oci.generative_ai_inference.GenerativeAiInferenceClient = MagicMock() + + return oci + + +def _make_fake_chat_response(text: str = "Hello from OCI") -> MagicMock: + """Build a minimal OCI chat response for generic models.""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + usage.total_tokens = 15 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_fake_cohere_chat_response(text: str = "Hello from Cohere") -> MagicMock: + """Build a minimal OCI chat response for Cohere models.""" + chat_response = MagicMock() + chat_response.text = text + chat_response.finish_reason = "COMPLETE" + chat_response.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 8 + usage.completion_tokens = 4 + usage.total_tokens = 12 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +@pytest.fixture() +def oci_fake_module() -> MagicMock: + return _make_fake_oci_module() + + +@pytest.fixture() +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: MagicMock) -> MagicMock: + """Patch the OCI module import so no real SDK is needed.""" + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture() +def oci_response_factories() -> dict[str, Any]: + return { + "chat": _make_fake_chat_response, + "cohere_chat": _make_fake_cohere_chat_response, + } + + +# --------------------------------------------------------------------------- +# Unit test defaults +# --------------------------------------------------------------------------- + +@pytest.fixture() +def oci_unit_values() -> dict[str, str]: + return { + "compartment_id": "ocid1.compartment.oc1..test", + "model": "meta.llama-3.3-70b-instruct", + "cohere_model": "cohere.command-r-plus-08-2024", + } + + +# --------------------------------------------------------------------------- +# Integration test fixtures (live OCI calls) +# --------------------------------------------------------------------------- + +def _env_models(env_var: str, fallback_var: str, default: str) -> list[str]: + """Read model list from env, supporting comma-separated values.""" + raw = os.getenv(env_var) or os.getenv(fallback_var) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live_config() -> dict[str, str]: + """Return live config dict or skip the test.""" + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set — skipping live test") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT for live tests") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + if os.getenv("OCI_AUTH_FILE_LOCATION"): + config["auth_file_location"] = os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config") + return config + + +@pytest.fixture( + params=_env_models("OCI_TEST_MODELS", "OCI_TEST_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_chat_model(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture() +def oci_live_config() -> dict[str, str]: + return _skip_unless_live_config() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 00000000000..a5f7efb1b66 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,269 @@ +"""Unit tests for the OCI Generative AI provider (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider routing +# --------------------------------------------------------------------------- + + +def test_oci_completion_is_used_when_oci_provider(patch_oci_module): + """LLM(model='oci/...') should resolve to OCICompletion.""" + from crewai.llm import LLM + + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..test", + ) + from crewai.llms.providers.oci.completion import OCICompletion + + # LLM.__new__ returns the native provider instance directly + assert isinstance(llm, OCICompletion) + + +@pytest.mark.parametrize( + "model_id, expected_provider", + [ + ("meta.llama-3.3-70b-instruct", "generic"), + ("google.gemini-2.5-flash", "generic"), + ("openai.gpt-4o", "generic"), + ("xai.grok-3", "generic"), + ("cohere.command-r-plus-08-2024", "cohere"), + ], +) +def test_oci_completion_infers_provider_family( + patch_oci_module, oci_unit_values, model_id, expected_provider +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=model_id, + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.oci_provider == expected_provider + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +def test_oci_completion_initialization_parameters(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + temperature=0.7, + max_tokens=512, + top_p=0.9, + top_k=40, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 512 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.compartment_id == oci_unit_values["compartment_id"] + + +def test_oci_completion_uses_region_to_build_endpoint(patch_oci_module, oci_unit_values, monkeypatch): + from crewai.llms.providers.oci.completion import OCICompletion + + monkeypatch.delenv("OCI_SERVICE_ENDPOINT", raising=False) + monkeypatch.setenv("OCI_REGION", "us-ashburn-1") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert "us-ashburn-1" in llm.service_endpoint + + +# --------------------------------------------------------------------------- +# Basic call +# --------------------------------------------------------------------------- + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("test response") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "test response" in result + fake_client.chat.assert_called_once() + + +def test_oci_completion_cohere_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_chat"]("cohere reply") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Hi"}]) + + assert "cohere reply" in result + fake_client.chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._coerce_text(None) == "" + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, oci_response_factories, oci_unit_values +): + """Ensure normalize is not called twice when _call_impl receives a list.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + call_count = 0 + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + original_normalize = llm._normalize_messages + + def counting_normalize(msgs): + nonlocal call_count + call_count += 1 + return original_normalize(msgs) + + llm._normalize_messages = counting_normalize + + llm.call(messages=[{"role": "user", "content": "hi"}]) + # call() normalizes once, _call_impl should not normalize again + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAI model quirks +# --------------------------------------------------------------------------- + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-4o", + compartment_id=oci_unit_values["compartment_id"], + max_tokens=1024, + ) + request = llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args + assert call_kwargs is not None + kwargs = call_kwargs[1] if call_kwargs[1] else {} + assert kwargs.get("max_completion_tokens") == 1024 + assert "max_tokens" not in kwargs + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-5", + compartment_id=oci_unit_values["compartment_id"], + temperature=0.5, + ) + llm.stop = ["END"] + llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args[1] + assert "temperature" not in call_kwargs + assert "stop" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("async result") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = await llm.acall(messages=[{"role": "user", "content": "async test"}]) + + assert "async result" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 00000000000..adeb0f2ccef --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI basic text completion. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024,google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_basic.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_basic_call(oci_chat_model: str, oci_live_config: dict): + """Synchronous text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call(messages=[{"role": "user", "content": "Say 'hello world' in one sentence."}]) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_async_call(oci_chat_model: str, oci_live_config: dict): + """Async text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = await llm.acall(messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}]) + + assert isinstance(result, str) + assert len(result) > 0 From 0689b9f7cf9f146bf5fa3f95423d2077454b2ed5 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:49:49 -0400 Subject: [PATCH 2/8] fix: return False from supports_function_calling until tool PR Tool calling is not implemented in this PR. Returning True would cause CrewAI to choose the native tools path, silently dropping tools from agents. Flagged by Cursor Bugbot review. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 8c05e8caae9..b03ffa7990d 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -487,7 +487,7 @@ def _ordered_client_access(self) -> Any: # ------------------------------------------------------------------ def supports_function_calling(self) -> bool: - return True + return False # Tool calling support will be added in a follow-up PR def supports_stop_words(self) -> bool: return True From bcbc529c2a2477b5030db302e19c2f37ce853825 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:56:05 -0400 Subject: [PATCH 3/8] refactor: remove supports_function_calling and supports_stop_words Both methods are unnecessary in this PR. The base class and callers already default correctly when the methods are absent: - supports_function_calling: callers use getattr with False default - supports_stop_words: base class already returns True These will be added back in the tool calling follow-up PR. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index b03ffa7990d..71abf20d516 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -486,12 +486,6 @@ def _ordered_client_access(self) -> Any: # Capability declarations # ------------------------------------------------------------------ - def supports_function_calling(self) -> bool: - return False # Tool calling support will be added in a follow-up PR - - def supports_stop_words(self) -> bool: - return True - def get_context_window_size(self) -> int: model_lower = self.model.lower() if model_lower.startswith("google.gemini"): From a25baa25bf5ba32046c07afda06e7061d3fb3cf0 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:00:30 -0400 Subject: [PATCH 4/8] cleanup: remove unused imports and dead code Remove json, re imports and _OCI_SCHEMA_NAME_PATTERN regex that are only needed for structured output (not in this PR scope). --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 71abf20d516..787f38d5995 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,10 +3,8 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager -import json import logging import os -import re import threading from typing import TYPE_CHECKING, Any, Literal, cast @@ -26,7 +24,6 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" DEFAULT_OCI_REGION = "us-chicago-1" -_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") def _get_oci_module() -> Any: From 544f1ac6ceee21406e22328e6cc30fe9508818b5 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:02:09 -0400 Subject: [PATCH 5/8] fix: use model_lower consistently in OCI pattern check Use model_lower instead of model in the dot check to match the convention used by all other providers in _matches_provider_pattern. Flagged by Cursor Bugbot. --- lib/crewai/src/crewai/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 61b033994c7..13e8b28b067 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -460,7 +460,7 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: return True if provider == "oci": - return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model_lower return False From ab14159a0f1cd17792ab8b342e95d8b3fb1929bf Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 20 Mar 2026 06:32:40 -0400 Subject: [PATCH 6/8] cleanup: remove redundant OCI check in _validate_model_in_constants The explicit OCI branch returned the same _matches_provider_pattern call as the generic fallback. Removing it since it adds no distinct logic. Flagged by Cursor Bugbot. --- lib/crewai/src/crewai/llm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 13e8b28b067..6ced0e6d469 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -497,10 +497,7 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True - if provider == "oci": - return cls._matches_provider_pattern(model, provider) - - # Fallback to pattern matching for models not in constants + # Fallback to pattern matching for models not in constants (includes OCI) return cls._matches_provider_pattern(model, provider) @classmethod From 3ff909a27513913cfc5b3f6810032e5f316bb2c9 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 20 Mar 2026 08:25:24 -0400 Subject: [PATCH 7/8] fix: address PR review feedback on OCI completion - Add missing empty-string fallback for bare strings in list content - Use case-insensitive model checks consistently across all methods - Replace over-engineered FIFO ticket queue with simple threading.Lock --- .../crewai/llms/providers/oci/completion.py | 29 ++++--------------- 1 file changed, 5 insertions(+), 24 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 787f38d5995..bd841ee1775 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -2,7 +2,6 @@ import asyncio from collections.abc import Mapping -from contextlib import contextmanager import logging import os import threading @@ -111,9 +110,7 @@ def __init__( self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs ) - self._client_condition = threading.Condition() - self._next_client_ticket = 0 - self._active_client_ticket = 0 + self._client_lock = threading.Lock() self.last_response_metadata = None # ------------------------------------------------------------------ @@ -128,7 +125,7 @@ def _infer_provider(self, model: str) -> str: return "generic" def _is_openai_gpt5_family(self) -> bool: - return self.model.startswith("openai.gpt-5") + return self.model.lower().startswith("openai.gpt-5") def _build_serving_mode(self) -> Any: models = self._oci.generative_ai_inference.models @@ -175,7 +172,7 @@ def _build_generic_content(self, content: Any) -> list[Any]: processed: list[Any] = [] for item in content: if isinstance(item, str): - processed.append(models.TextContent(text=item)) + processed.append(models.TextContent(text=item or ".")) elif isinstance(item, Mapping) and item.get("type") == "text": processed.append( models.TextContent(text=str(item.get("text", "")) or ".") @@ -266,7 +263,7 @@ def _build_chat_request( if self.temperature is not None and not self._is_openai_gpt5_family(): request_kwargs["temperature"] = self.temperature if self.max_tokens is not None: - if self.oci_provider == "generic" and self.model.startswith("openai."): + if self.oci_provider == "generic" and self.model.lower().startswith("openai."): request_kwargs["max_completion_tokens"] = self.max_tokens else: request_kwargs["max_tokens"] = self.max_tokens @@ -460,25 +457,9 @@ async def acall( # ------------------------------------------------------------------ def _chat(self, chat_details: Any) -> Any: - with self._ordered_client_access(): + with self._client_lock: return self.client.chat(chat_details) - @contextmanager - def _ordered_client_access(self) -> Any: - """Serialize shared OCI client access in call-arrival order.""" - with self._client_condition: - ticket = self._next_client_ticket - self._next_client_ticket += 1 - while ticket != self._active_client_ticket: - self._client_condition.wait() - - try: - yield - finally: - with self._client_condition: - self._active_client_ticket += 1 - self._client_condition.notify_all() - # ------------------------------------------------------------------ # Capability declarations # ------------------------------------------------------------------ From 7ef6c024eb5bb91dcb144a920f5b2d32f4a4043a Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 20 Mar 2026 10:15:40 -0400 Subject: [PATCH 8/8] fix: move normalization inside error context, tighten _call_impl signature - Move _normalize_messages inside llm_call_context and try/except so validation errors emit call_failed events consistently - Narrow _call_impl to accept list[LLMMessage] only, removing unreachable str normalization path --- .../src/crewai/llms/providers/oci/completion.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index bd841ee1775..c44a7940823 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -362,14 +362,11 @@ def _finalize_text_response( def _call_impl( self, *, - messages: str | list[LLMMessage], + messages: list[LLMMessage], from_task: Task | None, from_agent: Agent | None, ) -> str: - normalized_messages = ( - messages if isinstance(messages, list) else self._normalize_messages(messages) - ) - chat_request = self._build_chat_request(normalized_messages) + chat_request = self._build_chat_request(messages) chat_details = self._oci.generative_ai_inference.models.ChatDetails( compartment_id=self.compartment_id, serving_mode=self._build_serving_mode(), @@ -384,7 +381,7 @@ def _call_impl( content = self._extract_text(response) return self._finalize_text_response( content=content, - messages=normalized_messages, + messages=messages, from_task=from_task, from_agent=from_agent, ) @@ -399,10 +396,9 @@ def call( from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | BaseModel | list[dict[str, Any]]: - normalized_messages = self._normalize_messages(messages) - with llm_call_context(): try: + normalized_messages = self._normalize_messages(messages) self._emit_call_started_event( messages=normalized_messages, tools=tools,