Skip to content

[WIP] Peel off epilogue loop for circular buffering#2005

Closed
jacobhinkle wants to merge 3 commits intomainfrom
circular_buffer_epilogue
Closed

[WIP] Peel off epilogue loop for circular buffering#2005
jacobhinkle wants to merge 3 commits intomainfrom
circular_buffer_epilogue

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Mar 26, 2024

Currently, when a TensorView is circular buffered, we convert a loop like

for i in 0..N
  for j in ...
    ... = y[i, j]

to

allocate x[S*D] // if S is the single-buffered size of the TensorView
for i in 0..D-1 // prologue
  for j in ...
    if pred:
      x[i*S + j] = y[i, j]
  cp.async.commit
cp.async.wait D-2 // Block until all but the last D-2 transactions complete
for i in 0..N // main loop
  for j in ...
    if pred:
      x[(((i+D-1)%D)*S+j)] = y[i+D-1, j]
    ... = x[((i%D)*S+j)]
    cp.async.commit
    cp.async.wait D-2
    __syncthreads()

As mentioned in #2000, the way we do the predicated load for circular buffering means that we actually have "in-flight" committed cp.async calls which will store zeros into smem. Since we allow D-2 many of those to remain in flight after the end of this flight, they are never waited on, leading to the data race encountered in #2000.

This PR instead adds an epilogue loop, so that we now lower a circular buffered TV like this:

allocate x[S*D] // if S is the single-buffered size of the TensorView
for i in 0..D-1 // prologue
  for j in ...
    x[i*S + j] = y[i, j]
  cp.async.commit
cp.async.wait D-2 // Block until all but the last D-2 transactions complete
for i in 0..(N-(D-1)) // main loop
  for j in ...
    if pred:
      x[(((i+D-1)%D)*S+j)] = y[i+D-1, j]
    ... = x[((i%D)*S+j)]
    cp.async.commit
    cp.async.wait D-2
    __syncthreads()
for i in (N-(D-1))..N // epilogue
  cp.async.wait N-1-i
  for j in ...
    ... = x[((i%D)*S+j)]

Note that we have shortened the main loop and peeled it into the new "epilogue" loop which is identical to the main loop but without any additional async loads.

Challenges encountered so far

This is going reasonably well, but there is a limitation I didn't notice at first: as mentioned in the PTX docs the N argument in cp.async.wait_group N must be integer constant.

Comment on lines -654 to +660
int64_t keepStages() const {
return attribute<int64_t>(1);
Val* keepStages() const {
return attributeVal(1);
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Mar 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to make keepStages() a Val*. The issue is that this must be a compile time constant argument for the inline asm instruction. I think it still might work if we are able to unroll the epilogue loop, but that might not always be preferable/acceptable. So instead, we could also have a runtime function/kir node that calls a runtime helper function that wraps cp.async.wait_group N for variable N and handles values up to say 5 or 6 inside a switch statement.

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Mar 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UPDATE: the switch statement helper approach seems to work and doesn't require unrolling the epilogue (we can't get away from this switch statement even with unrolling the epilogue), but it means we need to set an upper limit on the number of unsynched stages. We could set that to something high like 10. We only need it to be num_stages, i.e. we don't need a switch statement with 10 cases if we have only 3 circular buffering stages. However, the requirement to have a constant N in cp.async.wait_group, along with the requirement for inline asm to have string literal inputs has stumped me. I tried all kinds of combinations of templates and macros but got nowhere.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another alternative is to, instead of making the epilogue

for i in (N-(D-1))..N // epilogue
  cp.async.wait N-1-i
  for j in ...
    ... = x[((i%D)*S+j)]

we change it to

for i in 0..(D-1) // epilogue
  cp.async.wait D-2-i
  for j in ...
    ... = x[(((N-(D-1) + i)%D)*S+j)]

so that the epilogue is naturally unrolled.

And at the same time

Fuser/csrc/kernel_ir.cpp

Lines 321 to 323 in b108bca

if (in->isConst()) {
constraint = "n";
} else {

needs change to generate an n for PTX constraints

also, you will need update

Fuser/csrc/index_compute.cpp

Lines 2354 to 2400 in b108bca

if (consumer_tv->isDoubleBuffered() || consumer_tv->isCircularBuffered()) {
auto db_loop =
gpu_lower->doubleBufferInfo().getDoubleBufferLoop(consumer_tv, loops);
auto stage_depth = (int64_t)gpu_lower->doubleBufferInfo().getStageDepthFor(
db_loop->iter_domain());
bool is_circular_buffer_loop = stage_depth > 2;
bool is_prolog =
db_loop->doubleBufferLoopStage() == DoubleBufferLoopStage::Prolog;
Val* db_switch_index = nullptr;
// In double buffered we don't materialize the prolog loop as there will
// be only one iteration. In circular buffer case we materialize the
// prolog loop as well covering the first N-1 iterations, N being the
// stage depth.
if (!is_prolog || is_circular_buffer_loop) {
if (is_prolog && is_circular_buffer_loop) {
// The buffer switching logic is the same as original index
// in the case of circular buffer prolog.
db_switch_index = db_loop->indexOrStartIfTrivial();
if (rotated_loops.count(db_loop)) {
db_switch_index =
SimplifyingIrBuilder::addExpr(db_switch_index, db_loop->step());
}
} else {
auto loop_index = db_loop->indexOrStartIfTrivial();
if (rotated_loops.count(db_loop)) {
loop_index =
SimplifyingIrBuilder::addExpr(loop_index, db_loop->step());
}
// Switching index generated for main loop or epilog component.
db_switch_index = SimplifyingIrBuilder::modExpr(
SimplifyingIrBuilder::addExpr(
loop_index,
SimplifyingIrBuilder::create<Val>(
stage_depth - 1, DataType::Index)),
SimplifyingIrBuilder::create<Val>(stage_depth, DataType::Index));
}
// Use the generated switching buffer index to access the buffer space.
auto original_alloc_size =
gpu_lower->doubleBufferInfo().getOriginalAllocSize(consumer_tv);
auto db_strided_index =
SimplifyingIrBuilder::mulExpr(db_switch_index, original_alloc_size);
strided_inds.push_back(db_strided_index);
}
}

to change loop index

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm understanding, you mean to use the unrolled loop index as the argument to wait_group. That fails since even if the loop variable is the actual argument we get

CUDA NVRTC compile error: __tmp_kernel_none_f0_c0_r0_g0.cu(10249): error: an asm operand must be an integral constant expression.
      asm volatile("cp.async.wait_group %0;"::"n"(ii)); 

That is, in PTX it is seeing this as a non-constant argument. I have tried interpolating it into that string, but for inline assembly the command must be a string literal...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmigdal-nv helped come up with this solution:

template <int num_stages>
__inline__ __device__ void cpAsyncPartialBarrier(int keep_stages) {
  if constexpr (num_stages < 0) {
    return;
  }
  if (keep_stages == num_stages) {
    asm volatile("cp.async.wait_group %0;"::"n"(num_stages));
  } else {
    cpAsyncPartialBarrier<num_stages - 1>(keep_stages);
  }
}

template <>
__inline__ __device__ void cpAsyncPartialBarrier<-1>(int keep_stages) {
}

...
  #pragma unroll
  for(nvfuser_index_t i13 = 12793; i13 < 12800LL; ++i13) {
    ...
    __syncthreads();
    cpAsyncPartialBarrier<8>((12800 - 2) - i13);
  }
  NVFUSER_UPDATE_MAGIC_ZERO;

We can replace ((12800 - 2) - i13) with the Val* we have currently. The compiler will evaluate the recursive template and prune the dead branches. I think the only downside to this is that we need to unroll the new loop, will probably hurt compilation time.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun C++...

template<int i>
void f() {}

int main() {
    #pragma unroll
    for (int i = 0; i < 10; i++) {
        f<i>();
    }
    return 0;
}
<source>: In function 'int main()':
<source>:7:13: error: no matching function for call to 'f<i>()'
    7 |         f<i>();
      |         ~~~~^~
<source>:2:6: note: candidate: 'template<int i> void f()'
    2 | void f() {}
      |      ^
<source>:2:6: note:   template argument deduction/substitution failed:
<source>:7:13: error: the value of 'i' is not usable in a constant expression
    7 |         f<i>();
      |         ~~~~^~
<source>:6:14: note: 'int i' is not const
    6 |     for (int i = 0; i < 10; i++) {
      |              ^
<source>:7:13: note: in template argument for type 'int'
    7 |         f<i>();
      |         ~~~~^~

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW I looked to see what CUTLASS does. It seems they do not peel off an epilogue loop. Instead they just wait_group 0 after the main loop. https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/gemm/threadblock/mma_multistage.h

jacobhinkle added a commit that referenced this pull request Mar 27, 2024
This is an alternative to #2005
@jacobhinkle
Copy link
Collaborator Author

Closing in favor of #2008.

@zasdfgbnm zasdfgbnm deleted the circular_buffer_epilogue branch March 28, 2024 00:05
jacobhinkle added a commit that referenced this pull request Mar 28, 2024
This just places a `cp.async.wait_group 0` instruction immediately after
any circular buffer main loop which is the approach taken by CUTLASS for
pipelining GEMMs: (see
[mma_multistage.h#L664-L665](https://github.com/NVIDIA/cutlass/blob/c4e3e122e266644c61b4af33d0cc09f4c391a64b/include/cutlass/gemm/threadblock/mma_multistage.h#L664-L665)).
The previous fix for #2000, #2001, is reverted.

This is an alternative to #2005.

Fixes #2000
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants