Skip to content

[Stream Lowering] Move initial stream sync to separate for-loop#4402

Merged
samnordmann merged 1 commit intomainfrom
separate_initial_stream_sync
May 13, 2025
Merged

[Stream Lowering] Move initial stream sync to separate for-loop#4402
samnordmann merged 1 commit intomainfrom
separate_initial_stream_sync

Conversation

@samnordmann
Copy link
Collaborator

port #4217 to stream lowering

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented May 9, 2025

Description

  • Moved initial stream sync to separate for-loop

  • Updated tests to reflect new structure

  • Ensured stream management is correctly handled


Changes walkthrough 📝

Relevant files
Enhancement
stream_parallel_type.cpp
Refactor stream sync to separate for-loop                               

csrc/host_ir/pass/stream_parallel_type.cpp

  • Created a new for-loop for getting the current stream
  • Moved initial stream sync operations to new for-loop
  • Updated return value to include new top-level expressions
  • +29/-7   
    Tests
    test_host_ir_stream_lowering.cpp
    Update tests for new stream sync structure                             

    tests/cpp/test_host_ir_stream_lowering.cpp

  • Updated test expectations to account for new for-loop structure
  • Increased expected number of top-level expressions
  • +83/-43 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Stream Management

    The new for-loop for getting the current stream (for_loop_initial_sync) is created and added to the top-level expressions. Ensure that this new loop does not introduce unnecessary overhead or affect the performance negatively.

    auto* for_loop_initial_sync = IrBuilder::create<ForLoop>(
        for_loop->iterDomain(),
        for_loop->index(),
        for_loop->start(),
        for_loop->stop(),
        for_loop->step(),
        /*vectorize=*/false,
        /*vectorize_shift=*/nullptr,
        /*unroll_required=*/false,
        CircularBufferLoopStage::NotApplicable,
        /*circular_buffer_loop_stage_depth=*/0);
    new_top_level_exprs.push_back(for_loop_initial_sync);
    
    Test Coverage

    Verify that the new test cases cover all possible scenarios and edge cases, especially those involving the new GetCurrentStream and the additional for-loop.

    tv0->setMemoryType(MemoryType::Global);
    tv1->setMemoryType(MemoryType::Global);
    tv1->axis(0)->parallelize(ParallelType::Stream);
    
    preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());
    
    EXPECT_EQ(hic->topLevelExprs().size(), 4);
    EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
    EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<hir::GetCurrentStream>());
    EXPECT_TRUE(hic->topLevelExprs().at(2)->isA<ForLoop>());
    EXPECT_TRUE(hic->topLevelExprs().at(3)->isA<ForLoop>());
    
    Test Assertions

    Ensure that the test assertions accurately reflect the expected changes in the top-level expressions after the stream management modifications.

    EXPECT_EQ(hic->topLevelExprs().size(), 4);
    EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
    EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<hir::GetCurrentStream>());
    EXPECT_TRUE(hic->topLevelExprs().at(2)->isA<ForLoop>());
    EXPECT_TRUE(hic->topLevelExprs().at(3)->isA<ForLoop>());
    


    hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
    EXPECT_EQ(container->topLevelExprs().size(), 2);
    EXPECT_EQ(container->topLevelExprs().size(), 4);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    FYI, #4427 will allow you to simplify the checks here.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks! That's useful. I will leave it like this for now but I'll keep that in mind and might change it later.

    @samnordmann samnordmann merged commit 21a7095 into main May 13, 2025
    60 checks passed
    @samnordmann samnordmann deleted the separate_initial_stream_sync branch May 13, 2025 16:14
    samnordmann added a commit that referenced this pull request May 28, 2025
    Stacked on top of:
    - #4401
    - #4402 
    
    Implements stream lowering to collective based pipelines.
    
    
    # Test dumps:
    generated with the command line:
    ```
    mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
    ```
    
    ```
    MultiDeviceStreamParallelTypeTest.Allgather
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
        T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
        Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 39
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.Allreduce
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx3{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx3{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
        T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
        Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 39
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.ReduceScatter
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx4{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx4{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
        T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
        Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 48
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.AG_matmul
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
      T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx11{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx11{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
        T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
        Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 59
        T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
        T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.matmul_AR
    
    %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
      T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx7{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx7{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
        T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
        T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
        Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 64
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast
    
    %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
      T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
      T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx13{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx13{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
        T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
        T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
        T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
        Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 80
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    
    HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
      T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
    Outputs:
      T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
    
    %kernel {
    T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
    } // %kernel
    } // %HostIrContainer
    ```
    nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
    Stacked on top of:
    - NVIDIA#4401
    - NVIDIA#4402 
    
    Implements stream lowering to collective based pipelines.
    
    
    # Test dumps:
    generated with the command line:
    ```
    mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
    ```
    
    ```
    MultiDeviceStreamParallelTypeTest.Allgather
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx2{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
        T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
        Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 39
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.Allreduce
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx3{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx3{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
        T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
        Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 39
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.ReduceScatter
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx4{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx4{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
        T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
        Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 48
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.AG_matmul
    
    %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
      T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx11{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx11{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
        T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
        Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 59
        T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
        T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    
    MultiDeviceStreamParallelTypeTest.matmul_AR
    
    %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
      T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx7{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx7{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
        T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
        T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
        Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 64
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    } // %HostIrContainer
    
    
    MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast
    
    %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
      PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
      T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
      T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
      GetCurrentStream into Stream 0
      FOR StreamIdx in iStreamIdx13{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        Synchronize Stream 0
      FOR StreamIdx in iStreamIdx13{i0}:
        SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
        T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
        T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
        T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
        T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
        T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
           = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
        T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
        Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
        Wait Communication 80
        SetCurrentStream to Stream 0
        Synchronize Stream ( StreamIdx % numberOfStreams )
    
    HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
      T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
    Outputs:
      T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
    
    %kernel {
    T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
    } // %kernel
    } // %HostIrContainer
    ```
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants