Scheduling/lowering for serial grid reduction#1456
Conversation
3b5be8e to
a89be48
Compare
csrc/device_lower/lower2device.cpp
Outdated
| {"UnrollPass", UnrollPass::runPass}, | ||
| {"processMisalignedVectorization", processMisalignedVectorization}, | ||
| // NOTE: serial GridReduction introduced here. New syncs can be | ||
| // introduced here which could impact smem reuse |
There was a problem hiding this comment.
We should probably factor out the logic for detecting serial grid reductions and insert those syncs much earlier than this point, so that they are available to reuseMemoryAllocations. This would let us re-use the wait sync instead of inserting a new one when we re-use prologue memory for epilogue.
b2fae7c to
cd1d8b2
Compare
Will revisit once sync pass is done, when we have a TensorIndex
c587404 to
ebef797
Compare
Still missing allocation/indexing of work buffer
I need to replay leaf transforms, then get index.
Codegen is now like
```c++
// Allocate global tensor T5
reduction::serialReductionStep(
T3[0LL],
T2[(i14 + i18)],
0.000000000e+00f,
T5[((((((((((((nvfuser_index_t)blockIdx.x) * 8LL) + ((nvfuser_index_t)blockIdx.y)) * 4LL) + i13) * 8LL) + (i18 + nvfuser_zero)) * 4LL) + ((nvfuser_index_t)threadIdx.y)) * 32LL) + ((nvfuser_index_t)threadIdx.x))],
[](float &a, float b) { a = a + b; },
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
true,
true);
```
This looks OK, although it will get a little better with hoisting. This
compiles, but I get an error in `runFusion`:
```
C++ exception with description "Expected T5_g[ iblockIdx.x59{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(262144, 32) ), 4) ), 8) ), 4) ), 8) )}, iblockIdx.y60{8}, ithreadIdx.y54{4}, ithreadIdx.x52{32}, iS58{4}, iS56{8}, rblockIdx.z49{5} ] to be bound to a tensor of rank 1, but got a tensor of rank 6
Exception raised from validateValWithConcreteValue at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:38 (most recent call first):
```
This is happening when binding inputs I believe.
Fixes execution error. Test passes!
| NVF_ERROR(buffer->as<TensorView>()->getMemoryType() == memory_type); | ||
| const auto domain = buffer->as<TensorView>()->domain(); | ||
| for (auto axis : domain->noReductions()) { | ||
| for (auto axis : TensorDomain::noReductions(domain->maybeAllocation())) { |
There was a problem hiding this comment.
This was done to work around the error described in 910ff09
There was a problem hiding this comment.
Can you give more details? Does T5 in the example in the commit have an allocation domain?
|
|
||
| namespace { | ||
|
|
||
| Val* getGridSyncBufferSize(const ParallelTypeBitmap& ptb) { |
There was a problem hiding this comment.
Moved to lower_utils::
Generated kernel now looks like
```c++
// Allocate global tensor T4
grid_sync::blockSerializeWait<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
#pragma unroll
for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
nvfuser_index_t i14;
i14 = 8LL * i13;
nvfuser_index_t i15;
i15 = 2048LL * i13;
nvfuser_index_t i16;
i16 = i4 + i15;
nvfuser_index_t i17;
i17 = -i15;
#pragma unroll
for(nvfuser_index_t i18 = 0; i18 < 8LL; ++i18) {
nvfuser_index_t i19;
i19 = 256LL * (i18 + nvfuser_zero);
nvfuser_index_t i20;
i20 = i16 + i19;
float T3[1LL];
T3[0LL] = 0.000000000e+00f;
// Allocate global tensor T5
reduction::serialReductionStep(
T3[0LL],
T2[(i14 + i18)],
0.000000000e+00f,
T5[i20],
[](float &a, float b) { a = a + b; },
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
true,
true);
if ((b6 && (i5 < (i17 - i19)))) {
T1[i20]
= T3[0LL];
}
}
}
NVFUSER_UPDATE_MAGIC_ZERO;
grid_sync::blockSerializeRelease<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
```
Note that the index `i20` matches the output `T1`. This is what we need
to reclaim `T1` in a later PR; it will still be a challenge in that work
to exact map between `T5` and `T3` in order to get `T1` and `T5` exact
mapped...
|
the --- main 2023-12-13 11:32:30.706530912 -0500
+++ thispr 2023-12-13 11:36:04.551566798 -0500
@@ -1,40 +1,40 @@
-%kernel_math { [7/142]
+%kernel_math {
T1_l[ rS2{i0}, iS3{i2} ]
= reduction( T0_g[ iS0{i0}, iS1{i2} ], op = add, initial value = float(0), allreduce = false )
T2_l[ iS4{i2} ]
= Set( T1_l[ rS2{i0}, iS3{i2} ], cache_op=Streaming )
T3_g[ rS5{i2} ]
= reduction( T2_l[ iS4{i2} ], op = add, initial value = float(0), allreduce = false )
-T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ]
- = reduction( T4_g[ iS6{i13}, iS7{i14}, iS8{i15} ], op = add, initial value = float(0), allreduce = false )
-T6_l[ iS12{i14}, iS13{i15} ]
- = Set( T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ], cache_op=Streaming )
-T7_l[ iS14{i14}, iS15{i15} ]
- = Set( T6_l[ iS12{i14}, iS13{i15} ], cache_op=Streaming )
-T8_l[ rS16{i14}, iS17{i15} ]
- = reduction( T7_l[ iS14{i14}, iS15{i15} ], op = add, initial value = float(0), allreduce = false )
-T14_l[ iS28{i15} ]
- = Set( T8_l[ rS16{i14}, iS17{i15} ], cache_op=Streaming )
-T9_l[ iS18{i14}, iS19{i15} ]
- = Set( T5_l[ rS9{i13}, iS10{i14}, iS11{i15} ], cache_op=Streaming )
-T10_l[ iS20{i14}, iS21{i15} ]
- = Set( T9_l[ iS18{i14}, iS19{i15} ], cache_op=Streaming )
-T11_l[ rS22{i14}, iS23{i15} ]
- = reduction( T10_l[ iS20{i14}, iS21{i15} ], op = add, initial value = float(0), allreduce = false )
-T15_l[ iS29{i15} ]
- = Set( T11_l[ rS22{i14}, iS23{i15} ], cache_op=Streaming )
-T17_l[ iS31{i15} ]
- = T14_l[ iS28{i15} ]
- + T15_l[ iS29{i15} ];
-T12_l[ iS24{i14}, iS25{i15} ]
- = Set( T7_l[ iS14{i14}, iS15{i15} ], cache_op=Streaming )
-T13_g[ rS26{i14}, iS27{i15} ]
- = reduction( T12_l[ iS24{i14}, iS25{i15} ], op = add, initial value = float(0), allreduce = false )
-T16_l[ iS30{i15} ]
- = Set( T13_g[ rS26{i14}, iS27{i15} ], cache_op=Streaming )
-T18_l[ iS32{i15} ]
- = T17_l[ iS31{i15} ]
- + T16_l[ iS30{i15} ];
-T19_g[ rS33{i15} ]
- = reduction( T18_l[ iS32{i15} ], op = add, initial value = float(0), allreduce = false )
+T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ]
+ = reduction( T4_g[ iS6{i15}, iS7{i16}, iS8{i17} ], op = add, initial value = float(0), allreduce = false )
+T6_l[ iS12{i16}, iS13{i17} ]
+ = Set( T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ], cache_op=Streaming )
+T7_l[ iS14{i16}, iS15{i17} ]
+ = Set( T6_l[ iS12{i16}, iS13{i17} ], cache_op=Streaming )
+T8_l[ rS16{i16}, iS17{i17} ]
+ = reduction( T7_l[ iS14{i16}, iS15{i17} ], op = add, initial value = float(0), allreduce = false )
+T14_l[ iS28{i17} ]
+ = Set( T8_l[ rS16{i16}, iS17{i17} ], cache_op=Streaming )
+T9_l[ iS18{i16}, iS19{i17} ]
+ = Set( T5_l[ rS9{i15}, iS10{i16}, iS11{i17} ], cache_op=Streaming )
+T10_l[ iS20{i16}, iS21{i17} ]
+ = Set( T9_l[ iS18{i16}, iS19{i17} ], cache_op=Streaming )
+T11_l[ rS22{i16}, iS23{i17} ]
+ = reduction( T10_l[ iS20{i16}, iS21{i17} ], op = add, initial value = float(0), allreduce = false )
+T15_l[ iS29{i17} ]
+ = Set( T11_l[ rS22{i16}, iS23{i17} ], cache_op=Streaming )
+T17_l[ iS31{i17} ]
+ = T14_l[ iS28{i17} ]
+ + T15_l[ iS29{i17} ];
+T12_l[ iS24{i16}, iS25{i17} ]
+ = Set( T7_l[ iS14{i16}, iS15{i17} ], cache_op=Streaming )
+T13_g[ rS26{i16}, iS27{i17} ]
+ = reduction( T12_l[ iS24{i16}, iS25{i17} ], op = add, initial value = float(0), allreduce = false )
+T16_l[ iS30{i17} ]
+ = Set( T13_g[ rS26{i16}, iS27{i17} ], cache_op=Streaming )
+T18_l[ iS32{i17} ]
+ = T17_l[ iS31{i17} ]
+ + T16_l[ iS30{i17} ];
+T19_g[ rS33{i17} ]
+ = reduction( T18_l[ iS32{i17} ], op = add, initial value = float(0), allreduce = false )
}I'm not sure how to make it show character diffs, but what's happening is the iterdomain extents are bumped starting in the second reduction, as if we replaced that reduction op, or maybe replaced its |
This is because now we create an additional |
Also sort expected output by line to give clearer error messages.
| std::string sortByLine(const std::string& input) { | ||
| auto ss = std::stringstream(input); | ||
| std::vector<std::string> lines; | ||
| std::string line; | ||
| while (std::getline(ss, line, '\n')) { | ||
| lines.push_back(line); | ||
| } | ||
| std::sort(lines.begin(), lines.end()); | ||
| std::stringstream output; | ||
| bool first = true; | ||
| for (auto line : lines) { | ||
| if (!first) { | ||
| output << std::endl; | ||
| } | ||
| first = false; | ||
| output << line; | ||
| } | ||
| return output.str(); | ||
| } |
There was a problem hiding this comment.
Just cleaning up the TODO in this test while I'm here. This makes the error message provide a useful diff.
| " PipelineVal representing Val T4_g[ iS6{i13}, iS7{i14}, iS8{i15} ] on stage " + | ||
| " PipelineVal representing Val T4_g[ iS6{i15}, iS7{i16}, iS8{i17} ] on stage " + |
There was a problem hiding this comment.
These changes were necessitated by the addition of an attribute in ReductionOp. A Val* is created when each attribute is added, incrementing the scalar name counter.
|
!build |
|
|
||
| NVF_ERROR(!rop->isAllreduce(), "Serial grid allReduce is not implemented"); | ||
|
|
||
| // Allocate global work buffer TensorIndex. |
There was a problem hiding this comment.
Why does the work buffer need to be scheduled? Can't it be simply just a buffer with the leaf domains of the reduction output tensor?
There was a problem hiding this comment.
I think you're right. It doesn't need to be scheduled explicitly, we just need to use the leaf domains of the output to determine how to index it (indexing as if the output were in global mem).
There was a problem hiding this comment.
Just pushed a change to simplify this by reusing out_tv->domain() instead of replaying to recreate it.
There was a problem hiding this comment.
This will share the same IterDomains with multiple TensorDomains, which I think we should avoid since we implicitly assume each tensor has unique IterDomains.
I think what we should do here would be something like:
std::vector<IterDomain*> work_buffer_ids(out_tv->nDims());
for (IterDomain* out_id : out_tv->leaf()) {
work_buffer_ids.push_back(IterDomainBuilder(out_id).build());
}
auto work_buffer_domain = IrBuilder::create<TensorDomain>(work_buffer_ids);
There was a problem hiding this comment.
Ah OK. I see what you mean now. I was trying to preserve the transforms originally so that we could potentially re-use a fusion output buffer for the work buffer in the future (i.e. "in-place" reduction). However, that is low priority since it only applies when the output has the same type as the mma accumulator and we commonly have single precision accumulator and half-precision output tensors so we can't reuse them anyway. For simplicity's sake I'll do as you suggest and just use the leaf domain of the mma output as the work buffer's root/leaf. I'll leave a comment to remind myself to revisit this if we implement global buffer re-use in the future.
There was a problem hiding this comment.
Sorry I wasn't clear. Yes, this should be enough for now.
Stacked on #1456. This simply enables serial reduction in the matmul scheduler when split-K is used.
This change enables
ReductionOps to be lowered as serial reductions (see #1405) if requested during scheduling.ReductionOpis modified by calling itsrequestSerialGridReduction()method. The output tensor can be scheduled before or after this method call, and should result in the op having all its reduction axes parallelized as grid dimensions.ReductionOps havingserialGridReductionRequested() == true, and we place syncs around their outer loop. At this point, we also analyze that outer loop to determine if there are any conflicting expressions, such as conflicting grid reductions.ReductionOpto aGridReductionthat has its serial buffer set. The serial buffer is a temporaryTensorIndexindexed like a global memory version of the reduction output tensor.The generated kernel looks like this:
Notice that the index
i20now matches between the outputT1and the intermediateT5. In another PR, I will attempt to extend our buffer reuse machinery to recognize this as a chance to useT1in place ofT5(i.e. inner aliasing, in-place reduction).Also notice that I have not yet hoisted the sync flags index, or the
first_blockandlast_blockpredicates.