diff --git a/src/forge/actors/vllm/v1/generator.py b/src/forge/actors/vllm/v1/generator.py index ea7e0326a..65b0c92e0 100644 --- a/src/forge/actors/vllm/v1/generator.py +++ b/src/forge/actors/vllm/v1/generator.py @@ -38,8 +38,15 @@ from torchstore.api import _controller as get_torchstore_controller from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.llm import UsageContext +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ExtractedToolCallInformation, + ToolCall, +) from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams +from vllm.tokenizers import cached_tokenizer_from_config +from vllm.tool_parsers import ToolParserManager from vllm.v1.engine.async_llm import AsyncLLM logger = logging.getLogger(__name__) @@ -78,6 +85,7 @@ class Generator(ForgeActor): sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams) prefetch_weights_to_shm: bool = True n_fetcher_procs: int = 8 + tool_call_parser: str | None = None def __post_init__(self): super().__init__() @@ -91,6 +99,8 @@ def __post_init__(self): self.engine_args = EngineArgs(**self.engine_args) self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS) + self._tool_parser = None # Will hold ToolParser instance if configured + if isinstance(self.sampling_params, Mapping): self.sampling_params = SamplingParams.from_optional(**self.sampling_params) self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY @@ -273,9 +283,42 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]): ) logger.info(f"Retrieved workers from registry: {self.workers}") + if self.tool_call_parser is not None: + self._tool_parser = self._init_tool_parser() + if self.prefetch_weights_to_shm: self._spawn_fetchers() + def _init_tool_parser(self, tokenizer=None): # type: ignore[no-untyped-def] + """Initialize the tool parser based on configuration. + + Args: + tokenizer: Optional tokenizer (with encode/decode methods). If not provided, + one is created from vllm_config. Passing explicitly is useful for testing. + + Returns: + Initialized ToolParser instance, or None if tool parsing is not configured. + """ + try: + if tokenizer is None: + tokenizer = cached_tokenizer_from_config( + model_config=self.vllm_config.model_config, + ) + parser_cls = ToolParserManager.get_tool_parser(self.tool_call_parser) # type: ignore[union-attr] + parser = parser_cls(tokenizer) + logger.info(f"Initialized tool parser: {self.tool_call_parser}") + return parser + except KeyError: + available = list(ToolParserManager.tool_parsers.keys()) + logger.error( + f"Unknown tool parser: '{self.tool_call_parser}'. " + f"Available parsers: {available}" + ) + return None + except Exception as e: + logger.error(f"Failed to initialize tool parser: {e}") + return None + def _spawn_fetchers(self): """Spawn weight fetchers that prefetch weights from torchstore to shared memory. @@ -545,6 +588,38 @@ def _extract_logprobs(self, output) -> torch.Tensor | None: ) return None + def _extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation: + """Extract tool calls from model output using the configured tool parser. + + Args: + model_output: Raw text output from the model. + + Returns: + ExtractedToolCallInformation with parsed tool calls and remaining content. + """ + if self._tool_parser is None: + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + + try: + dummy_request = ChatCompletionRequest( + model=self.vllm_config.model_config.model, + messages=[{"role": "user", "content": ""}], + seed=42, # to calm the linter + ) + + extracted = self._tool_parser.extract_tool_calls( + model_output, dummy_request + ) + + return extracted + except Exception as e: + logger.warning(f"Failed to parse tool calls: {e}") + return ExtractedToolCallInformation( + tools_called=False, tool_calls=[], content=model_output + ) + def _to_completions( self, request_output: RequestOutput, prompt: str ) -> list[Completion]: @@ -560,6 +635,14 @@ def _to_completions( completions = [] for output in request_output.outputs: + tool_calls: list[ToolCall] = [] + content: str | None = None + + if self._tool_parser is not None: + extracted = self._extract_tool_calls(output.text) + tool_calls = extracted.tool_calls + content = extracted.content + completion = Completion( prompt=to_prompt(prompt), text=output.text, @@ -575,6 +658,8 @@ def _to_completions( stop_reason=output.finish_reason, generator_version=self.generator_version, metadata={"num_cached_tokens": request_output.num_cached_tokens}, + tool_calls=tool_calls, + content=content, ) completions.append(completion) diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py index 5f875a9dc..ddd2a7947 100644 --- a/src/forge/data_models/completion.py +++ b/src/forge/data_models/completion.py @@ -4,11 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any import torch from forge.data_models.prompt import Prompt +from vllm.entrypoints.openai.protocol import ToolCall @dataclass @@ -38,3 +39,14 @@ class Completion: # extra information that might be useful for debugging metadata: dict[str, Any] | None = None + + tool_calls: list[ToolCall] = field(default_factory=list) + + # When tool parsing is enabled, this contains content outside of tool tags + # i.e. content before the tool calls + content: str | None = None + + @property + def has_tool_calls(self) -> bool: + """Returns True if the completion contains tool calls.""" + return len(self.tool_calls) > 0 diff --git a/tests/integration_tests/test_tool_parsing.py b/tests/integration_tests/test_tool_parsing.py new file mode 100644 index 000000000..0947eab4f --- /dev/null +++ b/tests/integration_tests/test_tool_parsing.py @@ -0,0 +1,198 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Integration tests for vLLM tool parsing in forge. + +Tests the full tool-calling workflow: model generates tool call -> parse -> execute -> return result. + +Requires GPU access. + +Run: + pytest tests/integration_tests/test_tool_parsing.py -v -s +""" + +import json +import logging + +import pytest +import pytest_asyncio +import torch +from forge.actors.generator import Generator +from huggingface_hub import snapshot_download +from vllm.tokenizers import get_tokenizer + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + +# All async tests and fixtures in this module share a single module-scoped event loop. +# This is required because the `policy` fixture is scope="module" - without this, +# each test would get its own function-scoped loop and hang when awaiting the +# module-scoped fixture's objects. +pytestmark = pytest.mark.asyncio(loop_scope="module") + +requires_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + +MODEL_NAME = "Qwen/Qwen3-0.6B" + +TOOLS = [ + { + "type": "function", + "function": { + "name": "calculator", + "description": "Evaluate a mathematical equation.", + "parameters": { + "type": "object", + "properties": { + "equation": { + "type": "string", + "description": "The mathematical equation to evaluate", + }, + }, + "required": ["equation"], + }, + }, + }, +] + + +def calculator(equation: str) -> str: + """Safely evaluate a mathematical equation.""" + try: + # Only allow safe math operations + allowed = set("0123456789+-*/().^ ") + if all(c in allowed for c in equation): + result = eval(equation.replace("^", "**")) + return str(result) + return "Error: Invalid characters in equation" + except Exception as e: + return f"Error: {e}" + + +@pytest.fixture(scope="module") +def model_path(): + """Download model once for all tests in this module.""" + logger.info(f"Downloading model checkpoint: {MODEL_NAME}") + cached_dir = snapshot_download(repo_id=MODEL_NAME) + logger.info(f"Model downloaded to: {cached_dir}") + return cached_dir + + +@pytest.fixture(scope="module") +def tokenizer(): + """Create tokenizer once for all tests in this module.""" + return get_tokenizer(MODEL_NAME) + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def policy(model_path): + """Create and teardown policy service once for all tests in this module.""" + logger.info("Setting up policy service...") + policy = await Generator.options( + procs=1, + num_replicas=1, + with_gpus=True, + ).as_service( + engine_args={"model": model_path}, + sampling_params={"n": 1, "max_tokens": 256}, + tool_call_parser="hermes", + ) + + yield policy + + # Teardown + logger.info("Shutting down policy service...") + await policy.shutdown() + + +@requires_cuda +async def test_tool_parsing_multi_turn(policy, tokenizer): + """ + Multi-turn conversation: tool call -> execute -> feed result back -> final answer. + """ + messages = [ + { + "role": "system", + "content": "/no_think Use the calculator tool for math.", + }, + {"role": "user", "content": "Calculate 123 + 456"}, + ] + + # First turn - get tool call + formatted = tokenizer.apply_chat_template( + messages, tools=TOOLS, tokenize=False, add_generation_prompt=True + ) + response = await policy.generate.route(formatted) + completion = response[0] + + assert completion.has_tool_calls, "Expected tool calls" + tool_call = completion.tool_calls[0] + args = json.loads(tool_call.function.arguments) + result = calculator(args["equation"]) + + # Add assistant response and tool result to conversation + messages.append( + { + "role": "assistant", + "content": completion.text, + } + ) + messages.append( + { + "role": "tool", + "tool_call_id": tool_call.id, + "content": result, + } + ) + + # Second turn - get final answer + formatted = tokenizer.apply_chat_template( + messages, tools=TOOLS, tokenize=False, add_generation_prompt=True + ) + response = await policy.generate.route(formatted) + final = response[0] + + logger.info(f"Final answer: {final.text}") + assert "579" in final.text, "Expected 123 + 456 = 579" + + +@requires_cuda +async def test_content_without_tool_calls(policy, tokenizer): + """ + Test that content equals text when no tool calls are made. + + When a request doesn't trigger tool usage, the completion's content + field should equal the raw text output. + """ + # Ask a non-math question that won't trigger the calculator tool + messages = [ + { + "role": "system", + "content": "/no_think You are a helpful assistant.", + }, + {"role": "user", "content": "What is the capital of France?"}, + ] + + formatted_request = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + + response = await policy.generate.route(formatted_request) + completion = response[0] + + logger.info(f"Response text: {completion.text}") + logger.info(f"Response content: {completion.content}") + + assert completion.tool_calls == [], "Should have no tool calls" + assert completion.content is not None, "Should have content when no tools called" + assert ( + completion.content == completion.text + ), "Content should equal text when no tools" diff --git a/tests/unit_tests/test_generator.py b/tests/unit_tests/test_generator.py new file mode 100644 index 000000000..e1f806220 --- /dev/null +++ b/tests/unit_tests/test_generator.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for Generator's _to_completions and _extract_tool_calls logic.""" + +import json +from unittest.mock import MagicMock + +import pytest +from vllm.outputs import CompletionOutput, RequestOutput + + +def _import_error(): + """Check if there are import errors that would cause CI failures.""" + try: + import forge.actors.generator # noqa: F401 + + return False + except ImportError: + return True + + +class _StubTokenizer: + """Minimal stub tokenizer for initializing the Hermes tool parser in tests. + + The Hermes tool parser from vLLM requires: + - get_vocab(): Returns vocab dict mapping tokens to ids + - vocab: Direct vocab attribute + - eos_token_id: End of sequence token id + - encode(text, add_special_tokens=False): Encode text to token ids + - decode(token_ids): Decode token ids to text + - and tokens in vocab (for streaming support) + """ + + def __init__(self): + # Include tool call tokens that Hermes parser validates in __init__ + # (needed for streaming, but validated even for non-streaming use) + self.vocab = { + "": 1, + "": 2, + } + self._id_to_token = {v: k for k, v in self.vocab.items()} + self.eos_token_id = 0 + + def get_vocab(self) -> dict[str, int]: + """Return vocabulary dict (required by Hermes tool parser).""" + return self.vocab + + def encode(self, text: str, add_special_tokens: bool = False) -> list[int]: + """Encode text to token ids. Returns ids for known tokens, empty otherwise.""" + if text in self.vocab: + return [self.vocab[text]] + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + """Decode token ids to text.""" + return "".join(self._id_to_token.get(tid, chr(tid)) for tid in token_ids) + + +@pytest.fixture(scope="module") +def stub_tokenizer(): + """Create a stub tokenizer compatible with Hermes tool parser.""" + return _StubTokenizer() + + +@pytest.fixture +def generator_with_hermes(stub_tokenizer): + """Create Generator with hermes parser properly initialized.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser="hermes", + ) + generator._tool_parser = generator._init_tool_parser(stub_tokenizer) + generator.generator_version = 1 + + return generator + + +def make_mock_request_output( + prompt: str = "test prompt", + outputs: list[dict] | None = None, +) -> RequestOutput: + """Create a mock vLLM RequestOutput for testing _to_completions.""" + if outputs is None: + outputs = [ + {"text": "test response", "token_ids": [1, 2, 3], "finish_reason": "stop"} + ] + + mock_outputs = [] + for out in outputs: + mock_output = MagicMock(spec=CompletionOutput) + mock_output.text = out.get("text", "") + mock_output.token_ids = out.get("token_ids", [1, 2, 3]) + mock_output.finish_reason = out.get("finish_reason", "stop") + mock_output.logprobs = out.get("logprobs", None) + mock_outputs.append(mock_output) + + mock_request_output = MagicMock(spec=RequestOutput) + mock_request_output.prompt = prompt + mock_request_output.prompt_token_ids = [100, 101, 102] + mock_request_output.outputs = mock_outputs + mock_request_output.num_cached_tokens = 0 + + return mock_request_output + + +@pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", +) +class TestInitToolParser: + """Test the _init_tool_parser method of Generator.""" + + def test_init_hermes_parser(self, stub_tokenizer): + """Test that passing tool_call_parser='hermes' initializes the parser.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser="hermes", + ) + + parser = generator._init_tool_parser(stub_tokenizer) + + assert parser is not None + assert hasattr(parser, "extract_tool_calls") + + def test_init_parser_none_when_not_configured(self): + """Test that no parser is created when tool_call_parser is None.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser=None, + ) + + assert generator.tool_call_parser is None + + def test_init_parser_invalid_parser_name(self, stub_tokenizer): + """Test that invalid parser name returns None.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser="nonexistent_parser", + ) + + parser = generator._init_tool_parser(stub_tokenizer) + assert parser is None + + +@pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", +) +class TestExtractToolCalls: + """Test _extract_tool_calls with real parser initialization.""" + + def test_extract_single_tool_call(self, generator_with_hermes): + """Test extracting a single tool call.""" + generator = generator_with_hermes + + model_output = """ +{"name": "calculator", "arguments": {"equation": "2 + 2"}} +""" + + result = generator._extract_tool_calls(model_output) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "calculator" + + args = json.loads(result.tool_calls[0].function.arguments) + assert args["equation"] == "2 + 2" + + def test_extract_tool_call_with_content_prefix(self, generator_with_hermes): + """Test extracting tool call when there's content before it.""" + generator = generator_with_hermes + + model_output = """Let me calculate that for you. + +{"name": "calculator", "arguments": {"equation": "15 * 7"}} +""" + + result = generator._extract_tool_calls(model_output) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "calculator" + assert "Let me calculate" in (result.content or "") + + def test_extract_tool_call_with_think_prefix(self, generator_with_hermes): + """Test extracting tool call when there's tags before it.""" + generator = generator_with_hermes + + model_output = """ +The user is asking for a math calculation. I should use the calculator tool. +Let me compute 2 + 2. + + +{"name": "calculator", "arguments": {"equation": "2 + 2"}} +""" + + result = generator._extract_tool_calls(model_output) + + assert result.tools_called is True + assert len(result.tool_calls) == 1 + assert result.tool_calls[0].function.name == "calculator" + # content should be preserved in the content field + assert result.content is not None + assert """ +The user is asking for a math calculation. I should use the calculator tool. +Let me compute 2 + 2. +""" in ( + result.content + ) + + def test_extract_multiple_tool_calls(self, generator_with_hermes): + """Test extracting multiple tool calls.""" + generator = generator_with_hermes + + model_output = """ +{"name": "calculator", "arguments": {"equation": "2 + 2"}} + + +{"name": "calculator", "arguments": {"equation": "3 * 4"}} +""" + + result = generator._extract_tool_calls(model_output) + + assert result.tools_called is True + assert len(result.tool_calls) == 2 + + equations = [ + json.loads(tc.function.arguments)["equation"] for tc in result.tool_calls + ] + assert "2 + 2" in equations + assert "3 * 4" in equations + + def test_no_tool_call_in_output(self, generator_with_hermes): + """Test when model output has no tool calls.""" + generator = generator_with_hermes + + model_output = "The capital of France is Paris." + + result = generator._extract_tool_calls(model_output) + + assert result.tools_called is False + assert result.tool_calls == [] + assert result.content == model_output + + def test_extract_tool_calls_no_parser(self): + """Test _extract_tool_calls returns content as-is when no parser.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser=None, + ) + generator._tool_parser = None + + result = generator._extract_tool_calls("Hello, world!") + + assert result.tools_called is False + assert result.tool_calls == [] + assert result.content == "Hello, world!" + + +@pytest.mark.skipif( + _import_error(), + reason="Import error, likely due to missing dependencies on CI.", +) +class TestToCompletions: + """Test _to_completions with real parser initialization.""" + + def test_to_completions_without_tool_parser(self): + """Test _to_completions when no tool parser is configured.""" + from forge.actors.generator import Generator + + generator = Generator( + engine_args={"model": "Qwen/Qwen3-0.6B"}, + sampling_params={"max_tokens": 64}, + tool_call_parser=None, + ) + generator._tool_parser = None + generator.generator_version = 1 + + request_output = make_mock_request_output( + prompt="What is 2 + 2?", + outputs=[{"text": "The answer is 4.", "token_ids": [10, 20, 30]}], + ) + + completions = generator._to_completions(request_output, request_output.prompt) + + assert len(completions) == 1 + completion = completions[0] + + assert completion.tool_calls == [] + assert completion.content is None + assert completion.text == "The answer is 4." + assert not completion.has_tool_calls + + def test_to_completions_no_tool_call_with_parser(self, generator_with_hermes): + """Test _to_completions when parser finds no tool calls.""" + generator = generator_with_hermes + + request_output = make_mock_request_output( + prompt="What is the capital of France?", + outputs=[ + {"text": "Paris is the capital of France.", "token_ids": [10, 20]} + ], + ) + + completions = generator._to_completions(request_output, request_output.prompt) + + assert len(completions) == 1 + completion = completions[0] + + assert not completion.has_tool_calls + assert completion.tool_calls == [] + assert completion.content == "Paris is the capital of France." + + def test_to_completions_multiple_outputs(self, generator_with_hermes): + """Test _to_completions with multiple outputs (n > 1).""" + generator = generator_with_hermes + + request_output = make_mock_request_output( + prompt="Calculate something", + outputs=[ + { + "text": """ +{"name": "calculator", "arguments": {"equation": "1 + 1"}} +""", + "token_ids": [1, 2], + }, + {"text": "The answer is obviously 2.", "token_ids": [3, 4]}, + ], + ) + + completions = generator._to_completions(request_output, request_output.prompt) + + assert len(completions) == 2 + # First completion has tool call + assert completions[0].has_tool_calls + assert completions[0].tool_calls[0].function.name == "calculator" + args = json.loads(completions[0].tool_calls[0].function.arguments) + assert args["equation"] == "1 + 1" + # Second completion has no tool call + assert not completions[1].has_tool_calls + assert completions[1].content == "The answer is obviously 2."