Skip to content

[Stream lowering] collective based pipelines#4387

Merged
samnordmann merged 2 commits intomainfrom
host_irs/stream_lowering/collective_based_algos_pr
May 28, 2025
Merged

[Stream lowering] collective based pipelines#4387
samnordmann merged 2 commits intomainfrom
host_irs/stream_lowering/collective_based_algos_pr

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented May 7, 2025

Stacked on top of:

Implements stream lowering to collective based pipelines.

Test dumps:

generated with the command line:

mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
MultiDeviceStreamParallelTypeTest.Allgather

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.Allreduce

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.ReduceScatter

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
    Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 48
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.AG_matmul

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
    Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 59
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.matmul_AR

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
    Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 64
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer


MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
  T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
    Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 80
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )

HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
  T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
Outputs:
  T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})

%kernel {
T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
   = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
} // %kernel
} // %HostIrContainer

@github-actions
Copy link

github-actions bot commented May 7, 2025

Review updated until commit 8a94bdb

Description

  • Implement stream lowering for collective-based pipelines

  • Remove specific handling for MatmulOp and LinearOp in HostIrLower

  • Add new tests for Allgather, Allreduce, ReduceScatter, and Matmul operations

  • Update CMakeLists.txt to include new test file


Changes walkthrough 📝

Relevant files
Enhancement
8 files
executor.cpp
Adjust index calculation in HirAliasSelect                             
+14/-3   
lower.cpp
Remove specific handling for MatmulOp and LinearOp             
+0/-24   
lower_to_communication.cpp
Remove lowerToCollectiveBasedPipelinedGemmComm function   
+0/-172 
stream_parallel_type.cpp
Set device mesh for output in TensorSlicingCache                 
+1/-0     
fusion_definition.cpp
Remove temporary disabling of StreamParallelType pass       
+0/-4     
test_multidevice_host_ir.cpp
Remove tests for OverlapDistributedMatmulTest                       
+0/-131 
test_multidevice_pipeline.cpp
Remove optimization guard initialization                                 
+1/-5     
test_multidevice_stream_parallel_type.cpp
Add new tests for stream parallel type operations               
+337/-0 
Configuration changes
1 files
CMakeLists.txt
Include new test file in build                                                     
+1/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Issue

The new code in handle(HirAliasSelect* hir_alias_select) does not handle the case where hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis()) might be null or invalid, which could lead to a runtime error.

    info.shape_info.logical_strides,
    info.type,
    c10::nullopt,
    device,
    c10::nullopt);
expr_evaluator_.bind(tv, tensor);
Code Removal

The removal of support for MatmulOp and LinearOp in canLower might affect existing functionalities that rely on these operations. Ensure that this change is intentional and does not break any existing use cases.

    return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
  } else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
    if (!ignore_inner_resharding && isInnerResharding(expr)) {
      return false;
    }
    return ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
  }
  return false;
}
Test Removal

The removal of tests for OverlapDistributedMatmulTest might lead to a loss of coverage for important functionalities. Ensure that these tests are no longer needed or that equivalent tests are added elsewhere.

  // validate the obtained results
  at::Tensor ref_output = send_buffer_aten + (recv_peer - my_device_index);
  EXPECT_TRUE(torch::allclose(ref_output, outputs.back().as<at::Tensor>()));
}

TEST_F(MultiDeviceTest, ShareIpcMemHandles) {
  static constexpr int kTensorSize = 4;
  static constexpr int kNumRepetitions = 10;

  if (communicator_->size() < 2 || torch::cuda::device_count() < 2) {
    GTEST_SKIP() << "This test needs at least 2 GPUs and 2 ranks.";
  }

  const DeviceIdxType my_rank = communicator_->deviceId();
  const int64_t size = communicator_->size();
  const DeviceIdxType send_peer = (my_rank + 1) % size;
  const DeviceIdxType recv_peer = (size + my_rank - 1) % size;

  auto container = std::make_unique<hir::HostIrContainer>();
  FusionGuard fg(container.get());

  auto* send_tv = makeContigTensor(1, DataType::Int32);
  auto* recv_tv = makeContigTensor(1, DataType::Int32);

  auto send = IrBuilder::create<P2PCommunication>(
      P2PCommunicationType::SEND,
      send_tv,
      IrBuilder::create<Val>(send_peer),
      CommunicatorBackend::kNccl);
  auto recv = IrBuilder::create<P2PCommunication>(
      P2PCommunicationType::RECV,
      recv_tv,
      IrBuilder::create<Val>(recv_peer),
      CommunicatorBackend::kNccl);
  std::vector<P2PCommunication*> grouped_communications = {send, recv};

  ExpressionEvaluator expr_evaluator;
  IpcHandleCache ipc_handle_cache(expr_evaluator);

  auto options =
      at::TensorOptions().dtype(at::kInt).device(communicator_->device());
  auto generate_tensor = [options](int repetition, int rank) {
    return at::arange(kTensorSize, options) + (repetition + 1) * 10 +
        100 * rank;
  };
  at::Tensor recv_tensor = at::empty({kTensorSize}, options);
  at::Tensor send_tensor = at::empty({kTensorSize}, options);

  expr_evaluator.bind(send_tv, send_tensor);
  expr_evaluator.bind(recv_tv, recv_tensor);

  for (auto repetition : c10::irange(kNumRepetitions)) {
    // all ranks set `send_tensor`
    send_tensor.copy_(generate_tensor(repetition, my_rank));

    // Exchange IpcHandle on the first iteration
    ipc_handle_cache.exchangeHandles(grouped_communications);

    // RDMA put-zcopy
    const P2pIpcHandle& send_ipc_handles = ipc_handle_cache.get(send);
    NVFUSER_CUDA_RT_SAFE_CALL(cudaMemcpy(
        send_ipc_handles.peer().ptr(),
        send_ipc_handles.local().ptr(),
        send_tensor.numel() * send_tensor.element_size(),
        cudaMemcpyDeviceToDevice));

    torch::cuda::synchronize();
    communicator_->barrier();
    at::Tensor ref_recv_tensor = generate_tensor(repetition, recv_peer);
    EXPECT_TRUE(torch::allclose(recv_tensor, ref_recv_tensor))
        << "Rank " << my_rank << " failed at repetition " << repetition
        << " with recv tensor " << recv_tensor << " and ref_recv_tensor "
        << ref_recv_tensor;
  }
}

} // namespace hir

} // namespace nvfuser

@samnordmann
Copy link
Collaborator Author

!test

Base automatically changed from separate_initial_stream_sync to main May 13, 2025 16:14
samnordmann added a commit that referenced this pull request May 14, 2025
# What 

This PR is only about cleaning and refactoring, no change in the
behavior.

- We create a new base class for HIR lowering optimization pass
`hir_pass::OptimizationPass` in `host_ir/pass/optimization_pass.h`.
- We create an option `hir_lowering_logging` to control debug logging
through NVFUSER_DUMP
- We create a guard to enable/disable a pass
- We make the existing pass `StreamLowering` and `InsertDeallocations`
inherit from this pass
- We factor out converting a resharding op to a communication into a
separate pass

# Why

preparation for #4387
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/collective_based_algos_pr branch from 2b38755 to 3c250c6 Compare May 15, 2025 12:17
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested review from nsarka and wujingyue May 15, 2025 15:40
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested a review from wujingyue May 19, 2025 12:23
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise

<< "Output: " << output << "\nExpected: " << expected_output;
}

TEST_F(MultiDeviceStreamParallelTypeTest, AG_matmul) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, there are two other AG_matmul patterns to consider. They come from the backprop of Matmul_RS and have different layouts than this one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I am not sure to be aware of that. Can you elaborate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's a detailed explanation of the backpropagation of a matrix multiplication followed by a ReduceScatter operation, especially in the context of distributed training (e.g., tensor parallelism).


Setup

Let:

  • A∈Rm×kA \in \mathbb{R}^{m \times k}

  • B∈Rk×nB \in \mathbb{R}^{k \times n}

  • C=AB∈Rm×nC = A B \in \mathbb{R}^{m \times n}

Now, assume ReduceScatter is applied to split and reduce C across processes along the output feature dimension (columns), e.g., split n into n/p per rank (with p processes).

So, define:

# forward
C = matmul(A, B)  # C shape: [m, n]
C_local = reduce_scatter(C)  # each rank gets C_local of shape [m, n/p]

Goal

Compute the gradient w.r.t. A and B, i.e., ∂loss∂A,∂loss∂B\frac{\partial \text{loss}}{\partial A}, \frac{\partial \text{loss}}{\partial B}, where the backward starts from C_local.grad.


Step-by-Step Backward

1. ReduceScatter Backward

In reverse, ReduceScatter becomes AllGather:

grad_C = all_gather(grad_C_local)  # shape: [m, n]

So each rank reconstructs the full C gradient.

2. Matrix Multiplication Backward

Given:

C=AB⇒∂loss∂A=∂loss∂CBT,∂loss∂B=AT∂loss∂CC = A B \Rightarrow \frac{\partial \text{loss}}{\partial A} = \frac{\partial \text{loss}}{\partial C} B^T, \quad \frac{\partial \text{loss}}{\partial B} = A^T \frac{\partial \text{loss}}{\partial C}

Apply this to grad_C (shape [m, n]):

grad_A = grad_C @ B.T       # shape: [m, k]
grad_B = A.T @ grad_C       # shape: [k, n]

Notes on Distributed Context

  • If B is sharded column-wise (as in tensor parallelism), then:

    • Each rank holds a shard B_i of shape [k, n/p]

    • C_local = A @ B_i, and reduce-scatter sums over the p ranks

In this case:

  • Forward: sum over p partial products

  • Backward:

    • Each rank computes its local grad_C_local

    • All-gather to reconstruct full grad_C

    • Then compute grad_A and local grad_B_i

Optimization: In some frameworks (e.g., Megatron-LM, DeepSpeed), grad_B is computed locally, and then all-reduce is used to combine the result.


Summary

Operation Forward Backward
matmul C=ABC = A B ∂A=∂CBT\partial A = \partial C B^T, ∂B=AT∂C\partial B = A^T \partial C
reduce_scatter Split+sum across ranks (on columns) all_gather gradients to get full tensor

Let me know your setup (e.g., column-wise sharded B, or row-wise A) and I can tailor the equations accordingly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to add more examples in

def test_column_parallel_linear(setup_default_process_group, multidevice_test):
.

Meanwhile, the above hopefully serves as a high-level description. In summary, they'll still be allgather+matmul but different dimensions get sharded and certain dimensions are row-major vs column-major.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added

def test_row_parallel_linear(setup_default_process_group, multidevice_test):
, the forward and backward of a linear+allreduce. It's not exactly linear+ReduceScatter but hopefully close enough for you to get a clue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll take a look, thank you very much

@samnordmann samnordmann requested a review from wujingyue May 21, 2025 11:17
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

I removed ReorderShardedAxisPass from MultiDeviceExecutor, so the PR doesn't need to touch utils.cpp anymore. It should be ready to merge (AFAIU the CI failures are unrelated)
cc @wujingyue

CMakeLists.txt Outdated
${CMAKE_SOURCE_DIR}/third_party/benchmark/include
${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include
${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include
${CMAKE_SOURCE_DIR}/third_party/googletest/googlemock/include
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this change is here. Seems unrelated to this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is kind of related. This was required because I added the #include <tests/cpp/validator.h> heading in tests/cpp/multidevice.h, otherwise leading to the compilation error

  [1/5] Building CXX object CMakeFiles/nvfuser_multidevice_bench.dir/benchmarks/cpp/transformer.cpp.o
  FAILED: CMakeFiles/nvfuser_multidevice_bench.dir/benchmarks/cpp/transformer.cpp.o
[...]
  In file included from /opt/pytorch/Fuser/tests/cpp/multidevice.h:15,
                   from /opt/pytorch/Fuser/benchmarks/cpp/transformer.cpp:15:                                                                                                                                           /opt/pytorch/Fuser/tests/cpp/validator.h:10:10: fatal error: gmock/gmock-matchers.h: No such file or directory
     10 | #include <gmock/gmock-matchers.h>
        |          ^~~~~~~~~~~~~~~~~~~~~~~~
  compilation terminated.

In the meantime, I moved the header inclusion to tests/cpp/test_multidevice_stream_parallel_type.cpp only. This allows to remove the CMakeList patch -- but if we want to use those gmock matchers more extensively we should reconsider.

CMakeLists.txt Outdated
target_include_directories(nvfuser_multidevice_bench PUBLIC ${NVFUSER_ROOT})
target_link_libraries(nvfuser_multidevice_bench PRIVATE
GTest::gtest
GTest::gmock
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CMakeLists.txt Outdated

if(NOT MSVC)
target_compile_options(nvfuser_bench PRIVATE
target_compile_options(nvfuser_multidevice_bench PRIVATE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is unrelated, but a bug IIUC. I revert it here to merge the current PR asap.

auto index =
expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>();
auto indexed_id =
hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A logical domain may have reduction dimensions, which don't exist in the corresponding at::Tensor. So it's likely problematic to use the same int axis on both in()->getLogicalDomain() and input, unless in() is guaranteed to be reduction free.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right, fixed. Thanks

@samnordmann samnordmann force-pushed the host_irs/stream_lowering/collective_based_algos_pr branch from ae66b1a to c96ef5d Compare May 27, 2025 12:32
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/collective_based_algos_pr branch from c96ef5d to c7b7606 Compare May 27, 2025 12:33
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann merged commit 16ab176 into main May 28, 2025
42 of 47 checks passed
@samnordmann samnordmann deleted the host_irs/stream_lowering/collective_based_algos_pr branch May 28, 2025 13:47
samnordmann added a commit that referenced this pull request Jun 3, 2025
on top of
- #4387

# What
Add Stream lowering to Allgather p2p linear, with NCCL backend

For example: `MultiDeviceStreamParallelTypeTest.AllgatherP2p` from
`tests/cpp/test_multidevice_stream_parallel_type.cpp`:
```
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);

  const DeviceMesh mesh =
      DeviceMesh::createForNumDevices(communicator_->size());
  tv0->setDeviceMesh(mesh);
  tv1->setDeviceMesh(mesh);
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::Stream);
```
is lowered to:
```
%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 32
      Wait Communication 32
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer
```


An test with an overlapped matmul is also proposed in `AG_matmul_P2p`,
which generates the following host program:
```
%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false
)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 43
      Wait Communication 43
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx )
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer
```
nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
Stacked on top of:
- NVIDIA#4401
- NVIDIA#4402 

Implements stream lowering to collective based pipelines.


# Test dumps:
generated with the command line:
```
mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
```

```
MultiDeviceStreamParallelTypeTest.Allgather

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.Allreduce

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.ReduceScatter

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
    Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 48
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.AG_matmul

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
    Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 59
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.matmul_AR

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
    Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 64
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer


MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
  T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
    Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 80
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )

HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
  T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
Outputs:
  T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})

%kernel {
T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
   = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
} // %kernel
} // %HostIrContainer
```
nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
on top of
- NVIDIA#4387

# What
Add Stream lowering to Allgather p2p linear, with NCCL backend

For example: `MultiDeviceStreamParallelTypeTest.AllgatherP2p` from
`tests/cpp/test_multidevice_stream_parallel_type.cpp`:
```
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);

  const DeviceMesh mesh =
      DeviceMesh::createForNumDevices(communicator_->size());
  tv0->setDeviceMesh(mesh);
  tv1->setDeviceMesh(mesh);
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::Stream);
```
is lowered to:
```
%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 32
      Wait Communication 32
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer
```


An test with an overlapped matmul is also proposed in `AG_matmul_P2p`,
which generates the following host program:
```
%HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false
)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx9{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx )
    IF Manual ( StreamIdx == deviceIdx.x ):
      T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 )
      T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
         = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming )
    ELSE:
      StartCoalescing
      P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL)
      EndCoalescing 43
      Wait Communication 43
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx )
    T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants