-
Notifications
You must be signed in to change notification settings - Fork 79
Closed
Labels
Description
I ran NVFUSER_DUMP=scheduler_params,cuda_kernel NVFUSER_DISABLE=aten_expr_eval NVFUSER_ENABLE=fuse_matmul pytest benchmarks/python/test_matmul.py -k 'torch.float16-16224_1936_56_TT' -vs and inspected the kernel and I was surprised to see an additional loop just before the output store:
__syncthreads();
#pragma unroll
for(nvfuser_index_t i162 = 0; i162 < 16; ++i162) {
nvfuser_index_t i163;
i163 = i51 + (1024 * i162);
Array<__half, 8, 8> T5;
#pragma unroll
for(nvfuser_index_t i164 = 0; i164 < 8; ++i164) {
nvfuser_index_t i165;
i165 = i163 + i164;
#pragma unroll 1
for(nvfuser_index_t i166 = 0; i166 < T0.logical_size[1LL]; ++i166) {
T5[i164]
= __float2half(T11[i165]);
}
}
if ((b54 && (i63 < (-(16 * i162))))) {
loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T2[(i53 + (i11 * i162))], &T5[0]);
}
}Here T11 is the shared memory buffer which was filled with the unswizzled output in a previous loop. This code is "correct" but the i166 loop is large and unnecessary; we could pull the contents out and we'd have the kernel I expected to see.
Scheduler params used:
===== Matmul Parameters ========
MMA macro: Ampere_16_16_16
DoubleBufferOptions:
double_buffer_smem_write: true
double_buffer_smem_read: true
smem_double_buffer_stage: 3
SupportedVectorization:
a: 8
b: 8
epilogue: 8
MatMulTileOptions: instruction tile [16, 16, 16], warp tile [64, 64, 32], CTA tile [256, 64, 32]
Rotate ldmatrix out of main loop: true
Async global mem load: true
Indexing mode: int32_t
Tile rastrization order: row-major
Grid swizzle factor: 1
Use shared memory epilogue: 1
Promote re-use of prologue shared memory: 1
Split-K factor: 1
====================================
Relevant part of the schedule:
T11_s[ iblockIdx.x39{( ceilDiv(i0, 256) )}, iblockIdx.y41{( ceilDiv(i4, 64) )}, ithreadIdx.z270{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y272{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS274{( ceilDiv(64, 16) )}, iS276{( ceilDiv(64, 16) )}, ithreadIdx.x287{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iS283{( ceilDiv(16, 8) )}, iS281{2}, iV286{2} ] ca_pos( 2 ) produce_pos( 2 )
= Set( T9_l[ iblockIdx.x31{( ceilDiv(i0, 256) )}, iblockIdx.y33{( ceilDiv(i4, 64) )}, rS35{( ceilDiv(i1, 32) )}, rS104{( ceilDiv(32, 16) )}, ithreadIdx.z96{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y98{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS100{( ceilDiv(64, 16) )}, iS102{( ceilDiv(64, 16) )}, ithreadIdx.x245{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iMMA240{( ceilDiv(16, 8) )}, iMMA239{2}, iMMA243{2}, rMMA248{( ceilDiv(( ceilDiv(16, 2) ), 4) )}, rMMA249{4}, rMMA247{2} ] ca_pos( 2 ) produce_pos( 6 ), cache_op=Streaming )
T5_l[ iS79{( ceilDiv(i0, 256) )}, iS81{( ceilDiv(i4, 64) )}, iS304{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, iS305{4}, iS303{1}, iS301{32}, iS299{8}, rS6{i1} ] ca_pos( 6 ) produce_pos( 2 )
= __float2half(T11_s[ iblockIdx.x39{( ceilDiv(i0, 256) )}, iblockIdx.y41{( ceilDiv(i4, 64) )}, ithreadIdx.z270{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y272{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS274{( ceilDiv(64, 16) )}, iS276{( ceilDiv(64, 16) )}, ithreadIdx.x287{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iS283{( ceilDiv(16, 8) )}, iS281{2}, iV286{2} ] ca_pos( 2 ) produce_pos( 2 ));
T2_g[ iblockIdx.x83{( ceilDiv(i0, 256) )}, iblockIdx.y85{( ceilDiv(i4, 64) )}, iS295{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, ithreadIdx.z296{4}, ithreadIdx.y294{1}, ithreadIdx.x292{32}, iV290{8} ] ca_pos( 6 ) produce_pos( 6 )
= Set( T5_l[ iS79{( ceilDiv(i0, 256) )}, iS81{( ceilDiv(i4, 64) )}, iS304{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, iS305{4}, iS303{1}, iS301{32}, iS299{8}, rS6{i1} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
}
I can't yet see why we have this i166 loop.
Reactions are currently unavailable