diff --git a/lib/crewai/pyproject.toml b/lib/crewai/pyproject.toml index a40484f048..cd997ba2c7 100644 --- a/lib/crewai/pyproject.toml +++ b/lib/crewai/pyproject.toml @@ -97,6 +97,11 @@ azure-ai-inference = [ anthropic = [ "anthropic~=0.73.0", ] +oci = [ + "oci>=2.168.0", + "aiohttp>=3.9.0", + "certifi", +] a2a = [ "a2a-sdk~=0.3.10", "httpx-auth~=0.23.1", diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 75b1f65468..986010a0e8 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -317,6 +317,7 @@ def writable(self) -> bool: "hosted_vllm", "cerebras", "dashscope", + "oci", ] @@ -384,6 +385,7 @@ def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: "hosted_vllm": "hosted_vllm", "cerebras": "cerebras", "dashscope": "dashscope", + "oci": "oci", } canonical_provider = provider_mapping.get(prefix.lower()) @@ -506,6 +508,9 @@ def _matches_provider_pattern(cls, model: str, provider: str) -> bool: # OpenRouter uses org/model format but accepts anything return True + if provider == "oci": + return model_lower.startswith("ocid1.generativeaiendpoint") or "." in model_lower + return False @classmethod @@ -541,6 +546,9 @@ def _validate_model_in_constants(cls, model: str, provider: str) -> bool: # azure does not provide a list of available models, determine a better way to handle this return True + if provider == "oci": + return cls._matches_provider_pattern(model, provider) + # Fallback to pattern matching for models not in constants return cls._matches_provider_pattern(model, provider) @@ -622,6 +630,11 @@ def _get_native_provider(cls, provider: str) -> type | None: return OpenAICompatibleCompletion + if provider == "oci": + from crewai.llms.providers.oci.completion import OCICompletion + + return OCICompletion + return None def __init__( diff --git a/lib/crewai/src/crewai/llms/providers/oci/__init__.py b/lib/crewai/src/crewai/llms/providers/oci/__init__.py new file mode 100644 index 0000000000..ea459d687c --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/__init__.py @@ -0,0 +1,19 @@ +from crewai.llms.providers.oci.completion import OCICompletion +from crewai.llms.providers.oci.vision import ( + IMAGE_EMBEDDING_MODELS, + VISION_MODELS, + encode_image, + is_vision_model, + load_image, + to_data_uri, +) + +__all__ = [ + "IMAGE_EMBEDDING_MODELS", + "VISION_MODELS", + "OCICompletion", + "encode_image", + "is_vision_model", + "load_image", + "to_data_uri", +] diff --git a/lib/crewai/src/crewai/llms/providers/oci/completion.py b/lib/crewai/src/crewai/llms/providers/oci/completion.py new file mode 100644 index 0000000000..fe4d60112c --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/completion.py @@ -0,0 +1,1523 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Mapping +from contextlib import contextmanager +import inspect +import json +import logging +import os +import re +import threading +from typing import TYPE_CHECKING, Any, Literal, cast +import uuid + +from pydantic import BaseModel + +from crewai.events.types.llm_events import LLMCallType +from crewai.llms.base_llm import BaseLLM, llm_call_context +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module +from crewai.utilities.pydantic_schema_utils import generate_model_description +from crewai.utilities.types import LLMMessage + + +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + from crewai.tools.base_tool import BaseTool + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" +_OCI_SCHEMA_NAME_PATTERN = re.compile(r"[^a-zA-Z0-9_-]") +_OCI_TOOL_RESULT_GUIDANCE = ( + "You have received tool results above. Respond to the user with a helpful, " + "natural language answer that incorporates the tool results. Do not output " + "raw JSON or tool call syntax. If you need additional information, you may " + "call another tool." +) +_OCI_RESERVED_REQUEST_KWARGS = { + "tool_choice", + "parallel_tool_calls", + "tool_result_guidance", +} + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCICompletion(BaseLLM): + """OCI Generative AI native provider for CrewAI. + + Supports basic text completions for generic (Meta, Google, OpenAI, xAI) + and Cohere model families hosted on the OCI Generative AI service. + """ + + def __init__( + self, + model: str, + *, + compartment_id: str | None = None, + service_endpoint: str | None = None, + auth_type: Literal[ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + | str = "API_KEY", + auth_profile: str | None = None, + auth_file_location: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + top_p: float | None = None, + top_k: int | None = None, + stream: bool = False, + oci_provider: str | None = None, + max_sequential_tool_calls: int = 8, + client: Any | None = None, + **kwargs: Any, + ) -> None: + kwargs.pop("provider", None) + super().__init__( + model=model, + temperature=temperature, + provider="oci", + **kwargs, + ) + + self.compartment_id = compartment_id or os.getenv("OCI_COMPARTMENT_ID") + if not self.compartment_id: + raise ValueError( + "OCI compartment_id is required. Set compartment_id or OCI_COMPARTMENT_ID." + ) + + self.service_endpoint = service_endpoint or os.getenv("OCI_SERVICE_ENDPOINT") + if self.service_endpoint is None: + region = os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + self.service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + self.auth_type = str(auth_type).upper() + self.auth_profile = cast( + str, auth_profile or os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + ) + self.auth_file_location = cast( + str, + auth_file_location + or os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config"), + ) + self.max_tokens = max_tokens + self.top_p = top_p + self.top_k = top_k + self.stream = stream + self.oci_provider = oci_provider or self._infer_provider(model) + self.max_sequential_tool_calls = max_sequential_tool_calls + self._oci = _get_oci_module() + + if client is not None: + self.client = client + else: + client_kwargs = create_oci_client_kwargs( + auth_type=self.auth_type, + service_endpoint=self.service_endpoint, + auth_file_location=self.auth_file_location, + auth_profile=self.auth_profile, + timeout=(10, 240), + oci_module=self._oci, + ) + self.client = self._oci.generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + self._client_condition = threading.Condition() + self._next_client_ticket = 0 + self._active_client_ticket = 0 + self.last_response_metadata = None + + # ------------------------------------------------------------------ + # Provider inference + # ------------------------------------------------------------------ + + def _infer_provider(self, model: str) -> str: + if model.startswith(CUSTOM_ENDPOINT_PREFIX): + return "generic" + if model.startswith("cohere."): + return "cohere" + return "generic" + + def _is_openai_gpt5_family(self) -> bool: + return self.model.startswith("openai.gpt-5") + + def _build_serving_mode(self) -> Any: + models = self._oci.generative_ai_inference.models + if self.model.startswith(CUSTOM_ENDPOINT_PREFIX): + return models.DedicatedServingMode(endpoint_id=self.model) + return models.OnDemandServingMode(model_id=self.model) + + # ------------------------------------------------------------------ + # Message helpers + # ------------------------------------------------------------------ + + def _normalize_messages( + self, messages: str | list[LLMMessage] + ) -> list[LLMMessage]: + return self._format_messages(messages) + + def _coerce_text(self, content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, Mapping): + if item.get("type") == "text": + parts.append(str(item.get("text", ""))) + elif "text" in item: + parts.append(str(item["text"])) + return "\n".join(part for part in parts if part) + return str(content) + + def _message_has_multimodal_content(self, content: Any) -> bool: + if not isinstance(content, list): + return False + for item in content: + if isinstance(item, Mapping) and item.get("type") not in (None, "text"): + return True + return False + + def _build_generic_content(self, content: Any) -> list[Any]: + """Translate CrewAI message content into OCI generic content objects. + + Handles text, image_url, document_url, video_url, and audio_url types. + """ + models = self._oci.generative_ai_inference.models + if isinstance(content, str): + return [models.TextContent(text=content or ".")] + + if not isinstance(content, list): + return [models.TextContent(text=self._coerce_text(content) or ".")] + + processed: list[Any] = [] + for item in content: + if isinstance(item, str): + processed.append(models.TextContent(text=item)) + continue + if not isinstance(item, Mapping): + raise ValueError( + f"OCI message content items must be strings or dictionaries, got: {type(item)}" + ) + + content_type = item.get("type") + if content_type == "text": + processed.append( + models.TextContent(text=str(item.get("text", "")) or ".") + ) + elif content_type == "image_url": + image_url = item.get("image_url", {}) + url = image_url.get("url") if isinstance(image_url, Mapping) else None + if not url: + raise ValueError("OCI image_url content requires image_url.url") + processed.append( + models.ImageContent(image_url=models.ImageUrl(url=url)) + ) + elif content_type in ("document_url", "document", "file"): + doc_data = item.get("document_url") or item.get("document") or item.get("file") + url = doc_data.get("url") if isinstance(doc_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI document content requires a url") + processed.append( + models.DocumentContent(document_url=models.DocumentUrl(url=url)) + ) + elif content_type in ("video_url", "video"): + video_data = item.get("video_url") or item.get("video") + url = video_data.get("url") if isinstance(video_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI video content requires a url") + processed.append( + models.VideoContent(video_url=models.VideoUrl(url=url)) + ) + elif content_type in ("audio_url", "audio"): + audio_data = item.get("audio_url") or item.get("audio") + url = audio_data.get("url") if isinstance(audio_data, Mapping) else item.get("url") + if not url: + raise ValueError("OCI audio content requires a url") + processed.append( + models.AudioContent(audio_url=models.AudioUrl(url=url)) + ) + else: + raise ValueError(f"Unsupported OCI content type: {content_type}") + + return processed or [models.TextContent(text=".")] + + def _build_generic_messages(self, messages: list[LLMMessage]) -> list[Any]: + """Map CrewAI conversation messages into OCI generic chat messages.""" + models = self._oci.generative_ai_inference.models + role_map = { + "user": models.UserMessage, + "assistant": models.AssistantMessage, + "system": models.SystemMessage, + } + oci_messages: list[Any] = [] + + for message in messages: + role = str(message.get("role", "user")).lower() + if role == "tool": + tool_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if message.get("tool_call_id"): + tool_kwargs["tool_call_id"] = message["tool_call_id"] + oci_messages.append(models.ToolMessage(**tool_kwargs)) + continue + + message_cls = role_map.get(role) + if message_cls is None: + logging.debug("Skipping unsupported OCI message role: %s", role) + continue + + message_kwargs: dict[str, Any] = { + "content": self._build_generic_content(message.get("content", "")), + } + if role == "assistant" and message.get("tool_calls"): + message_kwargs["tool_calls"] = [ + models.FunctionCall( + id=tool_call.get("id"), + name=tool_call.get("function", {}).get("name"), + arguments=tool_call.get("function", {}).get("arguments", "{}"), + ) + for tool_call in message.get("tool_calls", []) + if tool_call.get("function", {}).get("name") + ] + if not message_kwargs["content"]: + message_kwargs["content"] = [models.TextContent(text=".")] + + oci_messages.append(message_cls(**message_kwargs)) + + if ( + self._tool_result_guidance_enabled() + and any(str(message.get("role", "")).lower() == "tool" for message in messages) + ): + oci_messages.append( + models.SystemMessage( + content=[models.TextContent(text=_OCI_TOOL_RESULT_GUIDANCE)] + ) + ) + + return oci_messages + + def _build_cohere_chat_history( + self, messages: list[LLMMessage] + ) -> tuple[list[Any], list[Any] | None, str]: + """Translate CrewAI messages into Cohere's split history/tool-results shape.""" + models = self._oci.generative_ai_inference.models + chat_history: list[Any] = [] + trailing_tool_count = 0 + for message in reversed(messages): + if str(message.get("role", "")).lower() != "tool": + break + trailing_tool_count += 1 + + history_messages = ( + messages[:-trailing_tool_count] if trailing_tool_count else messages[:-1] + ) + + for message in history_messages: + role = str(message.get("role", "user")).lower() + content = message.get("content", "") + if self._message_has_multimodal_content(content): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) + + if role in ("user", "system"): + message_cls = ( + models.CohereUserMessage + if role == "user" + else models.CohereSystemMessage + ) + chat_history.append(message_cls(message=self._coerce_text(content))) + elif role == "assistant": + tool_calls = None + if message.get("tool_calls"): + tool_calls = [] + for tool_call in message.get("tool_calls", []): + function_info = tool_call.get("function", {}) + function_name = function_info.get("name") + if not function_name: + continue + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + tool_calls.append( + models.CohereToolCall(name=function_name, parameters=parameters) + ) + chat_history.append( + models.CohereChatBotMessage( + message=self._coerce_text(content) or " ", + tool_calls=tool_calls, + ) + ) + elif role == "tool": + tool_name = message.get("name") or "tool" + chat_history.append( + models.CohereToolMessage( + tool_results=[ + models.CohereToolResult( + call=models.CohereToolCall(name=tool_name, parameters={}), + outputs=[{"output": self._coerce_text(content)}], + ) + ] + ) + ) + + last_message = messages[-1] if messages else {"role": "user", "content": ""} + tool_results: list[Any] = [] + if str(last_message.get("role", "user")).lower() == "tool": + previous_tool_calls: dict[str, dict[str, Any]] = {} + for message in messages: + if str(message.get("role", "")).lower() != "assistant": + continue + for tool_call in message.get("tool_calls", []): + tool_call_id = tool_call.get("id") + if not tool_call_id: + continue + function_info = tool_call.get("function", {}) + raw_arguments = function_info.get("arguments", "{}") + if isinstance(raw_arguments, str): + try: + parameters = json.loads(raw_arguments) + except json.JSONDecodeError: + parameters = {} + elif isinstance(raw_arguments, Mapping): + parameters = dict(raw_arguments) + else: + parameters = {} + previous_tool_calls[tool_call_id] = { + "name": function_info.get("name", "tool"), + "parameters": parameters, + } + + for message in messages[-trailing_tool_count:]: + if str(message.get("role", "")).lower() != "tool": + continue + tool_call_id = message.get("tool_call_id") + if not isinstance(tool_call_id, str): + continue + previous_call = previous_tool_calls.get(tool_call_id, {}) + tool_results.append( + models.CohereToolResult( + call=models.CohereToolCall( + name=previous_call.get("name", message.get("name", "tool")), + parameters=previous_call.get("parameters", {}), + ), + outputs=[{"output": self._coerce_text(message.get("content", ""))}], + ) + ) + + message_text = self._coerce_text(last_message.get("content", "")) + if tool_results: + message_text = "" + + return chat_history, tool_results or None, message_text + + # ------------------------------------------------------------------ + # Tool formatting + # ------------------------------------------------------------------ + + def _format_tools(self, tools: list[dict[str, Any]] | None) -> list[Any]: + if not tools: + return [] + models = self._oci.generative_ai_inference.models + formatted: list[Any] = [] + for tool in tools: + if not isinstance(tool, Mapping): + continue + function_spec = tool.get("function", {}) + if not isinstance(function_spec, Mapping): + continue + name = function_spec.get("name") + if not name: + continue + parameters = function_spec.get("parameters", {}) + if not isinstance(parameters, Mapping): + parameters = {} + + if self.oci_provider == "cohere": + param_defs = {} + required = set(parameters.get("required", [])) + for pname, pschema in parameters.get("properties", {}).items(): + if not isinstance(pschema, Mapping): + continue + param_defs[pname] = models.CohereParameterDefinition( + description=pschema.get("description", ""), + type=pschema.get("type", "object"), + is_required=pname in required, + ) + formatted.append(models.CohereTool( + name=name, + description=function_spec.get("description", name), + parameter_definitions=param_defs, + )) + else: + formatted.append(models.FunctionDefinition( + name=name, + description=function_spec.get("description", name), + parameters={ + "type": parameters.get("type", "object"), + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + )) + return formatted + + def _tool_result_guidance_enabled(self) -> bool: + return bool(self.additional_params.get("tool_result_guidance")) + + def _parallel_tool_calls_enabled(self) -> bool: + return bool(self.additional_params.get("parallel_tool_calls")) + + def _build_tool_choice(self) -> Any | None: + tool_choice = self.additional_params.get("tool_choice") + if tool_choice is None: + return None + models = self._oci.generative_ai_inference.models + if isinstance(tool_choice, str): + if tool_choice == "auto": + return models.ToolChoiceAuto() + if tool_choice == "none": + return models.ToolChoiceNone() + if tool_choice in ("any", "required"): + return models.ToolChoiceRequired() + return models.ToolChoiceFunction(name=tool_choice) + if isinstance(tool_choice, bool): + return models.ToolChoiceRequired() if tool_choice else models.ToolChoiceNone() + if isinstance(tool_choice, Mapping): + fn = tool_choice.get("function") + if isinstance(fn, Mapping) and fn.get("name"): + return models.ToolChoiceFunction(name=str(fn["name"])) + return models.ToolChoiceAuto() + raise ValueError("Unrecognized OCI tool_choice. Expected str, bool, or function mapping.") + + def _allowed_passthrough_request_keys(self, request_cls: type[Any]) -> set[str]: + """Return request attributes that can safely be forwarded to the OCI SDK.""" + attribute_map = getattr(request_cls, "attribute_map", None) + if isinstance(attribute_map, Mapping): + return {str(key) for key in attribute_map} + swagger_types = getattr(request_cls, "swagger_types", None) + if isinstance(swagger_types, Mapping): + return {str(key) for key in swagger_types} + signature = inspect.signature(request_cls) + return { + name + for name, param in signature.parameters.items() + if name != "self" and param.kind is not inspect.Parameter.VAR_KEYWORD + } + + # ------------------------------------------------------------------ + # Request building + # ------------------------------------------------------------------ + + def _build_response_format( + self, response_model: type[BaseModel] | None + ) -> Any | None: + if response_model is None: + return None + models = self._oci.generative_ai_inference.models + schema_description = generate_model_description(response_model)["json_schema"] + schema_name = _OCI_SCHEMA_NAME_PATTERN.sub("_", schema_description["name"]) + json_schema = models.ResponseJsonSchema( + name=schema_name, + description=(response_model.__doc__ or "").strip() or schema_name, + schema=schema_description["schema"], + is_strict=schema_description["strict"], + ) + if self.oci_provider == "cohere": + return models.CohereResponseJsonFormat(schema=json_schema.schema) + return models.JsonSchemaResponseFormat(json_schema=json_schema) + + def _build_chat_request( + self, + messages: list[LLMMessage], + tools: list[dict[str, Any]] | None = None, + response_model: type[BaseModel] | None = None, + *, + is_stream: bool = False, + ) -> Any: + """Build the provider-specific OCI chat request for the current model.""" + models = self._oci.generative_ai_inference.models + + if self.oci_provider == "cohere": + if any(self._message_has_multimodal_content(msg.get("content")) for msg in messages): + raise ValueError( + "OCI Cohere models currently support text-only messages in CrewAI." + ) + chat_history, tool_results, message_text = self._build_cohere_chat_history( + messages + ) + request_kwargs: dict[str, Any] = { + "message": message_text, + "chat_history": chat_history, + "api_format": models.BaseChatRequest.API_FORMAT_COHERE, + } + if tool_results: + request_kwargs["tool_results"] = tool_results + else: + request_kwargs = { + "messages": self._build_generic_messages(messages), + "api_format": models.BaseChatRequest.API_FORMAT_GENERIC, + } + + if self.temperature is not None and not self._is_openai_gpt5_family(): + request_kwargs["temperature"] = self.temperature + if self.max_tokens is not None: + if self.oci_provider == "generic" and self.model.startswith("openai."): + request_kwargs["max_completion_tokens"] = self.max_tokens + else: + request_kwargs["max_tokens"] = self.max_tokens + if self.top_p is not None: + request_kwargs["top_p"] = self.top_p + if self.top_k is not None: + request_kwargs["top_k"] = self.top_k + + if self.stop and not self._is_openai_gpt5_family(): + stop_key = "stop_sequences" if self.oci_provider == "cohere" else "stop" + request_kwargs[stop_key] = list(self.stop) + + formatted_tools = self._format_tools(tools) + if formatted_tools: + request_kwargs["tools"] = formatted_tools + if self.oci_provider == "cohere": + if self._parallel_tool_calls_enabled(): + raise ValueError("OCI Cohere models do not support parallel_tool_calls.") + request_kwargs.setdefault("is_force_single_step", False) + else: + tool_choice = self._build_tool_choice() + if tool_choice is not None: + request_kwargs["tool_choice"] = tool_choice + if self._parallel_tool_calls_enabled(): + request_kwargs["is_parallel_tool_calls"] = True + + response_format = self._build_response_format(response_model) + if response_format is not None: + request_kwargs["response_format"] = response_format + + if is_stream: + request_kwargs["is_stream"] = True + request_kwargs["stream_options"] = models.StreamOptions( + is_include_usage=True + ) + + if self.oci_provider == "cohere": + allowed = self._allowed_passthrough_request_keys(models.CohereChatRequest) + passthrough = { + k: v for k, v in self.additional_params.items() + if k not in _OCI_RESERVED_REQUEST_KWARGS and k in allowed + } + request_kwargs.update(passthrough) + return models.CohereChatRequest(**request_kwargs) + + allowed = self._allowed_passthrough_request_keys(models.GenericChatRequest) + passthrough = { + k: v for k, v in self.additional_params.items() + if k not in _OCI_RESERVED_REQUEST_KWARGS and k in allowed + } + request_kwargs.update(passthrough) + return models.GenericChatRequest(**request_kwargs) + + # ------------------------------------------------------------------ + # Response extraction + # ------------------------------------------------------------------ + + def _extract_text(self, response: Any) -> str: + chat_response = response.data.chat_response + if self.oci_provider == "cohere": + if getattr(chat_response, "text", None): + return chat_response.text or "" + message = getattr(chat_response, "message", None) + if message is not None: + content = getattr(message, "content", None) or [] + return "".join( + part.text for part in content if getattr(part, "text", None) + ) + return "" + + choices = getattr(chat_response, "choices", None) or [] + if not choices: + return "" + message = getattr(choices[0], "message", None) + if message is None: + return "" + content = getattr(message, "content", None) or [] + return "".join(part.text for part in content if getattr(part, "text", None)) + + def _extract_tool_calls(self, response: Any) -> list[dict[str, Any]]: + """Normalize provider-specific tool calls back into CrewAI's shape.""" + chat_response = response.data.chat_response + raw: list[Any] = [] + if self.oci_provider == "cohere": + raw = getattr(chat_response, "tool_calls", None) or [] + else: + choices = getattr(chat_response, "choices", None) or [] + if choices: + msg = getattr(choices[0], "message", None) + raw = getattr(msg, "tool_calls", None) or [] + + if self.oci_provider == "cohere": + return [ + { + "id": uuid.uuid4().hex, + "type": "function", + "function": { + "name": getattr(tc, "name", ""), + "arguments": json.dumps(getattr(tc, "parameters", {}) or {}), + }, + } + for tc in raw + ] + return [ + { + "id": getattr(tc, "id", None), + "type": "function", + "function": { + "name": getattr(tc, "name", ""), + "arguments": getattr(tc, "arguments", "{}"), + }, + } + for tc in raw + ] + + def _extract_usage(self, response: Any) -> dict[str, int]: + chat_response = response.data.chat_response + usage = getattr(chat_response, "usage", None) + if usage is None: + return {} + return { + "prompt_tokens": getattr(usage, "prompt_tokens", 0), + "completion_tokens": getattr(usage, "completion_tokens", 0), + "total_tokens": getattr(usage, "total_tokens", 0), + } + + def _extract_response_metadata(self, response: Any) -> dict[str, Any]: + chat_response = response.data.chat_response + metadata: dict[str, Any] = {} + + finish_reason = getattr(chat_response, "finish_reason", None) + if finish_reason is None: + choices = getattr(chat_response, "choices", None) or [] + if choices: + finish_reason = getattr(choices[0], "finish_reason", None) + + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + + usage = self._extract_usage(response) + if usage: + metadata["usage"] = usage + + return metadata + + # ------------------------------------------------------------------ + # Streaming extraction + # ------------------------------------------------------------------ + + def _parse_stream_event(self, event: Any) -> dict[str, Any]: + """Convert OCI SSE event payloads into plain dicts.""" + event_data = getattr(event, "data", None) + if not event_data: + return {} + if isinstance(event_data, str): + try: + parsed = json.loads(event_data) + if isinstance(parsed, Mapping): + return dict(parsed) + return {} + except json.JSONDecodeError: + logging.debug("Skipping invalid OCI SSE payload: %s", event_data) + return {} + if isinstance(event_data, Mapping): + return dict(event_data) + return {} + + def _extract_text_from_stream_event(self, event_data: dict[str, Any]) -> str: + if self.oci_provider == "cohere": + if "text" in event_data: + return str(event_data.get("text", "")) + message = event_data.get("message", {}) + if isinstance(message, Mapping): + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) + ) + return "" + + message = event_data.get("message", {}) + if not isinstance(message, Mapping): + return "" + content = message.get("content", []) + if not isinstance(content, list): + return "" + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, Mapping) and part.get("text") + ) + + def _extract_tool_calls_from_stream_event( + self, event_data: dict[str, Any] + ) -> list[dict[str, Any]]: + message = event_data.get("message", {}) + if self.oci_provider == "cohere": + raw = event_data.get("toolCalls", []) + else: + raw = message.get("toolCalls", []) if isinstance(message, Mapping) else [] + if not isinstance(raw, list): + return [] + if self.oci_provider == "cohere": + return [ + {"id": None, "type": "function", "function": { + "name": str(tc.get("name", "")), + "arguments": json.dumps(tc.get("parameters", {})), + }} + for tc in raw if isinstance(tc, Mapping) + ] + return [ + {"id": tc.get("id"), "type": "function", "function": { + "name": tc.get("name"), "arguments": tc.get("arguments"), + }} + for tc in raw if isinstance(tc, Mapping) + ] + + def _extract_usage_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, int]: + usage = event_data.get("usage") + if not isinstance(usage, Mapping): + return {} + return { + "prompt_tokens": int(usage.get("promptTokens", 0) or 0), + "completion_tokens": int(usage.get("completionTokens", 0) or 0), + "total_tokens": int(usage.get("totalTokens", 0) or 0), + } + + def _extract_metadata_from_stream_event(self, event_data: dict[str, Any]) -> dict[str, Any]: + metadata: dict[str, Any] = {} + finish_reason = event_data.get("finishReason") + if finish_reason is not None: + metadata["finish_reason"] = finish_reason + usage = self._extract_usage_from_stream_event(event_data) + if usage: + metadata["usage"] = usage + return metadata + + # ------------------------------------------------------------------ + # Structured output + # ------------------------------------------------------------------ + + def _parse_structured_response( + self, + *, + content: str, + response_model: type[BaseModel], + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> BaseModel: + try: + structured_response = self._validate_structured_output( + content, response_model + ) + except Exception as error: + raise ValueError( + f"Failed to validate OCI structured response with model " + f"{response_model.__name__}: {error}" + ) from error + + if not isinstance(structured_response, BaseModel): + raise ValueError( + f"OCI structured response parsing returned unexpected type: " + f"{type(structured_response)}" + ) + + self._emit_call_completed_event( + response=structured_response.model_dump_json(), + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return structured_response + + # ------------------------------------------------------------------ + # Tool execution + # ------------------------------------------------------------------ + + def _handle_tool_calls( + self, + *, + normalized_messages: list[LLMMessage], + tools: list[dict[str, BaseTool]] | None, + callbacks: list[Any] | None, + available_functions: dict[str, Any] | None, + from_task: Task | None, + from_agent: Agent | None, + tool_depth: int, + response_model: type[BaseModel] | None = None, + tool_calls: list[dict[str, Any]], + ) -> str | BaseModel | list[dict[str, Any]]: + """Execute one round of tool calls and recurse until the model finishes.""" + if tool_calls and not available_functions: + self._emit_call_completed_event( + response=tool_calls, + call_type=LLMCallType.TOOL_CALL, + from_task=from_task, + from_agent=from_agent, + messages=normalized_messages, + ) + return tool_calls + + if tool_depth >= self.max_sequential_tool_calls: + raise RuntimeError( + "OCI native provider exceeded max_sequential_tool_calls." + ) + + next_messages = list(normalized_messages) + next_messages.append( + {"role": "assistant", "content": None, "tool_calls": tool_calls} + ) + + for tool_call in tool_calls: + fn = tool_call.get("function", {}) + fn_name = fn.get("name", "") + raw_args = fn.get("arguments", "{}") + if isinstance(raw_args, str): + try: + fn_args = json.loads(raw_args) + except json.JSONDecodeError: + fn_args = {} + elif isinstance(raw_args, Mapping): + fn_args = dict(raw_args) + else: + fn_args = {} + + result = self._handle_tool_execution( + function_name=fn_name, + function_args=fn_args, + available_functions=available_functions or {}, + from_task=from_task, + from_agent=from_agent, + ) + if result is None: + result = f"Tool '{fn_name}' failed or returned no result." + + next_messages.append({ + "role": "tool", + "tool_call_id": str(tool_call.get("id") or uuid.uuid4().hex), + "name": fn_name, + "content": str(result), + }) + + if self.stream: + return self._stream_call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + response_model=response_model, + ) + return self._call_impl( + messages=next_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth + 1, + response_model=response_model, + ) + + # ------------------------------------------------------------------ + # Call paths + # ------------------------------------------------------------------ + + def _finalize_text_response( + self, + *, + content: str, + messages: list[LLMMessage], + from_task: Task | None, + from_agent: Agent | None, + ) -> str: + content = self._apply_stop_words(content) + content = self._invoke_after_llm_call_hooks(messages, content, from_agent) + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=messages, + ) + return content + + def _call_impl( + self, + *, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + tool_depth: int = 0, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request( + normalized_messages, tools=tools, response_model=response_model + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage = self._extract_usage(response) + if usage: + self._track_token_usage_internal(usage) + self.last_response_metadata = self._extract_response_metadata(response) or None + + content = self._extract_text(response) + tool_calls = self._extract_tool_calls(response) + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + response_model=response_model, + tool_calls=tool_calls, + ) + + if response_model is not None: + return self._parse_structured_response( + content=content, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + return self._finalize_text_response( + content=content, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def _stream_call_impl( + self, + *, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + tool_depth: int = 0, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + """Handle OCI streaming while reconstructing final text/tool state.""" + normalized_messages = ( + messages if isinstance(messages, list) else self._normalize_messages(messages) + ) + chat_request = self._build_chat_request( + normalized_messages, tools=tools, response_model=response_model, + is_stream=True, + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + full_response = "" + tool_calls_by_index: dict[int, dict[str, Any]] = {} + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + response_id = uuid.uuid4().hex + + for event in self._stream_chat_events(chat_details): + event_data = self._parse_stream_event(event) + if not event_data: + continue + + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + full_response += text_chunk + self._emit_stream_chunk_event( + chunk=text_chunk, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.LLM_CALL, + response_id=response_id, + ) + + stream_tool_calls = self._extract_tool_calls_from_stream_event(event_data) + for index, tc in enumerate(stream_tool_calls): + state = tool_calls_by_index.setdefault( + index, + {"id": None, "type": "function", "function": {"name": None, "arguments": ""}}, + ) + if tc.get("id"): + state["id"] = tc["id"] + fn = tc.get("function", {}) + if fn.get("name"): + state["function"]["name"] = fn["name"] + chunk_args = fn.get("arguments") + if chunk_args: + state["function"]["arguments"] += str(chunk_args) + self._emit_stream_chunk_event( + chunk=str(chunk_args or ""), + tool_call={"id": state["id"], "type": "function", "function": { + "name": state["function"]["name"], + "arguments": str(chunk_args or ""), + }}, + from_task=from_task, + from_agent=from_agent, + call_type=LLMCallType.TOOL_CALL, + response_id=response_id, + ) + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + tool_calls = [ + { + "id": tc.get("id") or uuid.uuid4().hex, + "type": "function", + "function": { + "name": tc["function"].get("name", "") or "", + "arguments": tc["function"].get("arguments", "") or "", + }, + } + for _, tc in sorted(tool_calls_by_index.items()) + if tc["function"].get("name") + ] + + if tool_calls: + return self._handle_tool_calls( + normalized_messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=tool_depth, + response_model=response_model, + tool_calls=tool_calls, + ) + + if response_model is not None: + return self._parse_structured_response( + content=full_response, + response_model=response_model, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + return self._finalize_text_response( + content=full_response, + messages=normalized_messages, + from_task=from_task, + from_agent=from_agent, + ) + + def iter_stream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Yield raw text chunks from OCI without triggering tool recursion.""" + normalized_messages = self._normalize_messages(messages) + chat_request = self._build_chat_request(normalized_messages, is_stream=True) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + response = self._chat(chat_details) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + for event in response.data.events(): + event_data = self._parse_stream_event(event) + if not event_data: + continue + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + async def astream( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + ) -> Any: + """Async streaming — true async via aiohttp when available, thread fallback otherwise.""" + if self._async_client is None: + # Fallback: sync stream bridged via thread + loop = asyncio.get_running_loop() + queue: asyncio.Queue[str | None] = asyncio.Queue() + error_holder: list[BaseException] = [] + + def _producer() -> None: + try: + for chunk in self.iter_stream( + messages=messages, tools=tools, callbacks=callbacks, + available_functions=available_functions, from_task=from_task, + from_agent=from_agent, + ): + loop.call_soon_threadsafe(queue.put_nowait, chunk) + except BaseException as error: + error_holder.append(error) + finally: + loop.call_soon_threadsafe(queue.put_nowait, None) + + thread = threading.Thread(target=_producer, daemon=True) + thread.start() + while True: + chunk = await queue.get() + if chunk is None: + break + yield chunk + thread.join() + if error_holder: + raise error_holder[0] + return + + normalized_messages = self._normalize_messages(messages) + request_data = self._prepare_async_request( + normalized_messages, tools=tools, is_stream=True + ) + usage_data: dict[str, int] = {} + response_metadata: dict[str, Any] = {} + + async for event_data in self._async_client.chat_async( + compartment_id=request_data["compartment_id"], + chat_request_dict=request_data["chat_request_dict"], + serving_mode_dict=request_data["serving_mode_dict"], + stream=True, + ): + text_chunk = self._extract_text_from_stream_event(event_data) + if text_chunk: + yield text_chunk + + usage_chunk = self._extract_usage_from_stream_event(event_data) + if usage_chunk: + usage_data = usage_chunk + response_metadata.update(self._extract_metadata_from_stream_event(event_data)) + + if usage_data: + self._track_token_usage_internal(usage_data) + response_metadata["usage"] = usage_data + self.last_response_metadata = response_metadata or None + + def call( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | BaseModel | list[dict[str, Any]]: + normalized_messages = self._normalize_messages(messages) + + with llm_call_context(): + try: + self._emit_call_started_event( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + ) + + if not self._invoke_before_llm_call_hooks( + normalized_messages, from_agent + ): + raise ValueError("LLM call blocked by before_llm_call hook") + + if self.stream: + return self._stream_call_impl( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=0, + response_model=response_model, + ) + + return self._call_impl( + messages=normalized_messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + tool_depth=0, + response_model=response_model, + ) + except Exception as error: + error_message = f"OCI Generative AI call failed: {error!s}" + self._emit_call_failed_event( + error=error_message, + from_task=from_task, + from_agent=from_agent, + ) + raise + + async def acall( + self, + messages: str | list[LLMMessage], + tools: list[dict[str, BaseTool]] | None = None, + callbacks: list[Any] | None = None, + available_functions: dict[str, Any] | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, + response_model: type[BaseModel] | None = None, + ) -> str | Any: + """Async call — true async via aiohttp when available, thread fallback otherwise.""" + if self._async_client is None: + return await asyncio.to_thread( + self.call, messages, tools=tools, callbacks=callbacks, + available_functions=available_functions, from_task=from_task, + from_agent=from_agent, response_model=response_model, + ) + + normalized_messages = self._normalize_messages(messages) + request_data = self._prepare_async_request( + normalized_messages, tools=tools, response_model=response_model + ) + + response_data = None + async for data in self._async_client.chat_async( + compartment_id=request_data["compartment_id"], + chat_request_dict=request_data["chat_request_dict"], + serving_mode_dict=request_data["serving_mode_dict"], + stream=False, + ): + response_data = data + break + + if response_data is None: + raise RuntimeError("No response received from OCI GenAI async call") + + # Extract text from the raw JSON response + chat_response = response_data.get("chatResponse", {}) + content = self._extract_text_from_async_response(chat_response) + content = self._apply_stop_words(content) + + # Track usage + usage = chat_response.get("usage", {}) + if usage: + usage_dict = { + "prompt_tokens": usage.get("promptTokens", 0), + "completion_tokens": usage.get("completionTokens", 0), + "total_tokens": usage.get("totalTokens", 0), + } + self._track_token_usage_internal(usage_dict) + + self._emit_call_completed_event( + response=content, + call_type=LLMCallType.LLM_CALL, + from_task=from_task, + from_agent=from_agent, + messages=normalized_messages, + ) + return content + + def _extract_text_from_async_response(self, chat_response: dict[str, Any]) -> str: + """Extract text from a raw JSON chat response dict.""" + # Generic format: choices[0].message.content[].text + choices = chat_response.get("choices", []) + if choices: + message = choices[0].get("message", {}) + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, dict) and part.get("text") + ) + return str(content) if content else "" + + # Cohere format: text field or message.content + text = chat_response.get("text") + if text: + return str(text) + + message = chat_response.get("message", {}) + content = message.get("content", []) + if isinstance(content, list): + return "".join( + str(part.get("text", "")) + for part in content + if isinstance(part, dict) + ) + return "" + + # ------------------------------------------------------------------ + # Async support (true async via aiohttp, no thread wrappers) + # ------------------------------------------------------------------ + + @property + def _async_client(self) -> Any | None: + """Lazy-init true async client reusing the sync client's signer. + + Returns None if aiohttp is not installed or the client can't be created, + in which case callers fall back to asyncio.to_thread. + """ + if not hasattr(self, "_async_client_instance"): + self._async_client_instance = None + try: + from crewai.utilities.oci_async import OCIAsyncClient + + oci = self._oci + expected_cls = oci.generative_ai_inference.GenerativeAiInferenceClient + if isinstance(self.client, expected_cls): + base_client = self.client.base_client + self._async_client_instance = OCIAsyncClient( + service_endpoint=self.service_endpoint, + signer=base_client.signer, + config=getattr(base_client, "config", {}), + ) + except (ImportError, Exception): + pass + return self._async_client_instance + + def _prepare_async_request( + self, + messages: list[LLMMessage], + tools: list[dict[str, Any]] | None = None, + response_model: type[BaseModel] | None = None, + *, + is_stream: bool = False, + ) -> dict[str, Any]: + """Build request dicts for the async client (SDK objects → JSON dicts).""" + from oci.util import to_dict + + chat_request = self._build_chat_request( + messages, tools=tools, response_model=response_model, is_stream=is_stream + ) + chat_details = self._oci.generative_ai_inference.models.ChatDetails( + compartment_id=self.compartment_id, + serving_mode=self._build_serving_mode(), + chat_request=chat_request, + ) + return { + "compartment_id": chat_details.compartment_id, + "chat_request_dict": to_dict(chat_details.chat_request), + "serving_mode_dict": to_dict(chat_details.serving_mode), + } + + # ------------------------------------------------------------------ + # Client serialization (sync) + # ------------------------------------------------------------------ + + def _chat(self, chat_details: Any) -> Any: + with self._ordered_client_access(): + return self.client.chat(chat_details) + + def _stream_chat_events(self, chat_details: Any) -> Any: + """Yield streaming events while holding the shared OCI client lock.""" + with self._ordered_client_access(): + response = self.client.chat(chat_details) + yield from response.data.events() + + @contextmanager + def _ordered_client_access(self) -> Any: + """Serialize shared OCI client access in call-arrival order.""" + with self._client_condition: + ticket = self._next_client_ticket + self._next_client_ticket += 1 + while ticket != self._active_client_ticket: + self._client_condition.wait() + + try: + yield + finally: + with self._client_condition: + self._active_client_ticket += 1 + self._client_condition.notify_all() + + # ------------------------------------------------------------------ + # Capability declarations + # ------------------------------------------------------------------ + + def supports_function_calling(self) -> bool: + return True + + def supports_multimodal(self) -> bool: + return True + + def get_context_window_size(self) -> int: + model_lower = self.model.lower() + if model_lower.startswith("google.gemini"): + return 1048576 + if model_lower.startswith("openai."): + return 200000 + if model_lower.startswith("cohere."): + return 128000 + if model_lower.startswith("meta."): + return 131072 + return 131072 diff --git a/lib/crewai/src/crewai/llms/providers/oci/vision.py b/lib/crewai/src/crewai/llms/providers/oci/vision.py new file mode 100644 index 0000000000..d0049f2632 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/oci/vision.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import base64 +import mimetypes +from pathlib import Path + + +VISION_MODELS: list[str] = [ + "meta.llama-3.2-90b-vision-instruct", + "meta.llama-3.2-11b-vision-instruct", + "meta.llama-4-scout-17b-16e-instruct", + "meta.llama-4-maverick-17b-128e-instruct-fp8", + "google.gemini-2.5-flash", + "google.gemini-2.5-pro", + "google.gemini-2.5-flash-lite", + "xai.grok-4", + "xai.grok-4-1-fast-reasoning", + "xai.grok-4-1-fast-non-reasoning", + "xai.grok-4-fast-reasoning", + "xai.grok-4-fast-non-reasoning", + "cohere.command-a-vision", +] + +IMAGE_EMBEDDING_MODELS: list[str] = [ + "cohere.embed-v4.0", + "cohere.embed-multilingual-image-v3.0", +] + + +def to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + """Convert bytes, file paths, or data URIs into a data URI.""" + if isinstance(image, bytes): + encoded = base64.standard_b64encode(image).decode("utf-8") + return f"data:{mime_type};base64,{encoded}" + + image_str = str(image) + if image_str.startswith("data:"): + return image_str + + path = Path(image_str) + detected_mime = mimetypes.guess_type(str(path))[0] or mime_type + encoded = base64.standard_b64encode(path.read_bytes()).decode("utf-8") + return f"data:{detected_mime};base64,{encoded}" + + +def load_image(file_path: str | Path) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(file_path)}} + + +def encode_image( + image_bytes: bytes, mime_type: str = "image/png" +) -> dict[str, dict[str, str] | str]: + return {"type": "image_url", "image_url": {"url": to_data_uri(image_bytes, mime_type)}} + + +def is_vision_model(model_id: str) -> bool: + return model_id in VISION_MODELS diff --git a/lib/crewai/src/crewai/rag/embeddings/factory.py b/lib/crewai/src/crewai/rag/embeddings/factory.py index 8027793200..37c32e9062 100644 --- a/lib/crewai/src/crewai/rag/embeddings/factory.py +++ b/lib/crewai/src/crewai/rag/embeddings/factory.py @@ -70,6 +70,10 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec + from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, + ) + from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -99,6 +103,7 @@ "instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider", "jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider", "ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider", + "oci": "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider", "onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider", "openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider", "openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider", @@ -216,6 +221,10 @@ def build_embedder_from_dict( def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ... +@overload +def build_embedder_from_dict(spec: OCIProviderSpec) -> OCIEmbeddingFunction: ... + + @overload def build_embedder_from_dict(spec: dict[str, Any]) -> EmbeddingFunction[Any]: ... diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py new file mode 100644 index 0000000000..288602b1c8 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/__init__.py @@ -0,0 +1,17 @@ +"""OCI embedding provider exports.""" + +from crewai.rag.embeddings.providers.oci.embedding_callable import ( + OCIEmbeddingFunction, +) +from crewai.rag.embeddings.providers.oci.oci_provider import OCIProvider +from crewai.rag.embeddings.providers.oci.types import ( + OCIProviderConfig, + OCIProviderSpec, +) + +__all__ = [ + "OCIEmbeddingFunction", + "OCIProvider", + "OCIProviderConfig", + "OCIProviderSpec", +] diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py new file mode 100644 index 0000000000..de65f1d58c --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/embedding_callable.py @@ -0,0 +1,181 @@ +"""OCI embedding function implementation.""" + +from __future__ import annotations + +import base64 +from collections.abc import Iterator, Sequence +import mimetypes +import os +from pathlib import Path +from typing import Any, cast + +from chromadb.api.types import Documents, EmbeddingFunction, Embeddings +from typing_extensions import Unpack + +from crewai.rag.embeddings.providers.oci.types import OCIProviderConfig +from crewai.utilities.oci import create_oci_client_kwargs, get_oci_module + + +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" +DEFAULT_OCI_REGION = "us-chicago-1" + + +def _get_oci_module() -> Any: + """Backward-compatible module-local alias used by tests and patches.""" + return get_oci_module() + + +class OCIEmbeddingFunction(EmbeddingFunction[Documents]): + """Embedding function for OCI Generative AI embedding models.""" + + def __init__(self, **kwargs: Unpack[OCIProviderConfig]) -> None: + self._config = kwargs + self._client: Any = kwargs.get("client") + if self._client is None: + service_endpoint = kwargs.get("service_endpoint") + region = kwargs.get("region") or os.getenv("OCI_REGION", DEFAULT_OCI_REGION) + if service_endpoint is None: + service_endpoint = ( + f"https://inference.generativeai.{region}.oci.oraclecloud.com" + ) + + client_kwargs = create_oci_client_kwargs( + auth_type=kwargs.get("auth_type", "API_KEY"), + service_endpoint=service_endpoint, + auth_file_location=kwargs.get("auth_file_location", "~/.oci/config"), + auth_profile=kwargs.get("auth_profile", "DEFAULT"), + timeout=kwargs.get("timeout", (10, 120)), + oci_module=_get_oci_module(), + ) + self._client = ( + _get_oci_module().generative_ai_inference.GenerativeAiInferenceClient( + **client_kwargs + ) + ) + + def _require_client(self) -> Any: + if self._client is None: + raise ValueError("OCI embedding client is not initialized.") + return self._client + + @staticmethod + def name() -> str: + return "oci" + + @staticmethod + def build_from_config(config: dict[str, Any]) -> OCIEmbeddingFunction: + timeout = config.get("timeout") + if isinstance(timeout, list): + config = dict(config) + config["timeout"] = tuple(timeout) + return OCIEmbeddingFunction(**config) + + def get_config(self) -> dict[str, Any]: + config = dict(self._config) + config.pop("client", None) + timeout = config.get("timeout") + if isinstance(timeout, tuple): + config["timeout"] = list(timeout) + return config + + def _get_serving_mode(self) -> Any: + oci = _get_oci_module() + model_name = self._config.get("model_name") + if not model_name: + raise ValueError("OCI embeddings require model_name") + if model_name.startswith(CUSTOM_ENDPOINT_PREFIX): + return oci.generative_ai_inference.models.DedicatedServingMode( + endpoint_id=model_name + ) + return oci.generative_ai_inference.models.OnDemandServingMode( + model_id=model_name + ) + + def _build_request( + self, inputs: list[str], *, input_type: str | None = None + ) -> Any: + oci = _get_oci_module() + compartment_id = self._config.get("compartment_id") or os.getenv( + "OCI_COMPARTMENT_ID" + ) + if not compartment_id: + raise ValueError( + "OCI embeddings require compartment_id. Set it explicitly or use OCI_COMPARTMENT_ID." + ) + + request_kwargs: dict[str, Any] = { + "serving_mode": self._get_serving_mode(), + "compartment_id": compartment_id, + "truncate": self._config.get("truncate", "END"), + "inputs": inputs, + } + + resolved_input_type = input_type or self._config.get("input_type") + if resolved_input_type: + request_kwargs["input_type"] = resolved_input_type + + output_dimensions = self._config.get("output_dimensions") + if output_dimensions is not None: + embed_text_details = oci.generative_ai_inference.models.EmbedTextDetails + if hasattr(embed_text_details, "output_dimensions"): + request_kwargs["output_dimensions"] = output_dimensions + else: + raise ValueError( + "output_dimensions requires a newer OCI SDK. Upgrade the oci package." + ) + + return oci.generative_ai_inference.models.EmbedTextDetails(**request_kwargs) + + def _batch_inputs(self, input: list[str]) -> Iterator[list[str]]: + batch_size = self._config.get("batch_size", 96) + for index in range(0, len(input), batch_size): + yield input[index : index + batch_size] + + @staticmethod + def _to_data_uri(image: str | bytes | Path, mime_type: str = "image/png") -> str: + if isinstance(image, Path): + resolved_mime = mimetypes.guess_type(image.name)[0] or mime_type + data = image.read_bytes() + return ( + f"data:{resolved_mime};base64," + f"{base64.b64encode(data).decode('ascii')}" + ) + if isinstance(image, bytes): + return f"data:{mime_type};base64,{base64.b64encode(image).decode('ascii')}" + if image.startswith("data:"): + return image + path = Path(image) + if path.exists(): + return OCIEmbeddingFunction._to_data_uri(path, mime_type=mime_type) + raise ValueError( + "OCI image embeddings require a file path, raw bytes, or a data URI." + ) + + def __call__(self, input: Documents) -> Embeddings: + if isinstance(input, str): + input = [input] + embeddings: Embeddings = [] + for chunk in self._batch_inputs(input): + response = self._require_client().embed_text(self._build_request(chunk)) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings + + def embed_image( + self, image: str | bytes | Path, *, mime_type: str = "image/png" + ) -> list[float]: + return [ + float(value) + for value in self.embed_image_batch([image], mime_type=mime_type)[0] + ] + + def embed_image_batch( + self, images: Sequence[str | bytes | Path], *, mime_type: str = "image/png" + ) -> Embeddings: + embeddings: Embeddings = [] + for image in images: + data_uri = self._to_data_uri(image, mime_type=mime_type) + response = self._require_client().embed_text( + self._build_request([data_uri], input_type="IMAGE") + ) + embeddings.extend(cast(Embeddings, response.data.embeddings)) + return embeddings diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py new file mode 100644 index 0000000000..d0604fe628 --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/oci_provider.py @@ -0,0 +1,75 @@ +"""OCI embeddings provider.""" + +from typing import Any + +from pydantic import AliasChoices, Field + +from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class OCIProvider(BaseEmbeddingsProvider[OCIEmbeddingFunction]): + """OCI Generative AI embeddings provider.""" + + embedding_callable: type[OCIEmbeddingFunction] = Field( + default=OCIEmbeddingFunction, + description="OCI embedding function class", + ) + model_name: str = Field( + default="cohere.embed-english-v3.0", + description="Model name to use for embeddings", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_MODEL_NAME", "OCI_EMBED_MODEL", "model", "model_name", + ), + ) + compartment_id: str = Field( + description="OCI compartment ID", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_COMPARTMENT_ID", "OCI_COMPARTMENT_ID", "compartment_id", + ), + ) + service_endpoint: str | None = Field( + default=None, + description="OCI Generative AI inference endpoint", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_SERVICE_ENDPOINT", "OCI_SERVICE_ENDPOINT", "service_endpoint", + ), + ) + region: str | None = Field( + default=None, + description="OCI region for endpoint derivation", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_REGION", "OCI_REGION", "region", + ), + ) + auth_type: str = Field( + default="API_KEY", + description="OCI SDK auth type", + validation_alias=AliasChoices("EMBEDDINGS_OCI_AUTH_TYPE", "OCI_AUTH_TYPE"), + ) + auth_profile: str = Field( + default="DEFAULT", + description="OCI config profile name", + validation_alias=AliasChoices("EMBEDDINGS_OCI_AUTH_PROFILE", "OCI_AUTH_PROFILE"), + ) + auth_file_location: str = Field( + default="~/.oci/config", + description="OCI config file location", + validation_alias=AliasChoices( + "EMBEDDINGS_OCI_AUTH_FILE_LOCATION", "OCI_AUTH_FILE_LOCATION", + ), + ) + truncate: str = Field(default="END", description="OCI embedding truncate policy") + input_type: str | None = Field( + default=None, + description="Optional OCI embedding input type such as SEARCH_DOCUMENT or SEARCH_QUERY", + ) + output_dimensions: int | None = Field( + default=None, + description="Optional output dimensions for compatible OCI embedding models", + ) + batch_size: int = Field(default=96, description="OCI embedding batch size") + timeout: tuple[int, int] = Field( + default=(10, 120), description="OCI SDK connect/read timeout" + ) + client: Any | None = Field(default=None, description="Injected OCI client") diff --git a/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py new file mode 100644 index 0000000000..757d0be6ce --- /dev/null +++ b/lib/crewai/src/crewai/rag/embeddings/providers/oci/types.py @@ -0,0 +1,30 @@ +"""Type definitions for OCI embedding providers.""" + +from typing import Annotated, Any, Literal + +from typing_extensions import Required, TypedDict + + +class OCIProviderConfig(TypedDict, total=False): + """Configuration for OCI embedding provider.""" + + model_name: Annotated[str, "cohere.embed-english-v3.0"] + compartment_id: str + service_endpoint: str + region: str + auth_type: str + auth_profile: str + auth_file_location: str + truncate: str + input_type: str + output_dimensions: int + batch_size: int + timeout: tuple[int, int] + client: Any + + +class OCIProviderSpec(TypedDict, total=False): + """OCI provider specification.""" + + provider: Required[Literal["oci"]] + config: OCIProviderConfig diff --git a/lib/crewai/src/crewai/rag/embeddings/types.py b/lib/crewai/src/crewai/rag/embeddings/types.py index 794f4c6f9a..e01ecae3c9 100644 --- a/lib/crewai/src/crewai/rag/embeddings/types.py +++ b/lib/crewai/src/crewai/rag/embeddings/types.py @@ -17,6 +17,7 @@ from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec +from crewai.rag.embeddings.providers.oci.types import OCIProviderSpec from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec @@ -39,6 +40,7 @@ | InstructorProviderSpec | JinaProviderSpec | OllamaProviderSpec + | OCIProviderSpec | ONNXProviderSpec | OpenAIProviderSpec | OpenCLIPProviderSpec @@ -61,6 +63,7 @@ "instructor", "jina", "ollama", + "oci", "onnx", "openai", "openclip", diff --git a/lib/crewai/src/crewai/utilities/oci.py b/lib/crewai/src/crewai/utilities/oci.py new file mode 100644 index 0000000000..7935530690 --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any + + +def get_oci_module() -> Any: + """Import the OCI SDK lazily for optional CrewAI OCI integrations.""" + try: + import oci # type: ignore[import-untyped] + except ImportError: + raise ImportError( + 'OCI support is not available, to install: uv add "crewai[oci]"' + ) from None + return oci + + +def create_oci_client_kwargs( + *, + auth_type: str, + service_endpoint: str | None, + auth_file_location: str, + auth_profile: str, + timeout: tuple[int, int], + oci_module: Any | None = None, +) -> dict[str, Any]: + """Build OCI SDK client kwargs for the supported auth modes.""" + oci = oci_module or get_oci_module() + client_kwargs: dict[str, Any] = { + "config": {}, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": timeout, + } + + auth_type_upper = auth_type.upper() + if auth_type_upper == "API_KEY": + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + elif auth_type_upper == "SECURITY_TOKEN": + config = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + key_file = config["key_file"] + security_token_file = config["security_token_file"] + private_key = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as file: + security_token = file.read() + client_kwargs["config"] = config + client_kwargs["signer"] = oci.auth.signers.SecurityTokenSigner( + security_token, private_key + ) + elif auth_type_upper == "INSTANCE_PRINCIPAL": + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type_upper == "RESOURCE_PRINCIPAL": + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + valid_types = [ + "API_KEY", + "SECURITY_TOKEN", + "INSTANCE_PRINCIPAL", + "RESOURCE_PRINCIPAL", + ] + raise ValueError( + f"Invalid OCI auth_type '{auth_type}'. Valid values: {valid_types}" + ) + + return client_kwargs diff --git a/lib/crewai/src/crewai/utilities/oci_async.py b/lib/crewai/src/crewai/utilities/oci_async.py new file mode 100644 index 0000000000..3a3d7d9ec9 --- /dev/null +++ b/lib/crewai/src/crewai/utilities/oci_async.py @@ -0,0 +1,178 @@ +"""True async HTTP client for OCI Generative AI. + +Bypasses the synchronous OCI SDK for HTTP, using aiohttp directly +with OCI request signing. This gives real async I/O instead of +thread-pool wrappers, until the OCI SDK ships native async support. +""" + +from __future__ import annotations + +import json +import ssl +from contextlib import asynccontextmanager +from typing import Any, AsyncGenerator, AsyncIterator + +import aiohttp +import certifi +import requests + + +def _get_oci_genai_api_version() -> str: + """Detect OCI GenAI API version from SDK, fallback to known version.""" + try: + from oci.generative_ai_inference import GenerativeAiInferenceClient + + if hasattr(GenerativeAiInferenceClient, "API_VERSION"): + return GenerativeAiInferenceClient.API_VERSION + except ImportError: + pass + return "20231130" + + +OCI_GENAI_API_VERSION = _get_oci_genai_api_version() + + +def _snake_to_camel(name: str) -> str: + components = name.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _convert_keys_to_camel(obj: Any) -> Any: + """Recursively convert dict keys from snake_case to camelCase.""" + if isinstance(obj, dict): + return {_snake_to_camel(k): _convert_keys_to_camel(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [_convert_keys_to_camel(item) for item in obj] + return obj + + +class OCIAsyncClient: + """Async HTTP client for OCI Generative AI services. + + Uses aiohttp with OCI request signing for true async I/O. + Reuses aiohttp.ClientSession for connection pooling. + """ + + def __init__( + self, + service_endpoint: str, + signer: Any, + config: dict[str, Any] | None = None, + ) -> None: + self.service_endpoint = service_endpoint.rstrip("/") + self.signer = signer + self.config = config or {} + self._session: aiohttp.ClientSession | None = None + self._ensure_signer() + + def _ensure_signer(self) -> None: + if self.signer is not None: + return + if self.config: + try: + from oci.signer import Signer + + self.signer = Signer.from_config(self.config) + except Exception as e: + raise ValueError(f"Failed to create OCI signer from config: {e}") from e + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + self._session = aiohttp.ClientSession(connector=connector) + return self._session + + async def close(self) -> None: + if self._session is not None and not self._session.closed: + await self._session.close() + self._session = None + + async def __aenter__(self) -> OCIAsyncClient: + return self + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.close() + + def _sign_headers( + self, + method: str, + url: str, + body: dict[str, Any] | None = None, + stream: bool = False, + ) -> dict[str, str]: + req = requests.Request(method, url, json=body) + prepared = req.prepare() + signed = self.signer(prepared) + headers = dict(signed.headers) + if stream: + headers["Accept"] = "text/event-stream" + return headers + + @asynccontextmanager + async def _arequest( + self, + method: str, + url: str, + headers: dict[str, str], + json_body: dict[str, Any] | None = None, + timeout: int = 300, + ) -> AsyncGenerator[aiohttp.ClientResponse, None]: + session = await self._get_session() + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with session.request( + method, url, headers=headers, json=json_body, timeout=client_timeout + ) as response: + yield response + + async def _parse_sse_async( + self, content: aiohttp.StreamReader + ) -> AsyncIterator[dict[str, Any]]: + async for line in content: + line = line.strip() + if not line: + continue + decoded = line.decode("utf-8") + if decoded.lower().startswith("data:"): + data = decoded[5:].strip() + if data and not data.startswith("[DONE]"): + try: + yield json.loads(data) + except json.JSONDecodeError: + continue + + async def chat_async( + self, + compartment_id: str, + chat_request_dict: dict[str, Any], + serving_mode_dict: dict[str, Any], + stream: bool = False, + timeout: int = 300, + ) -> AsyncIterator[dict[str, Any]]: + """Make async chat request to OCI GenAI. + + Yields SSE events for streaming, or a single response dict. + """ + url = f"{self.service_endpoint}/{OCI_GENAI_API_VERSION}/actions/chat" + + body = { + "compartmentId": compartment_id, + "servingMode": _convert_keys_to_camel(serving_mode_dict), + "chatRequest": _convert_keys_to_camel(chat_request_dict), + } + + headers = self._sign_headers("POST", url, body, stream=stream) + + async with self._arequest("POST", url, headers, body, timeout) as response: + if response.status != 200: + error_text = await response.text() + raise RuntimeError( + f"OCI GenAI async request failed ({response.status}): {error_text}" + ) + + if stream: + async for event in self._parse_sse_async(response.content): + yield event + else: + data = await response.json() + yield data diff --git a/lib/crewai/tests/llms/oci/__init__.py b/lib/crewai/tests/llms/oci/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/crewai/tests/llms/oci/conftest.py b/lib/crewai/tests/llms/oci/conftest.py new file mode 100644 index 0000000000..164f53060f --- /dev/null +++ b/lib/crewai/tests/llms/oci/conftest.py @@ -0,0 +1,189 @@ +"""Fixtures for OCI provider unit and integration tests.""" + +from __future__ import annotations + +import os +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Fake OCI SDK module (replaces `import oci` in unit tests) +# --------------------------------------------------------------------------- + + +def _make_fake_oci_module() -> MagicMock: + """Build a lightweight mock of the OCI SDK surface used by OCICompletion.""" + oci = MagicMock() + + # Models namespace + models = oci.generative_ai_inference.models + + # Serving modes + models.OnDemandServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.DedicatedServingMode = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Content types + models.TextContent = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Message types + for cls_name in ( + "UserMessage", + "AssistantMessage", + "SystemMessage", + "CohereUserMessage", + "CohereSystemMessage", + "CohereChatBotMessage", + ): + setattr(models, cls_name, MagicMock(side_effect=lambda **kw: MagicMock(**kw))) + + # Request types + models.GenericChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.CohereChatRequest = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + models.BaseChatRequest = MagicMock() + models.BaseChatRequest.API_FORMAT_GENERIC = "GENERIC" + models.BaseChatRequest.API_FORMAT_COHERE = "COHERE" + + # ChatDetails + models.ChatDetails = MagicMock(side_effect=lambda **kw: MagicMock(**kw)) + + # Auth helpers + oci.config.from_file = MagicMock(return_value={"key_file": "/tmp/k", "security_token_file": "/tmp/t"}) + oci.signer.load_private_key_from_file = MagicMock(return_value="pk") + oci.auth.signers.SecurityTokenSigner = MagicMock() + oci.auth.signers.InstancePrincipalsSecurityTokenSigner = MagicMock() + oci.auth.signers.get_resource_principals_signer = MagicMock() + oci.retry.DEFAULT_RETRY_STRATEGY = "default_retry" + + # Client constructor + oci.generative_ai_inference.GenerativeAiInferenceClient = MagicMock() + + return oci + + +def _make_fake_chat_response(text: str = "Hello from OCI") -> MagicMock: + """Build a minimal OCI chat response for generic models.""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + + usage = MagicMock() + usage.prompt_tokens = 10 + usage.completion_tokens = 5 + usage.total_tokens = 15 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_fake_cohere_chat_response(text: str = "Hello from Cohere") -> MagicMock: + """Build a minimal OCI chat response for Cohere models.""" + chat_response = MagicMock() + chat_response.text = text + chat_response.finish_reason = "COMPLETE" + chat_response.tool_calls = None + + usage = MagicMock() + usage.prompt_tokens = 8 + usage.completion_tokens = 4 + usage.total_tokens = 12 + chat_response.usage = usage + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +@pytest.fixture() +def oci_fake_module() -> MagicMock: + return _make_fake_oci_module() + + +@pytest.fixture() +def patch_oci_module(monkeypatch: pytest.MonkeyPatch, oci_fake_module: MagicMock) -> MagicMock: + """Patch the OCI module import so no real SDK is needed.""" + monkeypatch.setattr( + "crewai.llms.providers.oci.completion._get_oci_module", + lambda: oci_fake_module, + ) + return oci_fake_module + + +@pytest.fixture() +def oci_response_factories() -> dict[str, Any]: + return { + "chat": _make_fake_chat_response, + "cohere_chat": _make_fake_cohere_chat_response, + } + + +# --------------------------------------------------------------------------- +# Unit test defaults +# --------------------------------------------------------------------------- + +@pytest.fixture() +def oci_unit_values() -> dict[str, str]: + return { + "compartment_id": "ocid1.compartment.oc1..test", + "model": "meta.llama-3.3-70b-instruct", + "cohere_model": "cohere.command-r-plus-08-2024", + } + + +# --------------------------------------------------------------------------- +# Integration test fixtures (live OCI calls) +# --------------------------------------------------------------------------- + +def _env_models(env_var: str, fallback_var: str, default: str) -> list[str]: + """Read model list from env, supporting comma-separated values.""" + raw = os.getenv(env_var) or os.getenv(fallback_var) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live_config() -> dict[str, str]: + """Return live config dict or skip the test.""" + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set — skipping live test") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT for live tests") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + if os.getenv("OCI_AUTH_FILE_LOCATION"): + config["auth_file_location"] = os.getenv("OCI_AUTH_FILE_LOCATION", "~/.oci/config") + return config + + +@pytest.fixture( + params=_env_models("OCI_TEST_MODELS", "OCI_TEST_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_chat_model(request: pytest.FixtureRequest) -> str: + return request.param + + +@pytest.fixture() +def oci_live_config() -> dict[str, str]: + return _skip_unless_live_config() diff --git a/lib/crewai/tests/llms/oci/test_oci.py b/lib/crewai/tests/llms/oci/test_oci.py new file mode 100644 index 0000000000..a5f7efb1b6 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci.py @@ -0,0 +1,269 @@ +"""Unit tests for the OCI Generative AI provider (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +# --------------------------------------------------------------------------- +# Provider routing +# --------------------------------------------------------------------------- + + +def test_oci_completion_is_used_when_oci_provider(patch_oci_module): + """LLM(model='oci/...') should resolve to OCICompletion.""" + from crewai.llm import LLM + + fake_client = MagicMock() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + llm = LLM( + model="oci/meta.llama-3.3-70b-instruct", + compartment_id="ocid1.compartment.oc1..test", + ) + from crewai.llms.providers.oci.completion import OCICompletion + + # LLM.__new__ returns the native provider instance directly + assert isinstance(llm, OCICompletion) + + +@pytest.mark.parametrize( + "model_id, expected_provider", + [ + ("meta.llama-3.3-70b-instruct", "generic"), + ("google.gemini-2.5-flash", "generic"), + ("openai.gpt-4o", "generic"), + ("xai.grok-3", "generic"), + ("cohere.command-r-plus-08-2024", "cohere"), + ], +) +def test_oci_completion_infers_provider_family( + patch_oci_module, oci_unit_values, model_id, expected_provider +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=model_id, + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.oci_provider == expected_provider + + +# --------------------------------------------------------------------------- +# Initialization +# --------------------------------------------------------------------------- + + +def test_oci_completion_initialization_parameters(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + temperature=0.7, + max_tokens=512, + top_p=0.9, + top_k=40, + ) + assert llm.temperature == 0.7 + assert llm.max_tokens == 512 + assert llm.top_p == 0.9 + assert llm.top_k == 40 + assert llm.compartment_id == oci_unit_values["compartment_id"] + + +def test_oci_completion_uses_region_to_build_endpoint(patch_oci_module, oci_unit_values, monkeypatch): + from crewai.llms.providers.oci.completion import OCICompletion + + monkeypatch.delenv("OCI_SERVICE_ENDPOINT", raising=False) + monkeypatch.setenv("OCI_REGION", "us-ashburn-1") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert "us-ashburn-1" in llm.service_endpoint + + +# --------------------------------------------------------------------------- +# Basic call +# --------------------------------------------------------------------------- + + +def test_oci_completion_call_uses_chat_api( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("test response") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "test response" in result + fake_client.chat.assert_called_once() + + +def test_oci_completion_cohere_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["cohere_chat"]("cohere reply") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call(messages=[{"role": "user", "content": "Hi"}]) + + assert "cohere reply" in result + fake_client.chat.assert_called_once() + + +# --------------------------------------------------------------------------- +# Message normalization +# --------------------------------------------------------------------------- + + +def test_oci_completion_treats_none_content_as_empty_text( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._coerce_text(None) == "" + + +def test_oci_completion_call_normalizes_messages_once( + patch_oci_module, oci_response_factories, oci_unit_values +): + """Ensure normalize is not called twice when _call_impl receives a list.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + call_count = 0 + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + original_normalize = llm._normalize_messages + + def counting_normalize(msgs): + nonlocal call_count + call_count += 1 + return original_normalize(msgs) + + llm._normalize_messages = counting_normalize + + llm.call(messages=[{"role": "user", "content": "hi"}]) + # call() normalizes once, _call_impl should not normalize again + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# OpenAI model quirks +# --------------------------------------------------------------------------- + + +def test_oci_openai_models_use_max_completion_tokens( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-4o", + compartment_id=oci_unit_values["compartment_id"], + max_tokens=1024, + ) + request = llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args + assert call_kwargs is not None + kwargs = call_kwargs[1] if call_kwargs[1] else {} + assert kwargs.get("max_completion_tokens") == 1024 + assert "max_tokens" not in kwargs + + +def test_oci_openai_gpt5_omits_unsupported_temperature_and_stop( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model="openai.gpt-5", + compartment_id=oci_unit_values["compartment_id"], + temperature=0.5, + ) + llm.stop = ["END"] + llm._build_chat_request([{"role": "user", "content": "test"}]) + + models = patch_oci_module.generative_ai_inference.models + call_kwargs = models.GenericChatRequest.call_args[1] + assert "temperature" not in call_kwargs + assert "stop" not in call_kwargs + + +# --------------------------------------------------------------------------- +# Async +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_oci_completion_acall_delegates_to_call( + patch_oci_module, oci_response_factories, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = oci_response_factories["chat"]("async result") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = await llm.acall(messages=[{"role": "user", "content": "async test"}]) + + assert "async result" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_async.py b/lib/crewai/tests/llms/oci/test_oci_async.py new file mode 100644 index 0000000000..64edec2354 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_async.py @@ -0,0 +1,93 @@ +"""Tests for OCI true async support (aiohttp-based).""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _skip_unless_live() -> dict[str, str]: + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + if not region: + pytest.skip("OCI_REGION not set") + config: dict[str, str] = {"compartment_id": compartment} + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +@pytest.fixture() +def oci_async_config(): + return _skip_unless_live() + + +@pytest.mark.asyncio +async def test_oci_true_async_client_is_used(oci_async_config: dict): + """Verify the true async client is initialized with a real OCI SDK client.""" + from crewai.utilities.oci_async import OCIAsyncClient + + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + assert llm._async_client is not None + assert isinstance(llm._async_client, OCIAsyncClient) + + +@pytest.mark.asyncio +async def test_oci_true_async_acall(oci_async_config: dict): + """True async acall should return a text response without blocking threads.""" + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + result = await llm.acall( + messages=[{"role": "user", "content": "Say hello in one word."}] + ) + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_true_async_astream(oci_async_config: dict): + """True async astream should yield chunks without thread bridges.""" + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Count to 3."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full = "".join(chunks) + assert len(full) > 0 + + +@pytest.mark.asyncio +async def test_oci_true_async_concurrent_calls(oci_async_config: dict): + """Multiple concurrent acall should run without blocking each other.""" + import asyncio + + llm = OCICompletion( + model="meta.llama-3.3-70b-instruct", + **oci_async_config, + ) + + results = await asyncio.gather( + llm.acall(messages=[{"role": "user", "content": "Say 'one'"}]), + llm.acall(messages=[{"role": "user", "content": "Say 'two'"}]), + ) + + assert len(results) == 2 + assert all(isinstance(r, str) and len(r) > 0 for r in results) diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_basic.py b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py new file mode 100644 index 0000000000..adeb0f2cce --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_basic.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI basic text completion. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024,google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_basic.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_basic_call(oci_chat_model: str, oci_live_config: dict): + """Synchronous text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call(messages=[{"role": "user", "content": "Say 'hello world' in one sentence."}]) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_async_call(oci_chat_model: str, oci_live_config: dict): + """Async text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = await llm.acall(messages=[{"role": "user", "content": "What is 2+2? Answer in one word."}]) + + assert isinstance(result, str) + assert len(result) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py new file mode 100644 index 0000000000..ae8cbc3d9a --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_multimodal.py @@ -0,0 +1,106 @@ +"""Live integration tests for OCI Generative AI multimodal content. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MULTIMODAL_MODELS="google.gemini-2.5-flash" \ + uv run pytest tests/llms/oci/test_oci_integration_multimodal.py -v +""" + +from __future__ import annotations + +import base64 +import os +import struct +import zlib + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _env_models(env_var: str, fallback: str, default: str) -> list[str]: + raw = os.getenv(env_var) or os.getenv(fallback) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live(): + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +def _make_red_png() -> bytes: + """Generate a minimal valid 2x2 red PNG image in memory.""" + width, height = 2, 2 + # Each row: filter byte (0) + RGB pixels + raw_data = b"" + for _ in range(height): + raw_data += b"\x00" # filter: none + for _ in range(width): + raw_data += b"\xff\x00\x00" # red pixel + + def _chunk(chunk_type: bytes, data: bytes) -> bytes: + c = chunk_type + data + return struct.pack(">I", len(data)) + c + struct.pack(">I", zlib.crc32(c) & 0xFFFFFFFF) + + ihdr = struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0) + png = b"\x89PNG\r\n\x1a\n" + png += _chunk(b"IHDR", ihdr) + png += _chunk(b"IDAT", zlib.compress(raw_data)) + png += _chunk(b"IEND", b"") + return png + + +def _png_data_uri() -> str: + """Return a data URI for a small red PNG.""" + png_bytes = _make_red_png() + encoded = base64.standard_b64encode(png_bytes).decode("utf-8") + return f"data:image/png;base64,{encoded}" + + +@pytest.fixture( + params=_env_models( + "OCI_TEST_MULTIMODAL_MODELS", "OCI_TEST_MULTIMODAL_MODEL", "google.gemini-2.5-flash" + ), + ids=lambda m: m, +) +def oci_multimodal_model(request): + return request.param + + +@pytest.fixture() +def oci_multimodal_config(): + return _skip_unless_live() + + +def test_oci_live_image_input(oci_multimodal_model: str, oci_multimodal_config: dict): + """Vision model should describe an image sent as a data URI.""" + llm = OCICompletion(model=oci_multimodal_model, **oci_multimodal_config) + result = llm.call( + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is this image? Answer in one word."}, + {"type": "image_url", "image_url": {"url": _png_data_uri()}}, + ], + } + ] + ) + + assert isinstance(result, str) + assert len(result) > 0 + assert "red" in result.lower() diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py new file mode 100644 index 0000000000..db41b1f3a9 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_streaming.py @@ -0,0 +1,40 @@ +"""Live integration tests for OCI Generative AI streaming. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct,cohere.command-r-plus-08-2024" \ + uv run pytest tests/llms/oci/test_oci_integration_streaming.py -v +""" + +from __future__ import annotations + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def test_oci_live_streaming_call(oci_chat_model: str, oci_live_config: dict): + """Streaming text completion with a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, stream=True, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "Count from 1 to 5, one per line."}] + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_oci_live_astream(oci_chat_model: str, oci_live_config: dict): + """Async streaming should yield text chunks from a live OCI model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + chunks: list[str] = [] + async for chunk in llm.astream( + messages=[{"role": "user", "content": "Say hello in three words."}] + ): + chunks.append(chunk) + + assert len(chunks) > 0 + full_text = "".join(chunks) + assert len(full_text) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_structured.py b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py new file mode 100644 index 0000000000..62ec951574 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_structured.py @@ -0,0 +1,33 @@ +"""Live integration tests for OCI Generative AI structured output. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_MODELS="meta.llama-3.3-70b-instruct" \ + uv run pytest tests/llms/oci/test_oci_integration_structured.py -v +""" + +from __future__ import annotations + +from pydantic import BaseModel + +from crewai.llms.providers.oci.completion import OCICompletion + + +class CapitalResponse(BaseModel): + """Response containing a country's capital city.""" + country: str + capital: str + + +def test_oci_live_structured_output(oci_chat_model: str, oci_live_config: dict): + """Structured output should return a validated Pydantic model.""" + llm = OCICompletion(model=oci_chat_model, **oci_live_config) + result = llm.call( + messages=[{"role": "user", "content": "What is the capital of France? Answer with country and capital fields."}], + response_model=CapitalResponse, + ) + + assert isinstance(result, CapitalResponse) + assert result.capital.lower() == "paris" + assert result.country.lower() == "france" diff --git a/lib/crewai/tests/llms/oci/test_oci_integration_tools.py b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py new file mode 100644 index 0000000000..305dbc7cbb --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_integration_tools.py @@ -0,0 +1,100 @@ +"""Live integration tests for OCI Generative AI tool calling. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + OCI_TEST_TOOL_MODELS="meta.llama-3.3-70b-instruct" \ + uv run pytest tests/llms/oci/test_oci_integration_tools.py -v +""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.llms.providers.oci.completion import OCICompletion + + +def _env_models(env_var: str, fallback: str, default: str) -> list[str]: + raw = os.getenv(env_var) or os.getenv(fallback) or default + return [m.strip() for m in raw.split(",") if m.strip()] + + +def _skip_unless_live(): + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + endpoint = os.getenv("OCI_SERVICE_ENDPOINT") + if not region and not endpoint: + pytest.skip("Set OCI_REGION or OCI_SERVICE_ENDPOINT") + config: dict[str, str] = {"compartment_id": compartment} + if endpoint: + config["service_endpoint"] = endpoint + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +TOOL_SPEC = [ + { + "type": "function", + "function": { + "name": "add_numbers", + "description": "Add two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": {"type": "number", "description": "First number"}, + "b": {"type": "number", "description": "Second number"}, + }, + "required": ["a", "b"], + }, + }, + } +] + + +@pytest.fixture( + params=_env_models("OCI_TEST_TOOL_MODELS", "OCI_TEST_TOOL_MODEL", "meta.llama-3.3-70b-instruct"), + ids=lambda m: m, +) +def oci_tool_model(request): + return request.param + + +@pytest.fixture() +def oci_tool_config(): + return _skip_unless_live() + + +def test_oci_live_tool_call_returns_raw(oci_tool_model: str, oci_tool_config: dict): + """Without available_functions, tool calls should be returned raw.""" + llm = OCICompletion(model=oci_tool_model, **oci_tool_config) + result = llm.call( + messages=[{"role": "user", "content": "What is 3 + 7? Use the add_numbers tool."}], + tools=TOOL_SPEC, + ) + + assert isinstance(result, list) + assert len(result) >= 1 + assert result[0]["function"]["name"] == "add_numbers" + + +def test_oci_live_tool_call_with_execution(oci_tool_model: str, oci_tool_config: dict): + """With available_functions, tools should execute and model should respond.""" + def add_numbers(a: float, b: float) -> str: + return str(float(a) + float(b)) + + llm = OCICompletion(model=oci_tool_model, **oci_tool_config) + result = llm.call( + messages=[{"role": "user", "content": "What is 3 + 7? Use the add_numbers tool."}], + tools=TOOL_SPEC, + available_functions={"add_numbers": add_numbers}, + ) + + assert isinstance(result, str) + assert "10" in result diff --git a/lib/crewai/tests/llms/oci/test_oci_multimodal.py b/lib/crewai/tests/llms/oci/test_oci_multimodal.py new file mode 100644 index 0000000000..b27d77e645 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_multimodal.py @@ -0,0 +1,195 @@ +"""Unit tests for OCI provider multimodal content (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def test_oci_builds_image_content(patch_oci_module, oci_unit_values): + """image_url content should produce ImageContent with ImageUrl.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + # Mock multimodal content types + patch_oci_module.generative_ai_inference.models.ImageContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="image", **kw) + ) + patch_oci_module.generative_ai_inference.models.ImageUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [ + {"type": "text", "text": "What is in this image?"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc123"}}, + ] + result = llm._build_generic_content(content) + + assert len(result) == 2 + patch_oci_module.generative_ai_inference.models.ImageContent.assert_called_once() + patch_oci_module.generative_ai_inference.models.ImageUrl.assert_called_once() + + +def test_oci_builds_document_content(patch_oci_module, oci_unit_values): + """document_url content should produce DocumentContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.DocumentContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="document", **kw) + ) + patch_oci_module.generative_ai_inference.models.DocumentUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "document_url", "document_url": {"url": "data:application/pdf;base64,xyz"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.DocumentContent.assert_called_once() + + +def test_oci_builds_video_content(patch_oci_module, oci_unit_values): + """video_url content should produce VideoContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.VideoContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="video", **kw) + ) + patch_oci_module.generative_ai_inference.models.VideoUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "video_url", "video_url": {"url": "https://example.com/video.mp4"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.VideoContent.assert_called_once() + + +def test_oci_builds_audio_content(patch_oci_module, oci_unit_values): + """audio_url content should produce AudioContent.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.AudioContent = MagicMock( + side_effect=lambda **kw: MagicMock(type="audio", **kw) + ) + patch_oci_module.generative_ai_inference.models.AudioUrl = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + content = [{"type": "audio_url", "audio_url": {"url": "data:audio/wav;base64,wav123"}}] + result = llm._build_generic_content(content) + + assert len(result) == 1 + patch_oci_module.generative_ai_inference.models.AudioContent.assert_called_once() + + +def test_oci_rejects_unsupported_content_type(patch_oci_module, oci_unit_values): + """Unknown content types should raise ValueError.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + with pytest.raises(ValueError, match="Unsupported OCI content type"): + llm._build_generic_content([{"type": "hologram", "data": "xyz"}]) + + +def test_oci_cohere_rejects_multimodal(patch_oci_module, oci_unit_values): + """Cohere models should reject multimodal content in _build_chat_request.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + messages = [ + {"role": "user", "content": [ + {"type": "text", "text": "Describe this"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,abc"}}, + ]} + ] + + with pytest.raises(ValueError, match="text-only"): + llm._build_chat_request(messages) + + +def test_oci_message_has_multimodal_content(patch_oci_module, oci_unit_values): + """_message_has_multimodal_content should detect non-text content types.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + assert llm._message_has_multimodal_content("just text") is False + assert llm._message_has_multimodal_content([{"type": "text", "text": "hi"}]) is False + assert llm._message_has_multimodal_content([{"type": "image_url", "image_url": {"url": "x"}}]) is True + assert llm._message_has_multimodal_content([{"type": "text"}, {"type": "audio_url"}]) is True + + +def test_oci_supports_multimodal(patch_oci_module, oci_unit_values): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.supports_multimodal() is True + + +def test_vision_helpers(): + """Test vision.py utility functions.""" + from crewai.llms.providers.oci.vision import ( + VISION_MODELS, + encode_image, + is_vision_model, + to_data_uri, + ) + + # to_data_uri with bytes + uri = to_data_uri(b"\x89PNG", "image/png") + assert uri.startswith("data:image/png;base64,") + + # to_data_uri passthrough + existing = "data:image/jpeg;base64,abc" + assert to_data_uri(existing) == existing + + # encode_image + result = encode_image(b"\x89PNG") + assert result["type"] == "image_url" + assert result["image_url"]["url"].startswith("data:image/png;base64,") + + # is_vision_model + assert is_vision_model("google.gemini-2.5-flash") is True + assert is_vision_model("meta.llama-3.3-70b-instruct") is False + + assert len(VISION_MODELS) > 0 diff --git a/lib/crewai/tests/llms/oci/test_oci_streaming.py b/lib/crewai/tests/llms/oci/test_oci_streaming.py new file mode 100644 index 0000000000..721fa3f671 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_streaming.py @@ -0,0 +1,150 @@ +"""Unit tests for OCI provider streaming (mocked SDK).""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + + +def _make_fake_stream_event(text: str = "", finish_reason: str | None = None, usage: dict | None = None) -> MagicMock: + """Build a single SSE event with optional text, finish, and usage.""" + payload: dict = {} + if text: + payload["message"] = {"content": [{"text": text}]} + if finish_reason: + payload["finishReason"] = finish_reason + if usage: + payload["usage"] = usage + + import json + event = MagicMock() + event.data = json.dumps(payload) + return event + + +def _make_fake_stream_response(*events: MagicMock) -> MagicMock: + """Wrap events into a response.data.events() iterable.""" + response = MagicMock() + response.data.events.return_value = iter(events) + return response + + +def test_oci_completion_streams_generic_responses( + patch_oci_module, oci_unit_values +): + """Streaming call should accumulate text chunks and return full response.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="Hello "), + _make_fake_stream_event(text="world"), + _make_fake_stream_event( + finish_reason="stop", + usage={"promptTokens": 5, "completionTokens": 2, "totalTokens": 7}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # StreamOptions mock + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + stream=True, + ) + result = llm.call(messages=[{"role": "user", "content": "Say hello"}]) + + assert "Hello " in result + assert "world" in result + assert llm.last_response_metadata is not None + assert llm.last_response_metadata.get("finish_reason") == "stop" + + +def test_oci_iter_stream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """iter_stream should yield individual text chunks.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="chunk1"), + _make_fake_stream_event(text="chunk2"), + _make_fake_stream_event( + usage={"promptTokens": 3, "completionTokens": 2, "totalTokens": 5}, + ), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = list(llm.iter_stream(messages=[{"role": "user", "content": "test"}])) + + assert chunks == ["chunk1", "chunk2"] + assert llm.last_response_metadata is not None + assert llm.last_response_metadata["usage"]["total_tokens"] == 5 + + +@pytest.mark.asyncio +async def test_oci_astream_yields_text_chunks( + patch_oci_module, oci_unit_values +): + """astream should yield chunks via async generator.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [ + _make_fake_stream_event(text="async1"), + _make_fake_stream_event(text="async2"), + _make_fake_stream_event(), + ] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.StreamOptions = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + chunks = [] + async for chunk in llm.astream(messages=[{"role": "user", "content": "test"}]): + chunks.append(chunk) + + assert chunks == ["async1", "async2"] + + +def test_oci_stream_chat_events_holds_client_lock( + patch_oci_module, oci_unit_values +): + """_stream_chat_events should hold the client lock for the full iteration.""" + from crewai.llms.providers.oci.completion import OCICompletion + + events = [_make_fake_stream_event(text="a"), _make_fake_stream_event(text="b")] + fake_client = MagicMock() + fake_client.chat.return_value = _make_fake_stream_response(*events) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + + # Before streaming, ticket should be 0 + assert llm._active_client_ticket == 0 + chat_details = MagicMock() + list(llm._stream_chat_events(chat_details)) + # After streaming completes, ticket should have advanced + assert llm._active_client_ticket == 1 diff --git a/lib/crewai/tests/llms/oci/test_oci_structured.py b/lib/crewai/tests/llms/oci/test_oci_structured.py new file mode 100644 index 0000000000..caa4f330bb --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_structured.py @@ -0,0 +1,204 @@ +"""Unit tests for OCI provider structured output (mocked SDK).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel + + +class WeatherResponse(BaseModel): + """Weather forecast response.""" + city: str + temperature: float + unit: str + + +def _make_json_response(data: dict) -> MagicMock: + """Build a fake OCI response returning JSON text.""" + text_part = MagicMock() + text_part.text = json.dumps(data) + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=8, total_tokens=18) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_cohere_json_response(data: dict) -> MagicMock: + """Build a fake OCI Cohere response returning JSON text.""" + chat_response = MagicMock() + chat_response.text = json.dumps(data) + chat_response.tool_calls = None + chat_response.finish_reason = "COMPLETE" + chat_response.usage = MagicMock(prompt_tokens=8, completion_tokens=6, total_tokens=14) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def test_oci_completion_structured_output_generic( + patch_oci_module, oci_unit_values +): + """response_model should parse JSON response into a Pydantic model.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_json_response( + {"city": "NYC", "temperature": 72.0, "unit": "F"} + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # Mock schema-related models + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in NYC?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "NYC" + assert result.temperature == 72.0 + assert result.unit == "F" + + +def test_oci_completion_structured_output_cohere( + patch_oci_module, oci_unit_values +): + """Cohere models should use CohereResponseJsonFormat for structured output.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_cohere_json_response( + {"city": "London", "temperature": 15.0, "unit": "C"} + ) + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.CohereResponseJsonFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in London?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "London" + + +def test_oci_completion_structured_output_with_fenced_json( + patch_oci_module, oci_unit_values +): + """Should handle JSON wrapped in markdown fences.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fenced = '```json\n{"city": "Tokyo", "temperature": 25.0, "unit": "C"}\n```' + text_part = MagicMock() + text_part.text = fenced + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=8, total_tokens=18) + + response = MagicMock() + response.data.chat_response = chat_response + + fake_client = MagicMock() + fake_client.chat.return_value = response + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in Tokyo?"}], + response_model=WeatherResponse, + ) + + assert isinstance(result, WeatherResponse) + assert result.city == "Tokyo" + + +def test_oci_build_response_format_returns_none_without_model( + patch_oci_module, oci_unit_values +): + """_build_response_format should return None when no response_model.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm._build_response_format(None) is None + + +def test_oci_build_response_format_creates_json_schema( + patch_oci_module, oci_unit_values +): + """_build_response_format should create a JsonSchemaResponseFormat for generic models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + patch_oci_module.generative_ai_inference.models.ResponseJsonSchema = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat = MagicMock( + side_effect=lambda **kw: MagicMock(**kw) + ) + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm._build_response_format(WeatherResponse) + + assert result is not None + patch_oci_module.generative_ai_inference.models.JsonSchemaResponseFormat.assert_called_once() diff --git a/lib/crewai/tests/llms/oci/test_oci_tools.py b/lib/crewai/tests/llms/oci/test_oci_tools.py new file mode 100644 index 0000000000..bfe08eb912 --- /dev/null +++ b/lib/crewai/tests/llms/oci/test_oci_tools.py @@ -0,0 +1,291 @@ +"""Unit tests for OCI provider tool calling (mocked SDK).""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +import pytest + + +def _make_tool_call_response(tool_name: str = "get_weather", args: dict | None = None) -> MagicMock: + """Build a fake OCI response with a generic tool call.""" + tc = MagicMock() + tc.id = "tc_001" + tc.name = tool_name + tc.arguments = json.dumps(args or {"city": "NYC"}) + + message = MagicMock() + message.content = [MagicMock(text="")] + message.tool_calls = [tc] + + choice = MagicMock() + choice.message = message + choice.finish_reason = "tool_calls" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_text_response(text: str = "The weather is sunny.") -> MagicMock: + """Build a fake OCI text response (used after tool execution).""" + text_part = MagicMock() + text_part.text = text + + message = MagicMock() + message.content = [text_part] + message.tool_calls = None + + choice = MagicMock() + choice.message = message + choice.finish_reason = "stop" + + chat_response = MagicMock() + chat_response.choices = [choice] + chat_response.finish_reason = None + chat_response.usage = MagicMock(prompt_tokens=20, completion_tokens=10, total_tokens=30) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +def _make_cohere_tool_call_response(tool_name: str = "get_weather") -> MagicMock: + """Build a fake OCI Cohere response with a tool call.""" + tc = MagicMock() + tc.name = tool_name + tc.parameters = {"city": "NYC"} + + chat_response = MagicMock() + chat_response.text = "" + chat_response.tool_calls = [tc] + chat_response.finish_reason = "COMPLETE" + chat_response.usage = MagicMock(prompt_tokens=8, completion_tokens=4, total_tokens=12) + + response = MagicMock() + response.data.chat_response = chat_response + return response + + +SAMPLE_TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the weather for a city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string", "description": "The city name"}}, + "required": ["city"], + }, + }, + } +] + + +def test_oci_completion_returns_tool_calls_for_executor( + patch_oci_module, oci_unit_values +): + """When no available_functions, tool calls should be returned raw.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "What is the weather?"}], + tools=SAMPLE_TOOLS, + ) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0]["function"]["name"] == "get_weather" + + +def test_oci_completion_executes_tool_calls_recursively( + patch_oci_module, oci_unit_values +): + """With available_functions, tool should be executed and model re-called.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + # First call returns tool call, second call returns text + fake_client.chat.side_effect = [ + _make_tool_call_response(), + _make_text_response("It is sunny in NYC."), + ] + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + def mock_get_weather(city: str) -> str: + return f"Sunny in {city}" + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather in NYC?"}], + tools=SAMPLE_TOOLS, + available_functions={"get_weather": mock_get_weather}, + ) + + assert isinstance(result, str) + assert "sunny" in result.lower() or "NYC" in result + assert fake_client.chat.call_count == 2 + + +def test_oci_completion_formats_generic_tools( + patch_oci_module, oci_unit_values +): + """_format_tools should produce FunctionDefinition for generic models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + formatted = llm._format_tools(SAMPLE_TOOLS) + + assert len(formatted) == 1 + models = patch_oci_module.generative_ai_inference.models + models.FunctionDefinition.assert_called_once() + + +def test_oci_completion_formats_cohere_tools( + patch_oci_module, oci_unit_values +): + """_format_tools should produce CohereTool for Cohere models.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + formatted = llm._format_tools(SAMPLE_TOOLS) + + assert len(formatted) == 1 + models = patch_oci_module.generative_ai_inference.models + models.CohereTool.assert_called_once() + + +def test_oci_completion_cohere_extracts_tool_calls( + patch_oci_module, oci_unit_values +): + """Cohere tool calls should be normalized to CrewAI shape with generated IDs.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_cohere_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + ) + result = llm.call( + messages=[{"role": "user", "content": "Weather?"}], + tools=SAMPLE_TOOLS, + ) + + assert isinstance(result, list) + assert result[0]["function"]["name"] == "get_weather" + assert result[0]["id"] # Should have a generated UUID + + +def test_oci_completion_rejects_parallel_tools_for_cohere( + patch_oci_module, oci_unit_values +): + """Cohere models should raise if parallel_tool_calls is enabled.""" + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["cohere_model"], + compartment_id=oci_unit_values["compartment_id"], + parallel_tool_calls=True, + ) + + with pytest.raises(ValueError, match="parallel_tool_calls"): + llm._build_chat_request( + [{"role": "user", "content": "test"}], + tools=SAMPLE_TOOLS, + ) + + +def test_oci_completion_respects_max_sequential_tool_calls( + patch_oci_module, oci_unit_values +): + """Should raise RuntimeError when tool depth exceeds max.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + # Always return tool calls to force recursion + fake_client.chat.return_value = _make_tool_call_response() + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + + def mock_tool(city: str) -> str: + return "result" + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + max_sequential_tool_calls=2, + ) + + with pytest.raises(RuntimeError, match="max_sequential_tool_calls"): + llm.call( + messages=[{"role": "user", "content": "test"}], + tools=SAMPLE_TOOLS, + available_functions={"get_weather": mock_tool}, + ) + + +def test_oci_completion_supports_function_calling( + patch_oci_module, oci_unit_values +): + from crewai.llms.providers.oci.completion import OCICompletion + + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = MagicMock() + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + ) + assert llm.supports_function_calling() is True + + +def test_oci_completion_filters_unknown_passthrough_params( + patch_oci_module, oci_unit_values +): + """Unknown additional_params should not crash the OCI SDK request.""" + from crewai.llms.providers.oci.completion import OCICompletion + + fake_client = MagicMock() + fake_client.chat.return_value = _make_text_response("ok") + patch_oci_module.generative_ai_inference.GenerativeAiInferenceClient.return_value = fake_client + # Make GenericChatRequest only accept known keys + patch_oci_module.generative_ai_inference.models.GenericChatRequest.attribute_map = { + "messages": "messages", + "api_format": "apiFormat", + } + + llm = OCICompletion( + model=oci_unit_values["model"], + compartment_id=oci_unit_values["compartment_id"], + additional_params={"bogus_param": "should_be_filtered"}, + ) + # Should not raise + result = llm.call(messages=[{"role": "user", "content": "test"}]) + assert isinstance(result, str) diff --git a/lib/crewai/tests/rag/embeddings/test_factory_oci.py b/lib/crewai/tests/rag/embeddings/test_factory_oci.py new file mode 100644 index 0000000000..9cc125b02b --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_factory_oci.py @@ -0,0 +1,207 @@ +"""Tests for OCI embedding provider wiring.""" + +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from crewai.rag.embeddings.factory import build_embedder +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +class _FakeOCI: + def __init__(self) -> None: + self.retry = SimpleNamespace(DEFAULT_RETRY_STRATEGY="retry") + self.config = SimpleNamespace( + from_file=lambda file_location, profile_name: { + "file_location": file_location, + "profile_name": profile_name, + } + ) + self.signer = SimpleNamespace( + load_private_key_from_file=lambda *_args, **_kwargs: "private-key" + ) + self.auth = SimpleNamespace( + signers=SimpleNamespace( + SecurityTokenSigner=lambda token, key: (token, key), + InstancePrincipalsSecurityTokenSigner=lambda: "instance-principal", + get_resource_principals_signer=lambda: "resource-principal", + ) + ) + self.generative_ai_inference = SimpleNamespace( + GenerativeAiInferenceClient=MagicMock(), + models=SimpleNamespace( + EmbedTextDetails=_simple_init_class("EmbedTextDetails"), + OnDemandServingMode=_simple_init_class("OnDemandServingMode"), + DedicatedServingMode=_simple_init_class("DedicatedServingMode"), + ), + ) + + +def _simple_init_class(name: str): + class _Simple: + output_dimensions = None + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + _Simple.__name__ = name + return _Simple + + +@patch("crewai.rag.embeddings.factory.import_and_validate_definition") +def test_build_embedder_oci(mock_import): + """Test building OCI embedder.""" + mock_provider_class = MagicMock() + mock_provider_instance = MagicMock() + mock_embedding_function = MagicMock() + + mock_import.return_value = mock_provider_class + mock_provider_class.return_value = mock_provider_instance + mock_provider_instance.embedding_callable.return_value = mock_embedding_function + + config = { + "provider": "oci", + "config": { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "region": "us-chicago-1", + "auth_profile": "DEFAULT", + }, + } + + build_embedder(config) + + mock_import.assert_called_once_with( + "crewai.rag.embeddings.providers.oci.oci_provider.OCIProvider" + ) + call_kwargs = mock_provider_class.call_args.kwargs + assert call_kwargs["model_name"] == "cohere.embed-english-v3.0" + assert call_kwargs["compartment_id"] == "ocid1.compartment.oc1..test" + assert call_kwargs["region"] == "us-chicago-1" + + +def test_oci_embedding_function_batches_requests(monkeypatch): + """Test OCI embedding batching and request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.side_effect = [ + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.1, 0.2], [0.3, 0.4]])), + SimpleNamespace(data=SimpleNamespace(embeddings=[[0.5, 0.6]])), + ] + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + region="us-chicago-1", + batch_size=2, + ) + + result = embedder(["a", "b", "c"]) + + result_rows = [embedding.tolist() for embedding in result] + expected_rows = [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]] + assert len(result_rows) == len(expected_rows) + for actual, expected in zip(result_rows, expected_rows, strict=True): + assert actual == pytest.approx(expected) + assert fake_client.embed_text.call_count == 2 + first_request = fake_client.embed_text.call_args_list[0].args[0] + assert first_request.compartment_id == "ocid1.compartment.oc1..test" + assert first_request.serving_mode.model_id == "cohere.embed-english-v3.0" + + +def test_oci_embedding_function_supports_output_dimensions(monkeypatch): + """Test OCI output_dimensions mapping.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.1, 0.2]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + output_dimensions=512, + ) + + embedder(["hello"]) + + request = fake_client.embed_text.call_args.args[0] + assert request.output_dimensions == 512 + + +def test_oci_embedding_function_exposes_serializable_config(monkeypatch): + """Test OCI embedding config serialization for ChromaDB compatibility.""" + fake_oci = _FakeOCI() + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + MagicMock() + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + compartment_id="ocid1.compartment.oc1..test", + timeout=(5, 30), + ) + + assert embedder.get_config() == { + "model_name": "cohere.embed-english-v3.0", + "compartment_id": "ocid1.compartment.oc1..test", + "timeout": [5, 30], + } + + rebuilt = OCIEmbeddingFunction.build_from_config(embedder.get_config()) + assert rebuilt.get_config() == embedder.get_config() + + +def test_oci_embedding_function_supports_image_embeddings(monkeypatch, tmp_path: Path): + """Test OCI image embedding request construction.""" + fake_oci = _FakeOCI() + fake_client = MagicMock() + fake_client.embed_text.return_value = SimpleNamespace( + data=SimpleNamespace(embeddings=[[0.7, 0.8, 0.9]]) + ) + fake_oci.generative_ai_inference.GenerativeAiInferenceClient.return_value = ( + fake_client + ) + + monkeypatch.setattr( + "crewai.rag.embeddings.providers.oci.embedding_callable._get_oci_module", + lambda: fake_oci, + ) + + image_path = tmp_path / "diagram.png" + image_path.write_bytes(b"\x89PNG\r\n\x1a\n") + + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-v4.0", + compartment_id="ocid1.compartment.oc1..test", + ) + + result = embedder.embed_image(image_path) + + assert result == pytest.approx([0.7, 0.8, 0.9]) + request = fake_client.embed_text.call_args.args[0] + assert request.input_type == "IMAGE" + assert request.inputs[0].startswith("data:image/png;base64,") diff --git a/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py new file mode 100644 index 0000000000..6d5c12a80c --- /dev/null +++ b/lib/crewai/tests/rag/embeddings/test_oci_embedding_integration.py @@ -0,0 +1,66 @@ +"""Live integration tests for OCI embedding provider. + +Run with: + OCI_AUTH_TYPE=API_KEY OCI_AUTH_PROFILE=API_KEY_AUTH \ + OCI_COMPARTMENT_ID= OCI_REGION=us-chicago-1 \ + uv run pytest tests/rag/embeddings/test_oci_embedding_integration.py -v +""" + +from __future__ import annotations + +import os + +import pytest + +from crewai.rag.embeddings.providers.oci.embedding_callable import OCIEmbeddingFunction + + +def _skip_unless_live() -> dict[str, str]: + compartment = os.getenv("OCI_COMPARTMENT_ID") + if not compartment: + pytest.skip("OCI_COMPARTMENT_ID not set") + region = os.getenv("OCI_REGION") + if not region: + pytest.skip("OCI_REGION not set") + config: dict[str, str] = { + "compartment_id": compartment, + "region": region, + } + if os.getenv("OCI_AUTH_TYPE"): + config["auth_type"] = os.getenv("OCI_AUTH_TYPE", "API_KEY") + if os.getenv("OCI_AUTH_PROFILE"): + config["auth_profile"] = os.getenv("OCI_AUTH_PROFILE", "DEFAULT") + return config + + +@pytest.fixture() +def oci_embed_config(): + return _skip_unless_live() + + +def test_oci_live_text_embedding(oci_embed_config: dict): + """Embed a text string and verify we get a non-empty vector.""" + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + input_type="SEARCH_DOCUMENT", + **oci_embed_config, + ) + result = embedder(["Hello world"]) + + assert len(result) == 1 + assert len(result[0]) > 0 + + +def test_oci_live_batch_embedding(oci_embed_config: dict): + """Batch embed multiple texts.""" + embedder = OCIEmbeddingFunction( + model_name="cohere.embed-english-v3.0", + input_type="SEARCH_DOCUMENT", + **oci_embed_config, + ) + texts = ["The cat sat on the mat", "The dog ran in the park", "Python is great"] + result = embedder(texts) + + assert len(result) == 3 + for vec in result: + assert len(vec) > 0