Refactor: Clean up EAGLE training dataset preparation#684
Refactor: Clean up EAGLE training dataset preparation#684benchislett wants to merge 3 commits intomainfrom
Conversation
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #684 +/- ##
==========================================
- Coverage 70.10% 70.09% -0.02%
==========================================
Files 221 221
Lines 25541 25541
==========================================
- Hits 17905 17902 -3
- Misses 7636 7639 +3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| sources: | ||
| - name: "sharegpt" | ||
| splits: | ||
| all: 0 |
There was a problem hiding this comment.
0 means "no samples". This is just an example template file showing which datasets are supported, so include most of the options even if they aren't used by default
|
Can you rebase to main as there are some conflicts? |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
📝 WalkthroughWalkthroughThis change consolidates multiple dataset-specific preparation scripts into a single unified dataset builder. Individual scripts for Daring-Anteater, MTBench, ShareGPT, and UltraChat are removed and replaced with Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant MakeDataset as make_dataset.py
participant ConfigLoader as Config Loader
participant AsyncLoaders as Async Loaders
participant Dedup as Deduplication
participant OutputWriter as Output Writer
User->>MakeDataset: python make_dataset.py --config-file example_data_config.yaml
MakeDataset->>ConfigLoader: load_conversations_for_split() for each source
ConfigLoader->>AsyncLoaders: Initialize dataset-specific loaders<br/>(MTBench, ShareGPT, UltraChat, etc.)
loop For each dataset source in config
AsyncLoaders->>AsyncLoaders: Download & parse dataset
AsyncLoaders->>AsyncLoaders: Filter/normalize records
AsyncLoaders->>MakeDataset: Yield conversation objects
end
MakeDataset->>Dedup: Aggregate samples from all sources
Dedup->>Dedup: Deduplicate by truncated ID
Dedup->>Dedup: Apply per-split constraints
Dedup->>Dedup: Apply global_limit subsampling
Dedup->>OutputWriter: Shuffle and prepare batches
OutputWriter->>OutputWriter: Generate conversation_id if missing
OutputWriter->>User: Write JSONL to train.jsonl
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.Add a pylint configuration file to your project to customize how CodeRabbit runs |
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
examples/speculative_decoding/prepare_input_conversations/make_dataset.py (1)
75-82: Normalizesplitsinto a dict in the type system too.After
__post_init__, the rest of the file only treatssplitsas a mapping, but the field is still annotated aslist[str] | dict[...]. Line 468 callssource.splits.items()unconditionally, so the runtime normalization works while the static contract stays wrong. A factory/classmethod that converts list input before instantiation would keep this API honest and avoid a mypy false branch here.As per coding guidelines, "Use mypy for type checking on Python code (configured in
pyproject.toml)."
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/speculative_decoding/prepare_input_conversations/make_dataset.py`:
- Around line 427-457: The current deduplication_ids set is local to
load_conversations_for_split (created where deduplication_ids is assigned) so
identical conversations from different splits/datasets can bypass dedup when
later concatenated; to fix, move the seen-set to the per-output scope and pass
it into load_conversations_for_split (or deduplicate after gathering split
results) so id_for_conversation(truncated_conversations) is checked against a
shared deduplication_ids before adding to unique_samples and before applying
global_limit/max_samples_for_constraint; update callers to accept a shared
deduplication_ids and ensure the check/update occurs prior to enforcing
global_limit.
- Around line 114-121: The function check_row_constraint performs a numeric
comparison before validating the type, causing TypeError for string typos; fix
it by first handling the special string "all", then validate that constraint is
an instance of (int, float) before doing constraint < 0, and if it's not a
numeric type raise a ValueError with a clear message; ensure
check_row_constraint still returns None for "all", returns the numeric value for
valid ints/floats (after the negative check), and raises for any other types
instead of letting a TypeError bubble up.
- Around line 272-287: In _load_ultrachat_conversations, the function currently
ignores the dataset's "messages" field and always emits a single-turn
conversation from "prompt"; update it to check ds[i].get("messages") and when
present and when the caller requested full conversations (honor the
--full-conversations flag by reading the relevant boolean variable or
parameter), convert those messages into the same [{"role": "...", "content":
"..."}] structure (mapping dataset speaker names to "user"/"assistant" as
appropriate), use that list as msgs, and only fall back to using the single
prompt (as today) when "messages" is missing/empty or full-conversations is
false; keep using id_for_conversation to generate prompt_id when needed and
preserve the prompt_id naming convention f"ultrachat-{split_name}-{prompt_id}"
and the existing length/yield semantics and final logger call.
In `@examples/speculative_decoding/README.md`:
- Around line 204-205: Update the README sentence that points readers to
make_dataset.py so it instead directs them to the example_data_config.yaml (the
config used by the new CLI flow) and briefly note that dataset selection is
managed via that YAML config and the CLI rather than by editing the script;
change the reference text in examples/speculative_decoding/README.md to mention
example_data_config.yaml and the CLI workflow.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 382c764d-1729-4360-aea6-a00065c567ad
📒 Files selected for processing (10)
examples/speculative_decoding/README.mdexamples/speculative_decoding/prepare_input_conversations/add_daring_anteater.pyexamples/speculative_decoding/prepare_input_conversations/add_mtbench.pyexamples/speculative_decoding/prepare_input_conversations/add_sharegpt.pyexamples/speculative_decoding/prepare_input_conversations/add_ultrachat.pyexamples/speculative_decoding/prepare_input_conversations/example_data_config.yamlexamples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.shexamples/speculative_decoding/prepare_input_conversations/make_dataset.pyexamples/speculative_decoding/prepare_input_conversations/utils.pyexamples/speculative_decoding/train_eagle3_and_export.sh
💤 Files with no reviewable changes (6)
- examples/speculative_decoding/prepare_input_conversations/add_ultrachat.py
- examples/speculative_decoding/prepare_input_conversations/add_sharegpt.py
- examples/speculative_decoding/prepare_input_conversations/add_mtbench.py
- examples/speculative_decoding/prepare_input_conversations/utils.py
- examples/speculative_decoding/prepare_input_conversations/add_daring_anteater.py
- examples/speculative_decoding/prepare_input_conversations/example_make_prompt_dataset.sh
| def check_row_constraint(constraint) -> int | float | None: | ||
| if constraint == "all": | ||
| return None | ||
| if constraint < 0: | ||
| raise ValueError("Number of samples to use for a split cannot be negative.") | ||
| if isinstance(constraint, (float, int)): | ||
| return constraint | ||
| return 0 |
There was a problem hiding this comment.
Validate the constraint type before comparing it.
constraint < 0 runs before the type check, so a typo like "50" or "half" in the YAML raises TypeError instead of a clear config error. This is user-facing input, so fail fast with explicit validation here.
🛠️ Proposed fix
-def check_row_constraint(constraint) -> int | float | None:
+def check_row_constraint(constraint: int | float | str | None) -> int | float | None:
if constraint == "all":
return None
+ if isinstance(constraint, bool) or not isinstance(constraint, (int, float)):
+ raise ValueError("Split constraint must be a non-negative int, float, or 'all'.")
if constraint < 0:
raise ValueError("Number of samples to use for a split cannot be negative.")
- if isinstance(constraint, (float, int)):
- return constraint
- return 0
+ return constraint🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/prepare_input_conversations/make_dataset.py`
around lines 114 - 121, The function check_row_constraint performs a numeric
comparison before validating the type, causing TypeError for string typos; fix
it by first handling the special string "all", then validate that constraint is
an instance of (int, float) before doing constraint < 0, and if it's not a
numeric type raise a ValueError with a clear message; ensure
check_row_constraint still returns None for "all", returns the numeric value for
valid ints/floats (after the negative check), and raises for any other types
instead of letting a TypeError bubble up.
| async def _load_ultrachat_conversations( | ||
| split_name: str, | ||
| ) -> AsyncGenerator[int | dict[str, Any], None]: | ||
| ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split_name) | ||
| ds = ds.shuffle(seed=42) | ||
| yield len(ds) | ||
| for i in range(len(ds)): | ||
| prompt = ds[i]["prompt"].strip() | ||
| prompt_id = ds[i]["prompt_id"].strip() | ||
| if prompt: | ||
| msgs = [{"role": "user", "content": prompt}] | ||
| if not prompt_id: | ||
| prompt_id = id_for_conversation(msgs) | ||
| prompt_id = f"ultrachat-{split_name}-{prompt_id}" | ||
| yield {"conversation_id": prompt_id, "conversations": msgs} | ||
| logger.info(f"Finished loading UltraChat {split_name} conversations.") |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
# First, understand the file structure and context around the function
fd "make_dataset.py" --type fRepository: NVIDIA/Model-Optimizer
Length of output: 139
🏁 Script executed:
# Look for --full-conversations flag usage
rg "full-conversations" --type py -i -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 545
🏁 Script executed:
# Search for how this flag is processed
rg "full_conversations" --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 626
🏁 Script executed:
# Check the entire make_dataset.py file to understand context
wc -l "examples/speculative_decoding/prepare_input_conversations/make_dataset.py"Repository: NVIDIA/Model-Optimizer
Length of output: 143
🏁 Script executed:
# Read the relevant section and surrounding code
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '250,320p'Repository: NVIDIA/Model-Optimizer
Length of output: 2813
🏁 Script executed:
# Check how _load_ultrachat_conversations is called and how args are passed
rg "_load_ultrachat_conversations" --type py -B 5 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 3732
🏁 Script executed:
# Look for where datasets are loaded and how full_conversations is used
rg "load_dataset" --type py -B 3 -A 3 | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 7282
🏁 Script executed:
# Read more context around the function calls to understand data flow
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '1,100p'Repository: NVIDIA/Model-Optimizer
Length of output: 3932
🏁 Script executed:
# Check the README mentioned in the review
cat "examples/speculative_decoding/README.md" | sed -n '40,60p'Repository: NVIDIA/Model-Optimizer
Length of output: 748
🏁 Script executed:
# See how _load_ultrachat_conversations is called in load_dataset_to_jsonl
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '350,450p'Repository: NVIDIA/Model-Optimizer
Length of output: 4768
🏁 Script executed:
# Look for how other loaders or processing functions use full_conversations
rg "full.conversations" --type py -B 5 -A 5 | head -150Repository: NVIDIA/Model-Optimizer
Length of output: 7547
🏁 Script executed:
# Check the utils to understand id_for_conversation
rg "def id_for_conversation" --type py -B 2 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 851
🏁 Script executed:
# Check if there's any reference to messages field in ultrachat context
rg "messages" --type py -B 2 -A 2 "prepare_input_conversations"Repository: NVIDIA/Model-Optimizer
Length of output: 133
🏁 Script executed:
# Search for any reference to messages field in the make_dataset.py file
grep -n "messages" "examples/speculative_decoding/prepare_input_conversations/make_dataset.py"Repository: NVIDIA/Model-Optimizer
Length of output: 376
🏁 Script executed:
# Let me check the UltraChat dataset structure by looking at actual usages
# First check if there's any test or example that shows UltraChat structure
rg "ultrachat" --type py -i -B 3 -A 3 | grep -A 5 -B 5 "messages"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🌐 Web query:
HuggingFaceH4/ultrachat_200k dataset schema structure fields
💡 Result:
HuggingFaceH4/ultrachat_200k entries are stored (Parquet) with this schema/field structure: prompt, messages, prompt_id. [1]
prompt:string(the instruction / opening user prompt) [1]messages:listof message objects, each with: [1]content:stringrole:string(e.g.,"user","assistant")
prompt_id:string(a 64-char id / hash) [1]
The dataset provides 4 splits: train_sft, test_sft, train_gen, test_gen. [1]
🏁 Script executed:
# Let's verify the suggested fix by examining if similar patterns exist elsewhere
# Check how nemotron loader handles messages
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '361,401p'Repository: NVIDIA/Model-Optimizer
Length of output: 1749
🏁 Script executed:
# Check if the suggested diff's logic makes sense
# Specifically look at how other loaders check for empty messages/content
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '200,270p'Repository: NVIDIA/Model-Optimizer
Length of output: 3057
🏁 Script executed:
# Let's also check if prompt_id handling in the suggested fix is correct
# Look at how prompt_id is currently handled
cat -n "examples/speculative_decoding/prepare_input_conversations/make_dataset.py" | sed -n '272,287p'Repository: NVIDIA/Model-Optimizer
Length of output: 893
Use the messages field from UltraChat dataset when available to honor --full-conversations flag.
The --full-conversations flag controls whether to preserve full conversation history, but _load_ultrachat_conversations ignores the dataset's messages field and always emits only the initial prompt as a single-turn conversation. This makes the flag ineffective for UltraChat. Extract multi-turn conversations from the messages field when present, and fall back to prompt-only when messages is empty or unavailable.
Suggested fix
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split_name)
ds = ds.shuffle(seed=42)
yield len(ds)
for i in range(len(ds)):
- prompt = ds[i]["prompt"].strip()
- prompt_id = ds[i]["prompt_id"].strip()
- if prompt:
- msgs = [{"role": "user", "content": prompt}]
- if not prompt_id:
- prompt_id = id_for_conversation(msgs)
- prompt_id = f"ultrachat-{split_name}-{prompt_id}"
- yield {"conversation_id": prompt_id, "conversations": msgs}
+ row = ds[i]
+ msgs = [
+ {"role": msg["role"], "content": msg["content"].strip()}
+ for msg in row.get("messages", [])
+ if msg.get("content", "").strip()
+ ]
+ if not msgs:
+ prompt = row["prompt"].strip()
+ if not prompt:
+ continue
+ msgs = [{"role": "user", "content": prompt}]
+
+ prompt_id = row.get("prompt_id")
+ if not prompt_id:
+ prompt_id = id_for_conversation(msgs)
+ prompt_id = f"ultrachat-{split_name}-{prompt_id}"
+ yield {"conversation_id": prompt_id, "conversations": msgs}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/prepare_input_conversations/make_dataset.py`
around lines 272 - 287, In _load_ultrachat_conversations, the function currently
ignores the dataset's "messages" field and always emits a single-turn
conversation from "prompt"; update it to check ds[i].get("messages") and when
present and when the caller requested full conversations (honor the
--full-conversations flag by reading the relevant boolean variable or
parameter), convert those messages into the same [{"role": "...", "content":
"..."}] structure (mapping dataset speaker names to "user"/"assistant" as
appropriate), use that list as msgs, and only fall back to using the single
prompt (as today) when "messages" is missing/empty or full-conversations is
false; keep using id_for_conversation to generate prompt_id when needed and
preserve the prompt_id naming convention f"ultrachat-{split_name}-{prompt_id}"
and the existing length/yield semantics and final logger call.
| deduplication_ids = set() | ||
| unique_samples = [] | ||
| max_num_samples = max_samples_for_constraint(num_samples, row_constraint) | ||
| async for sample in samples_it: | ||
| assert isinstance(sample, dict) and "conversations" in sample, ( | ||
| "Each conversation sample must be a dict with a 'conversations' field." | ||
| ) | ||
|
|
||
| # Strip the last turn of the conversation as long as it is an assistant completion, | ||
| # since we want to use these conversations as prompts only. | ||
| if strip_last_completion: | ||
| while sample["conversations"] and sample["conversations"][-1]["role"] != "user": | ||
| sample["conversations"].pop() | ||
|
|
||
| if not sample["conversations"]: | ||
| continue | ||
|
|
||
| sample["source_dataset"] = dataset_name | ||
| sample["source_split"] = split_name | ||
|
|
||
| # Deduplicate based on the first 512 characters from each turn. | ||
| # To avoid too many similar conversations with minor differences. | ||
| truncated_conversations = [ | ||
| {"role": msg["role"], "content": msg["content"][0:512]} | ||
| for msg in sample["conversations"] | ||
| ] | ||
| dedup_id = id_for_conversation(truncated_conversations) | ||
| if dedup_id not in deduplication_ids: | ||
| deduplication_ids.add(dedup_id) | ||
| unique_samples.append(sample) | ||
| if len(unique_samples) >= max_num_samples: |
There was a problem hiding this comment.
Deduplicate across the whole output, not just one split.
deduplication_ids is recreated on every load_conversations_for_split() call. Line 481 later concatenates many split results into one output file, so identical conversations from different splits or datasets can still survive into the same JSONL. If output uniqueness is a goal here, move the seen-set to the per-output scope and apply it before global_limit.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/prepare_input_conversations/make_dataset.py`
around lines 427 - 457, The current deduplication_ids set is local to
load_conversations_for_split (created where deduplication_ids is assigned) so
identical conversations from different splits/datasets can bypass dedup when
later concatenated; to fix, move the seen-set to the per-output scope and pass
it into load_conversations_for_split (or deduplicate after gathering split
results) so id_for_conversation(truncated_conversations) is checked against a
shared deduplication_ids before adding to unique_samples and before applying
global_limit/max_samples_for_constraint; update callers to accept a shared
deduplication_ids and ensure the check/update occurs prior to enforcing
global_limit.
| In addition to the default dataset, we support adding several other commonly used datasets in `prepare_input_conversations/make_dataset.py`: | ||
|
|
There was a problem hiding this comment.
Point users to the YAML config here.
This sentence sends readers to make_dataset.py, but dataset selection in this workflow now lives in prepare_input_conversations/example_data_config.yaml. Pointing to the config file will match the new CLI flow and keep users out of the script.
📝 Suggested wording
-In addition to the default dataset, we support adding several other commonly used datasets in `prepare_input_conversations/make_dataset.py`:
+In addition to the default dataset, you can enable several other commonly used datasets in `prepare_input_conversations/example_data_config.yaml`:📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| In addition to the default dataset, we support adding several other commonly used datasets in `prepare_input_conversations/make_dataset.py`: | |
| In addition to the default dataset, you can enable several other commonly used datasets in `prepare_input_conversations/example_data_config.yaml`: |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@examples/speculative_decoding/README.md` around lines 204 - 205, Update the
README sentence that points readers to make_dataset.py so it instead directs
them to the example_data_config.yaml (the config used by the new CLI flow) and
briefly note that dataset selection is managed via that YAML config and the CLI
rather than by editing the script; change the reference text in
examples/speculative_decoding/README.md to mention example_data_config.yaml and
the CLI workflow.
What does this PR do?
Type of change: Refactor
Overview:
make_dataset.pyUsage
See README for detailed example
Testing
Ran it locally on all dataset modes, works successfully and output looks good. Checked shuffling, conversation IDs, and output contents were all unique and usable.
Summary by CodeRabbit
Documentation
New Features
Refactor