Skip to content

add amd support for rope2d kernel#48

Open
ZJLi2013 wants to merge 1 commit intonaver:masterfrom
ZJLi2013:amd_support
Open

add amd support for rope2d kernel#48
ZJLi2013 wants to merge 1 commit intonaver:masterfrom
ZJLi2013:amd_support

Conversation

@ZJLi2013
Copy link
Copy Markdown

@ZJLi2013 ZJLi2013 commented Feb 25, 2026

before can't run on amd gpus, now can run on amd gpus

test_rope.py

# Copyright (C) 2022-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).

import unittest
import torch
import os
import sys

# Add the project root to sys.path to allow imports
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))

try:
    from models.curope.curope2d import cuRoPE2D

    MODULE_AVAILABLE = True
except ImportError:
    MODULE_AVAILABLE = False
    print(
        "Could not import models.curope.curope2d. Make sure curope extension is compiled."
    )


def rope_2d_ref(tokens, positions, base, F0=1.0):
    """
    Reference implementation of 2D RoPE.
    tokens: (B, H, N, D)
    positions: (B, N, 2)
    """
    # Work with (B, N, H, D) to match kernel logic structure
    x = tokens.transpose(1, 2).clone()  # (B, N, H, D)
    B, N, H, D = x.shape

    assert D % 4 == 0
    Q = D // 4

    # Calculate frequencies
    # inv_freq[i] = F0 / (base ** (i / Q))
    inv_freq = F0 / (
        base ** (torch.arange(0, Q, dtype=torch.float32, device=tokens.device) / Q)
    )

    # positions: (B, N, 2)
    pos_y = positions[..., 0].float().unsqueeze(-1)  # (B, N, 1)
    pos_x = positions[..., 1].float().unsqueeze(-1)  # (B, N, 1)

    angles_y = pos_y * inv_freq  # (B, N, Q)
    angles_x = pos_x * inv_freq  # (B, N, Q)

    cos_y = torch.cos(angles_y).unsqueeze(2)  # (B, N, 1, Q)
    sin_y = torch.sin(angles_y).unsqueeze(2)

    cos_x = torch.cos(angles_x).unsqueeze(2)
    sin_x = torch.sin(angles_x).unsqueeze(2)

    # Split tokens into [Y1, Y2, X1, X2]
    y1 = x[..., 0:Q]
    y2 = x[..., Q : 2 * Q]
    x1 = x[..., 2 * Q : 3 * Q]
    x2 = x[..., 3 * Q : 4 * Q]

    # Rotate Y
    # u' = u cos - v sin
    # v' = v cos + u sin
    y1_out = y1 * cos_y - y2 * sin_y
    y2_out = y2 * cos_y + y1 * sin_y

    # Rotate X
    x1_out = x1 * cos_x - x2 * sin_x
    x2_out = x2 * cos_x + x1 * sin_x

    out = torch.cat([y1_out, y2_out, x1_out, x2_out], dim=-1)

    # Transpose back to (B, H, N, D)
    return out.transpose(1, 2)


class TestCuRoPE2D(unittest.TestCase):

    def setUp(self):
        if not MODULE_AVAILABLE:
            self.skipTest("curope module not available")
        if not torch.cuda.is_available():
            self.skipTest("CUDA not available")

    def test_forward(self):
        B, H, N, D = 2, 4, 16, 32
        base = 100.0
        F0 = 1.0

        torch.manual_seed(42)
        # The kernel requires the input to be contiguous in (B, N, H, D) layout.
        # Since cuRoPE2D takes (B, H, N, D) and transposes it, we must ensure
        # that tokens.transpose(1, 2) is contiguous.
        # We achieve this by creating (B, N, H, D) and transposing to (B, H, N, D).
        tokens_storage = torch.randn(B, N, H, D, device="cuda", dtype=torch.float32)
        tokens = tokens_storage.transpose(1, 2)  # (B, H, N, D)

        positions = torch.randint(0, 100, (B, N, 2), device="cuda", dtype=torch.int64)

        model = cuRoPE2D(freq=base, F0=F0).cuda()

        # Reference
        out_ref = rope_2d_ref(tokens, positions, base, F0)

        # CUDA implementation (modifies in-place)
        # We need a fresh copy that preserves the stride structure
        tokens_cuda = tokens_storage.clone().transpose(1, 2)
        out_cuda = model(tokens_cuda, positions)

        # Check values
        diff = (out_ref - out_cuda).abs().max().item()
        self.assertTrue(
            torch.allclose(out_ref, out_cuda, atol=1e-4),
            f"Forward pass mismatch. Max diff: {diff}",
        )

        # Check in-place modification
        self.assertTrue(
            torch.allclose(tokens_cuda, out_cuda, atol=0),
            "Output should be same object/value as input (in-place)",
        )

    def test_backward(self):
        B, H, N, D = 2, 2, 8, 16
        base = 10.0
        F0 = 1.0

        torch.manual_seed(42)
        # Same stride requirement as forward
        tokens_storage = torch.randn(
            B, N, H, D, device="cuda", dtype=torch.float32, requires_grad=True
        )
        tokens = tokens_storage.transpose(1, 2)  # (B, H, N, D)

        positions = torch.randint(0, 20, (B, N, 2), device="cuda", dtype=torch.int64)

        # Reference Backward
        # We can use a standard contiguous tensor for reference as it doesn't use the kernel
        tokens_ref = tokens.clone().detach().contiguous()
        tokens_ref.requires_grad = True
        out_ref = rope_2d_ref(tokens_ref, positions, base, F0)
        loss_ref = out_ref.sum()
        loss_ref.backward()
        grad_ref = tokens_ref.grad

        # CUDA Backward
        model = cuRoPE2D(freq=base, F0=F0).cuda()
        # For CUDA, we need the specific stride
        # We must ensure tokens_cuda is NOT a leaf variable because it will be modified in-place.
        # Creating a leaf variable and then doing an operation makes it a non-leaf.
        tokens_leaf = tokens_storage.clone().detach().requires_grad_(True)
        tokens_non_leaf = tokens_leaf * 1.0
        tokens_cuda = tokens_non_leaf.transpose(1, 2)

        # Note: cuRoPE2D modifies in-place.
        # For gradient computation, we need to be careful if we reuse tokens_cuda.
        # But here we just do one pass.
        out_cuda = model(tokens_cuda, positions)
        loss_cuda = out_cuda.sum()
        loss_cuda.backward()
        # tokens_cuda is a non-leaf tensor, so it doesn't retain grad by default.
        # We check the gradient on the leaf variable tokens_leaf.
        grad_cuda = tokens_leaf.grad

        # grad_ref is (B, H, N, D), but grad_cuda is (B, N, H, D) because tokens_leaf is storage.
        # We need to transpose grad_cuda to match grad_ref.
        grad_cuda_t = grad_cuda.transpose(1, 2)

        diff = (grad_ref - grad_cuda_t).abs().max().item()
        self.assertTrue(
            torch.allclose(grad_ref, grad_cuda_t, atol=1e-4),
            f"Backward pass mismatch. Max diff: {diff}",
        )


if __name__ == "__main__":
    unittest.main()

verify on AMD MI300x

export PYTORCH_ROCM_ARCH="gfx942"
python3  setup.py install 
python3 test_rope.py

output

..
----------------------------------------------------------------------
Ran 2 tests in 0.813s

OK

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant