[Stream lowering] collective based pipelines#4387
Conversation
|
Review updated until commit 8a94bdb Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
1bb8272 to
4629a33
Compare
|
!test |
# What This PR is only about cleaning and refactoring, no change in the behavior. - We create a new base class for HIR lowering optimization pass `hir_pass::OptimizationPass` in `host_ir/pass/optimization_pass.h`. - We create an option `hir_lowering_logging` to control debug logging through NVFUSER_DUMP - We create a guard to enable/disable a pass - We make the existing pass `StreamLowering` and `InsertDeallocations` inherit from this pass - We factor out converting a resharding op to a communication into a separate pass # Why preparation for #4387
2b38755 to
3c250c6
Compare
|
!test |
|
!test |
| << "Output: " << output << "\nExpected: " << expected_output; | ||
| } | ||
|
|
||
| TEST_F(MultiDeviceStreamParallelTypeTest, AG_matmul) { |
There was a problem hiding this comment.
FYI, there are two other AG_matmul patterns to consider. They come from the backprop of Matmul_RS and have different layouts than this one.
There was a problem hiding this comment.
Interesting, I am not sure to be aware of that. Can you elaborate?
There was a problem hiding this comment.
Here's a detailed explanation of the backpropagation of a matrix multiplication followed by a ReduceScatter operation, especially in the context of distributed training (e.g., tensor parallelism).
Setup
Let:
-
A∈Rm×kA \in \mathbb{R}^{m \times k}
-
B∈Rk×nB \in \mathbb{R}^{k \times n}
-
C=AB∈Rm×nC = A B \in \mathbb{R}^{m \times n}
Now, assume ReduceScatter is applied to split and reduce C across processes along the output feature dimension (columns), e.g., split n into n/p per rank (with p processes).
So, define:
# forward
C = matmul(A, B) # C shape: [m, n]
C_local = reduce_scatter(C) # each rank gets C_local of shape [m, n/p]
Goal
Compute the gradient w.r.t. A and B, i.e., ∂loss∂A,∂loss∂B\frac{\partial \text{loss}}{\partial A}, \frac{\partial \text{loss}}{\partial B}, where the backward starts from C_local.grad.
Step-by-Step Backward
1. ReduceScatter Backward
In reverse, ReduceScatter becomes AllGather:
grad_C = all_gather(grad_C_local) # shape: [m, n]
So each rank reconstructs the full C gradient.
2. Matrix Multiplication Backward
Given:
C=AB⇒∂loss∂A=∂loss∂CBT,∂loss∂B=AT∂loss∂CC = A B \Rightarrow \frac{\partial \text{loss}}{\partial A} = \frac{\partial \text{loss}}{\partial C} B^T, \quad \frac{\partial \text{loss}}{\partial B} = A^T \frac{\partial \text{loss}}{\partial C}Apply this to grad_C (shape [m, n]):
grad_A = grad_C @ B.T # shape: [m, k]
grad_B = A.T @ grad_C # shape: [k, n]
Notes on Distributed Context
-
If
Bis sharded column-wise (as in tensor parallelism), then:-
Each rank holds a shard
B_iof shape[k, n/p] -
C_local = A @ B_i, and reduce-scatter sums over thepranks
-
In this case:
-
Forward: sum over
ppartial products -
Backward:
-
Each rank computes its local
grad_C_local -
All-gather to reconstruct full
grad_C -
Then compute
grad_Aand localgrad_B_i
-
Optimization: In some frameworks (e.g., Megatron-LM, DeepSpeed), grad_B is computed locally, and then all-reduce is used to combine the result.
Summary
| Operation | Forward | Backward |
|---|---|---|
| matmul | C=ABC = A B | ∂A=∂CBT\partial A = \partial C B^T, ∂B=AT∂C\partial B = A^T \partial C |
| reduce_scatter | Split+sum across ranks (on columns) | all_gather gradients to get full tensor |
Let me know your setup (e.g., column-wise sharded B, or row-wise A) and I can tailor the equations accordingly.
There was a problem hiding this comment.
I'll try to add more examples in
.Meanwhile, the above hopefully serves as a high-level description. In summary, they'll still be allgather+matmul but different dimensions get sharded and certain dimensions are row-major vs column-major.
There was a problem hiding this comment.
I added
, the forward and backward of a linear+allreduce. It's not exactly linear+ReduceScatter but hopefully close enough for you to get a clue.There was a problem hiding this comment.
I'll take a look, thank you very much
|
!test |
|
I removed |
CMakeLists.txt
Outdated
| ${CMAKE_SOURCE_DIR}/third_party/benchmark/include | ||
| ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include | ||
| ${CMAKE_SOURCE_DIR}/third_party/googletest/googletest/include | ||
| ${CMAKE_SOURCE_DIR}/third_party/googletest/googlemock/include |
There was a problem hiding this comment.
I'm not sure why this change is here. Seems unrelated to this PR.
There was a problem hiding this comment.
It is kind of related. This was required because I added the #include <tests/cpp/validator.h> heading in tests/cpp/multidevice.h, otherwise leading to the compilation error
[1/5] Building CXX object CMakeFiles/nvfuser_multidevice_bench.dir/benchmarks/cpp/transformer.cpp.o
FAILED: CMakeFiles/nvfuser_multidevice_bench.dir/benchmarks/cpp/transformer.cpp.o
[...]
In file included from /opt/pytorch/Fuser/tests/cpp/multidevice.h:15,
from /opt/pytorch/Fuser/benchmarks/cpp/transformer.cpp:15: /opt/pytorch/Fuser/tests/cpp/validator.h:10:10: fatal error: gmock/gmock-matchers.h: No such file or directory
10 | #include <gmock/gmock-matchers.h>
| ^~~~~~~~~~~~~~~~~~~~~~~~
compilation terminated.
In the meantime, I moved the header inclusion to tests/cpp/test_multidevice_stream_parallel_type.cpp only. This allows to remove the CMakeList patch -- but if we want to use those gmock matchers more extensively we should reconsider.
CMakeLists.txt
Outdated
| target_include_directories(nvfuser_multidevice_bench PUBLIC ${NVFUSER_ROOT}) | ||
| target_link_libraries(nvfuser_multidevice_bench PRIVATE | ||
| GTest::gtest | ||
| GTest::gmock |
There was a problem hiding this comment.
CMakeLists.txt
Outdated
|
|
||
| if(NOT MSVC) | ||
| target_compile_options(nvfuser_bench PRIVATE | ||
| target_compile_options(nvfuser_multidevice_bench PRIVATE |
There was a problem hiding this comment.
this is unrelated, but a bug IIUC. I revert it here to merge the current PR asap.
| auto index = | ||
| expr_evaluator_.evaluate(hir_alias_select->index()).as<int64_t>(); | ||
| auto indexed_id = | ||
| hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis()); |
There was a problem hiding this comment.
A logical domain may have reduction dimensions, which don't exist in the corresponding at::Tensor. So it's likely problematic to use the same int axis on both in()->getLogicalDomain() and input, unless in() is guaranteed to be reduction free.
There was a problem hiding this comment.
right, fixed. Thanks
ae66b1a to
c96ef5d
Compare
c96ef5d to
c7b7606
Compare
|
!test |
|
!test |
on top of - #4387 # What Add Stream lowering to Allgather p2p linear, with NCCL backend For example: `MultiDeviceStreamParallelTypeTest.AllgatherP2p` from `tests/cpp/test_multidevice_stream_parallel_type.cpp`: ``` TensorView* tv0 = makeContigTensor(2); TensorView* tv1 = set(tv0); fusion->addInput(tv0); fusion->addOutput(tv1); const DeviceMesh mesh = DeviceMesh::createForNumDevices(communicator_->size()); tv0->setDeviceMesh(mesh); tv1->setDeviceMesh(mesh); tv0->axis(0)->parallelize(ParallelType::DIDx); tv1->axis(0)->parallelize(ParallelType::Stream); ``` is lowered to: ``` %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx ) IF Manual ( StreamIdx == deviceIdx.x ): T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 ) T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming ) ELSE: StartCoalescing P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) EndCoalescing 32 Wait Communication 32 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer ``` An test with an overlapped matmul is also proposed in `AG_matmul_P2p`, which generates the following host program: ``` %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false) T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false ) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx9{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx9{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx ) IF Manual ( StreamIdx == deviceIdx.x ): T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 ) T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming ) ELSE: StartCoalescing P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) EndCoalescing 43 Wait Communication 43 T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx ) T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer ```
Stacked on top of: - NVIDIA#4401 - NVIDIA#4402 Implements stream lowering to collective based pipelines. # Test dumps: generated with the command line: ``` mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.* ``` ``` MultiDeviceStreamParallelTypeTest.Allgather %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx ) T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false) Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.Allreduce %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx3{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx ) T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false) Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 39 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.ReduceScatter %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx4{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx ) T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false) Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 48 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.AG_matmul %HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false) T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx11{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx ) T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false) Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 59 T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx ) T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_AR %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false) T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx7{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx ) T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx ) T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false) Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 64 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast %HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) : PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }) T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx13{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx ) T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx ) T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx ) T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false) Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL) Wait Communication 80 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs: T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) Outputs: T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) %kernel { T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} ) } // %kernel } // %HostIrContainer ```
on top of - NVIDIA#4387 # What Add Stream lowering to Allgather p2p linear, with NCCL backend For example: `MultiDeviceStreamParallelTypeTest.AllgatherP2p` from `tests/cpp/test_multidevice_stream_parallel_type.cpp`: ``` TensorView* tv0 = makeContigTensor(2); TensorView* tv1 = set(tv0); fusion->addInput(tv0); fusion->addOutput(tv1); const DeviceMesh mesh = DeviceMesh::createForNumDevices(communicator_->size()); tv0->setDeviceMesh(mesh); tv1->setDeviceMesh(mesh); tv0->axis(0)->parallelize(ParallelType::DIDx); tv1->axis(0)->parallelize(ParallelType::Stream); ``` is lowered to: ``` %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx2{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx ) IF Manual ( StreamIdx == deviceIdx.x ): T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 ) T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = Set( T2_l_float[iS4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming ) ELSE: StartCoalescing P2PCommunication 30 (type=recv, buffer=T3_l_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) P2PCommunication 31 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) EndCoalescing 32 Wait Communication 32 SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer ``` An test with an overlapped matmul is also proposed in `AG_matmul_P2p`, which generates the following host program: ``` %HostIrContainer { (T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) : T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i3 ), zero_init=false, resets_to_zero=false) T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i2 ) * i5 ), zero_init=false, resets_to_zero=false ) GetCurrentStream into Stream 0 FOR StreamIdx in iStreamIdx9{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) Synchronize Stream 0 FOR StreamIdx in iStreamIdx9{i0}: SetCurrentStream to Stream ( StreamIdx % numberOfStreams ) T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T3_g_float[iStreamIdx9{i0}, iS10{i2}, iS11{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx9{i0}, index = StreamIdx ) IF Manual ( StreamIdx == deviceIdx.x ): T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = ideviceIdx.x0{i0}, index = 0 ) T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = Set( T4_l_float[iS12{i2}, iS13{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), cache_op=Streaming ) ELSE: StartCoalescing P2PCommunication 41 (type=recv, buffer=T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) P2PCommunication 42 (type=send, buffer=T0_g_float[ideviceIdx.x0{i0}, iS1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), peer=StreamIdx, backend=NCCL) EndCoalescing 43 Wait Communication 43 T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = HirAliasSelect( T2_g_float[iStreamIdx5{i0}, iS6{i2}, iS7{i5}, rS8{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx5{i0}, index = StreamIdx ) T6_l_float[iS16{i2}, iS17{i5}, rS18{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = matmul(T5_l_float[iS14{i2}, iS15{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})) SetCurrentStream to Stream 0 Synchronize Stream ( StreamIdx % numberOfStreams ) } // %HostIrContainer ```
Stacked on top of:
Implements stream lowering to collective based pipelines.
Test dumps:
generated with the command line: