Conversation
Signed-off-by: Rohan Joshi <rohjoshi@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughThis PR introduces support for 2:4 structured sparsity in attention mechanisms by adding new configuration options and implementing conditional sparsity mask computation. Two new pre-defined configurations enable 2:4 sparsity with skip-softmax, optionally combined with calibration. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
📝 Coding Plan
Comment |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1019 +/- ##
==========================================
- Coverage 70.25% 70.22% -0.04%
==========================================
Files 220 220
Lines 25368 25391 +23
==========================================
+ Hits 17822 17830 +8
- Misses 7546 7561 +15 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 3
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/sparsity/attention_sparsity/config.py`:
- Around line 104-111: When apply_sparse24 is True, bc must be grouped by 4; add
a cross-field validation in the attention sparsity config (the model config
class in this file) to enforce that when apply_sparse24 is True then bc % 4 == 0
(and optionally bc >= 4). Implement this as a Pydantic root_validator (or the
class's post-init check) that inspects values['apply_sparse24'] and values['bc']
and raises a ValueError with a clear message if the condition fails; reference
the apply_sparse24 and bc fields in the validator so invalid configs are
rejected early.
- Around line 96-102: The config flag skip_diagonal_blocks is declared but never
used; wire it into FlashSkipSoftmax by reading it from method_config in
FlashSkipSoftmax.__init__ and storing it as self.skip_diagonal_blocks, then
propagate it into the mask construction functions (sparse24_mask_along_last_dim
and calc_correction_factor_and_p). Specifically, update
FlashSkipSoftmax.__init__ to extract
SparseAttentionAttributeConfig.skip_diagonal_blocks, pass that boolean into
sparse24_mask_along_last_dim, and in sparse24_mask_along_last_dim modify the
mask generation to force full (dense) tiles for diagonal tile indices (i.e.,
where query_tile_idx == key_tile_idx) when self.skip_diagonal_blocks is True;
finally adjust calc_correction_factor_and_p to compute correction factors and p
using the effective number of sparse elements after preserving diagonal tiles so
the probability/count math remains correct. If you prefer not to implement
behavior now, alternatively mark skip_diagonal_blocks as reserved by updating
SparseAttentionAttributeConfig docstring and removing any expectation that code
reads it.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`:
- Around line 217-225: The reported sparsity uses block_mask before the optional
2:4 pruning, so when apply_sparse24 is true the final mask (element_mask after
AND with sparse24_mask from sparse24_mask_along_last_dim) is denser than
reported; update the stats computation to reflect post-AND rates by computing
element-level sparsity from the final element_mask (or add two fields, e.g.,
block_sparsity computed from block_mask and element_sparsity computed from
element_mask) and populate stats accordingly where stats["sparsity"] is set so
callers get accurate post-sparse24 metrics.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 1d7d1c0e-1264-43a9-a838-5e7a5c47cd32
📒 Files selected for processing (3)
examples/llm_sparsity/attention_sparsity/hf_sa.pymodelopt/torch/sparsity/attention_sparsity/config.pymodelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
| skip_diagonal_blocks: bool = ModeloptField( | ||
| default=True, | ||
| title="Skip diagonal blocks.", | ||
| description=( | ||
| "When True, keep diagonal tiles dense for 2:4 sparse attention. Defaults to True." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Verify whether skip_diagonal_blocks is consumed beyond config declarations/presets.
rg -n --type=py '\bskip_diagonal_blocks\b' -C2Repository: NVIDIA/Model-Optimizer
Length of output: 461
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Find the sparse24 mask implementation and related files
find modelopt/torch/sparsity/attention_sparsity -name "*.py" -type f | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1121
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Look at the config.py file to understand the class structure
cat -n modelopt/torch/sparsity/attention_sparsity/config.py | head -120Repository: NVIDIA/Model-Optimizer
Length of output: 5112
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Search for mask construction or application logic in the sparse attention module
rg -n 'def.*mask' --type=py modelopt/torch/sparsity/attention_sparsity/ -A 5Repository: NVIDIA/Model-Optimizer
Length of output: 755
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read the flash_skip_softmax.py file to see the complete implementation
wc -l modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pyRepository: NVIDIA/Model-Optimizer
Length of output: 142
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Read the flash_skip_softmax.py file
cat -n modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.pyRepository: NVIDIA/Model-Optimizer
Length of output: 19968
skip_diagonal_blocks is defined in config but not implemented.
The field is exposed in SparseAttentionAttributeConfig (line 96-102) with documented intent to "keep diagonal tiles dense for 2:4 sparse attention," but:
- It is never extracted from
method_configinFlashSkipSoftmax.__init__ - It is not referenced in the sparse24 mask application logic (
calc_correction_factor_and_p, lines 218–224, 274–278) sparse24_mask_along_last_dimapplies uniform 2:4 sparsity without diagonal block preservation
Either wire the flag into the mask construction or document it as reserved for future use.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 96 - 102,
The config flag skip_diagonal_blocks is declared but never used; wire it into
FlashSkipSoftmax by reading it from method_config in FlashSkipSoftmax.__init__
and storing it as self.skip_diagonal_blocks, then propagate it into the mask
construction functions (sparse24_mask_along_last_dim and
calc_correction_factor_and_p). Specifically, update FlashSkipSoftmax.__init__ to
extract SparseAttentionAttributeConfig.skip_diagonal_blocks, pass that boolean
into sparse24_mask_along_last_dim, and in sparse24_mask_along_last_dim modify
the mask generation to force full (dense) tiles for diagonal tile indices (i.e.,
where query_tile_idx == key_tile_idx) when self.skip_diagonal_blocks is True;
finally adjust calc_correction_factor_and_p to compute correction factors and p
using the effective number of sparse elements after preserving diagonal tiles so
the probability/count math remains correct. If you prefer not to implement
behavior now, alternatively mark skip_diagonal_blocks as reserved by updating
SparseAttentionAttributeConfig docstring and removing any expectation that code
reads it.
| apply_sparse24: bool = ModeloptField( | ||
| default=False, | ||
| title="Apply 2:4 structured sparsity.", | ||
| description=( | ||
| "If True, additionally apply 2:4 structured sparsity (top-2 of every 4 elements " | ||
| "along seq_k) on top of the skip-softmax block mask. Only used by flash_skip_softmax." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
Add cross-field validation for apply_sparse24 and bc.
apply_sparse24=True currently accepts any positive bc, but sparse24 masking requires grouping by 4. Invalid bc values can pass config validation and then fail later at runtime.
Suggested fix
-from pydantic import Field, field_validator
+from pydantic import Field, field_validator, model_validator
@@
class SparseAttentionAttributeConfig(ModeloptBaseConfig):
@@
apply_sparse24: bool = ModeloptField(
default=False,
@@
)
+
+ `@model_validator`(mode="after")
+ def validate_sparse24_requirements(self):
+ if self.apply_sparse24 and self.bc % 4 != 0:
+ raise ValueError("bc must be divisible by 4 when apply_sparse24=True")
+ return self🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/config.py` around lines 104 - 111,
When apply_sparse24 is True, bc must be grouped by 4; add a cross-field
validation in the attention sparsity config (the model config class in this
file) to enforce that when apply_sparse24 is True then bc % 4 == 0 (and
optionally bc >= 4). Implement this as a Pydantic root_validator (or the class's
post-init check) that inspects values['apply_sparse24'] and values['bc'] and
raises a ValueError with a clear message if the condition fails; reference the
apply_sparse24 and bc fields in the validator so invalid configs are rejected
early.
| # Step 7b: Apply 2:4 structured sparsity on top of block mask (optional) | ||
| if self.apply_sparse24: | ||
| attn_padded = blocked_attn.reshape( | ||
| batch_size, num_heads, padded_seq_q, padded_seq_k | ||
| ) | ||
| sparse24_mask = sparse24_mask_along_last_dim(attn_padded) | ||
| sparse24_mask = sparse24_mask[:, :, :seq_q, :seq_k] | ||
| element_mask = element_mask & sparse24_mask | ||
|
|
There was a problem hiding this comment.
Reported sparsity no longer matches the final mask when sparse24 is enabled.
At Line 224 and Line 278, element_mask is further pruned with sparse24_mask, but the returned stats["sparsity"] is still computed from block_mask (pre-sparse24). This under-reports effective sparsity in sparse24 modes and can skew analysis/calibration interpretation.
Please compute and report a post-AND metric (or add separate block_sparsity and element_sparsity fields).
Also applies to: 273-279
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py`
around lines 217 - 225, The reported sparsity uses block_mask before the
optional 2:4 pruning, so when apply_sparse24 is true the final mask
(element_mask after AND with sparse24_mask from sparse24_mask_along_last_dim) is
denser than reported; update the stats computation to reflect post-AND rates by
computing element-level sparsity from the final element_mask (or add two fields,
e.g., block_sparsity computed from block_mask and element_sparsity computed from
element_mask) and populate stats accordingly where stats["sparsity"] is set so
callers get accurate post-sparse24 metrics.
Summary
Adds an apply_sparse24: bool config option to the existing flash_skip_softmax method. When enabled, a 2:4 structured sparsity mask (top-2 of every 4 elements along seq_k) is AND-ed with the skip-softmax block mask in
both prefill and decode phases.
This is a pure PyTorch-level feature for research and analysis — not a performance optimization. It allows studying the interaction between block-level and 2:4 structured sparsity patterns.
Changes
Summary by CodeRabbit
sparse24_skip_softmaxandsparse24_skip_softmax_calibfor enhanced sparsity patterns.