Conversation
This popped up in internal benchmarking. It doesn't seem to happen on TOT so there's something about this branch that's breaking.
This fixes the problematic case while avoiding bank conflicts for HSS fusions. HSH fusions still have bank conflicts, but that might be acceptable for now as it seems to still provide a speedup over the status quo.
SummaryI now think the root of the issue with HSH fusions is the swizzling of the epilogue smem buffer. This limitation is actually referred to in a comment in Fuser/csrc/scheduler/matmul.cpp Lines 153 to 155 in ca29b31 I think the best solution is for me to understand and extend this function to better support the 16-bit epilogue case. More detailHere is // main loop
}
#pragma unroll
for(nvfuser_index_t i118 = 0; i118 < 2; ++i118) {
nvfuser_index_t i119;
i119 = 32 * i118;
nvfuser_index_t i120;
i120 = i57 + (2048LL * i118);
#pragma unroll
for(nvfuser_index_t i121 = 0; i121 < 8; ++i121) {
nvfuser_index_t i122;
i122 = i119 + (4 * i121);
nvfuser_index_t i123;
i123 = i11 + i121;
nvfuser_index_t i124;
i124 = (i120 + (32LL * (i123 / 4))) + (8LL * (i58 ^ (i123 % 4)));
#pragma unroll
for(nvfuser_index_t i125 = 0; i125 < 2; ++i125) {
loadGeneric<float, 2>( &T7[(i124 + (1024LL * i125))], &T6[(i122 + (2LL * i125))]);
}
}
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i126 = 0; i126 < 16; ++i126) {
if ((b73 && (i74 < (-(4LL * i126))))) {
loadGeneric<float, 4>( &T3[(i64 + (i65 * i126))], &T7[(i61 + (512LL * i126))]);
}
}
}Notice that swizzle indexing is only present in the middle loop. The global stores are fully vectorized at width 4 with no expensive indexing ops, but note that they are strided by 512. I think that stride means Now have a look at // main loop
}
#pragma unroll
for(nvfuser_index_t i123 = 0; i123 < 2; ++i123) {
nvfuser_index_t i124;
i124 = 32 * i123;
nvfuser_index_t i125;
i125 = i56 + (2048LL * i123);
#pragma unroll
for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) {
nvfuser_index_t i127;
i127 = i124 + (4 * i126);
nvfuser_index_t i128;
i128 = i11 + i126;
nvfuser_index_t i129;
i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4)));
#pragma unroll
for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) {
loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))], &T3[(i127 + (2LL * i130))]);
}
}
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i131 = 0; i131 < 8; ++i131) {
nvfuser_index_t i132;
i132 = i58 + (1024 * i131);
Array<__half, 8, 8> T7;
#pragma unroll
for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) {
nvfuser_index_t i134;
i134 = i59 + i133;
nvfuser_index_t i135;
i135 = i134 % 128;
nvfuser_index_t i136;
i136 = i135 / 8;
nvfuser_index_t i137;
i137 = i134 / 128;
T7[i133]
= __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i60 + i137) % 4))))]);
}
if ((b73 && (i74 < (-(8 * i131))))) {
loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i63 + (i64 * i131))], &T7[0]);
}
}
}The global write is fully vectorized at width 8 now, but preparing the register To fix this, we could extend swizzling to support 16-bit outputs; we may also want to keep the current changes in this branch as they reduce the size of the smem buffer by computing the epilogue before writing to smem, but we should see how that effects vectorization of the final loop. We could instead change the schedule for the output tensor in order to vectorize the writes at width 4, but I think that would lead to at most 64-byte coalescing instead of the full 128 bytes. |
|
I would expect a relu epilogue to not display any real problems here since it is still an HSS matmul. However, this is // main loop
}
#pragma unroll
for(nvfuser_index_t i125 = 0; i125 < 2; ++i125) {
nvfuser_index_t i126;
i126 = 32 * i125;
nvfuser_index_t i127;
i127 = i57 + (2048LL * i125);
#pragma unroll
for(nvfuser_index_t i128 = 0; i128 < 8; ++i128) {
nvfuser_index_t i129;
i129 = i126 + (4 * i128);
nvfuser_index_t i130;
i130 = i11 + i128;
nvfuser_index_t i131;
i131 = (i127 + (32LL * (i130 / 4))) + (8LL * (i58 ^ (i130 % 4)));
#pragma unroll
for(nvfuser_index_t i132 = 0; i132 < 2; ++i132) {
loadGeneric<float, 2>( &T8[(i131 + (1024LL * i132))], &T3[(i129 + (2LL * i132))]);
}
}
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i133 = 0; i133 < 16; ++i133) {
nvfuser_index_t i134;
i134 = i59 + (512LL * i133);
Array<float, 4, 4> T7;
#pragma unroll
for(nvfuser_index_t i135 = 0; i135 < 4; ++i135) {
nvfuser_index_t i136;
i136 = i60 + i135;
nvfuser_index_t i137;
i137 = i136 % 128;
nvfuser_index_t i138;
i138 = i137 / 8;
nvfuser_index_t i139;
i139 = i136 / 128;
T7[i135]
= relu(T8[((((i134 + (128LL * i139)) + (32LL * (i138 / 4))) + (i137 % 8)) + (8LL * ((i138 % 4) ^ ((i62 + i139) % 4))))]);
}
if ((b75 && (i76 < (-(4LL * i133))))) {
loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T4[(i65 + (i66 * i133))], &T7[0]);
}
}
}versus with this branch: // main loop
}
#pragma unroll
for(nvfuser_index_t i165 = 0; i165 < 2; ++i165) {
nvfuser_index_t i166;
i166 = 32 * i165;
nvfuser_index_t i167;
i167 = i57 + (2048LL * i165);
#pragma unroll
for(nvfuser_index_t i168 = 0; i168 < 8; ++i168) {
nvfuser_index_t i169;
i169 = i166 + (4 * i168);
nvfuser_index_t i170;
i170 = i11 + i168;
nvfuser_index_t i171;
i171 = (i167 + (32LL * (i170 / 4))) + (8LL * (i58 ^ (i170 % 4)));
#pragma unroll
for(nvfuser_index_t i172 = 0; i172 < 2; ++i172) {
nvfuser_index_t i173;
i173 = i169 + (2LL * i172);
Array<float, 2, 2> T7;
#pragma unroll
for(nvfuser_index_t i174 = 0; i174 < 2; ++i174) {
T7[i174]
= relu(T5[(i173 + i174)]);
}
loadGeneric<float, 2>( &T8[(i171 + (1024LL * i172))], &T7[0]);
}
}
}
__syncthreads();
#pragma unroll
for(nvfuser_index_t i175 = 0; i175 < 16; ++i175) {
if ((b74 && (i75 < (-(4LL * i175))))) {
loadGeneric<float, 4>( &T6[(i64 + (i65 * i175))], &T8[(i61 + (512LL * i175))]);
}
}
}This looks better to me. I'm not sure why the original needs complicated indexing when reading |
This came up when working on #1770. In a private conversation, @zasdfgbnm noticed wisely that the problematic indexing is really a failure of expression simplification; if we could fully simplify the swizzling expression it could be entirely hoisted and we would be left with a nice clean linear index for the smem buffer in the epilogue loop. This is `NVFuserTest.FusionAmpereMatmulSmemEpilogueCast_CUDA` on `main`: ```c++ // main loop } __syncthreads(); #pragma unroll for(nvfuser_index_t i123 = 0; i123 < 4; ++i123) { nvfuser_index_t i124; i124 = 32 * i123; nvfuser_index_t i125; i125 = i56 + (2048LL * i123); #pragma unroll for(nvfuser_index_t i126 = 0; i126 < 8; ++i126) { nvfuser_index_t i127; i127 = i124 + (4 * i126); nvfuser_index_t i128; i128 = i11 + i126; nvfuser_index_t i129; i129 = (i125 + (32LL * (i128 / 4))) + (8LL * (i57 ^ (i128 % 4))); #pragma unroll for(nvfuser_index_t i130 = 0; i130 < 2; ++i130) { loadGeneric<float, 2>( &T8[(i129 + (1024LL * i130))], &T3[(i127 + (2LL * i130))]); } } } __syncthreads(); #pragma unroll for(nvfuser_index_t i131 = 0; i131 < 16; ++i131) { nvfuser_index_t i132; i132 = i58 + (1024 * i131); Array<__half, 8, 8> T7; #pragma unroll for(nvfuser_index_t i133 = 0; i133 < 8; ++i133) { nvfuser_index_t i134; i134 = i59 + i133; nvfuser_index_t i135; i135 = i134 % 128; nvfuser_index_t i136; i136 = i135 / 8; nvfuser_index_t i137; i137 = i134 / 128; T7[i133] = __float2half(T8[((((i132 + (128LL * i137)) + (32LL * (i136 / 4))) + (i135 % 8)) + (8LL * ((i136 % 4) ^ ((i31 + i137) % 4))))]); } if ((b72 && (i73 < (-(8 * i131))))) { loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T4[(i62 + (i63 * i131))], &T7[0]); } } } ``` This PR: ```c++ // main loop } __syncthreads(); #pragma unroll for(nvfuser_index_t i114 = 0; i114 < 4; ++i114) { nvfuser_index_t i115; i115 = 32 * i114; nvfuser_index_t i116; i116 = i50 + (2048LL * i114); #pragma unroll for(nvfuser_index_t i117 = 0; i117 < 8; ++i117) { nvfuser_index_t i118; i118 = i115 + (4 * i117); nvfuser_index_t i119; i119 = i12 + i117; nvfuser_index_t i120; i120 = (i116 + (32LL * (i119 / 4))) + (8LL * (i51 ^ (i119 % 4))); #pragma unroll for(nvfuser_index_t i121 = 0; i121 < 2; ++i121) { loadGeneric<float, 2>( &T7[(i120 + (1024LL * i121))], &T2[(i118 + (2LL * i121))]); } } } __syncthreads(); #pragma unroll for(nvfuser_index_t i122 = 0; i122 < 16; ++i122) { nvfuser_index_t i123; i123 = i53 + (1024 * i122); Array<__half, 8, 8> T6; #pragma unroll for(nvfuser_index_t i124 = 0; i124 < 8; ++i124) { T6[i124] = __float2half(T7[(i123 + i124)]); } if ((b67 && (i68 < (-(8 * i122))))) { loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T3[(i56 + (i57 * i122))], &T6[0]); } } } ``` ~~If we can also get `i134 % 8` simplified to `i134` and `i134 / 8` simplified to 0 then this should give a nice and efficient last loop.~~ This is done ~~Currently this PR is super slow (e.g. 101 s vs 8 s on main in debug mode) due to the added recursion. Memoizing past results would be beneficial, but that's a topic for another PR.~~ This PR is no longer slow, thanks to limited recursion depth and #1972. Fixes #1828
|
This slowdown was addressed by #1827 |
The matmul scheduler has the ability to use smem to "unswizzle" the result of the matrix multiplication, which is held in registers. If enabled, this means we insert a loop between the main loop and the epilogue loop in which we populate a shared memory buffer that holds values from the register result. Then the epilogue loop can traverse the output sequentially, loading values from the new smem buffer and computing the epilogue before a coalesced store to gmem.
When we have an HSH fusion, the output tile will be half the size of the float-valued register buffer. The idea behind this PR is simply to go ahead and compute the epilogue in the second loop so that we have a half-precision smem buffer.
Before:
After:
Status
As you can see above, currently the smem buffer is not swizzled so we hit bank conflicts. However, removing the bank conflict check shows that performance is good already. A non-split-K HSH TN benchmark of size 16384,16384,128 showed the following perf on this branch using a 3090 TI:
This is not due to occupancy change as we use 2 blocks per SM in each case.