Fix int32 overflow deadlock and non-power-of-2 crash in Triton AlltoAllv (#2133)#2133
Closed
snarayankh wants to merge 1 commit intometa-pytorch:mainfrom
Closed
Fix int32 overflow deadlock and non-power-of-2 crash in Triton AlltoAllv (#2133)#2133snarayankh wants to merge 1 commit intometa-pytorch:mainfrom
snarayankh wants to merge 1 commit intometa-pytorch:mainfrom
Conversation
Contributor
|
@snarayankh has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101236430. |
…llv (meta-pytorch#2133) Summary: This diff fixes two bugs in the Triton AlltoAllv kernel (`comms/pipes/collectives/triton/`): **int32 iteration/completion counter overflow causes deadlock** The iteration counter tensor and completion counters tensor were allocated as `torch.int32`. The kernel computes signal values as `sender_bpp * (iteration + 1)` in int32 arithmetic before passing to `wait_signal_from`, which sign-extends the result to int64 and then reinterprets as uint64 for the C++ comparison. With `blocks_per_peer=16`, int32 overflows after ~134M iterations. The overflowed negative int32 becomes a huge uint64 (~18.4 quintillion) after sign-extension, making the unsigned comparison `actual >= expected` permanently false — causing all recv blocks to spin forever in a **silent deadlock**. The full overflow chain: 1. `torch.zeros(1, dtype=torch.int32)` — iteration tensor created as int32 2. `tl.load(iteration_ptr)` — loaded as Triton i32 in the kernel 3. `sender_bpp * (iteration + 1)` — i32 × i32 = i32 **OVERFLOW** at ~134M iterations 4. `tl.full([], expected_value, tl.int64)` — sign-extends negative i32 to i64 in `wait_signal_from` wrapper 5. C++ `torchcomms_wait_signal_from(unsigned long long expected_value)` — reinterprets as uint64 = ~18.4 quintillion 6. `cmp_op(CmpOp::GE, actual_uint64, expected_uint64)` — unsigned comparison, permanently false → **DEADLOCK** Time to failure: ~77 days at 20 iter/s with bpp=16, ~1.5 days at 1000 iter/s (CUDA graph replay). The sender side is fine — `signal_block` promotes the small value to int64 via `tl.zeros([], dtype=tl.int64) + value` *before* arithmetic, so the cumulative uint64 signal is always correct. The mismatch is only on the receiver side where the expected value computation overflows in int32 before widening. Fix: Change both `_ITERATION_TENSOR_CACHE` and `_COMPLETION_COUNTERS_CACHE` from `dtype=torch.int32` to `dtype=torch.int64`. This makes all kernel arithmetic (iteration loads, bpp multiplications, counter atomics) operate in int64, matching the uint64 signal infrastructure. Zero performance cost — these are a single scalar tensor and a small per-rank array. **`_sum_int64_kernel` crashes for non-power-of-2 world sizes** `_sum_int64_kernel` passes `N=self.world_size` directly to `tl.arange(0, N)`. Triton requires `tl.arange` range to be a power of 2 — verified in `third-party/tp2/triton/2.0.x/triton/python/triton/language/semantic.py:603-614`: `if (range & (range - 1)) != 0: raise ValueError("arange's range must be a power of 2")`. Non-power-of-2 world sizes (3, 5, 6, 7...) cause a compile-time `ValueError` crash. The sibling kernel `_prepare_alltoallv_kernel` in the same file correctly uses `BLOCK_SIZE=triton.next_power_of_2(self.world_size)` with masking (`mask = offsets < W`), but `_sum_int64_kernel` was not updated to match. Fix: Apply the same power-of-2 rounding + masking pattern that `_prepare_alltoallv_kernel` already uses. The `_prep_block_size` is already computed and available. Reviewed By: srinathb-meta Differential Revision: D101236430
5b7cda8 to
7385504
Compare
Contributor
|
This pull request has been merged in 59a4ca4. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary:
This diff fixes two bugs in the Triton AlltoAllv kernel (
comms/pipes/collectives/triton/):int32 iteration/completion counter overflow causes deadlock
The iteration counter tensor and completion counters tensor were allocated as
torch.int32. The kernel computes signal values assender_bpp * (iteration + 1)in int32 arithmetic before passing towait_signal_from, which sign-extends the result to int64 and then reinterprets as uint64 for the C++ comparison. Withblocks_per_peer=16, int32 overflows after ~134M iterations. The overflowed negative int32 becomes a huge uint64 (~18.4 quintillion) after sign-extension, making the unsigned comparisonactual >= expectedpermanently false — causing all recv blocks to spin forever in a silent deadlock.The full overflow chain:
torch.zeros(1, dtype=torch.int32)— iteration tensor created as int32tl.load(iteration_ptr)— loaded as Triton i32 in the kernelsender_bpp * (iteration + 1)— i32 × i32 = i32 OVERFLOW at ~134M iterationstl.full([], expected_value, tl.int64)— sign-extends negative i32 to i64 inwait_signal_fromwrappertorchcomms_wait_signal_from(unsigned long long expected_value)— reinterprets as uint64 = ~18.4 quintillioncmp_op(CmpOp::GE, actual_uint64, expected_uint64)— unsigned comparison, permanently false → DEADLOCKTime to failure: ~77 days at 20 iter/s with bpp=16, ~1.5 days at 1000 iter/s (CUDA graph replay).
The sender side is fine —
signal_blockpromotes the small value to int64 viatl.zeros([], dtype=tl.int64) + valuebefore arithmetic, so the cumulative uint64 signal is always correct. The mismatch is only on the receiver side where the expected value computation overflows in int32 before widening.Fix: Change both
_ITERATION_TENSOR_CACHEand_COMPLETION_COUNTERS_CACHEfromdtype=torch.int32todtype=torch.int64. This makes all kernel arithmetic (iteration loads, bpp multiplications, counter atomics) operate in int64, matching the uint64 signal infrastructure. Zero performance cost — these are a single scalar tensor and a small per-rank array._sum_int64_kernelcrashes for non-power-of-2 world sizes_sum_int64_kernelpassesN=self.world_sizedirectly totl.arange(0, N). Triton requirestl.arangerange to be a power of 2 — verified inthird-party/tp2/triton/2.0.x/triton/python/triton/language/semantic.py:603-614:if (range & (range - 1)) != 0: raise ValueError("arange's range must be a power of 2"). Non-power-of-2 world sizes (3, 5, 6, 7...) cause a compile-timeValueErrorcrash.The sibling kernel
_prepare_alltoallv_kernelin the same file correctly usesBLOCK_SIZE=triton.next_power_of_2(self.world_size)with masking (mask = offsets < W), but_sum_int64_kernelwas not updated to match.Fix: Apply the same power-of-2 rounding + masking pattern that
_prepare_alltoallv_kernelalready uses. The_prep_block_sizeis already computed and available.Reviewed By: srinathb-meta
Differential Revision: D101236430