Skip to content

[Overlap] Move initial stream sync into separate for loop#4217

Merged
samnordmann merged 3 commits intomainfrom
overlap_manual_separate_initial_for_loop
Apr 11, 2025
Merged

[Overlap] Move initial stream sync into separate for loop#4217
samnordmann merged 3 commits intomainfrom
overlap_manual_separate_initial_for_loop

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Apr 9, 2025

Context

Previous bug fix in comms/compute overlap

A recent PR #3913 fixed a bug in the comms/compute overlap algorithm, consisting of adding a stream synchronization when entering the host for-loop. Namely, currently, the host generated program reads

$ mpirun -x NVFUSER_DUMP=host_ir -np 2 python -m pytest tests/python/multidevice/test_overlap.py --only-mpi
%HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
  GetCurrentStream into Stream 0
  T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    Synchronize Stream 0
    T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
       = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
    T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
       = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
    Communication 35 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 35
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
          T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
    SetCurrentStream to Stream 0
    Synchronize Stream ( i83 % numberOfStreams )
} // %HostIrContainer

Note the line Synchronize Stream 0 at the beginning of the for-loop

Degradation of performance

However, even though it has not been verified before merging, PR #3913 degraded performances of the overlapped algo. Basically, since PR #3913, we do not observe overlapping anymore, even when using UCC/TL/NCCL:
Screenshot 2025-04-10 at 15 30 32

Performance fix in the present PR

What

The current PR fixes this performance degradation. The idea has been found incenditally and reveals some probably interesting finding.

What needs to be done is to execute all the stream synchronization in a separate host for-loop, before the host for-loop responsible for the comms/compute. The new generated program reads:

%HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
  GetCurrentStream into Stream 0
  T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    Synchronize Stream 0
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
       = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
    T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
       = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
    Communication 36 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 36
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
          T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
    SetCurrentStream to Stream 0
    Synchronize Stream ( i83 % numberOfStreams )
} // %HostIrContainer

Performance fix

The obtained nsight profile shows that we achieve perfect overlap:
Screenshot 2025-04-10 at 15 34 22

Further todo:

The current PR modifies the function lowerToCollectiveBasedPipelinedGemmComm which will be removed and replaced in #4147
We need to port the current patch there.

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Apr 9, 2025

Review updated until commit d63e992

Description

  • Separate initial stream sync into a separate for-loop

  • Improve performance with comms/compute overlap


Changes walkthrough 📝

Relevant files
Enhancement
lower.cpp
Separate initial stream sync for-loop                                       

csrc/host_ir/lower.cpp

  • Created a separate for-loop for initial stream synchronization
  • Removed initial sync from the main for-loop body
  • Added comments explaining the performance improvement
  • +25/-3   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review

    Performance Impact

    The PR introduces a separate for-loop for the initial stream synchronization. It is crucial to verify that this change does not introduce performance regressions and that the performance gains from comms/compute overlap are realized.

    auto* for_loop_initial_sync = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    
    auto* number_of_streams =
        IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
    auto* stream_index = mod(j, number_of_streams);
    auto* stream = IrBuilder::create<hir::Stream>(stream_index);
    auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);
    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    
    // the initial sync of the streams with the user's stream is done in a
    // separate for-loop for performance reasons with comms/compute overlap
    std::vector<Expr*> loop_body_initial_sync = {set_stream, initial_sync_stream};
    for (Expr* expr : loop_body_initial_sync) {
      for_loop_initial_sync->body().push_back(expr);
    }
    
    auto* for_loop = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    
    Code Duplication

    The code for creating the ForLoop object is duplicated. This could lead to maintenance issues. Consider refactoring to avoid duplication.

    auto* for_loop_initial_sync = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    
    auto* number_of_streams =
        IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
    auto* stream_index = mod(j, number_of_streams);
    auto* stream = IrBuilder::create<hir::Stream>(stream_index);
    auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);
    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    
    // the initial sync of the streams with the user's stream is done in a
    // separate for-loop for performance reasons with comms/compute overlap
    std::vector<Expr*> loop_body_initial_sync = {set_stream, initial_sync_stream};
    for (Expr* expr : loop_body_initial_sync) {
      for_loop_initial_sync->body().push_back(expr);
    }
    
    auto* for_loop = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    
    Test Coverage

    Ensure that the PR includes tests that specifically validate the behavior of the new separate for-loop for initial stream synchronization. This will help confirm that the change works as intended and does not introduce bugs.

    auto* for_loop_initial_sync = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    
    auto* number_of_streams =
        IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
    auto* stream_index = mod(j, number_of_streams);
    auto* stream = IrBuilder::create<hir::Stream>(stream_index);
    auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);
    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    
    // the initial sync of the streams with the user's stream is done in a
    // separate for-loop for performance reasons with comms/compute overlap
    std::vector<Expr*> loop_body_initial_sync = {set_stream, initial_sync_stream};
    for (Expr* expr : loop_body_initial_sync) {
      for_loop_initial_sync->body().push_back(expr);
    }
    
    auto* for_loop = IrBuilder::create<ForLoop>(
        stream_axis,
        /*index=*/j,
        start,
        stop,
        step,
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    

    @samnordmann samnordmann requested review from nsarka and wujingyue April 10, 2025 13:40
    @samnordmann
    Copy link
    Collaborator Author

    !test

    @samnordmann samnordmann merged commit f911dca into main Apr 11, 2025
    42 of 43 checks passed
    @samnordmann samnordmann deleted the overlap_manual_separate_initial_for_loop branch April 11, 2025 14:02
    @samnordmann
    Copy link
    Collaborator Author

    To better explain the paradox shown in this pr.

    Here, we observe that depending on the order in which we emit device instruction, the behavior changes even though the programs are the same. Indeed, let Stream 0 be the user's stream, and imagine the user requests a program like

    kernel A, kernel B = “overlapped comm+compute”, kernel C
    

    The generated host program will consist of the instructions:

    • On stream 0:
      • Kernel A
      • Wait on all Streams i
      • Kernel C
    • for all Stream i, where i is a running index ranging from 1 to S:
      1i. Stream i waits on stream 0 to complete kernel A
      2i. Stream i executes Kernel B over tensor's slice i

    Regarding the instructions for Stream i, we can consider the two following orders to emit the instructions:

    1. "factorized": Loop over i to emit (1i + 2i) (one single host for-loop)
    2. "developped": Loop over i to emit 1i, then loop over i to emit 2i (two separate host for-loops)

    While those two variants should be equivalent, we see in the PR that only the "developped" way exhibits the desired behavior.

    @kevinstephano @naoyam @wujingyue
    If you think this should be sent to a cuda runtime expert, and if you have a contact in mind.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants