Skip to content
85 changes: 85 additions & 0 deletions src/forge/actors/vllm/v1/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@
from torchstore.api import _controller as get_torchstore_controller
from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.llm import UsageContext
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ExtractedToolCallInformation,
ToolCall,
)
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tool_parsers import ToolParserManager
from vllm.v1.engine.async_llm import AsyncLLM

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -78,6 +85,7 @@ class Generator(ForgeActor):
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
prefetch_weights_to_shm: bool = True
n_fetcher_procs: int = 8
tool_call_parser: str | None = None

def __post_init__(self):
super().__init__()
Expand All @@ -91,6 +99,8 @@ def __post_init__(self):
self.engine_args = EngineArgs(**self.engine_args)
self.vllm_config = self.engine_args.create_engine_config(UsageContext.LLM_CLASS)

self._tool_parser = None # Will hold ToolParser instance if configured

if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
self.sampling_params.output_kind = RequestOutputKind.FINAL_ONLY
Expand Down Expand Up @@ -273,9 +283,42 @@ async def setup(self, host_mesh, worker_registry, gpu_ids: list[str]):
)
logger.info(f"Retrieved workers from registry: {self.workers}")

if self.tool_call_parser is not None:
self._tool_parser = self._init_tool_parser()

if self.prefetch_weights_to_shm:
self._spawn_fetchers()

def _init_tool_parser(self, tokenizer=None): # type: ignore[no-untyped-def]
"""Initialize the tool parser based on configuration.

Args:
tokenizer: Optional tokenizer (with encode/decode methods). If not provided,
one is created from vllm_config. Passing explicitly is useful for testing.

Returns:
Initialized ToolParser instance, or None if tool parsing is not configured.
"""
try:
if tokenizer is None:
tokenizer = cached_tokenizer_from_config(
model_config=self.vllm_config.model_config,
)
parser_cls = ToolParserManager.get_tool_parser(self.tool_call_parser) # type: ignore[union-attr]
parser = parser_cls(tokenizer)
logger.info(f"Initialized tool parser: {self.tool_call_parser}")
return parser
except KeyError:
available = list(ToolParserManager.tool_parsers.keys())
logger.error(
f"Unknown tool parser: '{self.tool_call_parser}'. "
f"Available parsers: {available}"
)
return None
except Exception as e:
logger.error(f"Failed to initialize tool parser: {e}")
return None

def _spawn_fetchers(self):
"""Spawn weight fetchers that prefetch weights from torchstore to shared memory.

Expand Down Expand Up @@ -545,6 +588,38 @@ def _extract_logprobs(self, output) -> torch.Tensor | None:
)
return None

def _extract_tool_calls(self, model_output: str) -> ExtractedToolCallInformation:
"""Extract tool calls from model output using the configured tool parser.

Args:
model_output: Raw text output from the model.

Returns:
ExtractedToolCallInformation with parsed tool calls and remaining content.
"""
if self._tool_parser is None:
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

try:
dummy_request = ChatCompletionRequest(
model=self.vllm_config.model_config.model,
messages=[{"role": "user", "content": ""}],
seed=42, # to calm the linter
)

extracted = self._tool_parser.extract_tool_calls(
model_output, dummy_request
)

return extracted
except Exception as e:
logger.warning(f"Failed to parse tool calls: {e}")
return ExtractedToolCallInformation(
tools_called=False, tool_calls=[], content=model_output
)

def _to_completions(
self, request_output: RequestOutput, prompt: str
) -> list[Completion]:
Expand All @@ -560,6 +635,14 @@ def _to_completions(
completions = []

for output in request_output.outputs:
tool_calls: list[ToolCall] = []
content: str | None = None

if self._tool_parser is not None:
extracted = self._extract_tool_calls(output.text)
tool_calls = extracted.tool_calls
content = extracted.content

completion = Completion(
prompt=to_prompt(prompt),
text=output.text,
Expand All @@ -575,6 +658,8 @@ def _to_completions(
stop_reason=output.finish_reason,
generator_version=self.generator_version,
metadata={"num_cached_tokens": request_output.num_cached_tokens},
tool_calls=tool_calls,
content=content,
)
completions.append(completion)

Expand Down
14 changes: 13 additions & 1 deletion src/forge/data_models/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any

import torch
from forge.data_models.prompt import Prompt
from vllm.entrypoints.openai.protocol import ToolCall


@dataclass
Expand Down Expand Up @@ -38,3 +39,14 @@ class Completion:

# extra information that might be useful for debugging
metadata: dict[str, Any] | None = None

tool_calls: list[ToolCall] = field(default_factory=list)

# When tool parsing is enabled, this contains content outside of tool tags
# i.e. content before the tool calls
content: str | None = None

@property
def has_tool_calls(self) -> bool:
"""Returns True if the completion contains tool calls."""
return len(self.tool_calls) > 0
198 changes: 198 additions & 0 deletions tests/integration_tests/test_tool_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Integration tests for vLLM tool parsing in forge.

Tests the full tool-calling workflow: model generates tool call -> parse -> execute -> return result.

Requires GPU access.

Run:
pytest tests/integration_tests/test_tool_parsing.py -v -s
"""

import json
import logging

import pytest
import pytest_asyncio
import torch
from forge.actors.generator import Generator
from huggingface_hub import snapshot_download
from vllm.tokenizers import get_tokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# All async tests and fixtures in this module share a single module-scoped event loop.
# This is required because the `policy` fixture is scope="module" - without this,
# each test would get its own function-scoped loop and hang when awaiting the
# module-scoped fixture's objects.
pytestmark = pytest.mark.asyncio(loop_scope="module")

requires_cuda = pytest.mark.skipif(
not torch.cuda.is_available(),
reason="CUDA not available",
)

MODEL_NAME = "Qwen/Qwen3-0.6B"

TOOLS = [
{
"type": "function",
"function": {
"name": "calculator",
"description": "Evaluate a mathematical equation.",
"parameters": {
"type": "object",
"properties": {
"equation": {
"type": "string",
"description": "The mathematical equation to evaluate",
},
},
"required": ["equation"],
},
},
},
]


def calculator(equation: str) -> str:
"""Safely evaluate a mathematical equation."""
try:
# Only allow safe math operations
allowed = set("0123456789+-*/().^ ")
if all(c in allowed for c in equation):
result = eval(equation.replace("^", "**"))
return str(result)
return "Error: Invalid characters in equation"
except Exception as e:
return f"Error: {e}"


@pytest.fixture(scope="module")
def model_path():
"""Download model once for all tests in this module."""
logger.info(f"Downloading model checkpoint: {MODEL_NAME}")
cached_dir = snapshot_download(repo_id=MODEL_NAME)
logger.info(f"Model downloaded to: {cached_dir}")
return cached_dir


@pytest.fixture(scope="module")
def tokenizer():
"""Create tokenizer once for all tests in this module."""
return get_tokenizer(MODEL_NAME)


@pytest_asyncio.fixture(scope="module", loop_scope="module")
async def policy(model_path):
"""Create and teardown policy service once for all tests in this module."""
logger.info("Setting up policy service...")
policy = await Generator.options(
procs=1,
num_replicas=1,
with_gpus=True,
).as_service(
engine_args={"model": model_path},
sampling_params={"n": 1, "max_tokens": 256},
tool_call_parser="hermes",
)

yield policy

# Teardown
logger.info("Shutting down policy service...")
await policy.shutdown()


@requires_cuda
async def test_tool_parsing_multi_turn(policy, tokenizer):
"""
Multi-turn conversation: tool call -> execute -> feed result back -> final answer.
"""
messages = [
{
"role": "system",
"content": "/no_think Use the calculator tool for math.",
},
{"role": "user", "content": "Calculate 123 + 456"},
]

# First turn - get tool call
formatted = tokenizer.apply_chat_template(
messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
)
response = await policy.generate.route(formatted)
completion = response[0]

assert completion.has_tool_calls, "Expected tool calls"
tool_call = completion.tool_calls[0]
args = json.loads(tool_call.function.arguments)
result = calculator(args["equation"])

# Add assistant response and tool result to conversation
messages.append(
{
"role": "assistant",
"content": completion.text,
}
)
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": result,
}
)

# Second turn - get final answer
formatted = tokenizer.apply_chat_template(
messages, tools=TOOLS, tokenize=False, add_generation_prompt=True
)
response = await policy.generate.route(formatted)
final = response[0]

logger.info(f"Final answer: {final.text}")
assert "579" in final.text, "Expected 123 + 456 = 579"


@requires_cuda
async def test_content_without_tool_calls(policy, tokenizer):
"""
Test that content equals text when no tool calls are made.

When a request doesn't trigger tool usage, the completion's content
field should equal the raw text output.
"""
# Ask a non-math question that won't trigger the calculator tool
messages = [
{
"role": "system",
"content": "/no_think You are a helpful assistant.",
},
{"role": "user", "content": "What is the capital of France?"},
]

formatted_request = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)

response = await policy.generate.route(formatted_request)
completion = response[0]

logger.info(f"Response text: {completion.text}")
logger.info(f"Response content: {completion.content}")

assert completion.tool_calls == [], "Should have no tool calls"
assert completion.content is not None, "Should have content when no tools called"
assert (
completion.content == completion.text
), "Content should equal text when no tools"
Loading
Loading