Skip to content

Place epilogue in unswizzling loop#1770

Closed
jacobhinkle wants to merge 3 commits intomainfrom
epilogue_before_smem
Closed

Place epilogue in unswizzling loop#1770
jacobhinkle wants to merge 3 commits intomainfrom
epilogue_before_smem

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Feb 15, 2024

The matmul scheduler has the ability to use smem to "unswizzle" the result of the matrix multiplication, which is held in registers. If enabled, this means we insert a loop between the main loop and the epilogue loop in which we populate a shared memory buffer that holds values from the register result. Then the epilogue loop can traverse the output sequentially, loading values from the new smem buffer and computing the epilogue before a coalesced store to gmem.

When we have an HSH fusion, the output tile will be half the size of the float-valued register buffer. The idea behind this PR is simply to go ahead and compute the epilogue in the second loop so that we have a half-precision smem buffer.

Before:

  // ... main loop ...
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i163 = 0; i163 < 4; ++i163) {
    nvfuser_index_t i164;
    i164 = 32 * i163;
    nvfuser_index_t i165;
    i165 = i49 + (2048LL * i163);
    #pragma unroll
    for(nvfuser_index_t i166 = 0; i166 < 4; ++i166) {
      nvfuser_index_t i167;
      i167 = i164 + (8 * i166);
      nvfuser_index_t i168;
      i168 = i12 + (2 * i166);
      #pragma unroll
      for(nvfuser_index_t i169 = 0; i169 < 2; ++i169) {
        nvfuser_index_t i170;
        i170 = i167 + (4LL * i169);
        nvfuser_index_t i171;
        i171 = i168 + i169;
        nvfuser_index_t i172;
        i172 = (i165 + (32LL * (i171 / 4))) + (8LL * (i50 ^ (i171 % 4)));
        #pragma unroll
        for(nvfuser_index_t i173 = 0; i173 < 2; ++i173) {
          loadGeneric<float, 2>( &T7[(i172 + (1024LL * i173))],  &T4[(i170 + (2LL * i173))]);
        }
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i174 = 0; i174 < 16; ++i174) {
    nvfuser_index_t i175;
    i175 = i51 + (1024 * i174);    Array<__half, 8, 8> T6;
    #pragma unroll
    for(nvfuser_index_t i176 = 0; i176 < 8; ++i176) {
      nvfuser_index_t i177;
      i177 = i52 + i176;
      nvfuser_index_t i178;
      i178 = i177 % 128;
      nvfuser_index_t i179;
      i179 = i178 / 8;
      nvfuser_index_t i180;
      i180 = i177 / 128;
      T6[i176]
         = __float2half(T7[((((i175 + (128LL * i180)) + (32LL * (i179 / 4))) + (i178 % 8)) + (8LL * ((i179 % 4) ^ ((i53 + i180) % 4))))]);
    }
    if ((b68 && (i69 < (-(8 * i174))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T5[(i57 + (i58 * i174))], &T6[0]);
    }
  }
}

After:

  // ... main loop ...
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i157 = 0; i157 < 4; ++i157) {
    nvfuser_index_t i158;
    i158 = 32 * i157;
    nvfuser_index_t i159;
    i159 = i46 + (2048LL * i157);
    #pragma unroll
    for(nvfuser_index_t i160 = 0; i160 < 4; ++i160) {
      nvfuser_index_t i161;
      i161 = i158 + (8 * i160);
      nvfuser_index_t i162;
      i162 = i159 + (16 * i160);
      #pragma unroll
      for(nvfuser_index_t i163 = 0; i163 < 2; ++i163) {
        nvfuser_index_t i164;
        i164 = i161 + (4LL * i163);
        nvfuser_index_t i165;
        i165 = i162 + (8 * i163);
        #pragma unroll
        for(nvfuser_index_t i166 = 0; i166 < 2; ++i166) {
          nvfuser_index_t i167;
          i167 = i164 + (2LL * i166);
          Array<__half, 2, 2> T6;
          #pragma unroll
          for(nvfuser_index_t i168 = 0; i168 < 2; ++i168) {
            T6[i168]
               = __float2half(T4[(i167 + i168)]);
          }
          loadGeneric<__half, 2>( &T7[(i165 + (1024LL * i166))],  &T6[0]);
        }
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i169 = 0; i169 < 16; ++i169) {
    if ((b62 && (i63 < (-(8 * i169))))) {
      loadGeneric<__half, 8>( &T5[(i51 + (i52 * i169))],  &T7[(i47 + (1024 * i169))]);
    }
  }
}

Status

As you can see above, currently the smem buffer is not swizzled so we hit bank conflicts. However, removing the bank conflict check shows that performance is good already. A non-split-K HSH TN benchmark of size 16384,16384,128 showed the following perf on this branch using a 3090 TI:

  • no smem: 1506 us
  • smem epilogue (TOT): 1992 us
  • smem epilogue (this branch): 1386 us

This is not due to occupancy change as we use 2 blocks per SM in each case.

This popped up in internal benchmarking. It doesn't seem to happen on
TOT so there's something about this branch that's breaking.
This fixes the problematic case while avoiding bank conflicts for HSS
fusions. HSH fusions still have bank conflicts, but that might be
acceptable for now as it seems to still provide a speedup over the
status quo.
@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Feb 21, 2024

Summary

I now think the root of the issue with HSH fusions is the swizzling of the epilogue smem buffer. This limitation is actually referred to in a comment in swizzleSharedMemory:

// Only tested for (1) ldmatrix access with sizeof(T) == 16bit (i.e.
// half/bfloat16) and (2) epilogue general access with sizeof(T) == 32bit
// (i.e. float)

I think the best solution is for me to understand and extend this function to better support the 16-bit epilogue case.

More detail

Here is FusionAmpereMatmulSmemEpilogue_CUDA:

    // main loop
  }
  #pragma unroll
  for(nvfuser_index_t i118 = 0; i118 < 2; ++i118) {
    nvfuser_index_t i119;
    i119 = 32 * i118;
    nvfuser_index_t i120;
    i120 = i57 + (2048LL * i118);
    #pragma unroll
    for(nvfuser_index_t i121 = 0; i121 < 8; ++i121) {
      nvfuser_index_t i122;
      i122 = i119 + (4 * i121);
      nvfuser_index_t i123;
      i123 = i11 + i121;
      nvfuser_index_t i124;
      i124 = (i120 + (32LL * (i123 / 4))) + (8LL * (i58 ^ (i123 % 4)));
      #pragma unroll
      for(nvfuser_index_t i125 = 0; i125 < 2; ++i125) {
        loadGeneric<float, 2>( &T7[(i124 + (1024LL * i125))],  &T6[(i122 + (2LL * i125))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i126 = 0; i126 < 16; ++i126) {
    if ((b73 && (i74 < (-(4LL * i126))))) {
      loadGeneric<float, 4>( &T3[(i64 + (i65 * i126))],  &T7[(i61 + (512LL * i126))]);
    }
  }
}

Notice that swizzle indexing is only present in the middle loop. The global stores are fully vectorized at width 4 with no expensive indexing ops, but note that they are strided by 512. I think that stride means T7 does not have enough contiguity that this write could be vectorized at width 8 even if the hardware supported it.

Now have a look at FusionAmpereMatmulSmemEpilogueCast_CUDA:

    // main loop
  }                                                                                                                              
  #pragma unroll                                                                                                                 
  for(nvfuser_index_t i123 = 0; i123 < 2; ++i123) {                                                                              
    nvfuser_index_t i124;                                                                                                        
    i124 = 32 * i123;                                                                                                            
    nvfuser_index_t i125;                                                                                                        
    i125 = i56 + (2048LL * i123);                                                                                                
    #pragma unroll                                                                                                               
    for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) {                                                                            
      nvfuser_index_t i127;
      i127 = i124 + (4 * i126);
      nvfuser_index_t i128;
      i128 = i11 + i126;
      nvfuser_index_t i129;
      i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4)));
      #pragma unroll
      for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) {
        loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))],  &T3[(i127 + (2LL * i130))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i131 = 0; i131 < 8; ++i131) {
    nvfuser_index_t i132;
    i132 = i58 + (1024 * i131);
    Array<__half, 8, 8> T7;
    #pragma unroll
    for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) {
      nvfuser_index_t i134;
      i134 = i59 + i133;
      nvfuser_index_t i135;
      i135 = i134 % 128;
      nvfuser_index_t i136;
      i136 = i135 / 8;
      nvfuser_index_t i137;
      i137 = i134 / 128;
      T7[i133]
         = __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i60 + i137) % 4))))]);
    }
    if ((b73 && (i74 < (-(8 * i131))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i63 + (i64 * i131))], &T7[0]);
    }
  }
}

The global write is fully vectorized at width 8 now, but preparing the register T7 requires some fancy indexing. I think that's because of how T7 is written.

To fix this, we could extend swizzling to support 16-bit outputs; we may also want to keep the current changes in this branch as they reduce the size of the smem buffer by computing the epilogue before writing to smem, but we should see how that effects vectorization of the final loop. We could instead change the schedule for the output tensor in order to vectorize the writes at width 4, but I think that would lead to at most 64-byte coalescing instead of the full 128 bytes.

cc @zasdfgbnm @liqiangxl

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented Feb 21, 2024

I would expect a relu epilogue to not display any real problems here since it is still an HSS matmul. However, this is FusionAmpereMatmulSmemEpilogueRelu_CUDA on main:

    // main loop
  }
  #pragma unroll
  for(nvfuser_index_t i125 = 0; i125 < 2; ++i125) {
    nvfuser_index_t i126;
    i126 = 32 * i125;
    nvfuser_index_t i127;
    i127 = i57 + (2048LL * i125);
    #pragma unroll
    for(nvfuser_index_t i128 = 0; i128 < 8; ++i128) {
      nvfuser_index_t i129;
      i129 = i126 + (4 * i128);
      nvfuser_index_t i130;
      i130 = i11 + i128;
      nvfuser_index_t i131;
      i131 = (i127 + (32LL * (i130 / 4))) + (8LL * (i58 ^ (i130 % 4)));
      #pragma unroll
      for(nvfuser_index_t i132 = 0; i132 < 2; ++i132) {
        loadGeneric<float, 2>( &T8[(i131 + (1024LL * i132))],  &T3[(i129 + (2LL * i132))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i133 = 0; i133 < 16; ++i133) {
    nvfuser_index_t i134;
    i134 = i59 + (512LL * i133);
    Array<float, 4, 4> T7;
    #pragma unroll
    for(nvfuser_index_t i135 = 0; i135 < 4; ++i135) {
      nvfuser_index_t i136;
      i136 = i60 + i135;
      nvfuser_index_t i137;
      i137 = i136 % 128;
      nvfuser_index_t i138;
      i138 = i137 / 8;
      nvfuser_index_t i139;
      i139 = i136 / 128;
      T7[i135]
         = relu(T8[((((i134 + (128LL * i139)) + (32LL * (i138 / 4))) + (i137 % 8)) + (8LL * ((i138 % 4) ^ ((i62 + i139) % 4))))]);
    }
    if ((b75 && (i76 < (-(4LL * i133))))) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T4[(i65 + (i66 * i133))], &T7[0]);
    }
  }
}

versus with this branch:

    // main loop
  }
  #pragma unroll
  for(nvfuser_index_t i165 = 0; i165 < 2; ++i165) {
    nvfuser_index_t i166;
    i166 = 32 * i165;
    nvfuser_index_t i167;
    i167 = i57 + (2048LL * i165);
    #pragma unroll
    for(nvfuser_index_t i168 = 0; i168 < 8; ++i168) {
      nvfuser_index_t i169;
      i169 = i166 + (4 * i168);
      nvfuser_index_t i170;
      i170 = i11 + i168;
      nvfuser_index_t i171;
      i171 = (i167 + (32LL * (i170 / 4))) + (8LL * (i58 ^ (i170 % 4)));
      #pragma unroll
      for(nvfuser_index_t i172 = 0; i172 < 2; ++i172) {
        nvfuser_index_t i173;
        i173 = i169 + (2LL * i172);
        Array<float, 2, 2> T7;
        #pragma unroll
        for(nvfuser_index_t i174 = 0; i174 < 2; ++i174) {
          T7[i174]
             = relu(T5[(i173 + i174)]);
        }
        loadGeneric<float, 2>( &T8[(i171 + (1024LL * i172))],  &T7[0]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i175 = 0; i175 < 16; ++i175) {
    if ((b74 && (i75 < (-(4LL * i175))))) {
      loadGeneric<float, 4>( &T6[(i64 + (i65 * i175))],  &T8[(i61 + (512LL * i175))]);
    }
  }
}

This looks better to me. I'm not sure why the original needs complicated indexing when reading T8 in the last loop.

jacobhinkle added a commit that referenced this pull request Mar 22, 2024
This came up when working on #1770. In a private conversation,
@zasdfgbnm noticed wisely that the problematic indexing is really a
failure of expression simplification; if we could fully simplify the
swizzling expression it could be entirely hoisted and we would be left
with a nice clean linear index for the smem buffer in the epilogue loop.

This is `NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDA` on `main`:
```c++
    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) {
    nvfuser_index_t i124;
    i124 = 32 * i123;
    nvfuser_index_t i125;
    i125 = i56 + (2048LL * i123);
    #pragma unroll
    for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) {
      nvfuser_index_t i127;
      i127 = i124 + (4 * i126);
      nvfuser_index_t i128;
      i128 = i11 + i126;
      nvfuser_index_t i129;
      i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4)));
      #pragma unroll
      for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) {
        loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))],  &T3[(i127 + (2LL * i130))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i131 = 0; i131 < 16; ++i131) {
    nvfuser_index_t i132;
    i132 = i58 + (1024 * i131);
    Array<__half, 8, 8> T7;
    #pragma unroll
    for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) {
      nvfuser_index_t i134;
      i134 = i59 + i133;
      nvfuser_index_t i135;
      i135 = i134 % 128;
      nvfuser_index_t i136;
      i136 = i135 / 8;
      nvfuser_index_t i137;
      i137 = i134 / 128;
      T7[i133]
         = __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i31 + i137) % 4))))]);
    }
    if ((b72 && (i73 < (-(8 * i131))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i62 + (i63 * i131))], &T7[0]);
    }
  }
}
```
This PR:
```c++
    // main loop
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i114 = 0; i114 < 4; ++i114) {
    nvfuser_index_t i115;
    i115 = 32 * i114;
    nvfuser_index_t i116;
    i116 = i50 + (2048LL * i114);
    #pragma unroll
    for(nvfuser_index_t i117 = 0; i117 < 8; ++i117) {
      nvfuser_index_t i118;
      i118 = i115 + (4 * i117);
      nvfuser_index_t i119;
      i119 = i12 + i117;
      nvfuser_index_t i120;
      i120 = (i116 + (32LL * (i119 / 4))) + (8LL * (i51 ^ (i119 % 4)));
      #pragma unroll
      for(nvfuser_index_t i121 = 0; i121 < 2; ++i121) {
        loadGeneric<float, 2>( &T7[(i120 + (1024LL * i121))],  &T2[(i118 + (2LL * i121))]);
      }
    }
  }
  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i122 = 0; i122 < 16; ++i122) {
    nvfuser_index_t i123;
    i123 = i53 + (1024 * i122);
    Array<__half, 8, 8> T6;
    #pragma unroll
    for(nvfuser_index_t i124 = 0; i124 < 8; ++i124) {
      T6[i124]
         = __float2half(T7[(i123 + i124)]);
    }
    if ((b67 && (i68 < (-(8 * i122))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T3[(i56 + (i57 * i122))], &T6[0]);
    }
  }
}
```
~~If we can also get `i134 % 8` simplified to `i134` and `i134 / 8`
simplified to 0 then this should give a nice and efficient last loop.~~
This is done

~~Currently this PR is super slow (e.g. 101 s vs 8 s on main in debug
mode) due to the added recursion. Memoizing past results would be
beneficial, but that's a topic for another PR.~~ This PR is no longer
slow, thanks to limited recursion depth and #1972.

Fixes #1828
@jacobhinkle
Copy link
Collaborator Author

This slowdown was addressed by #1827

@zasdfgbnm zasdfgbnm deleted the epilogue_before_smem branch March 22, 2024 18:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant