Skip to content

Grid reduction with serialized blocks#1405

Merged
jacobhinkle merged 36 commits intomainfrom
serial_grid_reduce
Dec 8, 2023
Merged

Grid reduction with serialized blocks#1405
jacobhinkle merged 36 commits intomainfrom
serial_grid_reduce

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Nov 29, 2023

This PR introduces kernel IR nodes for the syncs that need to occur before and after a loop containing a serial grid reduction. That grid reduction can be inlined with other computation in a loop nest, and the sync nodes will be placed around the outer loop in the generated kernel. The kir::GridReduction node itself is modified to have an attribute available via bool kir::GridReduction::isSerial() const indicating whether this is a serial grid reduction. This PR tests that codegen is correct.

Default CUDA kernel for the included test:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T4, Tensor<int64_t, 1, 1> T5) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = 32LL * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i1;
  i1 = 32768LL * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i2;
  i2 = 262144LL * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i3;
  i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  nvfuser_index_t i4;
  i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2;
  nvfuser_index_t i5;
  i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  bool b6;
  b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL);
  // Allocate global tensor T4
  // Allocate global tensor T5
  float T2[128LL];
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      T2[(i7 + (32LL * i8))] = 0.000000000e+00f;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    nvfuser_index_t i9;
    i9 = 256LL * i7;
    nvfuser_index_t i10;
    i10 = i3 + i9;
    nvfuser_index_t i11;
    i11 = -i9;    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      nvfuser_index_t i12;
      i12 = 8192LL * (i8 + nvfuser_zero);
      if ((i5 < (i11 - i12))) {
        T2[(i7 + (32LL * i8))]
           = T0[(i10 + i12)];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 32LL * i13;
    nvfuser_index_t i15;
    i15 = 8192LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      reduction::gridReduce<false, false, true, false, false, false, false, true>(
        T3[0LL],
        T2[(i14 + i18)],
        [](float &a, float b) { a = a + b; },
        &T4[0],
        &T5[0],
        static_cast<float*>(shared_mem),
        true,
        true,
        float(0.000000000e+00f),
        ((i13 * 32LL) + i18),
        128LL);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[(i16 + i19)]
           = T3[0LL];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
}

The serial reduction kernel looks like this:

__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T6, Tensor<int64_t, 1, 1> T5) {
  alignas(16) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = 32LL * ((nvfuser_index_t)threadIdx.y);
  nvfuser_index_t i1;
  i1 = 32768LL * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i2;
  i2 = 262144LL * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i3;
  i3 = ((((2097152LL * ((nvfuser_index_t)blockIdx.z)) + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  nvfuser_index_t i4;
  i4 = ((((nvfuser_index_t)threadIdx.x) + i0) + i1) + i2;
  nvfuser_index_t i5;
  i5 = (((-2097152LL + ((nvfuser_index_t)threadIdx.x)) + i0) + i1) + i2;
  bool b6;
  b6 = ((nvfuser_index_t)blockIdx.z) == (((nvfuser_index_t)gridDim.z) + -1LL);
  // Allocate global tensor T6
  // Allocate global tensor T5
  float T2[128LL];
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      T2[(i7 + (32LL * i8))] = 0.000000000e+00f;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i7 = 0; i7 < 32LL; ++i7) {
    nvfuser_index_t i9;
    i9 = 256LL * i7;
    nvfuser_index_t i10;
    i10 = i3 + i9;
    nvfuser_index_t i11;
    i11 = -i9;
    #pragma unroll
    for(nvfuser_index_t i8 = 0; i8 < 4LL; ++i8) {
      nvfuser_index_t i12;
      i12 = 8192LL * (i8 + nvfuser_zero);
      if ((i5 < (i11 - i12))) {
        T2[(i7 + (32LL * i8))]
           = T0[(i10 + i12)];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeWait<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 32LL * i13;
    nvfuser_index_t i15;
    i15 = 8192LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 32LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      reduction::serialReductionStep(
        T3[0LL],
        T2[(i14 + i18)],
        0.000000000e+00f,
        T6[(i16 + i19)],
        [](float &a, float b) { a = a + b; },
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,
        true);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[(i16 + i19)]
           = T3[0LL];
      }
    }
  }
  grid_sync::blockSerializeRelease<false, false, true, false>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  NVFUSER_UPDATE_MAGIC_ZERO;
}

What is not included in this PR

There is no automatic scheduling or lowering to serial reductions in this PR. The included test works via a post-lowering hook in FusionExecutor to simply test that we can codegen the nodes properly once they are manually placed.

There is also no re-use of global buffers currently, so this is not yet an "in-place" reduction. I.e. we must manually allocate a work buffer that is the full size of the grid reduction output at this time. In the future, we can avoid the need for that workspace by aliasing an output buffer.

The work buffer must currently be the same dtype as the reduction element. In the future, we could relax this in order to cast to lower precision in the work buffer. This would enable us to re-use the global memory allocated for TST and HSH matmul output, at the expense of a small loss in precision.

Related to #1316 and #991.

@jacobhinkle jacobhinkle changed the title Serial grid reduce Serial grid reduction Nov 29, 2023
@jacobhinkle jacobhinkle changed the title Serial grid reduction Grid reduction with serialized blocks Nov 29, 2023
@Priya2698 Priya2698 mentioned this pull request Dec 4, 2023
@jacobhinkle jacobhinkle marked this pull request as ready for review December 6, 2023 15:14
I think this will have negligible perf impact by adding a sync in the
last block after the last loop. The upside is that we can keep the sync
buffer clean (i.e. all zeros), which at some point might help us to
re-use sync buffers, removing one memset kernel launch per execution.
@jacobhinkle
Copy link
Collaborator Author

!build

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jacobhinkle jacobhinkle merged commit 34c7fa4 into main Dec 8, 2023
@jacobhinkle jacobhinkle deleted the serial_grid_reduce branch December 8, 2023 17:40
jacobhinkle added a commit that referenced this pull request Jan 23, 2024
This change enables `ReductionOp`s to be lowered as serial reductions
(see #1405) if requested during scheduling.

1. At scheduling, a `ReductionOp` is modified by calling its
`requestSerialGridReduction()` method. The output tensor can be
scheduled before or after this method call, and should result in the op
having all its reduction axes parallelized as grid dimensions.
2. Early in lowering, we find `ReductionOp`s having
`serialGridReductionRequested() == true`, and we place syncs around
their outer loop. At this point, we also analyze that outer loop to
determine if there are any conflicting expressions, such as conflicting
grid reductions.
3. Later in lowering, during the indexing pass, we convert the
`ReductionOp` to a `GridReduction` that has its serial buffer set. The
serial buffer is a temporary `TensorIndex` indexed like a global memory
version of the reduction output tensor.

The generated kernel looks like this:
```c++
  // Allocate global tensor T4
  grid_sync::blockSerializeWait<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll
  for(nvfuser_index_t i13 = 0; i13 < 4LL; ++i13) {
    nvfuser_index_t i14;
    i14 = 8LL * i13;
    nvfuser_index_t i15;
    i15 = 2048LL * i13;
    nvfuser_index_t i16;
    i16 = i4 + i15;
    nvfuser_index_t i17;
    i17 = -i15;
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 8LL; ++i18) {
      nvfuser_index_t i19;
      i19 = 256LL * (i18 + nvfuser_zero);
      nvfuser_index_t i20;
      i20 = i16 + i19;
      float T3[1LL];
      T3[0LL] = 0.000000000e+00f;
      // Allocate global tensor T5
      reduction::serialReductionStep(
        T3[0LL],
        T2[(i14 + i18)],
        0.000000000e+00f,
        T5[i20],
        [](float &a, float b) { a = a + b; },
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,
        true);
      if ((b6 && (i5 < (i17 - i19)))) {
        T1[i20]
           = T3[0LL];
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  grid_sync::blockSerializeRelease<false, false, true>(&T4[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
```
Notice that the index `i20` now matches between the output `T1` and the
intermediate `T5`. In another PR, I will attempt to extend our buffer
reuse machinery to recognize this as a chance to use `T1` in place of
`T5` (i.e. inner aliasing, in-place reduction).

Also notice that I have not yet hoisted the sync flags index, or the
`first_block` and `last_block` predicates.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants