Grid reduction with serialized blocks#1405
Merged
jacobhinkle merged 36 commits intomainfrom Dec 8, 2023
Merged
Conversation
Soon I will add the serial gridReduce codegen version so that we can compare to one another.
... even if there is no wait(). Just writing a value to the semaphore is deadlocking for me on my 3090 Ti.
What's left: - modify allocation size for work buffer - codegen the serial grid reduction
jacobhinkle
commented
Dec 2, 2023
Closed
I think this will have negligible perf impact by adding a sync in the last block after the last loop. The upside is that we can keep the sync buffer clean (i.e. all zeros), which at some point might help us to re-use sync buffers, removing one memset kernel launch per execution.
Collaborator
Author
|
!build |
naoyam
reviewed
Dec 7, 2023
naoyam
reviewed
Dec 8, 2023
jacobhinkle
added a commit
that referenced
this pull request
Jan 23, 2024
This change enables `ReductionOp`s to be lowered as serial reductions (see #1405) if requested during scheduling. 1. At scheduling, a `ReductionOp` is modified by calling its `requestSerialGridReduction()` method. The output tensor can be scheduled before or after this method call, and should result in the op having all its reduction axes parallelized as grid dimensions. 2. Early in lowering, we find `ReductionOp`s having `serialGridReductionRequested() == true`, and we place syncs around their outer loop. At this point, we also analyze that outer loop to determine if there are any conflicting expressions, such as conflicting grid reductions. 3. Later in lowering, during the indexing pass, we convert the `ReductionOp` to a `GridReduction` that has its serial buffer set. The serial buffer is a temporary `TensorIndex` indexed like a global memory version of the reduction output tensor. The generated kernel looks like this: ```c++ // Allocate global tensor T4 grid_sync::blockSerializeWait<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); #pragma unroll for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) { nvfuser_index_t i14; i14 = 8LL * i13; nvfuser_index_t i15; i15 = 2048LL * i13; nvfuser_index_t i16; i16 = i4 + i15; nvfuser_index_t i17; i17 = -i15; #pragma unroll for(nvfuser_index_t i18 = 0; i18 < 8LL; ++i18) { nvfuser_index_t i19; i19 = 256LL * (i18 + nvfuser_zero); nvfuser_index_t i20; i20 = i16 + i19; float T3[1LL]; T3[0LL] = 0.000000000e+00f; // Allocate global tensor T5 reduction::serialReductionStep( T3[0LL], T2[(i14 + i18)], 0.000000000e+00f, T5[i20], [](float &a, float b) { a = a + b; }, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0, index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1, true, true); if ((b6 && (i5 < (i17 - i19)))) { T1[i20] = T3[0LL]; } } } NVFUSER_UPDATE_MAGIC_ZERO; grid_sync::blockSerializeRelease<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]); ``` Notice that the index `i20` now matches between the output `T1` and the intermediate `T5`. In another PR, I will attempt to extend our buffer reuse machinery to recognize this as a chance to use `T1` in place of `T5` (i.e. inner aliasing, in-place reduction). Also notice that I have not yet hoisted the sync flags index, or the `first_block` and `last_block` predicates.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR introduces kernel IR nodes for the syncs that need to occur before and after a loop containing a serial grid reduction. That grid reduction can be inlined with other computation in a loop nest, and the sync nodes will be placed around the outer loop in the generated kernel. The
kir::GridReductionnode itself is modified to have an attribute available viabool kir::GridReduction::isSerial() constindicating whether this is a serial grid reduction. This PR tests that codegen is correct.Default CUDA kernel for the included test:
The serial reduction kernel looks like this:
What is not included in this PR
There is no automatic scheduling or lowering to serial reductions in this PR. The included test works via a post-lowering hook in
FusionExecutorto simply test that we can codegen the nodes properly once they are manually placed.There is also no re-use of global buffers currently, so this is not yet an "in-place" reduction. I.e. we must manually allocate a work buffer that is the full size of the grid reduction output at this time. In the future, we can avoid the need for that workspace by aliasing an output buffer.
The work buffer must currently be the same dtype as the reduction element. In the future, we could relax this in order to cast to lower precision in the work buffer. This would enable us to re-use the global memory allocated for TST and HSH matmul output, at the expense of a small loss in precision.
Related to #1316 and #991.