Skip to content

[1/n] Add a Triton attention kernel with HF integration#1034

Open
kaix-nv wants to merge 1 commit intomainfrom
kaix/triton_kernel
Open

[1/n] Add a Triton attention kernel with HF integration#1034
kaix-nv wants to merge 1 commit intomainfrom
kaix/triton_kernel

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Mar 13, 2026

What does this PR do?

Type of change: ?

  • Adds a Triton flash attention kernel (triton_fa.py) with HF integration for use in sparse attention and quantization workflows. The kernel implements Flash Attention with varlen support, GQA, causal masking, and forward/backward.
  • Update the sparse attention to support backend="triton".

Key components:

  • modelopt/torch/kernels/triton_fa.py -- Core Triton kernel
  • modelopt/torch/kernels/hf_triton_attention.py -- HF adapter, registered as attn_implementation="modelopt_triton"
  • modelopt/torch/kernels/init.py -- Shared kernel registry
  • modelopt/torch/sparsity/attention_sparsity/conversion.py -- Backend selection (backend="triton" or "pytorch")
  • modelopt/torch/sparsity/attention_sparsity/config.py -- Added "triton" as valid backend option
  • examples/llm_sparsity/attention_sparsity/hf_sa.py -- Updated example to support --backend triton

Usage

# Direct kernel API (varlen packed format)
from modelopt.torch.kernels import attention

o = attention(
    q, k, v,  # [total_tokens, heads, head_dim]
    b_start_loc=b_start_loc,  # [batch] per-sequence start offsets
    b_seq_len=b_seq_len,      # [batch] per-sequence lengths
    max_input_len=max_seq_len,
    is_causal=True,
)

# HuggingFace integration (automatic via sparsify)
import modelopt.torch.sparsity.attention_sparsity as mtsa

config = {"sparse_cfg": {"*attn*": {"method": "flash_skip_softmax", "backend": "triton", "enable": True}}}
model = mtsa.sparsify(model, config=config)
# model now uses the Triton kernel for attention

# Or load directly with attn_implementation
model = AutoModelForCausalLM.from_pretrained(path, attn_implementation="modelopt_triton")

Testing

tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ / ❌ / N/A
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ / ❌ / N/A
  • Did you write any new necessary tests?: ✅ / ❌ / N/A
  • Did you update Changelog?: ✅ / ❌ / N/A

Additional Information

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 13, 2026

📝 Walkthrough

Walkthrough

This PR introduces Triton-based flash attention support to the sparse attention module, with variable-length sequence kernels, GQA support, and autograd integration. It adds backend selection between PyTorch eager and Triton implementations, updates configuration and registration flows, and refactors sparse attention activation to use method-provided contexts.

Changes

Cohort / File(s) Summary
Triton Kernel Implementation
modelopt/torch/kernels/triton_fa.py
Implements variable-length flash attention kernels in Triton with forward pass (online softmax, Q @ K^T with causal/padding masking) and backward kernels (dQ, dK, dV computation with LSE-based gradient stabilization). Supports GQA configurations via head grouping. Exposes attention() API with autograd support.
HuggingFace Triton Integration
modelopt/torch/kernels/hf_triton_attention.py
Registers Triton attention as "modelopt_triton" backend for HuggingFace transformers. Converts between HF format [batch, heads, seq, dim] and kernel varlen format [total_tokens, heads, dim], derives sequence lengths from attention masks, handles prefill/decode modes, and provides triton_attention_forward() and register_triton_attention() public APIs.
Kernel Module Exposure
modelopt/torch/kernels/__init__.py, modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
Exposes Triton kernels at package level with lazy CUDA availability detection. Defines IS_AVAILABLE flag, imports attention kernel and registration hook. Includes backward-compatibility re-export module for downstream imports.
Sparse Attention Backend Selection
modelopt/torch/sparsity/attention_sparsity/config.py, modelopt/torch/sparsity/attention_sparsity/conversion.py
Updates configuration to accept both "pytorch" and "triton" backends; adds backend validation and documentation. Introduces _set_attn_implementation() helper to register Triton backend, validate registration, and set model.config._attn_implementation accordingly before applying sparse attention conversions.
Sparse Attention Method Framework
modelopt/torch/sparsity/attention_sparsity/methods/registry.py, modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
Adds get_sparse_context() contract to SparseAttentionMethod base class for per-method activation strategies. Implements context in FlashSkipSoftmax that patches F.softmax with sparse masking during non-calibration phases.
Sparse Attention Module Refactoring
modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
Replaces hard-coded softmax patching with delegation to method-provided get_sparse_context(). Removes internal softmax wrapping logic and aligns statistics handling with method-provided context lifecycle.
Plugin & Logging Updates
modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
Adds structured logging via module-level logger and removes validate_eager_attention() helper function. Simplifies CUSTOM_MODEL_PLUGINS to contain only register_sparse_attention_on_the_fly.
Example & Documentation
examples/llm_sparsity/attention_sparsity/hf_sa.py, examples/llm_sparsity/attention_sparsity/README.md
Updates CLI to support "triton" backend selection, applies per-layer backend overrides via deep-copied sparse configs. Updates README to document both pytorch (eager) and triton backends and their corresponding attn_implementation settings.
Configuration
pyproject.toml
Adds Ruff ignore entry for Triton kernel naming conventions (N803, N806) in modelopt/torch/sparsity/attention_sparsity/kernels/*.
Test Suite
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py
Comprehensive GPU test module covering Triton attention correctness, GQA support, causal masking, and backward pass validation against SDPA reference. Includes integration tests for HF transformers (triton vs eager), gradient computation across variable-length batches, and long-sequence processing.

Sequence Diagram(s)

sequenceDiagram
    participant User as User / CLI
    participant Config as SparseAttentionConfig
    participant Conversion as convert_to_sparse_attention_model()
    participant Setup as _set_attn_implementation()
    participant Kernels as modelopt.torch.kernels
    participant HFAttn as HuggingFace Transformers
    
    User->>Config: Specify backend ("pytorch" or "triton")
    User->>Conversion: Call convert_to_sparse_attention_model()
    Conversion->>Setup: _set_attn_implementation(model, config)
    
    alt Backend is "triton"
        Setup->>Kernels: Check IS_AVAILABLE
        Kernels-->>Setup: True (CUDA + Triton present)
        Setup->>Kernels: Call register_triton_attention()
        Kernels->>HFAttn: Register "modelopt_triton" backend
        HFAttn-->>Kernels: Registration success
        Setup->>HFAttn: Set model.config._attn_implementation = "modelopt_triton"
    else Backend is "pytorch"
        Setup->>HFAttn: Set model.config._attn_implementation = "eager"
    end
    
    Conversion->>Conversion: Apply sparse attention plugins
    Conversion->>Conversion: Replace attention modules
Loading
sequenceDiagram
    participant Input as Model Forward Input
    participant SparseModule as SparseAttentionModule
    participant Method as SparseAttentionMethod
    participant Context as Sparse Context
    participant Softmax as F.softmax (patched)
    participant Output as Model Output
    
    Input->>SparseModule: query, key, value, attention_mask
    SparseModule->>Method: get_sparse_context(module)
    Method-->>SparseModule: Context manager with patched softmax
    
    SparseModule->>Context: __enter__()
    activate Context
    Context->>Softmax: Patch torch.nn.functional.softmax
    SparseModule->>SparseModule: Execute original attention forward
    SparseModule->>Softmax: Call softmax(logits)
    Softmax->>Softmax: Compute sparsity, apply sparse mask
    Softmax-->>SparseModule: Sparse-masked softmax output
    SparseModule->>SparseModule: Accumulate attention output
    SparseModule->>Context: __exit__()
    deactivate Context
    Context->>Softmax: Restore original softmax
    
    SparseModule-->>Output: Sparse attention result
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~55 minutes

🚥 Pre-merge checks | ✅ 4
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The pull request title accurately describes the primary changes: adding a Triton attention kernel and integrating it with HuggingFace models, which aligns with the substantial implementation across multiple files.
Docstring Coverage ✅ Passed Docstring coverage is 87.18% which is sufficient. The required threshold is 80.00%.
Security Anti-Patterns ✅ Passed No security anti-patterns detected: no torch.load with weights_only=False, numpy.load with allow_pickle=True, trust_remote_code=True, eval/exec calls, nosec comments, or unsafe dependencies found.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kaix/triton_kernel
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

@codecov
Copy link

codecov bot commented Mar 13, 2026

Codecov Report

❌ Patch coverage is 41.37931% with 17 lines in your changes missing coverage. Please review.
✅ Project coverage is 70.09%. Comparing base (1070d89) to head (a16409e).

Files with missing lines Patch % Lines
...y/attention_sparsity/methods/flash_skip_softmax.py 27.27% 8 Missing ⚠️
...pt/torch/sparsity/attention_sparsity/conversion.py 53.33% 7 Missing ⚠️
...delopt/torch/sparsity/attention_sparsity/config.py 0.00% 1 Missing ⚠️
...ch/sparsity/attention_sparsity/sparse_attention.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1034      +/-   ##
==========================================
- Coverage   70.10%   70.09%   -0.01%     
==========================================
  Files         221      221              
  Lines       25541    25554      +13     
==========================================
+ Hits        17905    17912       +7     
- Misses       7636     7642       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kaix-nv kaix-nv force-pushed the kaix/triton_kernel branch 8 times, most recently from e53e990 to ea9c45c Compare March 15, 2026 05:50
Add a Triton Flash Attention kernel that supports variable-length
batching, GQA, causal/non-causal masking, and autograd-compatible
forward/backward. Register it as attn_implementation="modelopt_triton"
for HuggingFace models.

Signed-off-by: Kai Xu <kaix@nvidia.com>
@kaix-nv kaix-nv force-pushed the kaix/triton_kernel branch from ea9c45c to a16409e Compare March 15, 2026 05:54
@kaix-nv kaix-nv changed the title Add Triton unified attention kernel with HuggingFace integration [OMNIML-3519] Add Triton unified attention kernel with HuggingFace integration Mar 15, 2026
@kaix-nv kaix-nv changed the title [OMNIML-3519] Add Triton unified attention kernel with HuggingFace integration [1/n] Add a Triton attention kernel with HF integration Mar 15, 2026
@kaix-nv kaix-nv marked this pull request as ready for review March 15, 2026 05:57
@kaix-nv kaix-nv requested review from a team as code owners March 15, 2026 05:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/llm_sparsity/attention_sparsity/README.md (1)

108-108: ⚠️ Potential issue | 🟡 Minor

Documentation inconsistency: table says "only supported backend" but Triton is now supported.

Line 108 states Backend: \pytorch` (only supported backend)which contradicts the introduction (lines 3-6) that describes bothpytorchandtriton` backends. Update the description to reflect both options.

📝 Suggested fix
-| `--backend` | `pytorch` | Backend: `pytorch` (only supported backend) |
+| `--backend` | `pytorch` | Backend: `pytorch` or `triton` |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_sparsity/attention_sparsity/README.md` at line 108, The table
row for the `--backend` flag currently reads "`--backend` | `pytorch` | Backend:
`pytorch` (only supported backend)`" but is outdated; update that cell to
reflect both supported backends (e.g., list `pytorch` and `triton` or say
"Backends: `pytorch`, `triton`") and adjust the description to remove "only
supported backend" and optionally indicate the default if one exists; locate the
row with the `--backend` entry in the README and modify its value/description
accordingly.
🧹 Nitpick comments (3)
examples/llm_sparsity/attention_sparsity/hf_sa.py (1)

147-154: Minor inconsistency between comment and code.

The comment on lines 147-148 states that mtsa.sparsify() automatically sets attn_implementation, but line 151 still hardcodes attn_implementation="eager". This works because the model is loaded before sparsification, but the comment could be clarified to indicate that the hardcoded value is for initial model loading and may be overridden by sparsify().

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/llm_sparsity/attention_sparsity/hf_sa.py` around lines 147 - 154,
The comment says mtsa.sparsify() sets attn_implementation automatically, but the
code still passes attn_implementation="eager" to
AutoModelForCausalLM.from_pretrained; update the comment (or remove the
hardcoded arg) to clarify that the explicit attn_implementation passed to
AutoModelForCausalLM.from_pretrained is only for the initial model load and may
be overridden later by mtsa.sparsify(), referencing
AutoModelForCausalLM.from_pretrained and mtsa.sparsify() so reviewers can find
the lines to change.
modelopt/torch/sparsity/attention_sparsity/methods/registry.py (1)

73-83: Consider using @abstractmethod for consistency.

The get_sparse_context method raises NotImplementedError but isn't decorated with @abstractmethod, unlike calculate_sparsity, apply_sparsity, and the name property in this class. If all subclasses must implement this method, using @abstractmethod would enforce this at class instantiation time rather than runtime.

♻️ Suggested change
+    `@abstractmethod`
     def get_sparse_context(self, module: torch.nn.Module):
         """Return a context manager that activates this method's sparsity during forward.

         Each method subclass implements its own activation mechanism:
         - Softmax-patching methods replace F.softmax during the forward pass.
         - Kernel-fused methods set flags on ``module`` that the kernel reads.

         Args:
             module: The SparseAttentionModule wrapping the attention layer.
         """
-        raise NotImplementedError(f"{type(self).__name__} must implement get_sparse_context()")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py` around lines
73 - 83, The get_sparse_context method currently raises NotImplementedError but
isn't decorated with `@abstractmethod` like calculate_sparsity, apply_sparsity and
the name property; add the `@abstractmethod` decorator to get_sparse_context in
the same abstract base class so subclasses are enforced to implement it at
instantiation time, keeping the existing signature def get_sparse_context(self,
module: torch.nn.Module): and leaving the NotImplementedError or replacing it
with a simple pass (no behavior change) — ensure the class still inherits from
abc.ABC or the existing base that provides `@abstractmethod` support.
tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py (1)

165-241: Add an end-to-end sparsify() integration test.

These tests prove direct attn_implementation="modelopt_triton" loading works, but they never execute modelopt/torch/sparsity/attention_sparsity/conversion.py::_set_attn_implementation(). A small mtsa.sparsify() test here would catch registration/order regressions in the new code path.

As per coding guidelines, "Write tests using pytest for all new features and examples; organize tests into tests/unit (fast CPU-based), tests/gpu (fast GPU-based), tests/gpu_megatron (Megatron-Core), tests/gpu_trtllm (TensorRT-LLM), and tests/examples (integration tests). All test coverage checks in PRs must pass for new features and examples."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py` around lines
165 - 241, Add an end-to-end test that calls mtsa.sparsify() to exercise the
sparsity registration and ensure conversion._set_attn_implementation() runs;
specifically, create a new GPU test (e.g., alongside
test_triton_matches_eager/test_triton_padded_batch) that loads the Tiny LLaMA
model, invokes mtsa.sparsify(model) (or the public API that triggers
conversion._set_attn_implementation()), then runs a forward/generate to assert
logits/tokens are valid/close (reuse attn_implementation="modelopt_triton" setup
and tokenizer logic); this will catch registration/order regressions without
changing existing eager/triton tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/kernels/hf_triton_attention.py`:
- Around line 31-46: The current _seq_lens_from_mask collapses any 2D mask to
lengths which only works for right-padded masks; detect rows that are not
right-padded (i.e., rows where a zero appears before a later one) and either
repack those sequences into a contiguous prefix ordering or raise a clear error;
implement the check by validating for each row of attention_mask that once a 0
appears no subsequent 1 exists, and if the check fails either 1) build packed
indices from the mask and return packed b_seq_len/indicator for the Triton
kernel or 2) raise ValueError with a message referencing _seq_lens_from_mask so
callers know they must provide right-padded masks (also apply the same guard
where similar logic occurs around the other block referenced at lines ~112-117).
- Around line 49-57: The triton_attention_forward implementation currently
infers causality via a seq_len heuristic and ignores the dropout parameter;
update it to read HF backend flags from kwargs instead: extract and honor
kwargs.get("is_causal") (or kwargs["is_causal"] with a clear validation) and
kwargs.get("dropout_p") (or fallback to the function's dropout parameter only if
absent), remove the seq_len<=1 causality heuristic, and ensure the attention
backend payload passed to the Triton kernel uses the explicit is_causal and
dropout values; alternatively, if those flags are unsupported, raise a clear
error if they are provided rather than silently applying wrong behavior.

In `@modelopt/torch/kernels/triton_fa.py`:
- Around line 488-520: In the forward function add fast-fail validations: check
that num_q_heads % num_kv_heads == 0 (derived from num_q_heads and num_kv_heads
/ kv_group_num) and raise a clear exception if not, and ensure when b_seq_len_k
is provided that max_input_len_k is not None (and likewise if b_seq_len_k is
None we set b_seq_len_k/b_start_loc_k/max_input_len_k as you already do); also
validate ctx.max_input_len_k (or the local max_input_len_k) is set before any
triton.cdiv calls so the backward pass won't see None — perform these checks at
the top of forward and raise ValueError with a descriptive message if they fail.

---

Outside diff comments:
In `@examples/llm_sparsity/attention_sparsity/README.md`:
- Line 108: The table row for the `--backend` flag currently reads "`--backend`
| `pytorch` | Backend: `pytorch` (only supported backend)`" but is outdated;
update that cell to reflect both supported backends (e.g., list `pytorch` and
`triton` or say "Backends: `pytorch`, `triton`") and adjust the description to
remove "only supported backend" and optionally indicate the default if one
exists; locate the row with the `--backend` entry in the README and modify its
value/description accordingly.

---

Nitpick comments:
In `@examples/llm_sparsity/attention_sparsity/hf_sa.py`:
- Around line 147-154: The comment says mtsa.sparsify() sets attn_implementation
automatically, but the code still passes attn_implementation="eager" to
AutoModelForCausalLM.from_pretrained; update the comment (or remove the
hardcoded arg) to clarify that the explicit attn_implementation passed to
AutoModelForCausalLM.from_pretrained is only for the initial model load and may
be overridden later by mtsa.sparsify(), referencing
AutoModelForCausalLM.from_pretrained and mtsa.sparsify() so reviewers can find
the lines to change.

In `@modelopt/torch/sparsity/attention_sparsity/methods/registry.py`:
- Around line 73-83: The get_sparse_context method currently raises
NotImplementedError but isn't decorated with `@abstractmethod` like
calculate_sparsity, apply_sparsity and the name property; add the
`@abstractmethod` decorator to get_sparse_context in the same abstract base class
so subclasses are enforced to implement it at instantiation time, keeping the
existing signature def get_sparse_context(self, module: torch.nn.Module): and
leaving the NotImplementedError or replacing it with a simple pass (no behavior
change) — ensure the class still inherits from abc.ABC or the existing base that
provides `@abstractmethod` support.

In `@tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py`:
- Around line 165-241: Add an end-to-end test that calls mtsa.sparsify() to
exercise the sparsity registration and ensure
conversion._set_attn_implementation() runs; specifically, create a new GPU test
(e.g., alongside test_triton_matches_eager/test_triton_padded_batch) that loads
the Tiny LLaMA model, invokes mtsa.sparsify(model) (or the public API that
triggers conversion._set_attn_implementation()), then runs a forward/generate to
assert logits/tokens are valid/close (reuse
attn_implementation="modelopt_triton" setup and tokenizer logic); this will
catch registration/order regressions without changing existing eager/triton
tests.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a7669ec7-b23c-4129-8118-df83d1d5a1fc

📥 Commits

Reviewing files that changed from the base of the PR and between 1070d89 and a16409e.

📒 Files selected for processing (14)
  • examples/llm_sparsity/attention_sparsity/README.md
  • examples/llm_sparsity/attention_sparsity/hf_sa.py
  • modelopt/torch/kernels/__init__.py
  • modelopt/torch/kernels/hf_triton_attention.py
  • modelopt/torch/kernels/triton_fa.py
  • modelopt/torch/sparsity/attention_sparsity/config.py
  • modelopt/torch/sparsity/attention_sparsity/conversion.py
  • modelopt/torch/sparsity/attention_sparsity/kernels/__init__.py
  • modelopt/torch/sparsity/attention_sparsity/methods/flash_skip_softmax.py
  • modelopt/torch/sparsity/attention_sparsity/methods/registry.py
  • modelopt/torch/sparsity/attention_sparsity/plugins/huggingface.py
  • modelopt/torch/sparsity/attention_sparsity/sparse_attention.py
  • pyproject.toml
  • tests/gpu/torch/sparsity/attention_sparsity/test_triton_fa.py

Comment on lines +26 to +41
if torch.cuda.is_available():
with import_plugin(
"triton",
msg_if_missing=(
"Your device is potentially capable of using the triton attention "
"kernel. Try to install triton with `pip install triton`."
),
):
from .triton_fa import attention as _attention

attention = _attention
IS_AVAILABLE = True
with import_plugin("transformers"):
from .hf_triton_attention import register_triton_attention as _register_triton_attention

register_triton_attention = _register_triton_attention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Load the Hugging Face registration hook independently of live CUDA detection.

modelopt/torch/sparsity/attention_sparsity/conversion.py::_set_attn_implementation() imports register_triton_attention from this module before any kernel runs. With the current torch.cuda.is_available() guard, CPU-side conversion always sees None and fails even if triton and transformers are installed. Gate the kernel entrypoint on CUDA, not the registry hook.

Comment on lines +31 to +46
def _seq_lens_from_mask(
attention_mask: torch.Tensor | None,
fallback: int,
device: torch.device,
) -> tuple[torch.Tensor | None, bool]:
"""Derive per-sequence lengths from attention mask.
Returns (b_seq_len, has_padding). If the mask is not a usable 2D format,
returns (None, False).
"""
if attention_mask is not None and attention_mask.dim() == 2:
mask = attention_mask.bool() if attention_mask.dtype != torch.bool else attention_mask
b_seq_len = mask.sum(dim=1).to(torch.int32).to(device)
has_padding = bool((b_seq_len != fallback).any())
return b_seq_len, has_padding
return None, False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject left-padded 2D masks or repack them.

Collapsing the mask to lengths only works for right padding. A row like [0, 0, 1, 1] will make the kernel read positions 0..1 as the valid prefix, and the post-mask then zeros the real tokens at 2..3. Please either pack from the mask positions or fail fast on non-right-padded masks.

🐛 Minimal guard
     if attention_mask is not None and attention_mask.dim() == 2:
         mask = attention_mask.bool() if attention_mask.dtype != torch.bool else attention_mask
+        if bool((~mask[:, :-1] & mask[:, 1:]).any()):
+            raise NotImplementedError(
+                "modelopt_triton currently supports only right-padded 2D attention masks"
+            )
         b_seq_len = mask.sum(dim=1).to(torch.int32).to(device)

Also applies to: 112-117

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/hf_triton_attention.py` around lines 31 - 46, The
current _seq_lens_from_mask collapses any 2D mask to lengths which only works
for right-padded masks; detect rows that are not right-padded (i.e., rows where
a zero appears before a later one) and either repack those sequences into a
contiguous prefix ordering or raise a clear error; implement the check by
validating for each row of attention_mask that once a 0 appears no subsequent 1
exists, and if the check fails either 1) build packed indices from the mask and
return packed b_seq_len/indicator for the Triton kernel or 2) raise ValueError
with a message referencing _seq_lens_from_mask so callers know they must provide
right-padded masks (also apply the same guard where similar logic occurs around
the other block referenced at lines ~112-117).

Comment on lines +49 to +57
def triton_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: torch.Tensor | None,
scaling: float,
dropout: float = 0.0,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

find . -name "hf_triton_attention.py" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 113


🏁 Script executed:

wc -l ./modelopt/torch/kernels/hf_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 117


🏁 Script executed:

cat -n ./modelopt/torch/kernels/hf_triton_attention.py

Repository: NVIDIA/Model-Optimizer

Length of output: 6610


Read the backend contract flags from kwargs instead of hardcoding causality via seq_len heuristic.

The current implementation determines causality by checking seq_len <= 1 (line 99: "is_causal": not is_decode) rather than respecting flags passed by Hugging Face. This violates the attention backend contract and will cause non-causal attention models (e.g., encoder-only transformers) to be incorrectly masked as causal during prefill. Additionally, the dropout parameter (line 56) is silently ignored despite being part of the interface.

The docstring correctly documents these limitations (lines 69–74), but the implementation should either:

  1. Extract and honor is_causal and dropout_p from kwargs to match HF's attention backend contract, or
  2. Validate unsupported configurations explicitly rather than silently applying incorrect behavior.

This affects lines 64–74 (docstring), 99 (causality logic), and the overall function contract.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/hf_triton_attention.py` around lines 49 - 57, The
triton_attention_forward implementation currently infers causality via a seq_len
heuristic and ignores the dropout parameter; update it to read HF backend flags
from kwargs instead: extract and honor kwargs.get("is_causal") (or
kwargs["is_causal"] with a clear validation) and kwargs.get("dropout_p") (or
fallback to the function's dropout parameter only if absent), remove the
seq_len<=1 causality heuristic, and ensure the attention backend payload passed
to the Triton kernel uses the explicit is_causal and dropout values;
alternatively, if those flags are unsupported, raise a clear error if they are
provided rather than silently applying wrong behavior.

Comment on lines +488 to +520
def forward(
ctx,
q,
k,
v,
b_start_loc,
b_seq_len,
max_input_len,
is_causal,
sm_scale,
b_start_loc_k,
b_seq_len_k,
max_input_len_k,
):
HEAD_DIM = q.shape[2]
num_q_heads = q.shape[1]
num_kv_heads = k.shape[1]
kv_group_num = num_q_heads // num_kv_heads
batch = b_seq_len.shape[0]

# Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V.
# Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata.
if b_seq_len_k is None:
b_seq_len_k = b_seq_len
b_start_loc_k = b_start_loc
max_input_len_k = max_input_len

# Pre-multiply scale by log2(e) so the kernel can use exp2()
# exp(score * sm_scale) = exp2(score * sm_scale * log2(e))
qk_scale = sm_scale * LOG2E
# Triton tiles must be powers of 2; pad head dim
BLOCK_D = triton.next_power_of_2(HEAD_DIM)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate GQA ratios and KV metadata before launching the kernel.

num_q_heads % num_kv_heads != 0 makes kv_head_idx walk past the available K/V heads, and providing b_seq_len_k without max_input_len_k leaves ctx.max_input_len_k=None, which blows up in backward at triton.cdiv(...). Please fail fast here instead of letting the Triton kernels hit undefined behavior.

🐛 Proposed validation
         HEAD_DIM = q.shape[2]
         num_q_heads = q.shape[1]
         num_kv_heads = k.shape[1]
+        if num_kv_heads == 0 or num_q_heads % num_kv_heads != 0:
+            raise ValueError(
+                f"num_q_heads ({num_q_heads}) must be a positive multiple of "
+                f"num_kv_heads ({num_kv_heads})"
+            )
         kv_group_num = num_q_heads // num_kv_heads
         batch = b_seq_len.shape[0]
 
         # Prefill: Q/K/V are the same packed tensor, reuse Q offsets for K/V.
         # Decode: K/V is a separate KV cache tensor, caller must pass explicit metadata.
         if b_seq_len_k is None:
             b_seq_len_k = b_seq_len
             b_start_loc_k = b_start_loc
             max_input_len_k = max_input_len
+        elif max_input_len_k is None:
+            raise ValueError("max_input_len_k is required when b_seq_len_k is provided")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/kernels/triton_fa.py` around lines 488 - 520, In the forward
function add fast-fail validations: check that num_q_heads % num_kv_heads == 0
(derived from num_q_heads and num_kv_heads / kv_group_num) and raise a clear
exception if not, and ensure when b_seq_len_k is provided that max_input_len_k
is not None (and likewise if b_seq_len_k is None we set
b_seq_len_k/b_start_loc_k/max_input_len_k as you already do); also validate
ctx.max_input_len_k (or the local max_input_len_k) is set before any triton.cdiv
calls so the backward pass won't see None — perform these checks at the top of
forward and raise ValueError with a descriptive message if they fail.

Comment on lines +47 to +55
sparse_cfg = config.sparse_cfg if hasattr(config, "sparse_cfg") else {}

# Collect backends only from layer configs (identified by having a "method" key).
# Other dict entries (e.g. "calibration") are not layer configs.
backends = {
v.get("backend", "pytorch")
for v in sparse_cfg.values()
if isinstance(v, dict) and "method" in v
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Don’t force eager when no sparse backend was selected.

backends can be empty here because the comprehension intentionally skips non-layer entries like "calibration". In that case this branch still downgrades the whole model to eager even though no sparse-attention backend was configured.

🐛 Minimal fix
-    elif model_config is not None:
+    elif backends and model_config is not None:

Also applies to: 84-87

Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving as pyproject.toml codeowner only

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants