[HostIr] Refactor hir optimization pass#4401
Conversation
|
!test |
|
Review updated until commit f76fc10 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test |
7511c7c to
fdd56a1
Compare
|
!test |
| const HostIrLowerParams& params = HostIrLowerParams()) | ||
| : params_(params) {} | ||
|
|
||
| static std::vector<Expr*> ConvertSingleOpToCommunication( |
There was a problem hiding this comment.
For the main stack, implemented in fusion_kernel_runtime.cpp, I consider this to be part of the fusion-IR-to-host-IR lowering (aka host IR lowering) instead of a host-IR-to-host-IR transformation pass (aka host IR pass).
Therefore, I prefer leaving this function in lower.h so it can be used in both host IR lowering and host IR passes.
There was a problem hiding this comment.
For the main stack, implemented in fusion_kernel_runtime.cpp, I consider this to be part of the fusion-IR-to-host-IR lowering (aka host IR lowering) instead of a host-IR-to-host-IR transformation pass (aka host IR pass)
Let me motivate the reason why I moved it:
- I personally I find it cleaner and more logical to keep this function separated, and close to the pass that uses it.
- The main stack should eventually use the pass and not the function. The function should eventually not even be exposed
- it is going to be a host-IR-to-host-IR pass after stream lowering will be integrated. To see it as part of lowering and not host IR pass is just because of the current implementation but is not necessary.
Therefore, I prefer leaving this function in lower.h so it can be used in both host IR lowering and host IR passes.
Would you be ok at least to put it a separate file host_ir/lower_to_communication.h|cpp? (as was named in the past)
Otherwise it would means that HostIrLower::lower from host_ir/lower.h calls the pass host_ir/pass/convert_op_to_communication.h which in turns calls HostIrLower:: ConvertSingleOpToCommunication from host_ir/lower.h, which is kind of cyclical and therefore misleading.
There was a problem hiding this comment.
Would you be ok at least to put it a separate file host_ir/lower_to_communication.h|cpp? (as was named in the past)
SGTM!
| virtual ~OptimizationPass() = default; | ||
|
|
||
| protected: | ||
| virtual void passImplementation(Fusion* fusion) { |
There was a problem hiding this comment.
This defeats the purpose of "recurring template pattern". I realized
might have misled you and is probably (I'll give that a quick shot) unnecessary. Seefor an example of how to avoid virtual methods.
There was a problem hiding this comment.
This defeats the purpose of "recurring template pattern"
you're right, thanks. My problem is that I do not know how to deal with that ConvertOpToCommunication pass (and others in the future) has a non-trivial constructor. Maybe I am just missing something.
what about not using recurring template pattern but and keep the runPass in the base class and passImplementation as a private member of the derived class ?
Let me know what you suggest
There was a problem hiding this comment.
Would the following work for non-trivial constructors?
template <typename Derived>
class Base {
public:
void do_something() {
static_cast<Derived*>(this)->implementation();
}
};
class Derived : public Base<Derived> {
public:
Derived(int x, std::string name) : x_(x), name_(std::move(name)) {}
void implementation() {
std::cout << "x: " << x_ << ", name: " << name_ << "\n";
}
private:
int x_;
std::string name_;
};
(Again, it uses no virtual methods here)
Alternatively, if you don't plan to use CRTP for your host IR passes, please build hir::OptimizationPass as a non-template class and simplify. If you choose to do this, keep insertDeallocations as a helper function and call it from your hir::OptimizationPass. This way, I can wrap it differently if/when I need to.
There was a problem hiding this comment.
Thanks! I was not so familiar with CRTP, now it is clearer. I have removed the virtual method. Let me know.
|
!test |
|
!test |
|
!test |
Stacked on top of: - #4401 - #4402 Implements stream lowering to collective based pipelines. # Test dumps: generated with the command line: ``` mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.* ``` ``` MultiDeviceStreamParallelTypeTest.Allgather %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false) Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.Allreduce %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false) Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.ReduceScatter %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false) Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 48 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.AG_matmul %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false) T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 59 T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx ) T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_AR %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false) T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx ) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false) Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 64 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) : PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }) T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx ) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false) Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 80 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs: T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) Outputs: T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) %kernel { T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} ) } // %kernel } // %HostIrContainer ```
Stacked on top of: - NVIDIA#4401 - NVIDIA#4402 Implements stream lowering to collective based pipelines. # Test dumps: generated with the command line: ``` mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.* ``` ``` MultiDeviceStreamParallelTypeTest.Allgather %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false) Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.Allreduce %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false) Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.ReduceScatter %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false) Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 48 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.AG_matmul %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false) T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 59 T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx ) T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_AR %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false) T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx ) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false) Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 64 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) : PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }) T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx ) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false) Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 80 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs: T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) Outputs: T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) %kernel { T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} ) } // %kernel } // %HostIrContainer ```
What
This PR is only about cleaning and refactoring, no change in the behavior.
hir_pass::OptimizationPassinhost_ir/pass/optimization_pass.h.hir_lowering_loggingto control debug logging through NVFUSER_DUMPStreamLoweringandInsertDeallocationsinherit from this passWhy
preparation for #4387