Skip to content

[HostIr Lowering] Add sync stream at start of for-loop#3913

Merged
samnordmann merged 1 commit intomainfrom
fix_stream_sync_start_forloop
Feb 18, 2025
Merged

[HostIr Lowering] Add sync stream at start of for-loop#3913
samnordmann merged 1 commit intomainfrom
fix_stream_sync_start_forloop

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Feb 17, 2025

Fixes a bug

@github-actions
Copy link

Description

  • Add initial stream synchronization at for-loop start

  • Ensure proper stream management in collective-based pipelined GEMM communication


Changes walkthrough 📝

Relevant files
Enhancement
lower.cpp
Add initial stream synchronization                                             

csrc/host_ir/lower.cpp

  • Added initial stream synchronization at the start of the for-loop
  • Ensured proper stream management by setting and synchronizing streams
  • +3/-0     

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Possible Issue

    The addition of initial_sync_stream at the start of the for-loop might cause unnecessary synchronization, potentially impacting performance negatively.

    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    Code Clarity

    The purpose of initial_sync_stream is not clear from the PR description. Adding comments to explain why this synchronization is necessary would improve code readability.

    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    Performance Impact

    The impact of adding initial_sync_stream on performance should be evaluated and documented. Performance metrics before and after this change are needed to justify the change.

    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);

    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann merged commit 55a0ab2 into main Feb 18, 2025
    57 of 61 checks passed
    @samnordmann samnordmann deleted the fix_stream_sync_start_forloop branch February 18, 2025 09:07
    samnordmann added a commit that referenced this pull request Apr 11, 2025
    # Context
    
    ## Previous bug fix in comms/compute overlap
    
    A recent PR #3913 fixed a bug in the
    comms/compute overlap algorithm, consisting of adding a stream
    synchronization when entering the host for-loop. Namely, currently, the
    host generated program reads
    ```
    $ mpirun -x NVFUSER_DUMP=host_ir -np 2 python -m pytest tests/python/multidevice/test_overlap.py --only-mpi
    ```
    ```
    %HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
      GetCurrentStream into Stream 0
      T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      FOR i83 in iS0{8}:
        SetCurrentStream to Stream ( i83 % numberOfStreams )
        Synchronize Stream 0
        T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
           = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
        T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
           = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
        Communication 35 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
        Wait Communication 35
        T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
           = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
        T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
           = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                    T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
              T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
        SetCurrentStream to Stream 0
        Synchronize Stream ( i83 % numberOfStreams )
    } // %HostIrContainer
    ```
    Note the line `Synchronize Stream 0` at the beginning of the for-loop
    
    
    ## Degradation of performance
    
    However, even though it has not been verified before merging, PR
    #3913 degraded performances of the
    overlapped algo. Basically, since PR #3913, we do not observe
    overlapping anymore, even when using UCC/TL/NCCL:
    <img width="1224" alt="Screenshot 2025-04-10 at 15 30 32"
    src="https://github.com/user-attachments/assets/e6e4fbcd-f5de-47b6-a05d-403dbc06ca74"
    />
    
    
    # Performance fix in the present PR
    
    ## What
    The current PR fixes this performance degradation. The idea has been
    found incenditally and reveals some probably interesting finding.
    
    What needs to be done is to execute all the stream synchronization in a
    separate host for-loop, before the host for-loop responsible for the
    comms/compute. The new generated program reads:
    ```
    %HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
      GetCurrentStream into Stream 0
      T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
      FOR i83 in iS0{8}:
        SetCurrentStream to Stream ( i83 % numberOfStreams )
        Synchronize Stream 0
      FOR i83 in iS0{8}:
        SetCurrentStream to Stream ( i83 % numberOfStreams )
        T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
           = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
        T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
           = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
        Communication 36 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
        Wait Communication 36
        T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
           = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
        T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
           = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                    T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
              T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
        SetCurrentStream to Stream 0
        Synchronize Stream ( i83 % numberOfStreams )
    } // %HostIrContainer
    ```
    
    ## Performance fix
    The obtained nsight profile shows that we achieve perfect overlap:
    <img width="1471" alt="Screenshot 2025-04-10 at 15 34 22"
    src="https://github.com/user-attachments/assets/fe1d6242-c518-42c5-93e9-b2d1e78945a4"
    />
    
    ## Further todo:
    
    The current PR modifies the function
    `lowerToCollectiveBasedPipelinedGemmComm` which will be removed and
    replaced in #4147
    We need to port the current patch there.
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants