Skip to content

Re-enable skipped test_normalization.py::test_instance_norm_multigpu#4833

Draft
jacobhinkle wants to merge 1 commit intomainfrom
jh/reenable_instance_norm_multigpu_test
Draft

Re-enable skipped test_normalization.py::test_instance_norm_multigpu#4833
jacobhinkle wants to merge 1 commit intomainfrom
jh/reenable_instance_norm_multigpu_test

Conversation

@jacobhinkle
Copy link
Collaborator

This test no longer fails, but it was skipped in #1730.

Fixes #1728

This test no longer fails, but it was skipped in #1730.

Fixes #1728
@jacobhinkle jacobhinkle requested review from jjsjann123 and naoyam July 23, 2025 14:38
@github-actions
Copy link

Description

  • Re-enabled skipped test for multi-GPU instance normalization

  • Removed skip decorator and condition


Changes walkthrough 📝

Relevant files
Tests
test_normalization.py
Re-enable multi-GPU instance normalization test                   

tests/python/test_normalization.py

  • Removed @pytest.mark.skip decorator
  • Kept @pytest.mark.skipif for multi-GPU requirement
  • +0/-3     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Test Stability

    Ensure that the test is stable and does not intermittently fail. Verify that the root cause of the previous failure has been addressed.

    def test_instance_norm_multigpu():
        class Model(nn.Module):
            def __init__(self):

    @jacobhinkle
    Copy link
    Collaborator Author

    !test

    @naoyam
    Copy link
    Collaborator

    naoyam commented Jul 23, 2025

    Reading #1728, I wonder what fixed the issue. Is the fusion scheduled differently?

    Copy link
    Collaborator

    @jjsjann123 jjsjann123 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    More than happy to enable the test again.
    But similar question as what @naoyam has, do we know that we can close the original issue now?

    @jacobhinkle
    Copy link
    Collaborator Author

    Reading #1728, I wonder what fixed the issue. Is the fusion scheduled differently?

    The test was fixed by #4748.

    Repro of failing segment

    # CUDA devices:
    #  0: NVIDIA H100 NVL
    #  1: NVIDIA H100 NVL
    # torch version: 2.8.0a0+34c6371d24.nvInternal
    # cuda version: 13.0
    # nvfuser version: 0.2.27+git8830aff
    import torch
    from nvfuser import FusionDefinition, DataType
    
    def nvfuser_fusion_id2(fd : FusionDefinition) -> None :
        T0 = fd.define_tensor(shape=[-1, -1, -1, -1, -1], contiguity=[True, None, None, None, None], dtype=DataType.Float, is_cpu=False)
        T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
        T2 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.Float, is_cpu=False)
        T3 = fd.define_tensor(shape=[-1, -1, -1, -1, -1], contiguity=[True, True, True, True, True], dtype=DataType.Float, is_cpu=False)
        S4 = fd.define_scalar(1, dtype=DataType.Int)
        S5 = fd.ops.size(T3, dim=2)
        S6 = fd.ops.mul(S4, S5)
        S7 = fd.ops.size(T3, dim=3)
        S8 = fd.ops.mul(S6, S7)
        S9 = fd.ops.size(T3, dim=4)
        S10 = fd.ops.mul(S8, S9)
        S11 = fd.ops.cast(S10, dtype=DataType.Float)
        S12 = fd.ops.reciprocal(S11)
        T13 = fd.ops.mul(T2, S12)
        T14 = fd.ops.mul(T1, T1)
        T15 = fd.ops.mul(T13, T14)
        T16 = fd.ops.broadcast(T15, is_broadcast_dim=[False, False, True, True, True])
        T17 = fd.ops.mul(T3, T16)
        T18 = fd.ops.sub(T0, T17)
        T19 = fd.ops.squeeze(T0, dims=[2, 3, 4], squeeze_expanded=True)
        S20 = fd.ops.size(T0, dim=2)
        S21 = fd.ops.size(T0, dim=3)
        S22 = fd.ops.mul(S20, S21)
        S23 = fd.ops.size(T0, dim=4)
        S24 = fd.ops.mul(S22, S23)
        S25 = fd.ops.cast(S24, dtype=DataType.Float)
        T26 = fd.ops.mul(T19, S25)
        T27 = fd.ops.mul(T26, S12)
        T28 = fd.ops.broadcast(T27, is_broadcast_dim=[False, False, True, True, True])
        T29 = fd.ops.sub(T18, T28)
        T30 = fd.ops.broadcast(T1, is_broadcast_dim=[False, False, True, True, True])
        T31 = fd.ops.mul(T29, T30)
        fd.add_output(T31)
    
    
    with FusionDefinition() as fd:
        nvfuser_fusion_id2(fd)
    
    inputs = [
        torch.randn(2, dtype=torch.float32, device='cuda:1').as_strided((2, 4, 128, 128, 128), (1, 0, 0, 0, 0)),
        torch.testing.make_tensor((2, 4), dtype=torch.float32, device='cuda:1'),
        torch.testing.make_tensor((2, 4), dtype=torch.float32, device='cuda:1'),
        torch.testing.make_tensor((2, 4, 128, 128, 128), dtype=torch.float32, device='cuda:1'),
    ]
    fd.execute(inputs)

    Params before #4748:

    ===== Pointwise Stats ========
    num_elems: 16777216
    elem_counts: 2 4 128 128 128
    max_dtype_size_for_vectorization: 4
    unroll_factor_inner: 1
    unroll_factor_outer: 1
    vectorize_factor: 4
    
    reorder_map:
    broadcast_byte_multiples: (0, 20), (20, 16), (20, 8), (20, 8), (20, 8),
    LHS elems: 8 RHS elems: 2097152
    
    
    ===== Pointwise Parameters ========
    Tag: Pointwise heuristics Pointwise Characteristics:
     Gridx: 1 BlckY: 1 BlckX: 128
    2D Schedule
      Bcast break point: 2
    vectorization_factor: 4
    unroll_factor_outer: 1
    unroll_factor_inner: 1
    ====================================
    

    Params after #4748:

    
    ===== Pointwise Stats ========
    num_elems: 16777216
    elem_counts: 2 4 128 128 128
    max_dtype_size_bit_for_vectorization: 32
    unroll_factor_inner: 1
    unroll_factor_outer: 1
    vectorize_factor: 4
    
    reorder_map:
    broadcast_byte_multiples: (0, 160), (160, 128), (160, 64), (160, 64), (160, 64),
    LHS elems: 0 RHS elems: 0
    
    
    ===== Pointwise Parameters ========
    Tag: Pointwise heuristics Pointwise Characteristics:
     Gridx: 1 BlckY: 1 BlckX: 128
    vectorization_factor: 4
    unroll_factor_outer: 1
    unroll_factor_inner: 1
    ====================================
    

    We are no longer using a 2D grid, but I'm not sure yet what part of that PR caused the change in heuristic. cc @zasdfgbnm

    @naoyam
    Copy link
    Collaborator

    naoyam commented Jul 28, 2025

    We are no longer using a 2D grid, but I'm not sure yet what part of that PR caused the change in heuristic. cc @zasdfgbnm

    It seems just a luck. That PR discovered an unfixed bug in the breakpoint logic (#4773). The bug is still there since the PR just changed the use of bytes to bits, however it did affect the breakpoint decision, so I believe that's why this segment no longer uses the 2D scheduling. Since the bug remains to be fixed, this segment may again use the 2D scheduling once the bug is fixed.

    I wonder what the root cause of the failure of this test case is. If there's a certain case where the 2D scheduling should not be used, we should include that condition into the algorithm of determining the breakpoint.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    python_tests.test_normalization.test_instance_norm_multigpu failure

    3 participants