Skip to content

[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes#2537

Open
KshitijLakhani wants to merge 20 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/bias-shape
Open

[PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes#2537
KshitijLakhani wants to merge 20 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/bias-shape

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 20, 2025

Description

TE common was not plumbing attention vector bias dimensions correctly to cuDNN.
Instead of using shape from Bias, i.e. [bias_sq, bias_skv] it was using [sq, skv] thereby passing larger than required dims. This PR correctly plumbs the bias shape from TE PyT to cuDNN via TE common.

Additionally, this PR also adds support for dbias , i.e, bias grad (fwd+bwd) calculation for b1ss, bhss, 11ss (initially only 1hss was supported) for CP and non-CP cases.
Support for bias calculation , i.e. no bias grad (fwd only) for 111s is also added for CP and non-CP cases
(bwd support to be added once cuDNN start supporting it in the future - TODOs sprinkled in code for the same)

Lastly, tests are added to support all newly added functionality for both CP and non-CP cases

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Pass bias_sq and bias_skv to fused_attn_arbitrary_seqlen_fwd_impl() and fused_attn_arbitrary_seqlen_bwd_impl()
  • Add new entries for bias_sq and bias_skv in FADescriptor_v1
  • Correct the bias passed to the MHA cuDNN graph to use bias_sq and bias_skv instead of s_q and s_kv
  • Enable dbias calculation for all cuDNN supported shapes : 1hss, 11ss, b1ss, bhss
  • Add TODOs for when cuDNN starts supporting dbias calculation for bias shape 111s

Testing:

  • Added tests (fwd only and no bias grad) for 111s bias shape in both, non-CP and CP fused attn tests
  • Added tests for 1hss, b1ss, bhss, 111s bias shapes in CP fused attn tests (non-CP already has tests for all other supported shapes)
  • Confirmed by using NVTE_DEBUG and additional test logging that the same test bias shape passes from PyT layer to cuDNN (this was necessary as there were hard coded shapes that would show a false positive thereby masking actual behavior)

Supplementary testing:

Using the reproducer : https://github.com/cyanguwa/TransformerEngine/tree/test_111s for bias [1,1,1,s] it can be seen in the cuDNN FE logs that prior to this PR the bias dims passed onto cuDNN from TE were
{"data_type":null,"dim":[1,1,128,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[16384,16384,128,1],"uid":0,"uid_assigned":false},
and after this PR they are:
"bias":{"data_type":null,"dim":[1,1,1,128],"is_pass_by_value":false,"is_virtual":false,"name":"bias","pass_by_value":null,"reordering_type":"NONE","stride":[128,128,128,1],"uid":0,"uid_assigned":false},

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 200fd98 to 8da3252 Compare December 22, 2025 18:21
@KshitijLakhani KshitijLakhani marked this pull request as ready for review December 22, 2025 18:24
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile Overview

Greptile Summary

This PR fixes a critical bug where TE was passing incorrect bias dimensions [s_q, s_kv] to cuDNN instead of the actual bias tensor dimensions [bias_sq, bias_skv]. This resulted in cuDNN receiving larger-than-necessary dimensions when bias shapes like [1,1,1,s] were used.

Key changes:

  • Added bias_sq and bias_skv fields to FADescriptor_v1 to track actual bias dimensions
  • Correctly extracts bias dimensions from input_Bias in forward pass and output_dBias in backward pass
  • Fixed bias tensor creation in cuDNN graphs to use actual bias dimensions instead of sequence lengths
  • Extended dbias calculation support from just 1hss to all cuDNN-supported shapes: 1hss, 11ss, b1ss, bhss
  • Added forward-only support for 111s bias shape (dbias not supported by cuDNN 9.18 yet - TODOs added for future)
  • Fixed the elif issue in utils.py where environment variable could be set incorrectly after disabling fused attention
  • Added comprehensive test coverage for all bias shapes in both CP and non-CP modes, including forward-only tests for 111s
  • Special handling in CP code for 111s where only s_kv dimension is split (not s_q)

All previously reported issues from review threads have been addressed.

Confidence Score: 5/5

  • This PR is safe to merge - it fixes a critical bug with correct dimension plumbing and adds proper test coverage
  • The implementation correctly addresses the bias dimension bug with clean architecture changes. All previous review concerns have been resolved (dimension extraction consistency, elif logic, None checks for bias gradients). The changes are well-tested with comprehensive coverage for all supported bias shapes in both CP and non-CP modes. The code properly handles edge cases like 111s shape where dbias is not supported, and includes appropriate TODOs for future enhancements.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/utils.h Added bias_sq and bias_skv fields to FADescriptor_v1 struct and comparison operator to track actual bias dimensions instead of using sequence lengths
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Correctly plumbs bias_sq and bias_skv from input tensors to cuDNN, fixes bias dimension extraction in backward pass, and adds support for dbias calculation for 1hss/11ss/b1ss/bhss shapes
transformer_engine/pytorch/attention/dot_product_attention/utils.py Fixed elif issue for 111s bias shape handling - correctly disables fused attention when dbias calculation needed for 111s, uses max512 backend only when bias doesn't require grad
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Added 111s bias shape support in CP - handles special case where only s_kv dimension is split (not s_q), correctly manages attn_dbias_ being None for this shape
tests/pytorch/attention/run_attention_with_cp.py Added comprehensive test coverage for all bias shapes (1hss/11ss/b1ss/bhss/111s) with CP, handles forward-only mode, fixed None-checking for bias gradients, properly handles 111s bias partitioning

Flowchart

flowchart TD
    A[PyTorch Layer: Input Bias Tensor] -->|Extract shape dimensions| B[bias_b, bias_h, bias_sq, bias_skv]
    B -->|Pass to| C[fused_attn_arbitrary_seqlen_fwd]
    C -->|Forward to| D[fused_attn_arbitrary_seqlen_fwd_impl]
    D -->|Store in| E[FADescriptor_v1 struct]
    E -->|Use for cuDNN graph| F[MHA Graph Bias Tensor Creation]
    F -->|Correct dimensions| G[bias tensor: bias_b, bias_h, bias_sq, bias_skv]
    
    H[Output dBias Tensor] -->|Backward: Extract all dims from output_dBias| I[bias_b, bias_h, bias_sq, bias_skv]
    I -->|Pass to| J[fused_attn_arbitrary_seqlen_bwd_impl]
    J -->|Check shape| K{bias_b==1 && bias_h==1 && bias_sq==1?}
    K -->|Yes: 111s shape| L[Skip dbias calculation - not supported in cuDNN 9.18]
    K -->|No: 1hss/11ss/b1ss/bhss| M[Enable dbias calculation via set_dbias]
    
    style G fill:#90EE90
    style L fill:#FFB6C1
    style M fill:#90EE90
Loading

Last reviewed commit: 2133bd8

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 22, 2025

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@KshitijLakhani KshitijLakhani changed the title Plumbing correct bias dims from TE to cudnn [PyT] Plumbing correct bias dims from TE to cudnn Dec 22, 2025
@KshitijLakhani KshitijLakhani added bug Something isn't working pytorch labels Dec 22, 2025
@cyanguwa
Copy link
Collaborator

Looks good - please pick the 111s test from my branch as well. Thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes bias dimension handling in fused attention by plumbing actual bias tensor dimensions (bias_sq, bias_skv) from input tensors through to cuDNN, replacing the previous incorrect usage of query/key sequence lengths (s_q, s_kv). This resolves dimension mismatches for broadcasted bias shapes like [1,1,1,s] where the bias dimensions are smaller than the attention matrix dimensions. The fix enables gradient computation for non-1hss bias shapes by removing the backward pass restriction in the Python layer.

Confidence Score: 4/5

  • Safe to merge after addressing minor consistency concern in backward pass dimension extraction
  • The core fix correctly addresses the bias dimension bug by extracting actual tensor shapes instead of using sequence lengths. The implementation is consistent across forward pass, backward pass, and FP8 paths. Test coverage has been expanded to validate the fix. One minor style issue: backward pass extracts bias_b/bias_h from output_dBias but bias_sq/bias_skv from input_Bias, creating potential inconsistency if shapes don't match, though this is unlikely in practice.
  • transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu for dimension extraction consistency in backward pass

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Adds bias_sq and bias_skv fields to FADescriptor_v1 struct and updates comparison operator
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 4/5 Updates fwd/bwd implementations to extract and use actual bias dimensions from input tensors instead of query/key sequence lengths
transformer_engine/pytorch/attention/dot_product_attention/utils.py 4/5 Removes restriction preventing bias gradient computation for non-1hss bias shapes, enabling backward pass support

Sequence Diagram

sequenceDiagram
    participant Py as Python Layer
    participant TE as TE Common (CUDA)
    participant cuDNN as cuDNN Backend
    
    Note over Py,cuDNN: Bias Dimension Propagation Fix
    
    Py->>TE: Pass bias tensor [b, h, bias_sq, bias_skv]
    Note over TE: Extract actual bias dims<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
    
    TE->>TE: Store in FADescriptor_v1<br/>(bias_sq, bias_skv)
    
    alt Before Fix
        Note over TE: Used s_q, s_kv incorrectly<br/>(e.g., [1,1,128,128] for [1,1,1,128])
    end
    
    alt After Fix
        Note over TE: Uses bias_sq, bias_skv correctly<br/>(e.g., [1,1,1,128] for [1,1,1,128])
    end
    
    TE->>cuDNN: Create bias tensor with<br/>dim={bias_b, bias_h, bias_sq, bias_skv}
    TE->>cuDNN: Create dBias tensor with same dims
    
    cuDNN->>TE: Compute attention + gradients
    TE->>Py: Return output with correct bias gradients
Loading

Comment on lines 1245 to 1248
bias_b = output_dBias->data.shape[0];
bias_h = output_dBias->data.shape[1];
bias_sq = input_Bias->data.shape[2];
bias_skv = input_Bias->data.shape[3];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bias dimensions are sourced from different tensors: bias_b and bias_h from output_dBias, while bias_sq and bias_skv from input_Bias. This assumes both tensors have matching shapes. Consider extracting all dimensions from the same tensor (preferably input_Bias for consistency with forward pass) or adding a validation check that shapes match.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in SHA 143ede5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

Fixes bias dimension plumbing from TransformerEngine to cuDNN by passing actual bias tensor dimensions (bias_sq, bias_skv) instead of sequence dimensions (s_q, s_kv). This resolves incorrect bias shape information being sent to cuDNN, particularly noticeable for bias shapes like [1,1,1,s] where the bias sequence dimensions differ from query/key/value sequence lengths. The fix enables cuDNN backend support for bias gradient computation in previously unsupported shapes.

Confidence Score: 5/5

  • Safe to merge - correct bug fix with comprehensive test coverage and no breaking changes
  • This PR correctly fixes the bias dimension plumbing issue where TE was incorrectly passing sequence dimensions instead of actual bias dimensions to cuDNN. The fix is well-implemented across all affected code paths (F16 and FP8), properly extracts bias dimensions from input tensors, and includes comprehensive test coverage. No functional issues or edge cases were identified.
  • No files require special attention

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated comparison operator
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Updated forward and backward implementations to extract and pass correct bias dimensions from input tensors to cuDNN
transformer_engine/pytorch/attention/dot_product_attention/utils.py 5/5 Removed restriction that disabled FusedAttention for bias gradients in non-1hss shapes, enabling cuDNN backend for these cases

Sequence Diagram

sequenceDiagram
    participant PyTorch as PyTorch Layer
    participant Utils as utils.py
    participant F16Impl as fused_attn_f16<br/>arbitrary_seqlen.cu
    participant Descriptor as FADescriptor_v1
    participant cuDNN as cuDNN FE Graph

    Note over PyTorch,cuDNN: Forward Pass with Bias [1,1,1,s]
    
    PyTorch->>Utils: get_attention_backend()<br/>check bias support
    Utils->>Utils: Enable cuDNN for<br/>bias gradient
    PyTorch->>F16Impl: fused_attn_arbitrary_seqlen_fwd()<br/>with input_Bias tensor
    F16Impl->>F16Impl: Extract bias dimensions:<br/>bias_sq = input_Bias.shape[2]<br/>bias_skv = input_Bias.shape[3]
    F16Impl->>Descriptor: Create FADescriptor_v1<br/>with bias_sq, bias_skv
    F16Impl->>cuDNN: Create bias tensor with<br/>dim=[bias_b, bias_h, bias_sq, bias_skv]
    Note over cuDNN: Correct dims [1,1,1,s]<br/>instead of [1,1,s,s]
    cuDNN-->>F16Impl: Execute attention
    F16Impl-->>PyTorch: Return output
Loading

@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR fixes a bug where TransformerEngine was incorrectly passing attention bias dimensions to cuDNN. Instead of using the actual bias tensor dimensions [bias_sq, bias_skv], it was using the full sequence dimensions [s_q, s_kv], which could be larger than the bias tensor.

Major Changes

  • Core Fix: Extract and pass actual bias dimensions (bias_sq, bias_skv) from the bias tensor shape throughout the call chain to cuDNN
  • Struct Update: Added bias_sq and bias_skv fields to FADescriptor_v1 for proper caching
  • Test Enhancement: Added bias gradient tracking and comparison in context parallelism tests
  • Backend Selection: Removed incorrect logic that disabled FusedAttention for non-1hss bias shapes when gradients weren't required

Issues Found

  • Critical Bug in Tests: run_attention_with_cp.py attempts to access bias.grad when bias is None (lines 342, 438), causing AttributeError for "no_bias" and "alibi" test cases

Confidence Score: 3/5

  • This PR fixes an important bug in bias dimension handling but introduces critical test failures
  • The core fix correctly addresses the bias dimension bug and is well-implemented across the C++/CUDA codebase. However, the test changes contain logic errors that will cause AttributeError when running tests with "no_bias" or "alibi" configurations, preventing proper validation of the fix.
  • Pay close attention to tests/pytorch/attention/run_attention_with_cp.py which has critical bugs on lines 342 and 438

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/common/fused_attn/utils.h 5/5 Added bias_sq and bias_skv fields to FADescriptor_v1 struct and updated the comparison operator. Changes are straightforward and correctly implemented.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu 5/5 Correctly extracts bias_sq and bias_skv from input_Bias->data.shape and passes them through the call chain to cuDNN. Bias tensor dimensions and strides are properly updated to use actual bias dimensions instead of sequence lengths.
tests/pytorch/attention/run_attention_with_cp.py 2/5 Adds bias gradient tracking and comparison logic for context parallelism tests. Contains critical bugs where bias.grad and bias_.grad are accessed when bias is None, causing AttributeError. Also adds proper reshaping logic for dbias comparison.

Sequence Diagram

sequenceDiagram
    participant Python as Python Layer<br/>(utils.py)
    participant ArbitraryFwd as fused_attn_arbitrary_seqlen_fwd<br/>(C++ wrapper)
    participant ArbitraryFwdImpl as fused_attn_arbitrary_seqlen_fwd_impl<br/>(C++ implementation)
    participant cuDNN as cuDNN Graph
    
    Note over Python,cuDNN: Forward Pass with Bias [1, 1, 1, s_kv]
    
    Python->>ArbitraryFwd: input_Bias tensor with shape [b, h, sq, skv]
    ArbitraryFwd->>ArbitraryFwd: Extract bias_b = input_Bias->shape[0]<br/>bias_h = input_Bias->shape[1]<br/>bias_sq = input_Bias->shape[2]<br/>bias_skv = input_Bias->shape[3]
    ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_b, bias_h, bias_sq, bias_skv
    ArbitraryFwdImpl->>ArbitraryFwdImpl: Store in FADescriptor_v1 for caching
    ArbitraryFwdImpl->>cuDNN: Create bias tensor with dimensions<br/>[bias_b, bias_h, bias_sq, bias_skv]<br/>Previously used [bias_b, bias_h, s_q, s_kv] ❌
    Note over cuDNN: Now receives correct bias dimensions ✓
    
    Note over Python,cuDNN: Backward Pass
    ArbitraryFwd->>ArbitraryFwd: Extract from output_dBias->shape
    ArbitraryFwd->>ArbitraryFwdImpl: Pass bias_sq, bias_skv
    ArbitraryFwdImpl->>cuDNN: Set dBias dimensions to [bias_b, bias_h, bias_sq, bias_skv]
    Note over cuDNN: dBias only computed if (bias_b==1 && bias_h==h)
Loading

else:
out.backward(dout)
dq, dk, dv = q.grad, k.grad, v.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias is None when attn_bias_type is "no_bias" or "alibi" (line 312), so bias.grad will raise AttributeError

Suggested change
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad
dq, dk, dv, dbias = q.grad, k.grad, v.grad, bias.grad if bias is not None else None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

else:
out_.backward(dout_)
dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bias_ is None when bias is None (line 355), so bias_.grad will raise AttributeError

Suggested change
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad
dq_, dk_, dv_, dbias_ = q_.grad, k_.grad, v_.grad, bias_.grad if bias_ is not None else None

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

dbias.shape[2] // (2 * world_size),
dbias.shape[3],
)
# bias has fixed axis (2) as dbias shape: (1, 1, max_seqlen_q, max_seqlen_kv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think our CP implementation (after your C changes) should support all bias shapes, not just 111s. I also think your reshaping here should work for all shapes. Could you run the tests to confirm?

@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 11c7107 to de3011e Compare January 21, 2026 19:41
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

1 similar comment
@KshitijLakhani
Copy link
Collaborator Author

/te-ci pytorch L0 L1

@KshitijLakhani KshitijLakhani changed the title [PyT] Plumbing correct bias dims from TE to cudnn [PyT] Plumbing correct bias dims from TE to cudnn, while adding support for additional bias shapes Feb 6, 2026
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch 2 times, most recently from ab1d2a9 to 8147617 Compare February 6, 2026 17:57
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

transformer_engine/pytorch/attention/dot_product_attention/utils.py
Incorrect fused-attn gating

This block disables FusedAttention only for 111s when bias.requires_grad, but it no longer disables fused attention for other non-1hss bias shapes with requires_grad=True (e.g. 11ss, b1ss, bhss). However fused_attn_f16_arbitrary_seqlen.cu still skips wiring dBias for some shapes, so leaving fused attention enabled here can silently produce missing/incorrect bias.grad (or a null dBias binding) when users request dbias for those shapes.

Concrete fix: restore explicit gating that matches the C++ set_dbias(...) condition, i.e. disable fused attention whenever dbias won’t be computed for the given fu_core_attention_bias_shape but fu_core_attention_bias_requires_grad is true.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

# For 1hss, 11ss, b1ss, bhss
shape_cache = shape
shape = shape.replace("_s_s", "_sq_skv")
if shape==shape_cache:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing spaces around == operator violates PEP 8 style

Suggested change
if shape==shape_cache:
if shape == shape_cache:

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 965 to 967
if not fu_core_attention_bias_requires_grad:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Independent condition - executed regardless of whether line 962's condition was true. This means NVTE_FUSED_ATTN_BACKEND=1 is set for all non-requires_grad bias cases, not just when use_fused_attention is still true. If line 964 set use_fused_attention=False, this environment variable shouldn't be set.

Consider changing to elif:

Suggested change
if not fu_core_attention_bias_requires_grad:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
elif not fu_core_attention_bias_requires_grad:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be "elif" here. Could you double check the logic here please?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're correct - it should be elif. Here's the logic issue:

Current code flow:

if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s":
    use_fused_attention = False  # Line 964 disables fused attention
if not fu_core_attention_bias_requires_grad:  # Line 965 - always evaluated independently
    os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

While the current code technically works (the conditions are mutually exclusive), using independent if statements obscures the intent. These are two branches handling the same scenario:

  • Branch 1: If bias requires grad AND shape is "111s" → disable fused attention
  • Branch 2: If bias doesn't require grad → use max512 backend

Using elif makes this relationship explicit and prevents future bugs if someone modifies the conditions.

Suggested fix:

Suggested change
if not fu_core_attention_bias_requires_grad:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
# dbias calculation is not supported for 111s as of cuDNN 9.18. So, use fused attention backend only if bias does not require grad.
if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s":
logger.warning("Disabling FusedAttention as dbias calculation is not supported for 111s")
use_fused_attention = False
elif not fu_core_attention_bias_requires_grad:

This ensures we only set the backend environment variable when we haven't already disabled fused attention, making the code's intent clearer.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
KshitijLakhani and others added 12 commits February 6, 2026 23:26
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…s, b1ss and bhss

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…hapes b1ss, bhss, 11ss in addition to the already supported 1hss

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…fused attn if dbias calculation for 111s is required, else enable

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…sts. Add bias shape 111s to test_dpa_bias_shapes

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…ed shape. Fix the comparison logic shapes for bias/dbias

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/fix/bias-shape branch from 7c77b7a to 126be03 Compare February 6, 2026 23:48
@KshitijLakhani
Copy link
Collaborator Author

/te-ci L0 L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines 966 to 978
and fu_core_attention_bias_type == "post_scale_bias"
and fu_core_attention_bias_shape != "1hss"
):
if fu_core_attention_bias_requires_grad:
# remove this line when cuDNN adds bwd support for
# [1, 1, s, s], [b, 1, s, s] and [b, h, s, s]
logger.debug("Disabling FusedAttention for dBias in [1, H, S, S] shape")
# dbias calculation is not supported for 111s as of cuDNN 9.18. So, use fused attention backend only if bias does not require grad.
if fu_core_attention_bias_requires_grad and fu_core_attention_bias_shape == "111s":
logger.warning(
"Disabling FusedAttention as dbias calculation is not supported for 111s"
)
use_fused_attention = False
else:
if not fu_core_attention_bias_requires_grad:
# max512 backend will only support [1, h, s, s]
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Env var set after disable

This block can still set os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" even after use_fused_attention was set to False for 111s+requires_grad, because the second branch is now an independent if rather than an else/elif. That leaves process-level state inconsistent with the selected backend and can leak into subsequent tests.

Consider making this mutually exclusive (e.g. elif not fu_core_attention_bias_requires_grad:) or gate the env-var assignment on use_fused_attention being true.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if "p2p" in cp_comm_type and config.attn_bias_type != "no_bias" and config.bias_shape == "111s":
pytest.skip(
f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel this skip message doesn't make sense. Should it be because we don't support dbias for 111s? But we can skip the bwd for it right?

Copy link
Collaborator Author

@KshitijLakhani KshitijLakhani Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about the skipping of bwd for the non-CP part.
For non-CP, this PR skips the testing of bwd and only test fwd for bias 111s. This works completely fine.
This is achieved by setting is_training=False and core_attention_bias_requires_grad=False

However, if the same is done for CP, i.e. set the two above shown flags to false, the test asserts at:

), "Sequence length does not meet divisible requirements!"
, which requires the bias seq dim to be divisible by 2, which will never happen right as bias is 111s.
So my understanding is that the only CP comm type, p2p, which supports bias, cannot actually support it (even in the fwd pass) due to the requirements in :
), "Sequence length does not meet divisible requirements!"

