Skip to content

Make test_triangle_updates a multi-GPU test#5890

Merged
wujingyue merged 3 commits intomainfrom
wjy/sp
Feb 27, 2026
Merged

Make test_triangle_updates a multi-GPU test#5890
wujingyue merged 3 commits intomainfrom
wjy/sp

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 29, 2026

I got the "triangle updates" test passing with 3D sharding in this PR. Below are the key issues identified, workarounds and their current status:

1. Sharding Propagation Rework

2. Multi-Dimensional Sharding & getCommunicationInfo

  • Issue: Fix convertSingleOpToCommunication for 2D sharding #4604
  • Details: The commit updates getCommunicationInfo to support multi-dimensional sharding. It reuses haveDifferentShardings to identify inconsistencies between input and output TensorView objects. The commit needs cleanup and further test verification to be merged.
  • Technical Debt: Per Extend IdModel to map DIDs for certain patterns. #3987, haveDifferentShardings is currently bottlenecked by the expensive ExpressionSimplifier. We need to transition this to be IdModel-based in a future iteration.

3. Misaligned Memory Access in Transpose Kernels

  • Issue: Misaligned memory access with 3D sharding #5920
  • Details: A generated transpose kernel is hitting misaligned memory access errors. This occurs during the transposition between the local Einsum and the downstream ReduceScatter. For context, this transposition was introduced by ReorderShardedAxisPass to ensure the scattered axis of the ReduceScatter is allocated outermost.

4. High memory usage

  • Issue: Reduce memory usage in distributed triangle updates #5942
  • Details: The current naive AllGather preceding the Einsum is functional but consumes too much memory for AlphaFold3 workloads due to long sequence lengths.
  • Proposed Fix: We need to implement stream-parallelization to enable:
    • Ring-based AllGather (with Swizzle), or
    • Broadcast-based communication (without Swizzle). AFAICT, fast broadcast requires multicasting and therefore symmetric memory.

cc @DejunL

@github-actions
Copy link

github-actions bot commented Jan 29, 2026

Review updated until commit 8649ff1

Description

  • Move test_triangle_updates function from single-GPU direct tests to multi-GPU tests

  • Add MPI test decorator and multi-device test fixture integration

  • Implement 3D device mesh sharding configuration (dp_size × cp_size × cp_size)

  • Add manual tensor sharding for inputs and matmul output with proper parallelization

  • Update test data preparation to use sharded tensors across multiple devices

Changes walkthrough

Relevant files
Tests
test_alphafold3.py
Remove test_triangle_updates from direct tests                     

tests/python/direct/test_alphafold3.py

  • Remove test_triangle_updates function and gating helper function
  • Add comment noting test has been moved to multidevice tests
  • +3/-127 
    test_alphafold3.py
    Add multi-GPU triangle updates test with 3D sharding         

    tests/python/multidevice/test_alphafold3.py

  • Add test_triangle_updates function with multi-GPU support
  • Implement 3D device mesh configuration with dp_size, cp_size, cp_size
    dimensions
  • Add manual sharding configuration for input tensors and matmul output
  • Integrate MPI test decorator and multidevice_test fixture
  • Update test data preparation to use sharded tensors across devices
  • +241/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Device Mesh Validation

    The device mesh creation assumes d is divisible by cp_size * cp_size. While there's a skip condition for invalid configurations, consider adding more robust validation or clearer error messaging for edge cases.

    mesh = nvfuser.multidevice.DeviceMesh(
        torch.arange(d).reshape(dp_size, cp_size, cp_size)
    )
    Manual Sharding Complexity

    The manual sharding logic for matmul_out is complex and scattered across multiple lines (182-193). Consider extracting this into a helper function or adding more detailed comments to explain the sharding strategy, especially given the TODO comment about sharding propagation rework.

    # TODO(#5901): this can be avoided with a better sharding propagation.
    #
    # matmul_out is of shape [b, c, i, j]. We shard `b` by `DIDz`, `i` by
    # `DIDy`, and `j` by `DIDx`.
    matmul_out.outer_split(-1, cp_size)
    match direction:
        case Direction.OUTGOING:
            matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
        case Direction.INCOMING:
            matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_y)
    matmul_out.outer_split(3, cp_size)
    matmul_out.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
    matmul_out.outer_split(2, cp_size)
    matmul_out.axis(2).parallelize(nvfuser.ParallelType.mesh_y)
    matmul_out.outer_split(0, dp_size)
    matmul_out.axis(0).parallelize(nvfuser.ParallelType.mesh_z)
    Performance Considerations

    The PR description mentions several performance issues including misaligned memory access (#5920) and high memory usage (#5942). While this test validates correctness, it may not catch these performance regressions. Consider if additional performance monitoring or smaller test cases would be beneficial.

    def test_triangle_updates(direction, multidevice_test):
        d = multidevice_test.size
        cp_size = 2
        if d % (cp_size * cp_size) != 0:
            pytest.skip(
                f"We only support even split, so {d} has to be divisible by {cp_size * cp_size} for {cp_size=}."
            )
        dp_size = d // (cp_size * cp_size)
    
        c_z = _DEFAULT_CONFIG.c_z
    
        with FusionDefinition() as fd:
            z_in_tv = fd.define_tensor(
                shape=[-1, -1, -1, c_z],
                dtype=DataType.BFloat16,
                contiguity=True,
            )  # [b, i, j, c_z]
            w_norm_in = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            b_norm_in = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_p_in = fd.define_tensor(
                shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_g_in = fd.define_tensor(
                shape=[c_z * 2, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_norm_out = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            b_norm_out = fd.define_tensor(
                shape=[c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_p_out = fd.define_tensor(
                shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            w_g_out = fd.define_tensor(
                shape=[c_z, c_z], dtype=DataType.BFloat16, contiguity=True
            )
            # Masking is used in an internal implementation: http://nv/e-4
            mask_tv = fd.define_tensor(
                shape=[-1, -1, -1], dtype=DataType.Bool, contiguity=True
            )  # [b, i, j]
    
            batch_size = fd.ops.size(z_in_tv, 0)
            n_tokens = fd.ops.size(z_in_tv, 1)
    
            z_in = layer_norm(fd, z_in_tv, w_norm_in, b_norm_in)
            z = gating(fd, z_in, w_p_in, z_in, w_g_in)
            mask = fd.ops.broadcast_in_dim(
                mask_tv,
                shape=[batch_size, n_tokens, n_tokens, c_z],
                broadcast_dims=[0, 1, 2],
            )
            z = fd.ops.where(mask, z, 0.0)
            a = fd.ops.slice(z, [0, 0, 0, 0], [batch_size, n_tokens, n_tokens, c_z])
            b = fd.ops.slice(z, [0, 0, 0, c_z], [batch_size, n_tokens, n_tokens, c_z * 2])
    
            match direction:
                case Direction.OUTGOING:
                    # z_out = einsum("bikc,bjkc->bijc", a, b)
                    a = fd.ops.permute(a, [0, 3, 1, 2])  # [b, c, i, k]
                    b = fd.ops.permute(b, [0, 3, 2, 1])  # [b, c, k, j]
                case Direction.INCOMING:
                    # z_out = einsum("bkic,bkjc->bijc", a, b)
                    a = fd.ops.permute(a, [0, 3, 2, 1])  # [b, c, i, k]
                    b = fd.ops.permute(b, [0, 3, 1, 2])  # [b, c, k, j]
            z = fd.ops.matmul(a, b)  # [b, c, i, j]
            matmul_out = z
            z = fd.ops.permute(z, [0, 2, 3, 1])  # [b, i, j, c]
    
            z = layer_norm(fd, z, w_norm_out, b_norm_out)
            z = gating(fd, z, w_p_out, z_in, w_g_out)
            fd.add_output(z)
    
            mesh = nvfuser.multidevice.DeviceMesh(
                torch.arange(d).reshape(dp_size, cp_size, cp_size)
            )
            for tv in [
                z_in_tv,
                w_norm_in,
                b_norm_in,
                w_p_in,
                w_g_in,
                w_norm_out,
                b_norm_out,
                w_p_out,
                w_g_out,
                mask_tv,
                matmul_out,
            ]:
                tv.set_device_mesh(mesh)
    
            for tv in [z_in_tv, mask_tv]:
                tv.outer_split(2, cp_size)
                tv.axis(2).parallelize(nvfuser.ParallelType.mesh_x)
                tv.outer_split(1, cp_size)
                tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y)
                tv.outer_split(0, dp_size)
                tv.axis(0).parallelize(nvfuser.ParallelType.mesh_z)
    
            # TODO(#5901): this can be avoided with a better sharding propagation.
            #
            # matmul_out is of shape [b, c, i, j]. We shard `b` by `DIDz`, `i` by
            # `DIDy`, and `j` by `DIDx`.
            matmul_out.outer_split(-1, cp_size)
            match direction:
                case Direction.OUTGOING:
                    matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
                case Direction.INCOMING:
                    matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_y)
            matmul_out.outer_split(3, cp_size)
            matmul_out.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
            matmul_out.outer_split(2, cp_size)
            matmul_out.axis(2).parallelize(nvfuser.ParallelType.mesh_y)
            matmul_out.outer_split(0, dp_size)
            matmul_out.axis(0).parallelize(nvfuser.ParallelType.mesh_z)
    
        batch_per_rank = 3
        n_tokens_per_rank = 5
        z_in_ref = torch.testing.make_tensor(
            batch_per_rank * dp_size,
            n_tokens_per_rank * cp_size,
            n_tokens_per_rank * cp_size,
            c_z,
            dtype=torch.bfloat16,
            device="cpu",
        )
        mask_ref = torch.testing.make_tensor(
            batch_per_rank * dp_size,
            n_tokens_per_rank * cp_size,
            n_tokens_per_rank * cp_size,
            dtype=torch.bool,
            device="cpu",
        )
    
        z_in = multidevice_test.shard_tensor(z_in_ref, z_in_tv)
        w_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        b_norm_in = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        w_p_in = torch.testing.make_tensor(
            c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
        )
        w_g_in = torch.testing.make_tensor(
            c_z * 2, c_z, dtype=torch.bfloat16, device="cuda"
        )
        w_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        b_norm_out = torch.testing.make_tensor(c_z, dtype=torch.bfloat16, device="cuda")
        w_p_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
        w_g_out = torch.testing.make_tensor(c_z, c_z, dtype=torch.bfloat16, device="cuda")
        mask = multidevice_test.shard_tensor(mask_ref, mask_tv)
        (z_out,) = fd.execute(
            [
                z_in,
                w_norm_in,
                b_norm_in,
                w_p_in,
                w_g_in,
                w_norm_out,
                b_norm_out,
                w_p_out,
                w_g_out,
                mask,
            ]
        )
        assert z_out.shape == (batch_per_rank, n_tokens_per_rank, n_tokens_per_rank, c_z)

    @wujingyue wujingyue changed the title 3D sharding for triangle updates Make test_triangle_updates a multi-GPU test Feb 23, 2026
    @wujingyue wujingyue marked this pull request as ready for review February 23, 2026 01:53
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Feb 23, 2026

    Greptile Summary

    Migrates the test_triangle_updates test from single-GPU to multi-GPU setup with 3D sharding for AlphaFold3 triangle updates operation.

    Key Changes:

    • Moved test from tests/python/direct/test_alphafold3.py to tests/python/multidevice/test_alphafold3.py
    • Added 3D device mesh configuration with data parallelism (DIDz), and 2D context parallelism (DIDx, DIDy)
    • Manual sharding applied to matmul_out tensor due to sharding propagation limitations (tracked in issue Improve sharding propagation for triangle updates outgoing #5901)
    • Test now requires devices divisible by cp_size * cp_size (4 devices minimum for cp_size=2)
    • Input tensors sharded across batch and token dimensions using multidevice_test.shard_tensor

    Implementation Notes:

    Confidence Score: 4/5

    Important Files Changed

    Filename Overview
    tests/python/direct/test_alphafold3.py Removed test_triangle_updates function and gating helper, added comment indicating move to multidevice tests
    tests/python/multidevice/test_alphafold3.py New multi-GPU test with 3D sharding (dp_size, cp_size, cp_size) for triangle updates operation, includes manual sharding configuration

    Last reviewed commit: 8649ff1

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Collaborator

    @chaserileyroberts chaserileyroberts left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Small comments

    Comment on lines +223 to +236
    (z_out,) = fd.execute(
    [
    z_in,
    w_norm_in,
    b_norm_in,
    w_p_in,
    w_g_in,
    w_norm_out,
    b_norm_out,
    w_p_out,
    w_g_out,
    mask,
    ]
    )
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we want to include a numerics execution test as well? Although I know those are hard to make for more complex ops.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we want to include a numerics execution test as well?

    Yes, we should.

    Although I know those are hard to make for more complex ops.

    Yes, it's hard to figure out a reasonable comparison threshold. We may have to go with toy sizes as in

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Are you aiming to do this in the PR?
    I do think we should have a numerics check to avoid silent errors within schedulers.

    Copy link
    Collaborator Author

    @wujingyue wujingyue Feb 24, 2026

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Are you aiming to do this in the PR?

    No -- https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/getting-started/helping-others-review-your-changes#write-small-pull-requests

    I do think we should have a numerics check to avoid silent errors within schedulers.

    Yes

    case Direction.INCOMING:
    matmul_out.axis(-2).parallelize(nvfuser.ParallelType.mesh_y)
    matmul_out.outer_split(3, cp_size)
    matmul_out.axis(3).parallelize(nvfuser.ParallelType.mesh_x)
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do the X, Y, and Z mesh axis represent real hardware communication links? I.e., Communicating over X would always be local NVLinks, Y being infiniband, etc, or is that arbitrary and left to the how the mesh gets constructed?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    is that arbitrary and left to the how the mesh gets constructed?

    That

    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue wujingyue requested a review from Priya2698 February 24, 2026 07:04
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    2 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue wujingyue merged commit bd913a9 into main Feb 27, 2026
    53 checks passed
    @wujingyue wujingyue deleted the wjy/sp branch February 27, 2026 23:18
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants