Skip to content

Unswizzle before grid reduction in split-K#1534

Merged
jacobhinkle merged 84 commits intomainfrom
splitk_smem_epilogue
Mar 22, 2024
Merged

Unswizzle before grid reduction in split-K#1534
jacobhinkle merged 84 commits intomainfrom
splitk_smem_epilogue

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Dec 15, 2023

Serial grid reductions are used in split-K matmuls as of #1510. This means we load and store elements in the reduction tensor according to the indexing of the work buffer. This is unlike ordinary grid reductions that use gridReduce, which reduces individual elements using a scheme that ensures coalescing by indexing into the work buffer based on threadIdx and blockIdx. Currently these split-K accesses are inefficient due to this lack of coalescing.

We currently already ensure coalesced output stores in matmuls when possible by using smem for the epilogue (#387). A shared memory buffer is used to communicate elements between threads so that the resulting tensor will have a proper global access pattern when it is written out to global memory as a tile of the output. Before this PR if we used split-K with use_smem_epilogue = true, the store to global memory will be coalesced but there will be uncoalesced accesses during the split-K reduction. This PR modifies scheduling so that in those cases, the smem epilogue tensor is placed before the split-K sum, so that unswizzling happens before completing the reduction. The result is that the reduction accesses are coalesced.

This is a generated kernel from NVFuserTest.FusionAmpereMatmulSplitKBias_CUDA:

// ... (main loop) ...
     #pragma unroll
      for(nvfuser_index_t i59 = 0; i59 < 4LL; ++i59) {
        nvfuser_index_t i104;
        i104 = 8LL * i59;
        nvfuser_index_t i105;
        i105 = 32LL * i59;
        #pragma unroll
        for(nvfuser_index_t i61 = 0; i61 < 8LL; ++i61) {
          nvfuser_index_t i106;
          i106 = 4LL * i61;
          asm(
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
            :"=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "=f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
            :"r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[1]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[2]),
             "r"((*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T4[i104]))[3]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[0]),
             "r"((*reinterpret_cast<Array<uint32_t, 2, 1>*>(&T5[i106]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[0]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[1]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[2]),
             "f"((*reinterpret_cast<Array<float, 4, 1>*>(&T16[(i105 + i106)]))[3])
          );
        }
      }
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    __syncthreads();
  }
  __syncthreads();  #pragma unroll
  for(nvfuser_index_t i107 = 0; i107 < 4LL; ++i107) {
    nvfuser_index_t i108;
    i108 = 32LL * i107;
    nvfuser_index_t i109;
    i109 = i38 + (2048LL * i107);
    #pragma unroll
    for(nvfuser_index_t i110 = 0; i110 < 8LL; ++i110) {
      nvfuser_index_t i111;
      i111 = i108 + (4LL * i110);
      nvfuser_index_t i112;
      i112 = i11 + i110;
      nvfuser_index_t i113;
      i113 = (i109 + (32LL * (i112 / 4LL))) + (8LL * (i39 ^ (i112 % 4LL)));
      #pragma unroll
      for(nvfuser_index_t i114 = 0; i114 < 2LL; ++i114) {
        loadGeneric<float, 2>( &T17[(i113 + (1024LL * i114))],  &T16[(i111 + (2LL * i114))]);
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  // Allocate global tensor T19
  grid_sync::blockSerializeWait<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i115 = 0; i115 < 32LL; ++i115) {
    nvfuser_index_t i116;
    i116 = i115 + nvfuser_zero;
    nvfuser_index_t i117;
    i117 = i44 + (i45 * i116);
    nvfuser_index_t i118;
    i118 = i47 + (4LL * i115);
    bool b119;
    b119 = i55 < (-(4LL * i116));
    bool b120;
    b120 = b54 && b119;
    Array<float, 4LL, 4> T6;
    T6.set(float(0.000000000e+00f));
    // Allocate global tensor T20
    reduction::serialReductionStep</*vec_size=*/4>(
      &T6[0LL],
      &T17[(i42 + (512LL * i115))],
      0.000000000e+00f,
      &T20[i117],
      [](float &a, float b) { a = a + b; },
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
      index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
      b120,
      b120);
    Array<float, 4LL, 4> T10;
    #pragma unroll
    for(nvfuser_index_t i121 = 0; i121 < 4LL; ++i121) {
      __half T18[1LL];
      T18[0LL] = 0LL;
      if (b119) {
        T18[0LL]
           = T2[(i118 + ((i48 + (i121 + nvfuser_zero)) / 128LL))];
      }
      __half T7[1LL];
      T7[0LL]
         = T18[0LL];
      float T8[1LL];
      T8[0LL]
         = __half2float(T7[0LL]);
      T10[i121]
        = T6[i121]
        + T8[0LL];
    }
    if ((b56 && b119)) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T9[i117], &T10[0LL]);
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeRelease<false, false, true>(&T19[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
}

Note that the i135 loop will be smaller once we have #1528 at which point it would more clearly show reduction followed by the loop for the predicated bias epilogue.

(Diff should be viewed hiding whitespace changes as many changes are to indentation).

jacobhinkle and others added 22 commits December 8, 2023 17:46
Will revisit once sync pass is done, when we have a TensorIndex
Still missing allocation/indexing of work buffer
I need to replay leaf transforms, then get index.
Codegen is now like
```c++
  // Allocate global tensor T5
  reduction::serialReductionStep(
    T3[0LL],
    T2[(i14 + i18)],
    0.000000000e+00f,
    T5[((((((((((((nvfuser_index_t)blockIdx.x) * 8LL) + ((nvfuser_index_t)blockIdx.y)) * 4LL) + i13) * 8LL) + (i18 + nvfuser_zero)) * 4LL) + ((nvfuser_index_t)threadIdx.y)) * 32LL) + ((nvfuser_index_t)threadIdx.x))],
    [](float &a, float b) { a = a + b; },
    index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
    index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
    true,
    true);
```
This looks OK, although it will get a little better with hoisting. This
compiles, but I get an error in `runFusion`:
```
C++ exception with description "Expected T5_g[ iblockIdx.x59{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(262144, 32) ), 4) ), 8) ), 4) ), 8) )}, iblockIdx.y60{8}, ithreadIdx.y54{4}, ithreadIdx.x52{32}, iS58{4}, iS56{8}, rblockIdx.z49{5} ] to be bound to a tensor of rank 1, but got a tensor of rank 6
Exception raised from validateValWithConcreteValue at /opt/pytorch/nvfuser/csrc/expr_evaluator.cpp:38 (most recent call first):
```
This is happening when binding inputs I believe.
Fixes execution error. Test passes!
Generated kernel now looks like
```c++
  // Allocate global tensor T4
  grid_sync::blockSerializeWait<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 8LL * i13;
    nvfuser_index_t i15;
    i15 = 2048LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 8LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      nvfuser_index_t i20;
      i20 = i16 + i19;
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      // Allocate global tensor T5
      reduction::serialReductionStep(
        T3[0LL],
        T2[(i14 + i18)],
        0.000000000e+00f,
        T5[i20],
        [](float &a, float b) { a = a + b; },
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,
        true);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[i20]
           = T3[0LL];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeRelease<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
```
Note that the index `i20` matches the output `T1`. This is what we need
to reclaim `T1` in a later PR; it will still be a challenge in that work
to exact map between `T5` and `T3` in order to get `T1` and `T5` exact
mapped...
Also sort expected output by line to give clearer error messages.
@jacobhinkle
Copy link
Collaborator Author

In tracking down some bugs, I noticed this:

int num_warp_k = cta_tile.k / warp_tile.k;

I tried to find an example where this is not equal to 1, but it seems like we never use a warp_tile.k that differs from cta_tile.k. Should we convert this to an assertion?

@jacobhinkle jacobhinkle force-pushed the matmul_serial_splitk branch from e5b68b4 to 5dc198c Compare January 8, 2024 17:49
We should not expect to re-use prologue smem when we have batch &
split-K, because in that case the batch loop will exist in the kernel so
there will be no smem re-use in those cases. Note that this _should_ be
possible but requires a block sync at the end of the batch loop and we
currently do not support that pattern. See #1899
@jacobhinkle
Copy link
Collaborator Author

!build --diff

@jacobhinkle jacobhinkle marked this pull request as ready for review March 12, 2024 13:27
@jacobhinkle
Copy link
Collaborator Author

With this PR I am seeing the following generated kernel for HSH split-K matmul using smem epilogue:

    // main loop
  }
  #pragma unroll
  for(nvfuser_index_t i146 = 0; i146 < 4; ++i146) {
    nvfuser_index_t i147;
    i147 = 32 * i146;
    nvfuser_index_t i148;
    i148 = i50 + (2048LL * i146);
    #pragma unroll
    for(nvfuser_index_t i149 = 0; i149 < 4; ++i149) {
      nvfuser_index_t i150;
      i150 = i147 + (8 * i149);
      nvfuser_index_t i151;
      i151 = i24 + (2 * i149);
      #pragma unroll
      for(nvfuser_index_t i152 = 0; i152 < 2; ++i152) {
        nvfuser_index_t i153;
        i153 = i150 + (4LL * i152);
        nvfuser_index_t i154;
        i154 = i151 + i152;
        nvfuser_index_t i155;
        i155 = (i148 + (32LL * (i154 / 4))) + (8LL * (i51 ^ (i154 % 4)));
        #pragma unroll
        for(nvfuser_index_t i156 = 0; i156 < 2; ++i156) {
          loadGeneric<float, 2>( &T12[(i155 + (1024LL * i156))],  &T11[(i153 + (2LL * i156))]);
        }
      }
    }
  }
  // Allocate global tensor T13
  grid_sync::blockSerializeWait<false, false, true>(&T13[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i157 = 0; i157 < 16; ++i157) {
    nvfuser_index_t i158;
    i158 = i52 + (1024 * i157);
    nvfuser_index_t i159;
    i159 = i10 * i157;
    nvfuser_index_t i160;
    i160 = i55 + i159;
    nvfuser_index_t i161;
    i161 = 8 * i157;
    nvfuser_index_t i162;
    i162 = -i161;
    nvfuser_index_t i163;
    i163 = i67 + i161;
    Array<float, 8, 4> T5;
    #pragma unroll
    for(nvfuser_index_t i164 = 0; i164 < 2; ++i164) {
      T5.set(float(0.000000000e+00f));
    }
    if ((b64 && (i65 < i162))) {
      #pragma unroll
      for(nvfuser_index_t i164 = 0; i164 < 2; ++i164) {
        nvfuser_index_t i165;
        i165 = i53 + i164;
        nvfuser_index_t i166;
        i166 = i165 % 32;
        nvfuser_index_t i167;
        i167 = i166 / 2;
        nvfuser_index_t i168;
        i168 = i165 / 32;
        // Allocate global tensor T14
        reduction::serialReductionStep</*vec_size=*/4>(
          &T5[(4LL * i164)],
          &T12[((((i158 + (128LL * i168)) + (32LL * (i167 / 4))) + (4 * (i166 % 2))) + (8LL * ((i167 % 4) ^ ((i15 + i168) % 4))))],
          0.000000000e+00f,
          &T14[((i160 + (4 * i166)) + (T1.logical_size[1LL] * i168))],
          [](float &a, float b) { a = a + b; },
          index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
          index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
          true,
          true);
      }
    } else {
      #pragma unroll
      for(nvfuser_index_t i164 = 0; i164 < 2; ++i164) {
        nvfuser_index_t i169;
        i169 = 4LL * i164;
        nvfuser_index_t i170;
        i170 = i53 + i164;
        nvfuser_index_t i171;
        i171 = i170 % 32;
        nvfuser_index_t i172;
        i172 = i171 / 2;
        nvfuser_index_t i173;
        i173 = i170 / 32;
        nvfuser_index_t i174;
        i174 = i66 + i169;
        bool b175;
        b175 = ((i6 + (i174 % 128)) < T1.logical_size[1LL]) && ((i163 + (i174 / 128)) < T0.logical_size[0LL]);
        // Allocate global tensor T15
        reduction::serialReductionStep</*vec_size=*/4>(
          &T5[i169],
          &T12[((((i158 + (128LL * i173)) + (32LL * (i172 / 4))) + (4 * (i171 % 2))) + (8LL * ((i172 % 4) ^ ((i15 + i173) % 4))))],
          0.000000000e+00f,
          &T15[((i160 + (4 * i171)) + (T1.logical_size[1LL] * i173))],
          [](float &a, float b) { a = a + b; },
          index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
          index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
          b175,
          b175);
      }
    }
    Array<__half, 8, 8> T9;
    #pragma unroll
    for(nvfuser_index_t i176 = 0; i176 < 8; ++i176) {
      T9[i176]
         = __float2half(T5[i176]);
    }
    if ((b68 && (i69 < i162))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T6[(i57 + i159)], &T9[0]);
    }
  }
  grid_sync::blockSerializeRelease<false, false, true>(&T13[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
}

Two things are noticeable:

  1. The index simplification issue from Strengthen index simplification for cast epilogue matmul #1827 is present. I think merging that branch will fix this and provide a speed-up to split-K just like it did for epilogues in that PR.
  2. There is a predicate causing different reduction steps in the two branches. This is because of the write/read predicates. I think this is OK but note that it presently means we use twice as much global workspace as we need, since each of these calls has a separate buffer. A fancier buffer aliasing algorithm would detect this and re-use the buffer, but even better (cleaner looking at least) would be to somehow collect the predicate into the read/write predicates for a single call to serialReductionStep.

@jacobhinkle
Copy link
Collaborator Author

This does have the expected perf impact:
image
image

.propagateToBoundary());
smem_epilogue->axis(-1)->parallelize(ParallelType::Vectorize);
if (num_splitk_dims != 0) {
splitk_sum->axis(-3)->parallelize(ParallelType::BIDz);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move all the scheduling of splitk_sum into a single utility function? Also, I don't understand what this line is trying to do. Is it parallelizing the iNw as BIDz?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this PR to place all the scheduling of splitk_sum in a single function. That was a great idea as it showed me a few ways to simplify things. I think it's clearer now.

@jacobhinkle jacobhinkle requested a review from zasdfgbnm March 20, 2024 00:48
@@ -906,11 +954,26 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
splitk_sum = mma_result;
mma_result = splitk_sum->rFactor({-4, -1});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to keep the split(-4) here but move the definition of splitk_sum and rFactor call inside scheduleSplitKSum? That is, we schedule smem epilogue and apply mma swizzle to mma_result first, then we do rFactor. This is just about code readability and make the code more modular, because as a code reader, I would not need to worry about split-k much when reading the code here, and I can only start thinking about split-k when reading scheduleSplitKSum.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a good idea. We need to define smem_epilogue as mma_result->cacheAfter() and since mma_result is the intermediate tensor in the split-k rfactor, this means we can't define smem_epilogue until we've done the split-K rfactor. However, I think we could work around that and hide both smem_epilogue and splitk_sum in a new scheduleEpilogue function. Since this PR is functioning, I will merge as-is (pending CI) and begin a new PR to do that cleanup.

@jacobhinkle
Copy link
Collaborator Author

!build --diff-bench

@jacobhinkle jacobhinkle merged commit 7525cc0 into main Mar 22, 2024
@jacobhinkle jacobhinkle deleted the splitk_smem_epilogue branch March 22, 2024 12:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants