diff --git a/src/forge/actors/vllm/v1/generator.py b/src/forge/actors/vllm/v1/generator.py
index ea7e0326a..65b0c92e0 100644
--- a/src/forge/actors/vllm/v1/generator.py
+++ b/src/forge/actors/vllm/v1/generator.py
@@ -38,8 +38,15 @@
from torchstore.api import _controller as get_torchstore_controller
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.llm import UsageContext
+from vllm.entrypoints.openai.protocol import (
+ ChatCompletionRequest,
+ ExtractedToolCallInformation,
+ ToolCall,
+)
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
+from vllm.tokenizers import cached_tokenizer_from_config
+from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM
logger = logging.getLogger(__name__)
@@ -78,6 +85,7 @@ class Generator(ForgeActor):
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
prefetch_weights_to_shm: bool = True
n_fetcher_procs: int = 8
+ tool_call_parser: str | None = None
def __post_init__(self):
super().__init__()
@@ -91,6 +99,8 @@ def __post_init__(self):
self.engine_args = EngineArgs(**self.engine_args)
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)
+ self._tool_parser = None # Will hold ToolParser instance if configured
+
if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
@@ -273,9 +283,42 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
)
logger.info(f"Retrieved workers from registry: {self.workers}")
+ if self.tool_call_parser is not None:
+ self._tool_parser = self._init_tool_parser()
+
if self.prefetch_weights_to_shm:
self._spawn_fetchers()
+ def _init_tool_parser(self, tokenizer=None): # type: ignore[no-untyped-def]
+ """Initialize the tool parser based on configuration.
+
+ Args:
+ tokenizer: Optional tokenizer (with encode/decode methods). If not provided,
+ one is created from vllm_config. Passing explicitly is useful for testing.
+
+ Returns:
+ Initialized ToolParser instance, or None if tool parsing is not configured.
+ """
+ try:
+ if tokenizer is None:
+ tokenizer = cached_tokenizer_from_config(
+ model_config=self.vllm_config.model_config,
+ )
+ parser_cls = ToolParserManager.get_tool_parser(self.tool_call_parser) # type: ignore[union-attr]
+ parser = parser_cls(tokenizer)
+ logger.info(f"Initialized tool parser: {self.tool_call_parser}")
+ return parser
+ except KeyError:
+ available = list(ToolParserManager.tool_parsers.keys())
+ logger.error(
+ f"Unknown tool parser: '{self.tool_call_parser}'. "
+ f"Available parsers: {available}"
+ )
+ return None
+ except Exception as e:
+ logger.error(f"Failed to initialize tool parser: {e}")
+ return None
+
def _spawn_fetchers(self):
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory.
@@ -545,6 +588,38 @@ def _extract_logprobs(self, output) -> torch.Tensor | None:
)
return None
+ def _extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation:
+ """Extract tool calls from model output using the configured tool parser.
+
+ Args:
+ model_output: Raw text output from the model.
+
+ Returns:
+ ExtractedToolCallInformation with parsed tool calls and remaining content.
+ """
+ if self._tool_parser is None:
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
+ try:
+ dummy_request = ChatCompletionRequest(
+ model=self.vllm_config.model_config.model,
+ messages=[{"role": "user", "content": ""}],
+ seed=42, # to calm the linter
+ )
+
+ extracted = self._tool_parser.extract_tool_calls(
+ model_output, dummy_request
+ )
+
+ return extracted
+ except Exception as e:
+ logger.warning(f"Failed to parse tool calls: {e}")
+ return ExtractedToolCallInformation(
+ tools_called=False, tool_calls=[], content=model_output
+ )
+
def _to_completions(
self, request_output: RequestOutput, prompt: str
) -> list[Completion]:
@@ -560,6 +635,14 @@ def _to_completions(
completions = []
for output in request_output.outputs:
+ tool_calls: list[ToolCall] = []
+ content: str | None = None
+
+ if self._tool_parser is not None:
+ extracted = self._extract_tool_calls(output.text)
+ tool_calls = extracted.tool_calls
+ content = extracted.content
+
completion = Completion(
prompt=to_prompt(prompt),
text=output.text,
@@ -575,6 +658,8 @@ def _to_completions(
stop_reason=output.finish_reason,
generator_version=self.generator_version,
metadata={"num_cached_tokens": request_output.num_cached_tokens},
+ tool_calls=tool_calls,
+ content=content,
)
completions.append(completion)
diff --git a/src/forge/data_models/completion.py b/src/forge/data_models/completion.py
index 5f875a9dc..ddd2a7947 100644
--- a/src/forge/data_models/completion.py
+++ b/src/forge/data_models/completion.py
@@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
-from dataclasses import dataclass
+from dataclasses import dataclass, field
from typing import Any
import torch
from forge.data_models.prompt import Prompt
+from vllm.entrypoints.openai.protocol import ToolCall
@dataclass
@@ -38,3 +39,14 @@ class Completion:
# extra information that might be useful for debugging
metadata: dict[str, Any] | None = None
+
+ tool_calls: list[ToolCall] = field(default_factory=list)
+
+ # When tool parsing is enabled, this contains content outside of tool tags
+ # i.e. content before the tool calls
+ content: str | None = None
+
+ @property
+ def has_tool_calls(self) -> bool:
+ """Returns True if the completion contains tool calls."""
+ return len(self.tool_calls) > 0
diff --git a/tests/integration_tests/test_tool_parsing.py b/tests/integration_tests/test_tool_parsing.py
new file mode 100644
index 000000000..0947eab4f
--- /dev/null
+++ b/tests/integration_tests/test_tool_parsing.py
@@ -0,0 +1,198 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Integration tests for vLLM tool parsing in forge.
+
+Tests the full tool-calling workflow: model generates tool call -> parse -> execute -> return result.
+
+Requires GPU access.
+
+Run:
+ pytest tests/integration_tests/test_tool_parsing.py -v -s
+"""
+
+import json
+import logging
+
+import pytest
+import pytest_asyncio
+import torch
+from forge.actors.generator import Generator
+from huggingface_hub import snapshot_download
+from vllm.tokenizers import get_tokenizer
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+
+# All async tests and fixtures in this module share a single module-scoped event loop.
+# This is required because the `policy` fixture is scope="module" - without this,
+# each test would get its own function-scoped loop and hang when awaiting the
+# module-scoped fixture's objects.
+pytestmark = pytest.mark.asyncio(loop_scope="module")
+
+requires_cuda = pytest.mark.skipif(
+ not torch.cuda.is_available(),
+ reason="CUDA not available",
+)
+
+MODEL_NAME = "Qwen/Qwen3-0.6B"
+
+TOOLS = [
+ {
+ "type": "function",
+ "function": {
+ "name": "calculator",
+ "description": "Evaluate a mathematical equation.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "equation": {
+ "type": "string",
+ "description": "The mathematical equation to evaluate",
+ },
+ },
+ "required": ["equation"],
+ },
+ },
+ },
+]
+
+
+def calculator(equation: str) -> str:
+ """Safely evaluate a mathematical equation."""
+ try:
+ # Only allow safe math operations
+ allowed = set("0123456789+-*/().^ ")
+ if all(c in allowed for c in equation):
+ result = eval(equation.replace("^", "**"))
+ return str(result)
+ return "Error: Invalid characters in equation"
+ except Exception as e:
+ return f"Error: {e}"
+
+
+@pytest.fixture(scope="module")
+def model_path():
+ """Download model once for all tests in this module."""
+ logger.info(f"Downloading model checkpoint: {MODEL_NAME}")
+ cached_dir = snapshot_download(repo_id=MODEL_NAME)
+ logger.info(f"Model downloaded to: {cached_dir}")
+ return cached_dir
+
+
+@pytest.fixture(scope="module")
+def tokenizer():
+ """Create tokenizer once for all tests in this module."""
+ return get_tokenizer(MODEL_NAME)
+
+
+@pytest_asyncio.fixture(scope="module", loop_scope="module")
+async def policy(model_path):
+ """Create and teardown policy service once for all tests in this module."""
+ logger.info("Setting up policy service...")
+ policy = await Generator.options(
+ procs=1,
+ num_replicas=1,
+ with_gpus=True,
+ ).as_service(
+ engine_args={"model": model_path},
+ sampling_params={"n": 1, "max_tokens": 256},
+ tool_call_parser="hermes",
+ )
+
+ yield policy
+
+ # Teardown
+ logger.info("Shutting down policy service...")
+ await policy.shutdown()
+
+
+@requires_cuda
+async def test_tool_parsing_multi_turn(policy, tokenizer):
+ """
+ Multi-turn conversation: tool call -> execute -> feed result back -> final answer.
+ """
+ messages = [
+ {
+ "role": "system",
+ "content": "/no_think Use the calculator tool for math.",
+ },
+ {"role": "user", "content": "Calculate 123 + 456"},
+ ]
+
+ # First turn - get tool call
+ formatted = tokenizer.apply_chat_template(
+ messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
+ )
+ response = await policy.generate.route(formatted)
+ completion = response[0]
+
+ assert completion.has_tool_calls, "Expected tool calls"
+ tool_call = completion.tool_calls[0]
+ args = json.loads(tool_call.function.arguments)
+ result = calculator(args["equation"])
+
+ # Add assistant response and tool result to conversation
+ messages.append(
+ {
+ "role": "assistant",
+ "content": completion.text,
+ }
+ )
+ messages.append(
+ {
+ "role": "tool",
+ "tool_call_id": tool_call.id,
+ "content": result,
+ }
+ )
+
+ # Second turn - get final answer
+ formatted = tokenizer.apply_chat_template(
+ messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
+ )
+ response = await policy.generate.route(formatted)
+ final = response[0]
+
+ logger.info(f"Final answer: {final.text}")
+ assert "579" in final.text, "Expected 123 + 456 = 579"
+
+
+@requires_cuda
+async def test_content_without_tool_calls(policy, tokenizer):
+ """
+ Test that content equals text when no tool calls are made.
+
+ When a request doesn't trigger tool usage, the completion's content
+ field should equal the raw text output.
+ """
+ # Ask a non-math question that won't trigger the calculator tool
+ messages = [
+ {
+ "role": "system",
+ "content": "/no_think You are a helpful assistant.",
+ },
+ {"role": "user", "content": "What is the capital of France?"},
+ ]
+
+ formatted_request = tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+
+ response = await policy.generate.route(formatted_request)
+ completion = response[0]
+
+ logger.info(f"Response text: {completion.text}")
+ logger.info(f"Response content: {completion.content}")
+
+ assert completion.tool_calls == [], "Should have no tool calls"
+ assert completion.content is not None, "Should have content when no tools called"
+ assert (
+ completion.content == completion.text
+ ), "Content should equal text when no tools"
diff --git a/tests/unit_tests/test_generator.py b/tests/unit_tests/test_generator.py
new file mode 100644
index 000000000..e1f806220
--- /dev/null
+++ b/tests/unit_tests/test_generator.py
@@ -0,0 +1,360 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""Unit tests for Generator's _to_completions and _extract_tool_calls logic."""
+
+import json
+from unittest.mock import MagicMock
+
+import pytest
+from vllm.outputs import CompletionOutput, RequestOutput
+
+
+def _import_error():
+ """Check if there are import errors that would cause CI failures."""
+ try:
+ import forge.actors.generator # noqa: F401
+
+ return False
+ except ImportError:
+ return True
+
+
+class _StubTokenizer:
+ """Minimal stub tokenizer for initializing the Hermes tool parser in tests.
+
+ The Hermes tool parser from vLLM requires:
+ - get_vocab(): Returns vocab dict mapping tokens to ids
+ - vocab: Direct vocab attribute
+ - eos_token_id: End of sequence token id
+ - encode(text, add_special_tokens=False): Encode text to token ids
+ - decode(token_ids): Decode token ids to text
+ - and tokens in vocab (for streaming support)
+ """
+
+ def __init__(self):
+ # Include tool call tokens that Hermes parser validates in __init__
+ # (needed for streaming, but validated even for non-streaming use)
+ self.vocab = {
+ "": 1,
+ "": 2,
+ }
+ self._id_to_token = {v: k for k, v in self.vocab.items()}
+ self.eos_token_id = 0
+
+ def get_vocab(self) -> dict[str, int]:
+ """Return vocabulary dict (required by Hermes tool parser)."""
+ return self.vocab
+
+ def encode(self, text: str, add_special_tokens: bool = False) -> list[int]:
+ """Encode text to token ids. Returns ids for known tokens, empty otherwise."""
+ if text in self.vocab:
+ return [self.vocab[text]]
+ return [ord(c) for c in text]
+
+ def decode(self, token_ids: list[int]) -> str:
+ """Decode token ids to text."""
+ return "".join(self._id_to_token.get(tid, chr(tid)) for tid in token_ids)
+
+
+@pytest.fixture(scope="module")
+def stub_tokenizer():
+ """Create a stub tokenizer compatible with Hermes tool parser."""
+ return _StubTokenizer()
+
+
+@pytest.fixture
+def generator_with_hermes(stub_tokenizer):
+ """Create Generator with hermes parser properly initialized."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser="hermes",
+ )
+ generator._tool_parser = generator._init_tool_parser(stub_tokenizer)
+ generator.generator_version = 1
+
+ return generator
+
+
+def make_mock_request_output(
+ prompt: str = "test prompt",
+ outputs: list[dict] | None = None,
+) -> RequestOutput:
+ """Create a mock vLLM RequestOutput for testing _to_completions."""
+ if outputs is None:
+ outputs = [
+ {"text": "test response", "token_ids": [1, 2, 3], "finish_reason": "stop"}
+ ]
+
+ mock_outputs = []
+ for out in outputs:
+ mock_output = MagicMock(spec=CompletionOutput)
+ mock_output.text = out.get("text", "")
+ mock_output.token_ids = out.get("token_ids", [1, 2, 3])
+ mock_output.finish_reason = out.get("finish_reason", "stop")
+ mock_output.logprobs = out.get("logprobs", None)
+ mock_outputs.append(mock_output)
+
+ mock_request_output = MagicMock(spec=RequestOutput)
+ mock_request_output.prompt = prompt
+ mock_request_output.prompt_token_ids = [100, 101, 102]
+ mock_request_output.outputs = mock_outputs
+ mock_request_output.num_cached_tokens = 0
+
+ return mock_request_output
+
+
+@pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+)
+class TestInitToolParser:
+ """Test the _init_tool_parser method of Generator."""
+
+ def test_init_hermes_parser(self, stub_tokenizer):
+ """Test that passing tool_call_parser='hermes' initializes the parser."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser="hermes",
+ )
+
+ parser = generator._init_tool_parser(stub_tokenizer)
+
+ assert parser is not None
+ assert hasattr(parser, "extract_tool_calls")
+
+ def test_init_parser_none_when_not_configured(self):
+ """Test that no parser is created when tool_call_parser is None."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser=None,
+ )
+
+ assert generator.tool_call_parser is None
+
+ def test_init_parser_invalid_parser_name(self, stub_tokenizer):
+ """Test that invalid parser name returns None."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser="nonexistent_parser",
+ )
+
+ parser = generator._init_tool_parser(stub_tokenizer)
+ assert parser is None
+
+
+@pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+)
+class TestExtractToolCalls:
+ """Test _extract_tool_calls with real parser initialization."""
+
+ def test_extract_single_tool_call(self, generator_with_hermes):
+ """Test extracting a single tool call."""
+ generator = generator_with_hermes
+
+ model_output = """
+{"name": "calculator", "arguments": {"equation": "2 + 2"}}
+"""
+
+ result = generator._extract_tool_calls(model_output)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "calculator"
+
+ args = json.loads(result.tool_calls[0].function.arguments)
+ assert args["equation"] == "2 + 2"
+
+ def test_extract_tool_call_with_content_prefix(self, generator_with_hermes):
+ """Test extracting tool call when there's content before it."""
+ generator = generator_with_hermes
+
+ model_output = """Let me calculate that for you.
+
+{"name": "calculator", "arguments": {"equation": "15 * 7"}}
+"""
+
+ result = generator._extract_tool_calls(model_output)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "calculator"
+ assert "Let me calculate" in (result.content or "")
+
+ def test_extract_tool_call_with_think_prefix(self, generator_with_hermes):
+ """Test extracting tool call when there's tags before it."""
+ generator = generator_with_hermes
+
+ model_output = """
+The user is asking for a math calculation. I should use the calculator tool.
+Let me compute 2 + 2.
+
+
+{"name": "calculator", "arguments": {"equation": "2 + 2"}}
+"""
+
+ result = generator._extract_tool_calls(model_output)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 1
+ assert result.tool_calls[0].function.name == "calculator"
+ # content should be preserved in the content field
+ assert result.content is not None
+ assert """
+The user is asking for a math calculation. I should use the calculator tool.
+Let me compute 2 + 2.
+""" in (
+ result.content
+ )
+
+ def test_extract_multiple_tool_calls(self, generator_with_hermes):
+ """Test extracting multiple tool calls."""
+ generator = generator_with_hermes
+
+ model_output = """
+{"name": "calculator", "arguments": {"equation": "2 + 2"}}
+
+
+{"name": "calculator", "arguments": {"equation": "3 * 4"}}
+"""
+
+ result = generator._extract_tool_calls(model_output)
+
+ assert result.tools_called is True
+ assert len(result.tool_calls) == 2
+
+ equations = [
+ json.loads(tc.function.arguments)["equation"] for tc in result.tool_calls
+ ]
+ assert "2 + 2" in equations
+ assert "3 * 4" in equations
+
+ def test_no_tool_call_in_output(self, generator_with_hermes):
+ """Test when model output has no tool calls."""
+ generator = generator_with_hermes
+
+ model_output = "The capital of France is Paris."
+
+ result = generator._extract_tool_calls(model_output)
+
+ assert result.tools_called is False
+ assert result.tool_calls == []
+ assert result.content == model_output
+
+ def test_extract_tool_calls_no_parser(self):
+ """Test _extract_tool_calls returns content as-is when no parser."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser=None,
+ )
+ generator._tool_parser = None
+
+ result = generator._extract_tool_calls("Hello, world!")
+
+ assert result.tools_called is False
+ assert result.tool_calls == []
+ assert result.content == "Hello, world!"
+
+
+@pytest.mark.skipif(
+ _import_error(),
+ reason="Import error, likely due to missing dependencies on CI.",
+)
+class TestToCompletions:
+ """Test _to_completions with real parser initialization."""
+
+ def test_to_completions_without_tool_parser(self):
+ """Test _to_completions when no tool parser is configured."""
+ from forge.actors.generator import Generator
+
+ generator = Generator(
+ engine_args={"model": "Qwen/Qwen3-0.6B"},
+ sampling_params={"max_tokens": 64},
+ tool_call_parser=None,
+ )
+ generator._tool_parser = None
+ generator.generator_version = 1
+
+ request_output = make_mock_request_output(
+ prompt="What is 2 + 2?",
+ outputs=[{"text": "The answer is 4.", "token_ids": [10, 20, 30]}],
+ )
+
+ completions = generator._to_completions(request_output, request_output.prompt)
+
+ assert len(completions) == 1
+ completion = completions[0]
+
+ assert completion.tool_calls == []
+ assert completion.content is None
+ assert completion.text == "The answer is 4."
+ assert not completion.has_tool_calls
+
+ def test_to_completions_no_tool_call_with_parser(self, generator_with_hermes):
+ """Test _to_completions when parser finds no tool calls."""
+ generator = generator_with_hermes
+
+ request_output = make_mock_request_output(
+ prompt="What is the capital of France?",
+ outputs=[
+ {"text": "Paris is the capital of France.", "token_ids": [10, 20]}
+ ],
+ )
+
+ completions = generator._to_completions(request_output, request_output.prompt)
+
+ assert len(completions) == 1
+ completion = completions[0]
+
+ assert not completion.has_tool_calls
+ assert completion.tool_calls == []
+ assert completion.content == "Paris is the capital of France."
+
+ def test_to_completions_multiple_outputs(self, generator_with_hermes):
+ """Test _to_completions with multiple outputs (n > 1)."""
+ generator = generator_with_hermes
+
+ request_output = make_mock_request_output(
+ prompt="Calculate something",
+ outputs=[
+ {
+ "text": """
+{"name": "calculator", "arguments": {"equation": "1 + 1"}}
+""",
+ "token_ids": [1, 2],
+ },
+ {"text": "The answer is obviously 2.", "token_ids": [3, 4]},
+ ],
+ )
+
+ completions = generator._to_completions(request_output, request_output.prompt)
+
+ assert len(completions) == 2
+ # First completion has tool call
+ assert completions[0].has_tool_calls
+ assert completions[0].tool_calls[0].function.name == "calculator"
+ args = json.loads(completions[0].tool_calls[0].function.arguments)
+ assert args["equation"] == "1 + 1"
+ # Second completion has no tool call
+ assert not completions[1].has_tool_calls
+ assert completions[1].content == "The answer is obviously 2."