So, that's why I just skip that test using the same assert message as a skip message.

I can make the skip message more descriptive my directly attributing this divisibility requirement not being fulfilled to the bias 111s shape by saying something as below:
f"CP implementation with KV P2P requires bias sequence dim to be divisible by 2, which is not possible with {config.bias_shape=}"

The reason I suggest we use the above skip message is :
i) The skip is due to the fundamental reason that the seq dim for bias is not divisible by 2
ii) To your point, which I agree with, point i) stems as a consequence of the bias shape being 111s and not because the bias dim is just explicitly set to a number not divisible by 2.

Let me know your thoughts @cyanguwa

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. That makes sense. Could you make the message a bit more descriptive please, like you suggested? For 111s + CP, I wonder if we should be broadcasting the sq dimension instead of slicing it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this skip
Added support for 111s bias in CP fused attn
Added support in CP fused attn tests to set is_training to False and no grad

@cyanguwa
Copy link
Collaborator

Please address the couple of questions I left, but otherwise, it looks good! Thanks!

KshitijLakhani and others added 4 commits February 13, 2026 13:44
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@KshitijLakhani
Copy link
Collaborator Author

/te-ci L0 L1

@KshitijLakhani KshitijLakhani removed the bug Something isn't working label Feb 13, 2026
"and [1,1,sq,skv], [1,h,sq,skv], [b,1,sq,skv], [b,h,sq,skv] for backward!"
)
# For all bias shapes except 111s, sq must be divisible by 2 and sk must be divisible by 2*cp_size
# For bias shape 111s, only sq must be divisible by 2
Copy link
Collaborator

@cyanguwa cyanguwa Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be:
For bias shape 111s, only skv must be divisible by 2*cp_size

)
else:
assert attn_bias.shape[-1] % (2 * cp_size) == 0, "Sequence length does not meet divisible requirements!"
# [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you be missing a dimension, or will it be fine if no index_select is done?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants