Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions lib/crewai/src/crewai/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def _handle_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return structured_response

Expand All @@ -1030,7 +1030,7 @@ def _handle_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return full_response

Expand All @@ -1045,7 +1045,7 @@ def _handle_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return full_response

Expand All @@ -1066,7 +1066,7 @@ def _handle_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return full_response

Expand Down Expand Up @@ -1217,7 +1217,7 @@ def _handle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=messages,
)
return structured_response

Expand Down Expand Up @@ -1258,7 +1258,7 @@ def _handle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return structured_response

Expand Down Expand Up @@ -1289,7 +1289,7 @@ def _handle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return text_response

Expand All @@ -1312,7 +1312,7 @@ def _handle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return text_response

Expand Down Expand Up @@ -1361,7 +1361,7 @@ async def _ahandle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=messages,
)
return structured_response

Expand Down Expand Up @@ -1396,7 +1396,7 @@ async def _ahandle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return structured_response

Expand Down Expand Up @@ -1425,7 +1425,7 @@ async def _ahandle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return text_response

Expand All @@ -1447,7 +1447,7 @@ async def _ahandle_non_streaming_response(
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
messages=params.get("messages", []),
)
return text_response

Expand Down
106 changes: 106 additions & 0 deletions lib/crewai/tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,3 +1024,109 @@ async def test_usage_info_streaming_with_acall():
assert llm._token_usage["total_tokens"] > 0

assert len(result) > 0


def test_non_streaming_response_no_keyerror_when_messages_missing_from_params():
"""Test that _handle_non_streaming_response does not raise KeyError when
params dict lacks a 'messages' key. Covers the fix for issue #5164."""
llm = LLM(model="gpt-4o-mini", is_litellm=True)

mock_message = MagicMock()
mock_message.content = "Test response"
mock_message.tool_calls = []
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = MagicMock()

with patch("litellm.completion", return_value=mock_response):
# Pass params WITHOUT "messages" key — before the fix this raised KeyError
result = llm._handle_non_streaming_response(params={"model": "gpt-4o-mini"})

assert result == "Test response"


def test_non_streaming_response_uses_validated_messages_for_litellm_response_model():
"""Test that _handle_non_streaming_response uses the locally validated
'messages' variable (not params['messages']) in the response_model + is_litellm
branch. Covers the fix for issue #5164."""
llm = LLM(model="gpt-4o-mini", is_litellm=True)

class DummyModel(BaseModel):
answer: str

messages = [{"role": "user", "content": "test"}]
params = {"model": "gpt-4o-mini", "messages": messages}

mock_result = MagicMock()
mock_result.model_dump_json.return_value = '{"answer": "ok"}'

with patch(
"crewai.utilities.internal_instructor.InternalInstructor"
) as MockInstructor:
instance = MockInstructor.return_value
instance.to_pydantic.return_value = mock_result

result = llm._handle_non_streaming_response(
params=params, response_model=DummyModel
)

assert result == '{"answer": "ok"}'


def test_non_streaming_response_with_response_model_no_keyerror():
"""Test that _handle_non_streaming_response does not raise KeyError
in the response_model + is_litellm branch when messages key is missing.
Before the fix, this would raise KeyError at the _handle_emit_call_events call."""
llm = LLM(model="gpt-4o-mini", is_litellm=True)

class DummyModel(BaseModel):
answer: str

# No "messages" key in params — should raise ValueError, not KeyError
params = {"model": "gpt-4o-mini"}

with pytest.raises(ValueError, match="Messages are required"):
llm._handle_non_streaming_response(params=params, response_model=DummyModel)


@pytest.mark.asyncio
async def test_async_non_streaming_response_no_keyerror_when_messages_missing():
"""Test that _ahandle_non_streaming_response does not raise KeyError when
params dict lacks a 'messages' key. Covers the async fix for issue #5164."""
llm = LLM(model="gpt-4o-mini", is_litellm=True)

mock_message = MagicMock()
mock_message.content = "Async response"
mock_message.tool_calls = []
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = MagicMock()

with patch("litellm.acompletion", return_value=mock_response):
result = await llm._ahandle_non_streaming_response(
params={"model": "gpt-4o-mini"}
)

assert result == "Async response"


@pytest.mark.asyncio
async def test_async_non_streaming_response_with_response_model_no_keyerror():
"""Test that _ahandle_non_streaming_response does not raise KeyError
in the response_model + is_litellm branch when messages key is missing.
Covers the async fix for issue #5164."""
llm = LLM(model="gpt-4o-mini", is_litellm=True)

class DummyModel(BaseModel):
answer: str

params = {"model": "gpt-4o-mini"}

with pytest.raises(ValueError, match="Messages are required"):
await llm._ahandle_non_streaming_response(
params=params, response_model=DummyModel
)
Loading