From bcc274316281e7cd891f58940416b7b9535d9dcd Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:41:13 -0400 Subject: [PATCH 01/12] =?UTF-8?q?feat:=20add=20OCI=20Generative=20AI=20pro?= =?UTF-8?q?vider=20=E2=80=94=20basic=20text=20completion?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add native OCI Generative AI support to CrewAI with basic text completion for generic (Meta, Google, OpenAI, xAI) and Cohere model families. This is the first in a series of PRs to incrementally build out full OCI support (streaming, tool calling, structured output, embeddings, and multimodal in follow-up PRs). Tracking issue: #4944 Supersedes: #4885 --- lib/crewai/pyproject.toml | 3 + lib/crewai/src/crewai/llm.py | 13 + .../src/crewai/llms/providers/oci/__init__.py | 5 + .../crewai/llms/providers/oci/completion.py | 505 ++++++++++++++++++ lib/crewai/src/crewai/utilities/oci.py | 72 +++ lib/crewai/tests/llms/oci/__init__.py | 0 lib/crewai/tests/llms/oci/conftest.py | 189 +++++++ lib/crewai/tests/llms/oci/test_oci.py | 269 ++++++++++ .../llms/oci/test_oci_integration_basic.py | 33 ++ 9 files changed, 1089 insertions(+) create mode 100644 lib/crewai/src/crewai/llms/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/llms/providers/oci/completion.py create mode 100644 lib/crewai/src/crewai/utilities/oci.py create mode 100644 lib/crewai/tests/llms/oci/__init__.py create mode 100644 lib/crewai/tests/llms/oci/conftest.py create mode 100644 lib/crewai/tests/llms/oci/test_oci.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_basic.py diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index a40484f048..e3fce3fb46 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -97,6 +97,9 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +oci = [ + "oci>=2.168.0", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 75b1f65468..fcd0f7fa92 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -317,6 +317,7 @@ def writable(self) -> bool: "hosted_vllm", "cerebras", "dashscope", + "oci", ] @@ -384,6 +385,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "hosted_vllm": "hosted_vllm", "cerebras": "cerebras", "dashscope": "dashscope", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -506,6 +508,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: # OpenRouter uses org/model format but accepts anything return True + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return False @classmethod @@ -541,6 +546,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -622,6 +630,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return OpenAICompatibleCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 0000000000..0c397558bd --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,5 @@ +from crewai.llms.providers.oci.completion import OCICompletion + +__all__ = [ + "OCICompletion", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 0000000000..8c05e8caae --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from contextlib import contextmanager +import json +import logging +import os +import re +import threading +from typing import TYPE_CHECKING, Any, Literal, cast + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI. + + Supports basic text completions for generic (Meta, Google, OpenAI, xAI) + and Cohere model families hosted on the OCI Generative AI service. + """ + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + oci_provider: str | None = None, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.oci_provider = oci_provider or self._infer_provider(model) + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 + self.last_response_metadata = None + + # ------------------------------------------------------------------ + # Provider inference + # ------------------------------------------------------------------ + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + # ------------------------------------------------------------------ + # Message helpers + # ------------------------------------------------------------------ + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects.""" + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed: list[Any] = [] + for item in content: + if isinstance(item, str): + processed.append(models.TextContent(text=item)) + elif isinstance(item, Mapping) and item.get("type") == "text": + processed.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + else: + processed.append( + models.TextContent(text=self._coerce_text(item) or ".") + ) + return processed or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + oci_messages.append( + message_cls( + content=self._build_generic_content(message.get("content", "")), + ) + ) + + return oci_messages + + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], str]: + """Translate CrewAI messages into Cohere's split history + message shape.""" + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + + for message in messages[:-1]: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + message_text = self._coerce_text(last_message.get("content", "")) + return chat_history, message_text + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_chat_request( + self, + messages: list[LLMMessage], + ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + chat_history, message_text = self._build_cohere_chat_history(messages) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + if self.oci_provider == "cohere": + return models.CohereChatRequest(**request_kwargs) + return models.GenericChatRequest(**request_kwargs) + + # ------------------------------------------------------------------ + # Response extraction + # ------------------------------------------------------------------ + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + # ------------------------------------------------------------------ + # Call paths + # ------------------------------------------------------------------ + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + return self._call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + return await asyncio.to_thread( + self.call, + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + # ------------------------------------------------------------------ + # Client serialization + # ------------------------------------------------------------------ + + def _chat(self, chat_details: Any) -> Any: + with self._ordered_client_access(): + return self.client.chat(chat_details) + + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + + # ------------------------------------------------------------------ + # Capability declarations + # ------------------------------------------------------------------ + + def supports_function_calling(self) -> bool: + return True + + def supports_stop_words(self) -> bool: + return True + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 0000000000..7935530690 --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/tests/llms/oci/__init__.py b/lib/crewai/tests/llms/oci/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 0000000000..164f53060f --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,189 @@ +"""Fixtures for OCI provider unit and integration tests.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fake OCI SDK module (replaces `import oci` in unit tests) +# --------------------------------------------------------------------------- + + +def _make_fake_oci_module() -> MagicMock: + """Build a lightweight mock of the OCI SDK surface used by OCICompletion.""" + oci = MagicMock() + + # Models namespace + models = oci.generative_ai_inference.models + + # Serving modes + models.OnDemandServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.DedicatedServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Content types + models.TextContent = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Message types + for cls_name in ( + "UserMessage", + "AssistantMessage", + "SystemMessage", + "CohereUserMessage", + "CohereSystemMessage", + "CohereChatBotMessage", + ): + setattr(models, cls_name, MagicMock(side_effect=lambda **kw: MagicMock(**kw))) + + # Request types + models.GenericChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.CohereChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.BaseChatRequest = MagicMock() + models.BaseChatRequest.API_FORMAT_GENERIC = "GENERIC" + models.BaseChatRequest.API_FORMAT_COHERE = "COHERE" + + # ChatDetails + models.ChatDetails = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Auth helpers + oci.config.from_file = MagicMock(return_value={"key_file": "/tmp/k", "security_token_file": "/tmp/t"}) + oci.signer.load_private_key_from_file = MagicMock(return_value="pk") + oci.auth.signers.SecurityTokenSigner = MagicMock() + oci.auth.signers.InstancePrincipalsSecurityTokenSigner = MagicMock() + oci.auth.signers.get_resource_principals_signer = MagicMock() + oci.retry.DEFAULT_RETRY_STRATEGY = "default_retry" + + # Client constructor + oci.generative_ai_inference.GenerativeAiInferenceClient = MagicMock() + + return oci + + +def _make_fake_chat_response(text: str = "Hello from OCI") -> MagicMock: + """Build a minimal OCI chat response for generic models.""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + usage.total_tokens = 15 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_fake_cohere_chat_response(text: str = "Hello from Cohere") -> MagicMock: + """Build a minimal OCI chat response for Cohere models.""" + chat_response = MagicMock() + chat_response.text = text + chat_response.finish_reason = "COMPLETE" + chat_response.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 8 + usage.completion_tokens = 4 + usage.total_tokens = 12 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +@pytest.fixture() +def oci_fake_module() -> MagicMock: + return _make_fake_oci_module() + + +@pytest.fixture() +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: MagicMock) -> MagicMock: + """Patch the OCI module import so no real SDK is needed.""" + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture() +def oci_response_factories() -> dict[str, Any]: + return { + "chat": _make_fake_chat_response, + "cohere_chat": _make_fake_cohere_chat_response, + } + + +# --------------------------------------------------------------------------- +# Unit test defaults +# --------------------------------------------------------------------------- + +@pytest.fixture() +def oci_unit_values() -> dict[str, str]: + return { + "compartment_id": "ocid1.compartment.oc1..test", + "model": "meta.llama-3.3-70b-instruct", + "cohere_model": "cohere.command-r-plus-08-2024", + } + + +# --------------------------------------------------------------------------- +# Integration test fixtures (live OCI calls) +# --------------------------------------------------------------------------- + +def _env_models(env_var: str, fallback_var: str, default: str) -> list[str]: + """Read model list from env, supporting comma-separated values.""" + raw = os.getenv(env_var) or os.getenv(fallback_var) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live_config() -> dict[str, str]: + """Return live config dict or skip the test.""" + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set — skipping live test") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT for live tests") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + if os.getenv("OCI_AUTH_FILE_LOCATION"): + config["auth_file_location"] = os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config") + return config + + +@pytest.fixture( + params=_env_models("OCI_TEST_MODELS", "OCI_TEST_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_chat_model(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture() +def oci_live_config() -> dict[str, str]: + return _skip_unless_live_config() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 0000000000..a5f7efb1b6 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,269 @@ +"""Unit tests for the OCI Generative AI provider (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider routing +# --------------------------------------------------------------------------- + + +def test_oci_completion_is_used_when_oci_provider(patch_oci_module): + """LLM(model='oci/...') should resolve to OCICompletion.""" + from crewai.llm import LLM + + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..test", + ) + from crewai.llms.providers.oci.completion import OCICompletion + + # LLM.__new__ returns the native provider instance directly + assert isinstance(llm, OCICompletion) + + +@pytest.mark.parametrize( + "model_id, expected_provider", + [ + ("meta.llama-3.3-70b-instruct", "generic"), + ("google.gemini-2.5-flash", "generic"), + ("openai.gpt-4o", "generic"), + ("xai.grok-3", "generic"), + ("cohere.command-r-plus-08-2024", "cohere"), + ], +) +def test_oci_completion_infers_provider_family( + patch_oci_module, oci_unit_values, model_id, expected_provider +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=model_id, + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.oci_provider == expected_provider + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +def test_oci_completion_initialization_parameters(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + temperature=0.7, + max_tokens=512, + top_p=0.9, + top_k=40, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 512 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.compartment_id == oci_unit_values["compartment_id"] + + +def test_oci_completion_uses_region_to_build_endpoint(patch_oci_module, oci_unit_values, monkeypatch): + from crewai.llms.providers.oci.completion import OCICompletion + + monkeypatch.delenv("OCI_SERVICE_ENDPOINT", raising=False) + monkeypatch.setenv("OCI_REGION", "us-ashburn-1") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert "us-ashburn-1" in llm.service_endpoint + + +# --------------------------------------------------------------------------- +# Basic call +# --------------------------------------------------------------------------- + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("test response") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "test response" in result + fake_client.chat.assert_called_once() + + +def test_oci_completion_cohere_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_chat"]("cohere reply") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Hi"}]) + + assert "cohere reply" in result + fake_client.chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._coerce_text(None) == "" + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, oci_response_factories, oci_unit_values +): + """Ensure normalize is not called twice when _call_impl receives a list.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + call_count = 0 + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + original_normalize = llm._normalize_messages + + def counting_normalize(msgs): + nonlocal call_count + call_count += 1 + return original_normalize(msgs) + + llm._normalize_messages = counting_normalize + + llm.call(messages=[{"role": "user", "content": "hi"}]) + # call() normalizes once, _call_impl should not normalize again + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAI model quirks +# --------------------------------------------------------------------------- + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-4o", + compartment_id=oci_unit_values["compartment_id"], + max_tokens=1024, + ) + request = llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args + assert call_kwargs is not None + kwargs = call_kwargs[1] if call_kwargs[1] else {} + assert kwargs.get("max_completion_tokens") == 1024 + assert "max_tokens" not in kwargs + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-5", + compartment_id=oci_unit_values["compartment_id"], + temperature=0.5, + ) + llm.stop = ["END"] + llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args[1] + assert "temperature" not in call_kwargs + assert "stop" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("async result") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = await llm.acall(messages=[{"role": "user", "content": "async test"}]) + + assert "async result" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 0000000000..adeb0f2cce --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI basic text completion. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024,google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_basic.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_basic_call(oci_chat_model: str, oci_live_config: dict): + """Synchronous text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call(messages=[{"role": "user", "content": "Say 'hello world' in one sentence."}]) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_async_call(oci_chat_model: str, oci_live_config: dict): + """Async text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = await llm.acall(messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}]) + + assert isinstance(result, str) + assert len(result) > 0 From f4d1b5cc329a5457dafe4c2e27821411b4776f52 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:49:49 -0400 Subject: [PATCH 02/12] fix: return False from supports_function_calling until tool PR Tool calling is not implemented in this PR. Returning True would cause CrewAI to choose the native tools path, silently dropping tools from agents. Flagged by Cursor Bugbot review. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 8c05e8caae..b03ffa7990 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -487,7 +487,7 @@ def _ordered_client_access(self) -> Any: # ------------------------------------------------------------------ def supports_function_calling(self) -> bool: - return True + return False # Tool calling support will be added in a follow-up PR def supports_stop_words(self) -> bool: return True From 0ba50115cdbf4ac06630442b7db26c6a53d06a3d Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 13:56:05 -0400 Subject: [PATCH 03/12] refactor: remove supports_function_calling and supports_stop_words Both methods are unnecessary in this PR. The base class and callers already default correctly when the methods are absent: - supports_function_calling: callers use getattr with False default - supports_stop_words: base class already returns True These will be added back in the tool calling follow-up PR. --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index b03ffa7990..71abf20d51 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -486,12 +486,6 @@ def _ordered_client_access(self) -> Any: # Capability declarations # ------------------------------------------------------------------ - def supports_function_calling(self) -> bool: - return False # Tool calling support will be added in a follow-up PR - - def supports_stop_words(self) -> bool: - return True - def get_context_window_size(self) -> int: model_lower = self.model.lower() if model_lower.startswith("google.gemini"): From af22c1099cfef2dc29c74f0ab494418b04cb38df Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:00:30 -0400 Subject: [PATCH 04/12] cleanup: remove unused imports and dead code Remove json, re imports and _OCI_SCHEMA_NAME_PATTERN regex that are only needed for structured output (not in this PR scope). --- lib/crewai/src/crewai/llms/providers/oci/completion.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 71abf20d51..787f38d599 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,10 +3,8 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager -import json import logging import os -import re import threading from typing import TYPE_CHECKING, Any, Literal, cast @@ -26,7 +24,6 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" DEFAULT_OCI_REGION = "us-chicago-1" -_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") def _get_oci_module() -> Any: From d9bf9a40d95f8e785a01a41862950d5bbe66e8a3 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:02:09 -0400 Subject: [PATCH 05/12] fix: use model_lower consistently in OCI pattern check Use model_lower instead of model in the dot check to match the convention used by all other providers in _matches_provider_pattern. Flagged by Cursor Bugbot. --- lib/crewai/src/crewai/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index fcd0f7fa92..986010a0e8 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -509,7 +509,7 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: return True if provider == "oci": - return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model_lower return False From 58d69104428757620aa9b610b75def18f01ed5b8 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 14:40:23 -0400 Subject: [PATCH 06/12] feat: add streaming support to OCI Generative AI provider Add streaming text completion via OCI SSE events: - stream=True in call() routes to _stream_call_impl with chunk events - iter_stream() yields raw text chunks (sync generator) - astream() wraps iter_stream via thread+queue for async callers - _stream_chat_events holds client lock for full stream duration - SSE event parsing handles both string and mapping payloads Tested live against meta.llama-3.3-70b-instruct, cohere.command-r-plus-08-2024, google.gemini-2.5-flash, and openai.gpt-5.2-chat-latest. Depends on: #4959 Tracking issue: #4944 --- .../crewai/llms/providers/oci/completion.py | 230 ++++++++++++++++++ .../oci/test_oci_integration_streaming.py | 40 +++ .../tests/llms/oci/test_oci_streaming.py | 150 ++++++++++++ 3 files changed, 420 insertions(+) create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_streaming.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_streaming.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 787f38d599..1147a1c2ad 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,10 +3,12 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager +import json import logging import os import threading from typing import TYPE_CHECKING, Any, Literal, cast +import uuid from pydantic import BaseModel @@ -57,6 +59,7 @@ def __init__( max_tokens: int | None = None, top_p: float | None = None, top_k: int | None = None, + stream: bool = False, oci_provider: str | None = None, client: Any | None = None, **kwargs: Any, @@ -94,6 +97,7 @@ def __init__( self.max_tokens = max_tokens self.top_p = top_p self.top_k = top_k + self.stream = stream self.oci_provider = oci_provider or self._infer_provider(model) self._oci = _get_oci_module() @@ -246,6 +250,8 @@ def _build_cohere_chat_history( def _build_chat_request( self, messages: list[LLMMessage], + *, + is_stream: bool = False, ) -> Any: """Build the provider-specific OCI chat request for the current model.""" models = self._oci.generative_ai_inference.models @@ -279,6 +285,12 @@ def _build_chat_request( stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" request_kwargs[stop_key] = list(self.stop) + if is_stream: + request_kwargs["is_stream"] = True + request_kwargs["stream_options"] = models.StreamOptions( + is_include_usage=True + ) + if self.oci_provider == "cohere": return models.CohereChatRequest(**request_kwargs) return models.GenericChatRequest(**request_kwargs) @@ -339,6 +351,75 @@ def _extract_response_metadata(self, response: Any) -> dict[str, Any]: return metadata + # ------------------------------------------------------------------ + # Streaming extraction + # ------------------------------------------------------------------ + + def _parse_stream_event(self, event: Any) -> dict[str, Any]: + """Convert OCI SSE event payloads into plain dicts.""" + event_data = getattr(event, "data", None) + if not event_data: + return {} + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + if isinstance(parsed, Mapping): + return dict(parsed) + return {} + except json.JSONDecodeError: + logging.debug("Skipping invalid OCI SSE payload: %s", event_data) + return {} + if isinstance(event_data, Mapping): + return dict(event_data) + return {} + + def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: + if self.oci_provider == "cohere": + if "text" in event_data: + return str(event_data.get("text", "")) + message = event_data.get("message", {}) + if isinstance(message, Mapping): + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) + ) + return "" + + message = event_data.get("message", {}) + if not isinstance(message, Mapping): + return "" + content = message.get("content", []) + if not isinstance(content, list): + return "" + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) and part.get("text") + ) + + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: + usage = event_data.get("usage") + if not isinstance(usage, Mapping): + return {} + return { + "prompt_tokens": int(usage.get("promptTokens", 0) or 0), + "completion_tokens": int(usage.get("completionTokens", 0) or 0), + "total_tokens": int(usage.get("totalTokens", 0) or 0), + } + + def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, Any]: + metadata: dict[str, Any] = {} + finish_reason = event_data.get("finishReason") + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + usage = self._extract_usage_from_stream_event(event_data) + if usage: + metadata["usage"] = usage + return metadata + # ------------------------------------------------------------------ # Call paths # ------------------------------------------------------------------ @@ -392,6 +473,142 @@ def _call_impl( from_agent=from_agent, ) + def _stream_call_impl( + self, + *, + messages: str | list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + """Handle OCI streaming while reconstructing final text state.""" + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + full_response = "" + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + response_id = uuid.uuid4().hex + + for event in self._stream_chat_events(chat_details): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.LLM_CALL, + response_id=response_id, + ) + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + return self._finalize_text_response( + content=full_response, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def iter_stream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Yield raw text chunks from OCI without triggering tool recursion.""" + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Expose the sync OCI SSE stream through an async generator facade.""" + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + + thread.join() + if error_holder: + raise error_holder[0] + def call( self, messages: str | list[LLMMessage], @@ -420,6 +637,13 @@ def call( ): raise ValueError("LLM call blocked by before_llm_call hook") + if self.stream: + return self._stream_call_impl( + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + return self._call_impl( messages=normalized_messages, from_task=from_task, @@ -463,6 +687,12 @@ def _chat(self, chat_details: Any) -> Any: with self._ordered_client_access(): return self.client.chat(chat_details) + def _stream_chat_events(self, chat_details: Any) -> Any: + """Yield streaming events while holding the shared OCI client lock.""" + with self._ordered_client_access(): + response = self.client.chat(chat_details) + yield from response.data.events() + @contextmanager def _ordered_client_access(self) -> Any: """Serialize shared OCI client access in call-arrival order.""" diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py new file mode 100644 index 0000000000..db41b1f3a9 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py @@ -0,0 +1,40 @@ +"""Live integration tests for OCI Generative AI streaming. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024" \ + uv run pytest tests/llms/oci/test_oci_integration_streaming.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_streaming_call(oci_chat_model: str, oci_live_config: dict): + """Streaming text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, stream=True, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "Count from 1 to 5, one per line."}] + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_astream(oci_chat_model: str, oci_live_config: dict): + """Async streaming should yield text chunks from a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Say hello in three words."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full_text = "".join(chunks) + assert len(full_text) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_streaming.py b/lib/crewai/tests/llms/oci/test_oci_streaming.py new file mode 100644 index 0000000000..721fa3f671 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_streaming.py @@ -0,0 +1,150 @@ +"""Unit tests for OCI provider streaming (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def _make_fake_stream_event(text: str = "", finish_reason: str | None = None, usage: dict | None = None) -> MagicMock: + """Build a single SSE event with optional text, finish, and usage.""" + payload: dict = {} + if text: + payload["message"] = {"content": [{"text": text}]} + if finish_reason: + payload["finishReason"] = finish_reason + if usage: + payload["usage"] = usage + + import json + event = MagicMock() + event.data = json.dumps(payload) + return event + + +def _make_fake_stream_response(*events: MagicMock) -> MagicMock: + """Wrap events into a response.data.events() iterable.""" + response = MagicMock() + response.data.events.return_value = iter(events) + return response + + +def test_oci_completion_streams_generic_responses( + patch_oci_module, oci_unit_values +): + """Streaming call should accumulate text chunks and return full response.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="Hello "), + _make_fake_stream_event(text="world"), + _make_fake_stream_event( + finish_reason="stop", + usage={"promptTokens": 5, "completionTokens": 2, "totalTokens": 7}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # StreamOptions mock + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + stream=True, + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "Hello " in result + assert "world" in result + assert llm.last_response_metadata is not None + assert llm.last_response_metadata.get("finish_reason") == "stop" + + +def test_oci_iter_stream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """iter_stream should yield individual text chunks.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="chunk1"), + _make_fake_stream_event(text="chunk2"), + _make_fake_stream_event( + usage={"promptTokens": 3, "completionTokens": 2, "totalTokens": 5}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = list(llm.iter_stream(messages=[{"role": "user", "content": "test"}])) + + assert chunks == ["chunk1", "chunk2"] + assert llm.last_response_metadata is not None + assert llm.last_response_metadata["usage"]["total_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_oci_astream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """astream should yield chunks via async generator.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="async1"), + _make_fake_stream_event(text="async2"), + _make_fake_stream_event(), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = [] + async for chunk in llm.astream(messages=[{"role": "user", "content": "test"}]): + chunks.append(chunk) + + assert chunks == ["async1", "async2"] + + +def test_oci_stream_chat_events_holds_client_lock( + patch_oci_module, oci_unit_values +): + """_stream_chat_events should hold the client lock for the full iteration.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [_make_fake_stream_event(text="a"), _make_fake_stream_event(text="b")] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + # Before streaming, ticket should be 0 + assert llm._active_client_ticket == 0 + chat_details = MagicMock() + list(llm._stream_chat_events(chat_details)) + # After streaming completes, ticket should have advanced + assert llm._active_client_ticket == 1 From 69532733cad893de74527377f2b7131f2fa30e08 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 15:17:15 -0400 Subject: [PATCH 07/12] feat: add tool calling support to OCI Generative AI provider Add native function calling for generic and Cohere model families: - _format_tools converts CrewAI tool specs to OCI SDK format - _extract_tool_calls normalizes responses back to CrewAI shape - _handle_tool_calls executes tools and recurses until model finishes - Cohere tool message handling with trailing tool results - Tool choice control (auto/none/required/function) - Passthrough parameter filtering via SDK introspection - Streaming tool call accumulation from SSE fragments - supports_function_calling() returns True Tested live against meta.llama-3.3-70b-instruct with raw tool call return and recursive tool execution. Depends on: #4961 (streaming), #4959 (basic text) Tracking issue: #4944 --- .../crewai/llms/providers/oci/completion.py | 527 +++++++++++++++++- .../llms/oci/test_oci_integration_tools.py | 100 ++++ lib/crewai/tests/llms/oci/test_oci_tools.py | 291 ++++++++++ 3 files changed, 902 insertions(+), 16 deletions(-) create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_tools.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_tools.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 1147a1c2ad..253a311068 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -3,6 +3,7 @@ import asyncio from collections.abc import Mapping from contextlib import contextmanager +import inspect import json import logging import os @@ -26,6 +27,17 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_TOOL_RESULT_GUIDANCE = ( + "You have received tool results above. Respond to the user with a helpful, " + "natural language answer that incorporates the tool results. Do not output " + "raw JSON or tool call syntax. If you need additional information, you may " + "call another tool." +) +_OCI_RESERVED_REQUEST_KWARGS = { + "tool_choice", + "parallel_tool_calls", + "tool_result_guidance", +} def _get_oci_module() -> Any: @@ -61,6 +73,7 @@ def __init__( top_k: int | None = None, stream: bool = False, oci_provider: str | None = None, + max_sequential_tool_calls: int = 8, client: Any | None = None, **kwargs: Any, ) -> None: @@ -99,6 +112,7 @@ def __init__( self.top_k = top_k self.stream = stream self.oci_provider = oci_provider or self._infer_provider(model) + self.max_sequential_tool_calls = max_sequential_tool_calls self._oci = _get_oci_module() if client is not None: @@ -202,13 +216,45 @@ def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: for message in messages: role = str(message.get("role", "user")).lower() + if role == "tool": + tool_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if message.get("tool_call_id"): + tool_kwargs["tool_call_id"] = message["tool_call_id"] + oci_messages.append(models.ToolMessage(**tool_kwargs)) + continue + message_cls = role_map.get(role) if message_cls is None: logging.debug("Skipping unsupported OCI message role: %s", role) continue + + message_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if role == "assistant" and message.get("tool_calls"): + message_kwargs["tool_calls"] = [ + models.FunctionCall( + id=tool_call.get("id"), + name=tool_call.get("function", {}).get("name"), + arguments=tool_call.get("function", {}).get("arguments", "{}"), + ) + for tool_call in message.get("tool_calls", []) + if tool_call.get("function", {}).get("name") + ] + if not message_kwargs["content"]: + message_kwargs["content"] = [models.TextContent(text=".")] + + oci_messages.append(message_cls(**message_kwargs)) + + if ( + self._tool_result_guidance_enabled() + and any(str(message.get("role", "")).lower() == "tool" for message in messages) + ): oci_messages.append( - message_cls( - content=self._build_generic_content(message.get("content", "")), + models.SystemMessage( + content=[models.TextContent(text=_OCI_TOOL_RESULT_GUIDANCE)] ) ) @@ -216,12 +262,21 @@ def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: def _build_cohere_chat_history( self, messages: list[LLMMessage] - ) -> tuple[list[Any], str]: - """Translate CrewAI messages into Cohere's split history + message shape.""" + ) -> tuple[list[Any], list[Any] | None, str]: + """Translate CrewAI messages into Cohere's split history/tool-results shape.""" models = self._oci.generative_ai_inference.models chat_history: list[Any] = [] + trailing_tool_count = 0 + for message in reversed(messages): + if str(message.get("role", "")).lower() != "tool": + break + trailing_tool_count += 1 + + history_messages = ( + messages[:-trailing_tool_count] if trailing_tool_count else messages[:-1] + ) - for message in messages[:-1]: + for message in history_messages: role = str(message.get("role", "user")).lower() content = message.get("content", "") @@ -233,15 +288,188 @@ def _build_cohere_chat_history( ) chat_history.append(message_cls(message=self._coerce_text(content))) elif role == "assistant": + tool_calls = None + if message.get("tool_calls"): + tool_calls = [] + for tool_call in message.get("tool_calls", []): + function_info = tool_call.get("function", {}) + function_name = function_info.get("name") + if not function_name: + continue + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + tool_calls.append( + models.CohereToolCall(name=function_name, parameters=parameters) + ) chat_history.append( models.CohereChatBotMessage( message=self._coerce_text(content) or " ", + tool_calls=tool_calls, + ) + ) + elif role == "tool": + tool_name = message.get("name") or "tool" + chat_history.append( + models.CohereToolMessage( + tool_results=[ + models.CohereToolResult( + call=models.CohereToolCall(name=tool_name, parameters={}), + outputs=[{"output": self._coerce_text(content)}], + ) + ] ) ) last_message = messages[-1] if messages else {"role": "user", "content": ""} + tool_results: list[Any] = [] + if str(last_message.get("role", "user")).lower() == "tool": + previous_tool_calls: dict[str, dict[str, Any]] = {} + for message in messages: + if str(message.get("role", "")).lower() != "assistant": + continue + for tool_call in message.get("tool_calls", []): + tool_call_id = tool_call.get("id") + if not tool_call_id: + continue + function_info = tool_call.get("function", {}) + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + previous_tool_calls[tool_call_id] = { + "name": function_info.get("name", "tool"), + "parameters": parameters, + } + + for message in messages[-trailing_tool_count:]: + if str(message.get("role", "")).lower() != "tool": + continue + tool_call_id = message.get("tool_call_id") + if not isinstance(tool_call_id, str): + continue + previous_call = previous_tool_calls.get(tool_call_id, {}) + tool_results.append( + models.CohereToolResult( + call=models.CohereToolCall( + name=previous_call.get("name", message.get("name", "tool")), + parameters=previous_call.get("parameters", {}), + ), + outputs=[{"output": self._coerce_text(message.get("content", ""))}], + ) + ) + message_text = self._coerce_text(last_message.get("content", "")) - return chat_history, message_text + if tool_results: + message_text = "" + + return chat_history, tool_results or None, message_text + + # ------------------------------------------------------------------ + # Tool formatting + # ------------------------------------------------------------------ + + def _format_tools(self, tools: list[dict[str, Any]] | None) -> list[Any]: + if not tools: + return [] + models = self._oci.generative_ai_inference.models + formatted: list[Any] = [] + for tool in tools: + if not isinstance(tool, Mapping): + continue + function_spec = tool.get("function", {}) + if not isinstance(function_spec, Mapping): + continue + name = function_spec.get("name") + if not name: + continue + parameters = function_spec.get("parameters", {}) + if not isinstance(parameters, Mapping): + parameters = {} + + if self.oci_provider == "cohere": + param_defs = {} + required = set(parameters.get("required", [])) + for pname, pschema in parameters.get("properties", {}).items(): + if not isinstance(pschema, Mapping): + continue + param_defs[pname] = models.CohereParameterDefinition( + description=pschema.get("description", ""), + type=pschema.get("type", "object"), + is_required=pname in required, + ) + formatted.append(models.CohereTool( + name=name, + description=function_spec.get("description", name), + parameter_definitions=param_defs, + )) + else: + formatted.append(models.FunctionDefinition( + name=name, + description=function_spec.get("description", name), + parameters={ + "type": parameters.get("type", "object"), + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + )) + return formatted + + def _tool_result_guidance_enabled(self) -> bool: + return bool(self.additional_params.get("tool_result_guidance")) + + def _parallel_tool_calls_enabled(self) -> bool: + return bool(self.additional_params.get("parallel_tool_calls")) + + def _build_tool_choice(self) -> Any | None: + tool_choice = self.additional_params.get("tool_choice") + if tool_choice is None: + return None + models = self._oci.generative_ai_inference.models + if isinstance(tool_choice, str): + if tool_choice == "auto": + return models.ToolChoiceAuto() + if tool_choice == "none": + return models.ToolChoiceNone() + if tool_choice in ("any", "required"): + return models.ToolChoiceRequired() + return models.ToolChoiceFunction(name=tool_choice) + if isinstance(tool_choice, bool): + return models.ToolChoiceRequired() if tool_choice else models.ToolChoiceNone() + if isinstance(tool_choice, Mapping): + fn = tool_choice.get("function") + if isinstance(fn, Mapping) and fn.get("name"): + return models.ToolChoiceFunction(name=str(fn["name"])) + return models.ToolChoiceAuto() + raise ValueError("Unrecognized OCI tool_choice. Expected str, bool, or function mapping.") + + def _allowed_passthrough_request_keys(self, request_cls: type[Any]) -> set[str]: + """Return request attributes that can safely be forwarded to the OCI SDK.""" + attribute_map = getattr(request_cls, "attribute_map", None) + if isinstance(attribute_map, Mapping): + return {str(key) for key in attribute_map} + swagger_types = getattr(request_cls, "swagger_types", None) + if isinstance(swagger_types, Mapping): + return {str(key) for key in swagger_types} + signature = inspect.signature(request_cls) + return { + name + for name, param in signature.parameters.items() + if name != "self" and param.kind is not inspect.Parameter.VAR_KEYWORD + } # ------------------------------------------------------------------ # Request building @@ -250,6 +478,7 @@ def _build_cohere_chat_history( def _build_chat_request( self, messages: list[LLMMessage], + tools: list[dict[str, Any]] | None = None, *, is_stream: bool = False, ) -> Any: @@ -257,12 +486,16 @@ def _build_chat_request( models = self._oci.generative_ai_inference.models if self.oci_provider == "cohere": - chat_history, message_text = self._build_cohere_chat_history(messages) + chat_history, tool_results, message_text = self._build_cohere_chat_history( + messages + ) request_kwargs: dict[str, Any] = { "message": message_text, "chat_history": chat_history, "api_format": models.BaseChatRequest.API_FORMAT_COHERE, } + if tool_results: + request_kwargs["tool_results"] = tool_results else: request_kwargs = { "messages": self._build_generic_messages(messages), @@ -285,6 +518,20 @@ def _build_chat_request( stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" request_kwargs[stop_key] = list(self.stop) + formatted_tools = self._format_tools(tools) + if formatted_tools: + request_kwargs["tools"] = formatted_tools + if self.oci_provider == "cohere": + if self._parallel_tool_calls_enabled(): + raise ValueError("OCI Cohere models do not support parallel_tool_calls.") + request_kwargs.setdefault("is_force_single_step", False) + else: + tool_choice = self._build_tool_choice() + if tool_choice is not None: + request_kwargs["tool_choice"] = tool_choice + if self._parallel_tool_calls_enabled(): + request_kwargs["is_parallel_tool_calls"] = True + if is_stream: request_kwargs["is_stream"] = True request_kwargs["stream_options"] = models.StreamOptions( @@ -292,7 +539,20 @@ def _build_chat_request( ) if self.oci_provider == "cohere": + allowed = self._allowed_passthrough_request_keys(models.CohereChatRequest) + passthrough = { + k: v for k, v in self.additional_params.items() + if k not in _OCI_RESERVED_REQUEST_KWARGS and k in allowed + } + request_kwargs.update(passthrough) return models.CohereChatRequest(**request_kwargs) + + allowed = self._allowed_passthrough_request_keys(models.GenericChatRequest) + passthrough = { + k: v for k, v in self.additional_params.items() + if k not in _OCI_RESERVED_REQUEST_KWARGS and k in allowed + } + request_kwargs.update(passthrough) return models.GenericChatRequest(**request_kwargs) # ------------------------------------------------------------------ @@ -321,6 +581,42 @@ def _extract_text(self, response: Any) -> str: content = getattr(message, "content", None) or [] return "".join(part.text for part in content if getattr(part, "text", None)) + def _extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + """Normalize provider-specific tool calls back into CrewAI's shape.""" + chat_response = response.data.chat_response + raw: list[Any] = [] + if self.oci_provider == "cohere": + raw = getattr(chat_response, "tool_calls", None) or [] + else: + choices = getattr(chat_response, "choices", None) or [] + if choices: + msg = getattr(choices[0], "message", None) + raw = getattr(msg, "tool_calls", None) or [] + + if self.oci_provider == "cohere": + return [ + { + "id": uuid.uuid4().hex, + "type": "function", + "function": { + "name": getattr(tc, "name", ""), + "arguments": json.dumps(getattr(tc, "parameters", {}) or {}), + }, + } + for tc in raw + ] + return [ + { + "id": getattr(tc, "id", None), + "type": "function", + "function": { + "name": getattr(tc, "name", ""), + "arguments": getattr(tc, "arguments", "{}"), + }, + } + for tc in raw + ] + def _extract_usage(self, response: Any) -> dict[str, int]: chat_response = response.data.chat_response usage = getattr(chat_response, "usage", None) @@ -400,6 +696,31 @@ def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: if isinstance(part, Mapping) and part.get("text") ) + def _extract_tool_calls_from_stream_event( + self, event_data: dict[str, Any] + ) -> list[dict[str, Any]]: + message = event_data.get("message", {}) + if self.oci_provider == "cohere": + raw = event_data.get("toolCalls", []) + else: + raw = message.get("toolCalls", []) if isinstance(message, Mapping) else [] + if not isinstance(raw, list): + return [] + if self.oci_provider == "cohere": + return [ + {"id": None, "type": "function", "function": { + "name": str(tc.get("name", "")), + "arguments": json.dumps(tc.get("parameters", {})), + }} + for tc in raw if isinstance(tc, Mapping) + ] + return [ + {"id": tc.get("id"), "type": "function", "function": { + "name": tc.get("name"), "arguments": tc.get("arguments"), + }} + for tc in raw if isinstance(tc, Mapping) + ] + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: usage = event_data.get("usage") if not isinstance(usage, Mapping): @@ -420,6 +741,94 @@ def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dic metadata["usage"] = usage return metadata + # ------------------------------------------------------------------ + # Tool execution + # ------------------------------------------------------------------ + + def _handle_tool_calls( + self, + *, + normalized_messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None, + callbacks: list[Any] | None, + available_functions: dict[str, Any] | None, + from_task: Task | None, + from_agent: Agent | None, + tool_depth: int, + tool_calls: list[dict[str, Any]], + ) -> str | BaseModel | list[dict[str, Any]]: + """Execute one round of tool calls and recurse until the model finishes.""" + if tool_calls and not available_functions: + self._emit_call_completed_event( + response=tool_calls, + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=normalized_messages, + ) + return tool_calls + + if tool_depth >= self.max_sequential_tool_calls: + raise RuntimeError( + "OCI native provider exceeded max_sequential_tool_calls." + ) + + next_messages = list(normalized_messages) + next_messages.append( + {"role": "assistant", "content": None, "tool_calls": tool_calls} + ) + + for tool_call in tool_calls: + fn = tool_call.get("function", {}) + fn_name = fn.get("name", "") + raw_args = fn.get("arguments", "{}") + if isinstance(raw_args, str): + try: + fn_args = json.loads(raw_args) + except json.JSONDecodeError: + fn_args = {} + elif isinstance(raw_args, Mapping): + fn_args = dict(raw_args) + else: + fn_args = {} + + result = self._handle_tool_execution( + function_name=fn_name, + function_args=fn_args, + available_functions=available_functions or {}, + from_task=from_task, + from_agent=from_agent, + ) + if result is None: + result = f"Tool '{fn_name}' failed or returned no result." + + next_messages.append({ + "role": "tool", + "tool_call_id": str(tool_call.get("id") or uuid.uuid4().hex), + "name": fn_name, + "content": str(result), + }) + + if self.stream: + return self._stream_call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + ) + return self._call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + ) + # ------------------------------------------------------------------ # Call paths # ------------------------------------------------------------------ @@ -447,13 +856,17 @@ def _call_impl( self, *, messages: str | list[LLMMessage], - from_task: Task | None, - from_agent: Agent | None, - ) -> str: + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + tool_depth: int = 0, + ) -> str | BaseModel | list[dict[str, Any]]: normalized_messages = ( messages if isinstance(messages, list) else self._normalize_messages(messages) ) - chat_request = self._build_chat_request(normalized_messages) + chat_request = self._build_chat_request(normalized_messages, tools=tools) chat_details = self._oci.generative_ai_inference.models.ChatDetails( compartment_id=self.compartment_id, serving_mode=self._build_serving_mode(), @@ -466,6 +879,19 @@ def _call_impl( self.last_response_metadata = self._extract_response_metadata(response) or None content = self._extract_text(response) + tool_calls = self._extract_tool_calls(response) + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + tool_calls=tool_calls, + ) + return self._finalize_text_response( content=content, messages=normalized_messages, @@ -477,20 +903,27 @@ def _stream_call_impl( self, *, messages: str | list[LLMMessage], - from_task: Task | None, - from_agent: Agent | None, - ) -> str: - """Handle OCI streaming while reconstructing final text state.""" + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + tool_depth: int = 0, + ) -> str | BaseModel | list[dict[str, Any]]: + """Handle OCI streaming while reconstructing final text/tool state.""" normalized_messages = ( messages if isinstance(messages, list) else self._normalize_messages(messages) ) - chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_request = self._build_chat_request( + normalized_messages, tools=tools, is_stream=True + ) chat_details = self._oci.generative_ai_inference.models.ChatDetails( compartment_id=self.compartment_id, serving_mode=self._build_serving_mode(), chat_request=chat_request, ) full_response = "" + tool_calls_by_index: dict[int, dict[str, Any]] = {} usage_data: dict[str, int] = {} response_metadata: dict[str, Any] = {} response_id = uuid.uuid4().hex @@ -511,6 +944,32 @@ def _stream_call_impl( response_id=response_id, ) + stream_tool_calls = self._extract_tool_calls_from_stream_event(event_data) + for index, tc in enumerate(stream_tool_calls): + state = tool_calls_by_index.setdefault( + index, + {"id": None, "type": "function", "function": {"name": None, "arguments": ""}}, + ) + if tc.get("id"): + state["id"] = tc["id"] + fn = tc.get("function", {}) + if fn.get("name"): + state["function"]["name"] = fn["name"] + chunk_args = fn.get("arguments") + if chunk_args: + state["function"]["arguments"] += str(chunk_args) + self._emit_stream_chunk_event( + chunk=str(chunk_args or ""), + tool_call={"id": state["id"], "type": "function", "function": { + "name": state["function"]["name"], + "arguments": str(chunk_args or ""), + }}, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.TOOL_CALL, + response_id=response_id, + ) + usage_chunk = self._extract_usage_from_stream_event(event_data) if usage_chunk: usage_data = usage_chunk @@ -521,6 +980,31 @@ def _stream_call_impl( response_metadata["usage"] = usage_data self.last_response_metadata = response_metadata or None + tool_calls = [ + { + "id": tc.get("id") or uuid.uuid4().hex, + "type": "function", + "function": { + "name": tc["function"].get("name", "") or "", + "arguments": tc["function"].get("arguments", "") or "", + }, + } + for _, tc in sorted(tool_calls_by_index.items()) + if tc["function"].get("name") + ] + + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + tool_calls=tool_calls, + ) + return self._finalize_text_response( content=full_response, messages=normalized_messages, @@ -640,14 +1124,22 @@ def call( if self.stream: return self._stream_call_impl( messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, from_task=from_task, from_agent=from_agent, + tool_depth=0, ) return self._call_impl( messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, from_task=from_task, from_agent=from_agent, + tool_depth=0, ) except Exception as error: error_message = f"OCI Generative AI call failed: {error!s}" @@ -713,6 +1205,9 @@ def _ordered_client_access(self) -> Any: # Capability declarations # ------------------------------------------------------------------ + def supports_function_calling(self) -> bool: + return True + def get_context_window_size(self) -> int: model_lower = self.model.lower() if model_lower.startswith("google.gemini"): diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_tools.py b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py new file mode 100644 index 0000000000..305dbc7cbb --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py @@ -0,0 +1,100 @@ +"""Live integration tests for OCI Generative AI tool calling. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_TOOL_MODELS="meta.llama-3.3-70b-instruct" \ + uv run pytest tests/llms/oci/test_oci_integration_tools.py -v +""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _env_models(env_var: str, fallback: str, default: str) -> list[str]: + raw = os.getenv(env_var) or os.getenv(fallback) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live(): + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +TOOL_SPEC = [ + { + "type": "function", + "function": { + "name": "add_numbers", + "description": "Add two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + } +] + + +@pytest.fixture( + params=_env_models("OCI_TEST_TOOL_MODELS", "OCI_TEST_TOOL_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_tool_model(request): + return request.param + + +@pytest.fixture() +def oci_tool_config(): + return _skip_unless_live() + + +def test_oci_live_tool_call_returns_raw(oci_tool_model: str, oci_tool_config: dict): + """Without available_functions, tool calls should be returned raw.""" + llm = OCICompletion(model=oci_tool_model, **oci_tool_config) + result = llm.call( + messages=[{"role": "user", "content": "What is 3 + 7? Use the add_numbers tool."}], + tools=TOOL_SPEC, + ) + + assert isinstance(result, list) + assert len(result) >= 1 + assert result[0]["function"]["name"] == "add_numbers" + + +def test_oci_live_tool_call_with_execution(oci_tool_model: str, oci_tool_config: dict): + """With available_functions, tools should execute and model should respond.""" + def add_numbers(a: float, b: float) -> str: + return str(float(a) + float(b)) + + llm = OCICompletion(model=oci_tool_model, **oci_tool_config) + result = llm.call( + messages=[{"role": "user", "content": "What is 3 + 7? Use the add_numbers tool."}], + tools=TOOL_SPEC, + available_functions={"add_numbers": add_numbers}, + ) + + assert isinstance(result, str) + assert "10" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_tools.py b/lib/crewai/tests/llms/oci/test_oci_tools.py new file mode 100644 index 0000000000..bfe08eb912 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_tools.py @@ -0,0 +1,291 @@ +"""Unit tests for OCI provider tool calling (mocked SDK).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + + +def _make_tool_call_response(tool_name: str = "get_weather", args: dict | None = None) -> MagicMock: + """Build a fake OCI response with a generic tool call.""" + tc = MagicMock() + tc.id = "tc_001" + tc.name = tool_name + tc.arguments = json.dumps(args or {"city": "NYC"}) + + message = MagicMock() + message.content = [MagicMock(text="")] + message.tool_calls = [tc] + + choice = MagicMock() + choice.message = message + choice.finish_reason = "tool_calls" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_text_response(text: str = "The weather is sunny.") -> MagicMock: + """Build a fake OCI text response (used after tool execution).""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=20, completion_tokens=10, total_tokens=30) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_cohere_tool_call_response(tool_name: str = "get_weather") -> MagicMock: + """Build a fake OCI Cohere response with a tool call.""" + tc = MagicMock() + tc.name = tool_name + tc.parameters = {"city": "NYC"} + + chat_response = MagicMock() + chat_response.text = "" + chat_response.tool_calls = [tc] + chat_response.finish_reason = "COMPLETE" + chat_response.usage = MagicMock(prompt_tokens=8, completion_tokens=4, total_tokens=12) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string", "description": "The city name"}}, + "required": ["city"], + }, + }, + } +] + + +def test_oci_completion_returns_tool_calls_for_executor( + patch_oci_module, oci_unit_values +): + """When no available_functions, tool calls should be returned raw.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "What is the weather?"}], + tools=SAMPLE_TOOLS, + ) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + + +def test_oci_completion_executes_tool_calls_recursively( + patch_oci_module, oci_unit_values +): + """With available_functions, tool should be executed and model re-called.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + # First call returns tool call, second call returns text + fake_client.chat.side_effect = [ + _make_tool_call_response(), + _make_text_response("It is sunny in NYC."), + ] + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + def mock_get_weather(city: str) -> str: + return f"Sunny in {city}" + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in NYC?"}], + tools=SAMPLE_TOOLS, + available_functions={"get_weather": mock_get_weather}, + ) + + assert isinstance(result, str) + assert "sunny" in result.lower() or "NYC" in result + assert fake_client.chat.call_count == 2 + + +def test_oci_completion_formats_generic_tools( + patch_oci_module, oci_unit_values +): + """_format_tools should produce FunctionDefinition for generic models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + formatted = llm._format_tools(SAMPLE_TOOLS) + + assert len(formatted) == 1 + models = patch_oci_module.generative_ai_inference.models + models.FunctionDefinition.assert_called_once() + + +def test_oci_completion_formats_cohere_tools( + patch_oci_module, oci_unit_values +): + """_format_tools should produce CohereTool for Cohere models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + formatted = llm._format_tools(SAMPLE_TOOLS) + + assert len(formatted) == 1 + models = patch_oci_module.generative_ai_inference.models + models.CohereTool.assert_called_once() + + +def test_oci_completion_cohere_extracts_tool_calls( + patch_oci_module, oci_unit_values +): + """Cohere tool calls should be normalized to CrewAI shape with generated IDs.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_cohere_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather?"}], + tools=SAMPLE_TOOLS, + ) + + assert isinstance(result, list) + assert result[0]["function"]["name"] == "get_weather" + assert result[0]["id"] # Should have a generated UUID + + +def test_oci_completion_rejects_parallel_tools_for_cohere( + patch_oci_module, oci_unit_values +): + """Cohere models should raise if parallel_tool_calls is enabled.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + parallel_tool_calls=True, + ) + + with pytest.raises(ValueError, match="parallel_tool_calls"): + llm._build_chat_request( + [{"role": "user", "content": "test"}], + tools=SAMPLE_TOOLS, + ) + + +def test_oci_completion_respects_max_sequential_tool_calls( + patch_oci_module, oci_unit_values +): + """Should raise RuntimeError when tool depth exceeds max.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + # Always return tool calls to force recursion + fake_client.chat.return_value = _make_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + def mock_tool(city: str) -> str: + return "result" + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + max_sequential_tool_calls=2, + ) + + with pytest.raises(RuntimeError, match="max_sequential_tool_calls"): + llm.call( + messages=[{"role": "user", "content": "test"}], + tools=SAMPLE_TOOLS, + available_functions={"get_weather": mock_tool}, + ) + + +def test_oci_completion_supports_function_calling( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.supports_function_calling() is True + + +def test_oci_completion_filters_unknown_passthrough_params( + patch_oci_module, oci_unit_values +): + """Unknown additional_params should not crash the OCI SDK request.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_text_response("ok") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # Make GenericChatRequest only accept known keys + patch_oci_module.generative_ai_inference.models.GenericChatRequest.attribute_map = { + "messages": "messages", + "api_format": "apiFormat", + } + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + additional_params={"bogus_param": "should_be_filtered"}, + ) + # Should not raise + result = llm.call(messages=[{"role": "user", "content": "test"}]) + assert isinstance(result, str) From 02a9d29c1ba51666c346ca08d9a0d9e3563f6bdb Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 15:36:24 -0400 Subject: [PATCH 08/12] feat: add structured output support to OCI Generative AI provider Add response_model (Pydantic) support for structured output: - _build_response_format converts Pydantic schema to OCI JsonSchemaResponseFormat (generic) or CohereResponseJsonFormat - _parse_structured_response validates and returns typed models - response_model threaded through call, _call_impl, _stream_call_impl, and _handle_tool_calls for full coverage - Handles JSON in markdown fences via base class _validate_structured_output Tested live against meta.llama-3.3-70b-instruct and google.gemini-2.5-flash. Depends on: #4962 (tool calling), #4961 (streaming), #4959 (basic text) Tracking issue: #4944 --- .../crewai/llms/providers/oci/completion.py | 98 ++++++++- .../oci/test_oci_integration_structured.py | 33 +++ .../tests/llms/oci/test_oci_structured.py | 204 ++++++++++++++++++ 3 files changed, 333 insertions(+), 2 deletions(-) create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_structured.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_structured.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 253a311068..cd996e42c8 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -7,6 +7,7 @@ import json import logging import os +import re import threading from typing import TYPE_CHECKING, Any, Literal, cast import uuid @@ -16,6 +17,7 @@ from crewai.events.types.llm_events import LLMCallType from crewai.llms.base_llm import BaseLLM, llm_call_context from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.pydantic_schema_utils import generate_model_description from crewai.utilities.types import LLMMessage @@ -27,6 +29,7 @@ CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") _OCI_TOOL_RESULT_GUIDANCE = ( "You have received tool results above. Respond to the user with a helpful, " "natural language answer that incorporates the tool results. Do not output " @@ -475,10 +478,29 @@ def _allowed_passthrough_request_keys(self, request_cls: type[Any]) -> set[str]: # Request building # ------------------------------------------------------------------ + def _build_response_format( + self, response_model: type[BaseModel] | None + ) -> Any | None: + if response_model is None: + return None + models = self._oci.generative_ai_inference.models + schema_description = generate_model_description(response_model)["json_schema"] + schema_name = _OCI_SCHEMA_NAME_PATTERN.sub("_", schema_description["name"]) + json_schema = models.ResponseJsonSchema( + name=schema_name, + description=(response_model.__doc__ or "").strip() or schema_name, + schema=schema_description["schema"], + is_strict=schema_description["strict"], + ) + if self.oci_provider == "cohere": + return models.CohereResponseJsonFormat(schema=json_schema.schema) + return models.JsonSchemaResponseFormat(json_schema=json_schema) + def _build_chat_request( self, messages: list[LLMMessage], tools: list[dict[str, Any]] | None = None, + response_model: type[BaseModel] | None = None, *, is_stream: bool = False, ) -> Any: @@ -532,6 +554,10 @@ def _build_chat_request( if self._parallel_tool_calls_enabled(): request_kwargs["is_parallel_tool_calls"] = True + response_format = self._build_response_format(response_model) + if response_format is not None: + request_kwargs["response_format"] = response_format + if is_stream: request_kwargs["is_stream"] = True request_kwargs["stream_options"] = models.StreamOptions( @@ -741,6 +767,44 @@ def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dic metadata["usage"] = usage return metadata + # ------------------------------------------------------------------ + # Structured output + # ------------------------------------------------------------------ + + def _parse_structured_response( + self, + *, + content: str, + response_model: type[BaseModel], + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> BaseModel: + try: + structured_response = self._validate_structured_output( + content, response_model + ) + except Exception as error: + raise ValueError( + f"Failed to validate OCI structured response with model " + f"{response_model.__name__}: {error}" + ) from error + + if not isinstance(structured_response, BaseModel): + raise ValueError( + f"OCI structured response parsing returned unexpected type: " + f"{type(structured_response)}" + ) + + self._emit_call_completed_event( + response=structured_response.model_dump_json(), + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return structured_response + # ------------------------------------------------------------------ # Tool execution # ------------------------------------------------------------------ @@ -755,6 +819,7 @@ def _handle_tool_calls( from_task: Task | None, from_agent: Agent | None, tool_depth: int, + response_model: type[BaseModel] | None = None, tool_calls: list[dict[str, Any]], ) -> str | BaseModel | list[dict[str, Any]]: """Execute one round of tool calls and recurse until the model finishes.""" @@ -818,6 +883,7 @@ def _handle_tool_calls( from_task=from_task, from_agent=from_agent, tool_depth=tool_depth + 1, + response_model=response_model, ) return self._call_impl( messages=next_messages, @@ -827,6 +893,7 @@ def _handle_tool_calls( from_task=from_task, from_agent=from_agent, tool_depth=tool_depth + 1, + response_model=response_model, ) # ------------------------------------------------------------------ @@ -862,11 +929,14 @@ def _call_impl( from_task: Task | None = None, from_agent: Agent | None = None, tool_depth: int = 0, + response_model: type[BaseModel] | None = None, ) -> str | BaseModel | list[dict[str, Any]]: normalized_messages = ( messages if isinstance(messages, list) else self._normalize_messages(messages) ) - chat_request = self._build_chat_request(normalized_messages, tools=tools) + chat_request = self._build_chat_request( + normalized_messages, tools=tools, response_model=response_model + ) chat_details = self._oci.generative_ai_inference.models.ChatDetails( compartment_id=self.compartment_id, serving_mode=self._build_serving_mode(), @@ -889,9 +959,19 @@ def _call_impl( from_task=from_task, from_agent=from_agent, tool_depth=tool_depth, + response_model=response_model, tool_calls=tool_calls, ) + if response_model is not None: + return self._parse_structured_response( + content=content, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + return self._finalize_text_response( content=content, messages=normalized_messages, @@ -909,13 +989,15 @@ def _stream_call_impl( from_task: Task | None = None, from_agent: Agent | None = None, tool_depth: int = 0, + response_model: type[BaseModel] | None = None, ) -> str | BaseModel | list[dict[str, Any]]: """Handle OCI streaming while reconstructing final text/tool state.""" normalized_messages = ( messages if isinstance(messages, list) else self._normalize_messages(messages) ) chat_request = self._build_chat_request( - normalized_messages, tools=tools, is_stream=True + normalized_messages, tools=tools, response_model=response_model, + is_stream=True, ) chat_details = self._oci.generative_ai_inference.models.ChatDetails( compartment_id=self.compartment_id, @@ -1002,9 +1084,19 @@ def _stream_call_impl( from_task=from_task, from_agent=from_agent, tool_depth=tool_depth, + response_model=response_model, tool_calls=tool_calls, ) + if response_model is not None: + return self._parse_structured_response( + content=full_response, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + return self._finalize_text_response( content=full_response, messages=normalized_messages, @@ -1130,6 +1222,7 @@ def call( from_task=from_task, from_agent=from_agent, tool_depth=0, + response_model=response_model, ) return self._call_impl( @@ -1140,6 +1233,7 @@ def call( from_task=from_task, from_agent=from_agent, tool_depth=0, + response_model=response_model, ) except Exception as error: error_message = f"OCI Generative AI call failed: {error!s}" diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_structured.py b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py new file mode 100644 index 0000000000..62ec951574 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI structured output. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct" \ + uv run pytest tests/llms/oci/test_oci_integration_structured.py -v +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from crewai.llms.providers.oci.completion import OCICompletion + + +class CapitalResponse(BaseModel): + """Response containing a country's capital city.""" + country: str + capital: str + + +def test_oci_live_structured_output(oci_chat_model: str, oci_live_config: dict): + """Structured output should return a validated Pydantic model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "What is the capital of France? Answer with country and capital fields."}], + response_model=CapitalResponse, + ) + + assert isinstance(result, CapitalResponse) + assert result.capital.lower() == "paris" + assert result.country.lower() == "france" diff --git a/lib/crewai/tests/llms/oci/test_oci_structured.py b/lib/crewai/tests/llms/oci/test_oci_structured.py new file mode 100644 index 0000000000..caa4f330bb --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_structured.py @@ -0,0 +1,204 @@ +"""Unit tests for OCI provider structured output (mocked SDK).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + + +class WeatherResponse(BaseModel): + """Weather forecast response.""" + city: str + temperature: float + unit: str + + +def _make_json_response(data: dict) -> MagicMock: + """Build a fake OCI response returning JSON text.""" + text_part = MagicMock() + text_part.text = json.dumps(data) + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=8, total_tokens=18) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_cohere_json_response(data: dict) -> MagicMock: + """Build a fake OCI Cohere response returning JSON text.""" + chat_response = MagicMock() + chat_response.text = json.dumps(data) + chat_response.tool_calls = None + chat_response.finish_reason = "COMPLETE" + chat_response.usage = MagicMock(prompt_tokens=8, completion_tokens=6, total_tokens=14) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def test_oci_completion_structured_output_generic( + patch_oci_module, oci_unit_values +): + """response_model should parse JSON response into a Pydantic model.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_json_response( + {"city": "NYC", "temperature": 72.0, "unit": "F"} + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # Mock schema-related models + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in NYC?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "NYC" + assert result.temperature == 72.0 + assert result.unit == "F" + + +def test_oci_completion_structured_output_cohere( + patch_oci_module, oci_unit_values +): + """Cohere models should use CohereResponseJsonFormat for structured output.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_cohere_json_response( + {"city": "London", "temperature": 15.0, "unit": "C"} + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.CohereResponseJsonFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in London?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "London" + + +def test_oci_completion_structured_output_with_fenced_json( + patch_oci_module, oci_unit_values +): + """Should handle JSON wrapped in markdown fences.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fenced = '```json\n{"city": "Tokyo", "temperature": 25.0, "unit": "C"}\n```' + text_part = MagicMock() + text_part.text = fenced + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=8, total_tokens=18) + + response = MagicMock() + response.data.chat_response = chat_response + + fake_client = MagicMock() + fake_client.chat.return_value = response + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in Tokyo?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "Tokyo" + + +def test_oci_build_response_format_returns_none_without_model( + patch_oci_module, oci_unit_values +): + """_build_response_format should return None when no response_model.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._build_response_format(None) is None + + +def test_oci_build_response_format_creates_json_schema( + patch_oci_module, oci_unit_values +): + """_build_response_format should create a JsonSchemaResponseFormat for generic models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm._build_response_format(WeatherResponse) + + assert result is not None + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat.assert_called_once() From 1936c057c5a807370845e28a5a45e9b9779ba3ed Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 15:47:36 -0400 Subject: [PATCH 09/12] feat: add multimodal support to OCI Generative AI provider Add multimodal content handling for generic model families: - vision.py: model lists, data URI helpers, image encoding utilities - _build_generic_content handles image_url, document_url, video_url, audio_url content types mapped to OCI SDK content objects - _message_has_multimodal_content detects non-text payloads - Cohere models reject multimodal with clear error message - supports_multimodal() returns True Depends on: #4963, #4962, #4961, #4959 Tracking issue: #4944 --- .../src/crewai/llms/providers/oci/__init__.py | 14 ++ .../crewai/llms/providers/oci/completion.py | 68 +++++- .../src/crewai/llms/providers/oci/vision.py | 57 +++++ .../tests/llms/oci/test_oci_multimodal.py | 195 ++++++++++++++++++ 4 files changed, 330 insertions(+), 4 deletions(-) create mode 100644 lib/crewai/src/crewai/llms/providers/oci/vision.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_multimodal.py diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py index 0c397558bd..ea459d687c 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/__init__.py +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -1,5 +1,19 @@ from crewai.llms.providers.oci.completion import OCICompletion +from crewai.llms.providers.oci.vision import ( + IMAGE_EMBEDDING_MODELS, + VISION_MODELS, + encode_image, + is_vision_model, + load_image, + to_data_uri, +) __all__ = [ + "IMAGE_EMBEDDING_MODELS", + "VISION_MODELS", "OCICompletion", + "encode_image", + "is_vision_model", + "load_image", + "to_data_uri", ] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index cd996e42c8..859aa542bd 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -184,8 +184,19 @@ def _coerce_text(self, content: Any) -> str: return "\n".join(part for part in parts if part) return str(content) + def _message_has_multimodal_content(self, content: Any) -> bool: + if not isinstance(content, list): + return False + for item in content: + if isinstance(item, Mapping) and item.get("type") not in (None, "text"): + return True + return False + def _build_generic_content(self, content: Any) -> list[Any]: - """Translate CrewAI message content into OCI generic content objects.""" + """Translate CrewAI message content into OCI generic content objects. + + Handles text, image_url, document_url, video_url, and audio_url types. + """ models = self._oci.generative_ai_inference.models if isinstance(content, str): return [models.TextContent(text=content or ".")] @@ -197,14 +208,52 @@ def _build_generic_content(self, content: Any) -> list[Any]: for item in content: if isinstance(item, str): processed.append(models.TextContent(text=item)) - elif isinstance(item, Mapping) and item.get("type") == "text": + continue + if not isinstance(item, Mapping): + raise ValueError( + f"OCI message content items must be strings or dictionaries, got: {type(item)}" + ) + + content_type = item.get("type") + if content_type == "text": processed.append( models.TextContent(text=str(item.get("text", "")) or ".") ) - else: + elif content_type == "image_url": + image_url = item.get("image_url", {}) + url = image_url.get("url") if isinstance(image_url, Mapping) else None + if not url: + raise ValueError("OCI image_url content requires image_url.url") + processed.append( + models.ImageContent(image_url=models.ImageUrl(url=url)) + ) + elif content_type in ("document_url", "document", "file"): + doc_data = item.get("document_url") or item.get("document") or item.get("file") + url = doc_data.get("url") if isinstance(doc_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI document content requires a url") processed.append( - models.TextContent(text=self._coerce_text(item) or ".") + models.DocumentContent(document_url=models.DocumentUrl(url=url)) ) + elif content_type in ("video_url", "video"): + video_data = item.get("video_url") or item.get("video") + url = video_data.get("url") if isinstance(video_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI video content requires a url") + processed.append( + models.VideoContent(video_url=models.VideoUrl(url=url)) + ) + elif content_type in ("audio_url", "audio"): + audio_data = item.get("audio_url") or item.get("audio") + url = audio_data.get("url") if isinstance(audio_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI audio content requires a url") + processed.append( + models.AudioContent(audio_url=models.AudioUrl(url=url)) + ) + else: + raise ValueError(f"Unsupported OCI content type: {content_type}") + return processed or [models.TextContent(text=".")] def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: @@ -282,6 +331,10 @@ def _build_cohere_chat_history( for message in history_messages: role = str(message.get("role", "user")).lower() content = message.get("content", "") + if self._message_has_multimodal_content(content): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) if role in ("user", "system"): message_cls = ( @@ -508,6 +561,10 @@ def _build_chat_request( models = self._oci.generative_ai_inference.models if self.oci_provider == "cohere": + if any(self._message_has_multimodal_content(msg.get("content")) for msg in messages): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) chat_history, tool_results, message_text = self._build_cohere_chat_history( messages ) @@ -1302,6 +1359,9 @@ def _ordered_client_access(self) -> Any: def supports_function_calling(self) -> bool: return True + def supports_multimodal(self) -> bool: + return True + def get_context_window_size(self) -> int: model_lower = self.model.lower() if model_lower.startswith("google.gemini"): diff --git a/lib/crewai/src/crewai/llms/providers/oci/vision.py b/lib/crewai/src/crewai/llms/providers/oci/vision.py new file mode 100644 index 0000000000..d0049f2632 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/vision.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import base64 +import mimetypes +from pathlib import Path + + +VISION_MODELS: list[str] = [ + "meta.llama-3.2-90b-vision-instruct", + "meta.llama-3.2-11b-vision-instruct", + "meta.llama-4-scout-17b-16e-instruct", + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "google.gemini-2.5-flash", + "google.gemini-2.5-pro", + "google.gemini-2.5-flash-lite", + "xai.grok-4", + "xai.grok-4-1-fast-reasoning", + "xai.grok-4-1-fast-non-reasoning", + "xai.grok-4-fast-reasoning", + "xai.grok-4-fast-non-reasoning", + "cohere.command-a-vision", +] + +IMAGE_EMBEDDING_MODELS: list[str] = [ + "cohere.embed-v4.0", + "cohere.embed-multilingual-image-v3.0", +] + + +def to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + """Convert bytes, file paths, or data URIs into a data URI.""" + if isinstance(image, bytes): + encoded = base64.standard_b64encode(image).decode("utf-8") + return f"data:{mime_type};base64,{encoded}" + + image_str = str(image) + if image_str.startswith("data:"): + return image_str + + path = Path(image_str) + detected_mime = mimetypes.guess_type(str(path))[0] or mime_type + encoded = base64.standard_b64encode(path.read_bytes()).decode("utf-8") + return f"data:{detected_mime};base64,{encoded}" + + +def load_image(file_path: str | Path) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(file_path)}} + + +def encode_image( + image_bytes: bytes, mime_type: str = "image/png" +) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(image_bytes, mime_type)}} + + +def is_vision_model(model_id: str) -> bool: + return model_id in VISION_MODELS diff --git a/lib/crewai/tests/llms/oci/test_oci_multimodal.py b/lib/crewai/tests/llms/oci/test_oci_multimodal.py new file mode 100644 index 0000000000..b27d77e645 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_multimodal.py @@ -0,0 +1,195 @@ +"""Unit tests for OCI provider multimodal content (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def test_oci_builds_image_content(patch_oci_module, oci_unit_values): + """image_url content should produce ImageContent with ImageUrl.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + # Mock multimodal content types + patch_oci_module.generative_ai_inference.models.ImageContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="image", **kw) + ) + patch_oci_module.generative_ai_inference.models.ImageUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}}, + ] + result = llm._build_generic_content(content) + + assert len(result) == 2 + patch_oci_module.generative_ai_inference.models.ImageContent.assert_called_once() + patch_oci_module.generative_ai_inference.models.ImageUrl.assert_called_once() + + +def test_oci_builds_document_content(patch_oci_module, oci_unit_values): + """document_url content should produce DocumentContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.DocumentContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="document", **kw) + ) + patch_oci_module.generative_ai_inference.models.DocumentUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "document_url", "document_url": {"url": "data:application/pdf;base64,xyz"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.DocumentContent.assert_called_once() + + +def test_oci_builds_video_content(patch_oci_module, oci_unit_values): + """video_url content should produce VideoContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.VideoContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="video", **kw) + ) + patch_oci_module.generative_ai_inference.models.VideoUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.VideoContent.assert_called_once() + + +def test_oci_builds_audio_content(patch_oci_module, oci_unit_values): + """audio_url content should produce AudioContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.AudioContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="audio", **kw) + ) + patch_oci_module.generative_ai_inference.models.AudioUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,wav123"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.AudioContent.assert_called_once() + + +def test_oci_rejects_unsupported_content_type(patch_oci_module, oci_unit_values): + """Unknown content types should raise ValueError.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + with pytest.raises(ValueError, match="Unsupported OCI content type"): + llm._build_generic_content([{"type": "hologram", "data": "xyz"}]) + + +def test_oci_cohere_rejects_multimodal(patch_oci_module, oci_unit_values): + """Cohere models should reject multimodal content in _build_chat_request.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + messages = [ + {"role": "user", "content": [ + {"type": "text", "text": "Describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]} + ] + + with pytest.raises(ValueError, match="text-only"): + llm._build_chat_request(messages) + + +def test_oci_message_has_multimodal_content(patch_oci_module, oci_unit_values): + """_message_has_multimodal_content should detect non-text content types.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + assert llm._message_has_multimodal_content("just text") is False + assert llm._message_has_multimodal_content([{"type": "text", "text": "hi"}]) is False + assert llm._message_has_multimodal_content([{"type": "image_url", "image_url": {"url": "x"}}]) is True + assert llm._message_has_multimodal_content([{"type": "text"}, {"type": "audio_url"}]) is True + + +def test_oci_supports_multimodal(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.supports_multimodal() is True + + +def test_vision_helpers(): + """Test vision.py utility functions.""" + from crewai.llms.providers.oci.vision import ( + VISION_MODELS, + encode_image, + is_vision_model, + to_data_uri, + ) + + # to_data_uri with bytes + uri = to_data_uri(b"\x89PNG", "image/png") + assert uri.startswith("data:image/png;base64,") + + # to_data_uri passthrough + existing = "data:image/jpeg;base64,abc" + assert to_data_uri(existing) == existing + + # encode_image + result = encode_image(b"\x89PNG") + assert result["type"] == "image_url" + assert result["image_url"]["url"].startswith("data:image/png;base64,") + + # is_vision_model + assert is_vision_model("google.gemini-2.5-flash") is True + assert is_vision_model("meta.llama-3.3-70b-instruct") is False + + assert len(VISION_MODELS) > 0 From 9cb670cebf543bcf3984fc55a352a85afbb3817c Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 16:11:06 -0400 Subject: [PATCH 10/12] test: add live multimodal integration test Send a 2x2 red PNG to google.gemini-2.5-flash via data URI and verify it identifies the color. Tests the full image_url content pipeline end-to-end against a live OCI vision model. --- .../oci/test_oci_integration_multimodal.py | 106 ++++++++++++++++++ 1 file changed, 106 insertions(+) create mode 100644 lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py new file mode 100644 index 0000000000..ae8cbc3d9a --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py @@ -0,0 +1,106 @@ +"""Live integration tests for OCI Generative AI multimodal content. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MULTIMODAL_MODELS="google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_multimodal.py -v +""" + +from __future__ import annotations + +import base64 +import os +import struct +import zlib + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _env_models(env_var: str, fallback: str, default: str) -> list[str]: + raw = os.getenv(env_var) or os.getenv(fallback) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live(): + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +def _make_red_png() -> bytes: + """Generate a minimal valid 2x2 red PNG image in memory.""" + width, height = 2, 2 + # Each row: filter byte (0) + RGB pixels + raw_data = b"" + for _ in range(height): + raw_data += b"\x00" # filter: none + for _ in range(width): + raw_data += b"\xff\x00\x00" # red pixel + + def _chunk(chunk_type: bytes, data: bytes) -> bytes: + c = chunk_type + data + return struct.pack(">I", len(data)) + c + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF) + + ihdr = struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0) + png = b"\x89PNG\r\n\x1a\n" + png += _chunk(b"IHDR", ihdr) + png += _chunk(b"IDAT", zlib.compress(raw_data)) + png += _chunk(b"IEND", b"") + return png + + +def _png_data_uri() -> str: + """Return a data URI for a small red PNG.""" + png_bytes = _make_red_png() + encoded = base64.standard_b64encode(png_bytes).decode("utf-8") + return f"data:image/png;base64,{encoded}" + + +@pytest.fixture( + params=_env_models( + "OCI_TEST_MULTIMODAL_MODELS", "OCI_TEST_MULTIMODAL_MODEL", "google.gemini-2.5-flash" + ), + ids=lambda m: m, +) +def oci_multimodal_model(request): + return request.param + + +@pytest.fixture() +def oci_multimodal_config(): + return _skip_unless_live() + + +def test_oci_live_image_input(oci_multimodal_model: str, oci_multimodal_config: dict): + """Vision model should describe an image sent as a data URI.""" + llm = OCICompletion(model=oci_multimodal_model, **oci_multimodal_config) + result = llm.call( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image? Answer in one word."}, + {"type": "image_url", "image_url": {"url": _png_data_uri()}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert len(result) > 0 + assert "red" in result.lower() From 28ff59fbbd06de3a08257a521f976ab04dcb1c2f Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Thu, 19 Mar 2026 16:42:26 -0400 Subject: [PATCH 11/12] feat: add OCI Generative AI embeddings provider Add OCI embedding support integrated with CrewAI's RAG pipeline: - OCIEmbeddingFunction: ChromaDB-compatible embedding callable with batching, config serialization, image embedding support - OCIProvider: Pydantic-based provider with alias validation for env vars and config keys - Factory registration in embeddings/factory.py + types.py - Supports text and image embeddings, output dimensions, custom endpoints, all 4 OCI auth modes Tested live against cohere.embed-english-v3.0 with API_KEY auth. Depends on: #4964, #4963, #4962, #4961, #4959 Tracking issue: #4944 --- .../src/crewai/rag/embeddings/factory.py | 9 + .../rag/embeddings/providers/oci/__init__.py | 17 ++ .../providers/oci/embedding_callable.py | 181 +++++++++++++++ .../embeddings/providers/oci/oci_provider.py | 75 +++++++ .../rag/embeddings/providers/oci/types.py | 30 +++ lib/crewai/src/crewai/rag/embeddings/types.py | 3 + .../tests/rag/embeddings/test_factory_oci.py | 207 ++++++++++++++++++ .../test_oci_embedding_integration.py | 66 ++++++ 8 files changed, 588 insertions(+) create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py create mode 100644 lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py create mode 100644 lib/crewai/tests/rag/embeddings/test_factory_oci.py create mode 100644 lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index 8027793200..37c32e9062 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -70,6 +70,10 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec + from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, + ) + from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -99,6 +103,7 @@ "instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider", "jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider", "ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider", + "oci": "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider", "onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider", "openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider", "openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider", @@ -216,6 +221,10 @@ def build_embedder_from_dict( def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ... +@overload +def build_embedder_from_dict(spec: OCIProviderSpec) -> OCIEmbeddingFunction: ... + + @overload def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ... diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py new file mode 100644 index 0000000000..288602b1c8 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py @@ -0,0 +1,17 @@ +"""OCI embedding provider exports.""" + +from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, +) +from crewai.rag.embeddings.providers.oci.oci_provider import OCIProvider +from crewai.rag.embeddings.providers.oci.types import ( + OCIProviderConfig, + OCIProviderSpec, +) + +__all__ = [ + "OCIEmbeddingFunction", + "OCIProvider", + "OCIProviderConfig", + "OCIProviderSpec", +] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py new file mode 100644 index 0000000000..de65f1d58c --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -0,0 +1,181 @@ +"""OCI embedding function implementation.""" + +from __future__ import annotations + +import base64 +from collections.abc import Iterator, Sequence +import mimetypes +import os +from pathlib import Path +from typing import Any, cast + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from typing_extensions import Unpack + +from crewai.rag.embeddings.providers.oci.types import OCIProviderConfig +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCIEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for OCI Generative AI embedding models.""" + + def __init__(self, **kwargs: Unpack[OCIProviderConfig]) -> None: + self._config = kwargs + self._client: Any = kwargs.get("client") + if self._client is None: + service_endpoint = kwargs.get("service_endpoint") + region = kwargs.get("region") or os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + if service_endpoint is None: + service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + client_kwargs = create_oci_client_kwargs( + auth_type=kwargs.get("auth_type", "API_KEY"), + service_endpoint=service_endpoint, + auth_file_location=kwargs.get("auth_file_location", "~/.oci/config"), + auth_profile=kwargs.get("auth_profile", "DEFAULT"), + timeout=kwargs.get("timeout", (10, 120)), + oci_module=_get_oci_module(), + ) + self._client = ( + _get_oci_module().generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + ) + + def _require_client(self) -> Any: + if self._client is None: + raise ValueError("OCI embedding client is not initialized.") + return self._client + + @staticmethod + def name() -> str: + return "oci" + + @staticmethod + def build_from_config(config: dict[str, Any]) -> OCIEmbeddingFunction: + timeout = config.get("timeout") + if isinstance(timeout, list): + config = dict(config) + config["timeout"] = tuple(timeout) + return OCIEmbeddingFunction(**config) + + def get_config(self) -> dict[str, Any]: + config = dict(self._config) + config.pop("client", None) + timeout = config.get("timeout") + if isinstance(timeout, tuple): + config["timeout"] = list(timeout) + return config + + def _get_serving_mode(self) -> Any: + oci = _get_oci_module() + model_name = self._config.get("model_name") + if not model_name: + raise ValueError("OCI embeddings require model_name") + if model_name.startswith(CUSTOM_ENDPOINT_PREFIX): + return oci.generative_ai_inference.models.DedicatedServingMode( + endpoint_id=model_name + ) + return oci.generative_ai_inference.models.OnDemandServingMode( + model_id=model_name + ) + + def _build_request( + self, inputs: list[str], *, input_type: str | None = None + ) -> Any: + oci = _get_oci_module() + compartment_id = self._config.get("compartment_id") or os.getenv( + "OCI_COMPARTMENT_ID" + ) + if not compartment_id: + raise ValueError( + "OCI embeddings require compartment_id. Set it explicitly or use OCI_COMPARTMENT_ID." + ) + + request_kwargs: dict[str, Any] = { + "serving_mode": self._get_serving_mode(), + "compartment_id": compartment_id, + "truncate": self._config.get("truncate", "END"), + "inputs": inputs, + } + + resolved_input_type = input_type or self._config.get("input_type") + if resolved_input_type: + request_kwargs["input_type"] = resolved_input_type + + output_dimensions = self._config.get("output_dimensions") + if output_dimensions is not None: + embed_text_details = oci.generative_ai_inference.models.EmbedTextDetails + if hasattr(embed_text_details, "output_dimensions"): + request_kwargs["output_dimensions"] = output_dimensions + else: + raise ValueError( + "output_dimensions requires a newer OCI SDK. Upgrade the oci package." + ) + + return oci.generative_ai_inference.models.EmbedTextDetails(**request_kwargs) + + def _batch_inputs(self, input: list[str]) -> Iterator[list[str]]: + batch_size = self._config.get("batch_size", 96) + for index in range(0, len(input), batch_size): + yield input[index : index + batch_size] + + @staticmethod + def _to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + if isinstance(image, Path): + resolved_mime = mimetypes.guess_type(image.name)[0] or mime_type + data = image.read_bytes() + return ( + f"data:{resolved_mime};base64," + f"{base64.b64encode(data).decode('ascii')}" + ) + if isinstance(image, bytes): + return f"data:{mime_type};base64,{base64.b64encode(image).decode('ascii')}" + if image.startswith("data:"): + return image + path = Path(image) + if path.exists(): + return OCIEmbeddingFunction._to_data_uri(path, mime_type=mime_type) + raise ValueError( + "OCI image embeddings require a file path, raw bytes, or a data URI." + ) + + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + embeddings: Embeddings = [] + for chunk in self._batch_inputs(input): + response = self._require_client().embed_text(self._build_request(chunk)) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings + + def embed_image( + self, image: str | bytes | Path, *, mime_type: str = "image/png" + ) -> list[float]: + return [ + float(value) + for value in self.embed_image_batch([image], mime_type=mime_type)[0] + ] + + def embed_image_batch( + self, images: Sequence[str | bytes | Path], *, mime_type: str = "image/png" + ) -> Embeddings: + embeddings: Embeddings = [] + for image in images: + data_uri = self._to_data_uri(image, mime_type=mime_type) + response = self._require_client().embed_text( + self._build_request([data_uri], input_type="IMAGE") + ) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py new file mode 100644 index 0000000000..d0604fe628 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py @@ -0,0 +1,75 @@ +"""OCI embeddings provider.""" + +from typing import Any + +from pydantic import AliasChoices, Field + +from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class OCIProvider(BaseEmbeddingsProvider[OCIEmbeddingFunction]): + """OCI Generative AI embeddings provider.""" + + embedding_callable: type[OCIEmbeddingFunction] = Field( + default=OCIEmbeddingFunction, + description="OCI embedding function class", + ) + model_name: str = Field( + default="cohere.embed-english-v3.0", + description="Model name to use for embeddings", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_MODEL_NAME", "OCI_EMBED_MODEL", "model", "model_name", + ), + ) + compartment_id: str = Field( + description="OCI compartment ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_COMPARTMENT_ID", "OCI_COMPARTMENT_ID", "compartment_id", + ), + ) + service_endpoint: str | None = Field( + default=None, + description="OCI Generative AI inference endpoint", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_SERVICE_ENDPOINT", "OCI_SERVICE_ENDPOINT", "service_endpoint", + ), + ) + region: str | None = Field( + default=None, + description="OCI region for endpoint derivation", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_REGION", "OCI_REGION", "region", + ), + ) + auth_type: str = Field( + default="API_KEY", + description="OCI SDK auth type", + validation_alias=AliasChoices("EMBEDDINGS_OCI_AUTH_TYPE", "OCI_AUTH_TYPE"), + ) + auth_profile: str = Field( + default="DEFAULT", + description="OCI config profile name", + validation_alias=AliasChoices("EMBEDDINGS_OCI_AUTH_PROFILE", "OCI_AUTH_PROFILE"), + ) + auth_file_location: str = Field( + default="~/.oci/config", + description="OCI config file location", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_AUTH_FILE_LOCATION", "OCI_AUTH_FILE_LOCATION", + ), + ) + truncate: str = Field(default="END", description="OCI embedding truncate policy") + input_type: str | None = Field( + default=None, + description="Optional OCI embedding input type such as SEARCH_DOCUMENT or SEARCH_QUERY", + ) + output_dimensions: int | None = Field( + default=None, + description="Optional output dimensions for compatible OCI embedding models", + ) + batch_size: int = Field(default=96, description="OCI embedding batch size") + timeout: tuple[int, int] = Field( + default=(10, 120), description="OCI SDK connect/read timeout" + ) + client: Any | None = Field(default=None, description="Injected OCI client") diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py new file mode 100644 index 0000000000..757d0be6ce --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py @@ -0,0 +1,30 @@ +"""Type definitions for OCI embedding providers.""" + +from typing import Annotated, Any, Literal + +from typing_extensions import Required, TypedDict + + +class OCIProviderConfig(TypedDict, total=False): + """Configuration for OCI embedding provider.""" + + model_name: Annotated[str, "cohere.embed-english-v3.0"] + compartment_id: str + service_endpoint: str + region: str + auth_type: str + auth_profile: str + auth_file_location: str + truncate: str + input_type: str + output_dimensions: int + batch_size: int + timeout: tuple[int, int] + client: Any + + +class OCIProviderSpec(TypedDict, total=False): + """OCI provider specification.""" + + provider: Required[Literal["oci"]] + config: OCIProviderConfig diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py index 794f4c6f9a..e01ecae3c9 100644 --- a/lib/crewai/src/crewai/rag/embeddings/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/types.py @@ -17,6 +17,7 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec +from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -39,6 +40,7 @@ | InstructorProviderSpec | JinaProviderSpec | OllamaProviderSpec + | OCIProviderSpec | ONNXProviderSpec | OpenAIProviderSpec | OpenCLIPProviderSpec @@ -61,6 +63,7 @@ "instructor", "jina", "ollama", + "oci", "onnx", "openai", "openclip", diff --git a/lib/crewai/tests/rag/embeddings/test_factory_oci.py b/lib/crewai/tests/rag/embeddings/test_factory_oci.py new file mode 100644 index 0000000000..9cc125b02b --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_factory_oci.py @@ -0,0 +1,207 @@ +"""Tests for OCI embedding provider wiring.""" + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.rag.embeddings.factory import build_embedder +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class _FakeOCI: + def __init__(self) -> None: + self.retry = SimpleNamespace(DEFAULT_RETRY_STRATEGY="retry") + self.config = SimpleNamespace( + from_file=lambda file_location, profile_name: { + "file_location": file_location, + "profile_name": profile_name, + } + ) + self.signer = SimpleNamespace( + load_private_key_from_file=lambda *_args, **_kwargs: "private-key" + ) + self.auth = SimpleNamespace( + signers=SimpleNamespace( + SecurityTokenSigner=lambda token, key: (token, key), + InstancePrincipalsSecurityTokenSigner=lambda: "instance-principal", + get_resource_principals_signer=lambda: "resource-principal", + ) + ) + self.generative_ai_inference = SimpleNamespace( + GenerativeAiInferenceClient=MagicMock(), + models=SimpleNamespace( + EmbedTextDetails=_simple_init_class("EmbedTextDetails"), + OnDemandServingMode=_simple_init_class("OnDemandServingMode"), + DedicatedServingMode=_simple_init_class("DedicatedServingMode"), + ), + ) + + +def _simple_init_class(name: str): + class _Simple: + output_dimensions = None + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + _Simple.__name__ = name + return _Simple + + +@patch("crewai.rag.embeddings.factory.import_and_validate_definition") +def test_build_embedder_oci(mock_import): + """Test building OCI embedder.""" + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_embedding_function = MagicMock() + + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + mock_provider_instance.embedding_callable.return_value = mock_embedding_function + + config = { + "provider": "oci", + "config": { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "region": "us-chicago-1", + "auth_profile": "DEFAULT", + }, + } + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider" + ) + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["model_name"] == "cohere.embed-english-v3.0" + assert call_kwargs["compartment_id"] == "ocid1.compartment.oc1..test" + assert call_kwargs["region"] == "us-chicago-1" + + +def test_oci_embedding_function_batches_requests(monkeypatch): + """Test OCI embedding batching and request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.side_effect = [ + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.1, 0.2], [0.3, 0.4]])), + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.5, 0.6]])), + ] + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + region="us-chicago-1", + batch_size=2, + ) + + result = embedder(["a", "b", "c"]) + + result_rows = [embedding.tolist() for embedding in result] + expected_rows = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + assert len(result_rows) == len(expected_rows) + for actual, expected in zip(result_rows, expected_rows, strict=True): + assert actual == pytest.approx(expected) + assert fake_client.embed_text.call_count == 2 + first_request = fake_client.embed_text.call_args_list[0].args[0] + assert first_request.compartment_id == "ocid1.compartment.oc1..test" + assert first_request.serving_mode.model_id == "cohere.embed-english-v3.0" + + +def test_oci_embedding_function_supports_output_dimensions(monkeypatch): + """Test OCI output_dimensions mapping.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.1, 0.2]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + output_dimensions=512, + ) + + embedder(["hello"]) + + request = fake_client.embed_text.call_args.args[0] + assert request.output_dimensions == 512 + + +def test_oci_embedding_function_exposes_serializable_config(monkeypatch): + """Test OCI embedding config serialization for ChromaDB compatibility.""" + fake_oci = _FakeOCI() + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + timeout=(5, 30), + ) + + assert embedder.get_config() == { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "timeout": [5, 30], + } + + rebuilt = OCIEmbeddingFunction.build_from_config(embedder.get_config()) + assert rebuilt.get_config() == embedder.get_config() + + +def test_oci_embedding_function_supports_image_embeddings(monkeypatch, tmp_path: Path): + """Test OCI image embedding request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.7, 0.8, 0.9]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + image_path = tmp_path / "diagram.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + ) + + result = embedder.embed_image(image_path) + + assert result == pytest.approx([0.7, 0.8, 0.9]) + request = fake_client.embed_text.call_args.args[0] + assert request.input_type == "IMAGE" + assert request.inputs[0].startswith("data:image/png;base64,") diff --git a/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py new file mode 100644 index 0000000000..6d5c12a80c --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py @@ -0,0 +1,66 @@ +"""Live integration tests for OCI embedding provider. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + uv run pytest tests/rag/embeddings/test_oci_embedding_integration.py -v +""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +def _skip_unless_live() -> dict[str, str]: + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + if not region: + pytest.skip("OCI_REGION not set") + config: dict[str, str] = { + "compartment_id": compartment, + "region": region, + } + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +@pytest.fixture() +def oci_embed_config(): + return _skip_unless_live() + + +def test_oci_live_text_embedding(oci_embed_config: dict): + """Embed a text string and verify we get a non-empty vector.""" + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + input_type="SEARCH_DOCUMENT", + **oci_embed_config, + ) + result = embedder(["Hello world"]) + + assert len(result) == 1 + assert len(result[0]) > 0 + + +def test_oci_live_batch_embedding(oci_embed_config: dict): + """Batch embed multiple texts.""" + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + input_type="SEARCH_DOCUMENT", + **oci_embed_config, + ) + texts = ["The cat sat on the mat", "The dog ran in the park", "Python is great"] + result = embedder(texts) + + assert len(result) == 3 + for vec in result: + assert len(vec) > 0 From 6d3a27bea709db0fbc5c990570696fdc3f82fbdf Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 20 Mar 2026 06:28:49 -0400 Subject: [PATCH 12/12] feat: add true async support to OCI provider via aiohttp Replace asyncio.to_thread wrappers with true async I/O using aiohttp for acall() and astream(). The OCI SDK is sync-only, so we bypass it for HTTP and use its signer for request authentication directly. - oci_async.py: OCIAsyncClient with aiohttp, OCI request signing, native SSE parsing, connection pooling - acall(): true async chat completion (no thread pool) - astream(): true async SSE streaming (no thread+queue bridge) - Graceful fallback to asyncio.to_thread when aiohttp unavailable or client is mocked (unit tests) - aiohttp + certifi added to crewai[oci] optional deps Temporary measure until OCI SDK ships native async support. Tested live: acall, astream, and concurrent acall against meta.llama-3.3-70b-instruct with API_KEY auth. Depends on: #4966, #4964, #4963, #4962, #4961, #4959 Tracking issue: #4944 --- lib/crewai/pyproject.toml | 2 + .../crewai/llms/providers/oci/completion.py | 226 +++++++++++++++--- lib/crewai/src/crewai/utilities/oci_async.py | 178 ++++++++++++++ lib/crewai/tests/llms/oci/test_oci_async.py | 93 +++++++ 4 files changed, 460 insertions(+), 39 deletions(-) create mode 100644 lib/crewai/src/crewai/utilities/oci_async.py create mode 100644 lib/crewai/tests/llms/oci/test_oci_async.py diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index e3fce3fb46..cd997ba2c7 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -99,6 +99,8 @@ anthropic = [ ] oci = [ "oci>=2.168.0", + "aiohttp>=3.9.0", + "certifi", ] a2a = [ "a2a-sdk~=0.3.10", diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py index 859aa542bd..fe4d60112c 100644 --- a/lib/crewai/src/crewai/llms/providers/oci/completion.py +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -1208,39 +1208,64 @@ async def astream( from_task: Task | None = None, from_agent: Agent | None = None, ) -> Any: - """Expose the sync OCI SSE stream through an async generator facade.""" - loop = asyncio.get_running_loop() - queue: asyncio.Queue[str | None] = asyncio.Queue() - error_holder: list[BaseException] = [] + """Async streaming — true async via aiohttp when available, thread fallback otherwise.""" + if self._async_client is None: + # Fallback: sync stream bridged via thread + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, tools=tools, callbacks=callbacks, + available_functions=available_functions, from_task=from_task, + from_agent=from_agent, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + thread.join() + if error_holder: + raise error_holder[0] + return - def _producer() -> None: - try: - for chunk in self.iter_stream( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - ): - loop.call_soon_threadsafe(queue.put_nowait, chunk) - except BaseException as error: - error_holder.append(error) - finally: - loop.call_soon_threadsafe(queue.put_nowait, None) - - thread = threading.Thread(target=_producer, daemon=True) - thread.start() - - while True: - chunk = await queue.get() - if chunk is None: - break - yield chunk + normalized_messages = self._normalize_messages(messages) + request_data = self._prepare_async_request( + normalized_messages, tools=tools, is_stream=True + ) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + async for event_data in self._async_client.chat_async( + compartment_id=request_data["compartment_id"], + chat_request_dict=request_data["chat_request_dict"], + serving_mode_dict=request_data["serving_mode_dict"], + stream=True, + ): + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) - thread.join() - if error_holder: - raise error_holder[0] + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None def call( self, @@ -1311,19 +1336,142 @@ async def acall( from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: - return await asyncio.to_thread( - self.call, - messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, + """Async call — true async via aiohttp when available, thread fallback otherwise.""" + if self._async_client is None: + return await asyncio.to_thread( + self.call, messages, tools=tools, callbacks=callbacks, + available_functions=available_functions, from_task=from_task, + from_agent=from_agent, response_model=response_model, + ) + + normalized_messages = self._normalize_messages(messages) + request_data = self._prepare_async_request( + normalized_messages, tools=tools, response_model=response_model + ) + + response_data = None + async for data in self._async_client.chat_async( + compartment_id=request_data["compartment_id"], + chat_request_dict=request_data["chat_request_dict"], + serving_mode_dict=request_data["serving_mode_dict"], + stream=False, + ): + response_data = data + break + + if response_data is None: + raise RuntimeError("No response received from OCI GenAI async call") + + # Extract text from the raw JSON response + chat_response = response_data.get("chatResponse", {}) + content = self._extract_text_from_async_response(chat_response) + content = self._apply_stop_words(content) + + # Track usage + usage = chat_response.get("usage", {}) + if usage: + usage_dict = { + "prompt_tokens": usage.get("promptTokens", 0), + "completion_tokens": usage.get("completionTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + self._track_token_usage_internal(usage_dict) + + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, from_task=from_task, from_agent=from_agent, - response_model=response_model, + messages=normalized_messages, + ) + return content + + def _extract_text_from_async_response(self, chat_response: dict[str, Any]) -> str: + """Extract text from a raw JSON chat response dict.""" + # Generic format: choices[0].message.content[].text + choices = chat_response.get("choices", []) + if choices: + message = choices[0].get("message", {}) + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, dict) and part.get("text") + ) + return str(content) if content else "" + + # Cohere format: text field or message.content + text = chat_response.get("text") + if text: + return str(text) + + message = chat_response.get("message", {}) + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, dict) + ) + return "" + + # ------------------------------------------------------------------ + # Async support (true async via aiohttp, no thread wrappers) + # ------------------------------------------------------------------ + + @property + def _async_client(self) -> Any | None: + """Lazy-init true async client reusing the sync client's signer. + + Returns None if aiohttp is not installed or the client can't be created, + in which case callers fall back to asyncio.to_thread. + """ + if not hasattr(self, "_async_client_instance"): + self._async_client_instance = None + try: + from crewai.utilities.oci_async import OCIAsyncClient + + oci = self._oci + expected_cls = oci.generative_ai_inference.GenerativeAiInferenceClient + if isinstance(self.client, expected_cls): + base_client = self.client.base_client + self._async_client_instance = OCIAsyncClient( + service_endpoint=self.service_endpoint, + signer=base_client.signer, + config=getattr(base_client, "config", {}), + ) + except (ImportError, Exception): + pass + return self._async_client_instance + + def _prepare_async_request( + self, + messages: list[LLMMessage], + tools: list[dict[str, Any]] | None = None, + response_model: type[BaseModel] | None = None, + *, + is_stream: bool = False, + ) -> dict[str, Any]: + """Build request dicts for the async client (SDK objects → JSON dicts).""" + from oci.util import to_dict + + chat_request = self._build_chat_request( + messages, tools=tools, response_model=response_model, is_stream=is_stream + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, ) + return { + "compartment_id": chat_details.compartment_id, + "chat_request_dict": to_dict(chat_details.chat_request), + "serving_mode_dict": to_dict(chat_details.serving_mode), + } # ------------------------------------------------------------------ - # Client serialization + # Client serialization (sync) # ------------------------------------------------------------------ def _chat(self, chat_details: Any) -> Any: diff --git a/lib/crewai/src/crewai/utilities/oci_async.py b/lib/crewai/src/crewai/utilities/oci_async.py new file mode 100644 index 0000000000..3a3d7d9ec9 --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci_async.py @@ -0,0 +1,178 @@ +"""True async HTTP client for OCI Generative AI. + +Bypasses the synchronous OCI SDK for HTTP, using aiohttp directly +with OCI request signing. This gives real async I/O instead of +thread-pool wrappers, until the OCI SDK ships native async support. +""" + +from __future__ import annotations + +import json +import ssl +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator + +import aiohttp +import certifi +import requests + + +def _get_oci_genai_api_version() -> str: + """Detect OCI GenAI API version from SDK, fallback to known version.""" + try: + from oci.generative_ai_inference import GenerativeAiInferenceClient + + if hasattr(GenerativeAiInferenceClient, "API_VERSION"): + return GenerativeAiInferenceClient.API_VERSION + except ImportError: + pass + return "20231130" + + +OCI_GENAI_API_VERSION = _get_oci_genai_api_version() + + +def _snake_to_camel(name: str) -> str: + components = name.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _convert_keys_to_camel(obj: Any) -> Any: + """Recursively convert dict keys from snake_case to camelCase.""" + if isinstance(obj, dict): + return {_snake_to_camel(k): _convert_keys_to_camel(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_convert_keys_to_camel(item) for item in obj] + return obj + + +class OCIAsyncClient: + """Async HTTP client for OCI Generative AI services. + + Uses aiohttp with OCI request signing for true async I/O. + Reuses aiohttp.ClientSession for connection pooling. + """ + + def __init__( + self, + service_endpoint: str, + signer: Any, + config: dict[str, Any] | None = None, + ) -> None: + self.service_endpoint = service_endpoint.rstrip("/") + self.signer = signer + self.config = config or {} + self._session: aiohttp.ClientSession | None = None + self._ensure_signer() + + def _ensure_signer(self) -> None: + if self.signer is not None: + return + if self.config: + try: + from oci.signer import Signer + + self.signer = Signer.from_config(self.config) + except Exception as e: + raise ValueError(f"Failed to create OCI signer from config: {e}") from e + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + self._session = aiohttp.ClientSession(connector=connector) + return self._session + + async def close(self) -> None: + if self._session is not None and not self._session.closed: + await self._session.close() + self._session = None + + async def __aenter__(self) -> OCIAsyncClient: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close() + + def _sign_headers( + self, + method: str, + url: str, + body: dict[str, Any] | None = None, + stream: bool = False, + ) -> dict[str, str]: + req = requests.Request(method, url, json=body) + prepared = req.prepare() + signed = self.signer(prepared) + headers = dict(signed.headers) + if stream: + headers["Accept"] = "text/event-stream" + return headers + + @asynccontextmanager + async def _arequest( + self, + method: str, + url: str, + headers: dict[str, str], + json_body: dict[str, Any] | None = None, + timeout: int = 300, + ) -> AsyncGenerator[aiohttp.ClientResponse, None]: + session = await self._get_session() + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with session.request( + method, url, headers=headers, json=json_body, timeout=client_timeout + ) as response: + yield response + + async def _parse_sse_async( + self, content: aiohttp.StreamReader + ) -> AsyncIterator[dict[str, Any]]: + async for line in content: + line = line.strip() + if not line: + continue + decoded = line.decode("utf-8") + if decoded.lower().startswith("data:"): + data = decoded[5:].strip() + if data and not data.startswith("[DONE]"): + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + async def chat_async( + self, + compartment_id: str, + chat_request_dict: dict[str, Any], + serving_mode_dict: dict[str, Any], + stream: bool = False, + timeout: int = 300, + ) -> AsyncIterator[dict[str, Any]]: + """Make async chat request to OCI GenAI. + + Yields SSE events for streaming, or a single response dict. + """ + url = f"{self.service_endpoint}/{OCI_GENAI_API_VERSION}/actions/chat" + + body = { + "compartmentId": compartment_id, + "servingMode": _convert_keys_to_camel(serving_mode_dict), + "chatRequest": _convert_keys_to_camel(chat_request_dict), + } + + headers = self._sign_headers("POST", url, body, stream=stream) + + async with self._arequest("POST", url, headers, body, timeout) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError( + f"OCI GenAI async request failed ({response.status}): {error_text}" + ) + + if stream: + async for event in self._parse_sse_async(response.content): + yield event + else: + data = await response.json() + yield data diff --git a/lib/crewai/tests/llms/oci/test_oci_async.py b/lib/crewai/tests/llms/oci/test_oci_async.py new file mode 100644 index 0000000000..64edec2354 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_async.py @@ -0,0 +1,93 @@ +"""Tests for OCI true async support (aiohttp-based).""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _skip_unless_live() -> dict[str, str]: + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + if not region: + pytest.skip("OCI_REGION not set") + config: dict[str, str] = {"compartment_id": compartment} + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +@pytest.fixture() +def oci_async_config(): + return _skip_unless_live() + + +@pytest.mark.asyncio +async def test_oci_true_async_client_is_used(oci_async_config: dict): + """Verify the true async client is initialized with a real OCI SDK client.""" + from crewai.utilities.oci_async import OCIAsyncClient + + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + assert llm._async_client is not None + assert isinstance(llm._async_client, OCIAsyncClient) + + +@pytest.mark.asyncio +async def test_oci_true_async_acall(oci_async_config: dict): + """True async acall should return a text response without blocking threads.""" + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + result = await llm.acall( + messages=[{"role": "user", "content": "Say hello in one word."}] + ) + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_true_async_astream(oci_async_config: dict): + """True async astream should yield chunks without thread bridges.""" + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Count to 3."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full = "".join(chunks) + assert len(full) > 0 + + +@pytest.mark.asyncio +async def test_oci_true_async_concurrent_calls(oci_async_config: dict): + """Multiple concurrent acall should run without blocking each other.""" + import asyncio + + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + + results = await asyncio.gather( + llm.acall(messages=[{"role": "user", "content": "Say 'one'"}]), + llm.acall(messages=[{"role": "user", "content": "Say 'two'"}]), + ) + + assert len(results) == 2 + assert all(isinstance(r, str) and len(r) > 0 for r in results)