Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"SPECBENCH_MEDUSA": models.SpecBenchMedusaModel,
}
datasets_available = {
"humaneval": datasets.HumanEval,
"mtbench": datasets.MTBench,
"random": datasets.RandomToken,
"specbench": datasets.SpecBench,
Expand Down Expand Up @@ -157,6 +158,7 @@ def run_simple(args):
tensor_parallel_size=args.tp_size,
moe_expert_parallel_size=args.ep_size,
trust_remote_code=args.trust_remote_code,
parallel_drafting=args.parallel_drafting,
**engine_args,
)

Expand Down Expand Up @@ -286,6 +288,7 @@ def run_simple(args):
"--output_length", type=int, required=False, default=4096, help="Output length"
)
parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length")
parser.add_argument("--parallel_drafting", action="store_true", help="Enable parallel drafting")
parser.add_argument(
"--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
)
Expand Down
3 changes: 2 additions & 1 deletion examples/specdec_bench/specdec_bench/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .humaneval import HumanEval
from .mtbench import MTBench
from .random_token import RandomToken
from .specbench import SpecBench
from .speed import SPEEDBench

__all__ = ["MTBench", "RandomToken", "SPEEDBench", "SpecBench"]
__all__ = ["HumanEval", "MTBench", "RandomToken", "SPEEDBench", "SpecBench"]
34 changes: 34 additions & 0 deletions examples/specdec_bench/specdec_bench/datasets/humaneval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from datasets import load_dataset
Comment thread
benchislett marked this conversation as resolved.
Outdated

from .base import Dataset, Request

def format_prompt(prompt: str) -> str:
return "Complete the following Python function. Only output the code, no explanations.\n\n" + prompt

class HumanEval(Dataset):
def __init__(self, path, num_samples=164, **kwargs):
self.data: list[Request] = [] # list of list of questions.
self.num_samples = num_samples
self._preprocess(path)

def _preprocess(self, path: str):
dataset = load_dataset(path, split='test')
for item in dataset:
self.data.append(Request(system_prompt=None, turns=[format_prompt(item["prompt"])]))
self.data = self.data[: self.num_samples]
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def process_final(self, text_outputs):
for request_id, turns in self.prompt_ar.items():
self.out["Request_AR"][request_id] = {}
for turn_id, turn in turns.items():
if len(turn) > 1 and turn[0] <= 1:
turn = turn[1:] # Skip prefill if it is 1 or less, indicating no specdec
if len(turn) > 1:
turn = turn[:-1] # Skip final acceptance due to EOS truncating speculation
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only skip if EOS is present?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's reasonable to skip anyways, since truncation for any reason (EOS, length, stop token) might misrepresent the AR since it is not aligned with the draft size per step

ar = sum(turn) / len(turn)
self.out["Request_AR"][request_id][turn_id] = ar
all_ar.append(ar)
Expand Down
3 changes: 3 additions & 0 deletions examples/specdec_bench/specdec_bench/models/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs
}
elif kwargs.get("speculative_algorithm") == "NONE":
specdec = None

if kwargs.get("parallel_drafting") and specdec is not None:
specdec["parallel_drafting"] = True
Comment thread
benchislett marked this conversation as resolved.

if specdec is None:
num_speculative_tokens = 1
Expand Down
Loading