Skip to content

Confusing extra epilogue casting loop between lds and stg #2404

@jacobhinkle

Description

@jacobhinkle

I ran NVFUSER_DUMP=scheduler_params,cuda_kernel NVFUSER_DISABLE=aten_expr_eval NVFUSER_ENABLE=fuse_matmul pytest benchmarks/python/test_matmul.py -k 'torch.float16-16224_1936_56_TT' -vs and inspected the kernel and I was surprised to see an additional loop just before the output store:

  __syncthreads();
  #pragma unroll
  for(nvfuser_index_t i162 = 0; i162 < 16; ++i162) {
    nvfuser_index_t i163;
    i163 = i51 + (1024 * i162);
    Array<__half, 8, 8> T5;
    #pragma unroll
    for(nvfuser_index_t i164 = 0; i164 < 8; ++i164) {
      nvfuser_index_t i165;
      i165 = i163 + i164;
      #pragma unroll 1
      for(nvfuser_index_t i166 = 0; i166 < T0.logical_size[1LL]; ++i166) {
        T5[i164]
           = __float2half(T11[i165]);
      }
    }
    if ((b54 && (i63 < (-(16 * i162))))) {
      loadLocalToGlobal<__half, /*vec_size=*/8, /*is_volatile=*/false>( &T2[(i53 + (i11 * i162))], &T5[0]);
    }
  }

Here T11 is the shared memory buffer which was filled with the unswizzled output in a previous loop. This code is "correct" but the i166 loop is large and unnecessary; we could pull the contents out and we'd have the kernel I expected to see.

Scheduler params used:

===== Matmul Parameters ========                                                                        
                                                                                                        
MMA macro: Ampere_16_16_16                                                                              
DoubleBufferOptions:                                                                                    
  double_buffer_smem_write: true                                                                        
  double_buffer_smem_read: true                                                                         
  smem_double_buffer_stage: 3                                                                           
SupportedVectorization:                                                                                 
  a: 8                                                                                                  
  b: 8                                                                                                  
  epilogue: 8                                                                                           
MatMulTileOptions: instruction tile [16, 16, 16], warp tile [64, 64, 32], CTA tile [256, 64, 32]        
Rotate ldmatrix out of main loop: true                                                                  
Async global mem load: true                                                                             
Indexing mode: int32_t                                                                                  
Tile rastrization order: row-major                                                                      
Grid swizzle factor: 1                                                                                  
Use shared memory epilogue: 1                                                                           
Promote re-use of prologue shared memory: 1                                                             
Split-K factor: 1                                                                                       
====================================

Relevant part of the schedule:

T11_s[ iblockIdx.x39{( ceilDiv(i0, 256) )}, iblockIdx.y41{( ceilDiv(i4, 64) )}, ithreadIdx.z270{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y272{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS274{( ceilDiv(64, 16) )}, iS276{( ceilDiv(64, 16) )}, ithreadIdx.x287{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iS283{( ceilDiv(16, 8) )}, iS281{2}, iV286{2} ] ca_pos( 2 ) produce_pos( 2 )
   = Set( T9_l[ iblockIdx.x31{( ceilDiv(i0, 256) )}, iblockIdx.y33{( ceilDiv(i4, 64) )}, rS35{( ceilDiv(i1, 32) )}, rS104{( ceilDiv(32, 16) )}, ithreadIdx.z96{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y98{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS100{( ceilDiv(64, 16) )}, iS102{( ceilDiv(64, 16) )}, ithreadIdx.x245{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iMMA240{( ceilDiv(16, 8) )}, iMMA239{2}, iMMA243{2}, rMMA248{( ceilDiv(( ceilDiv(16, 2) ), 4) )}, rMMA249{4}, rMMA247{2} ] ca_pos( 2 ) produce_pos( 6 ), cache_op=Streaming )
T5_l[ iS79{( ceilDiv(i0, 256) )}, iS81{( ceilDiv(i4, 64) )}, iS304{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, iS305{4}, iS303{1}, iS301{32}, iS299{8}, rS6{i1} ] ca_pos( 6 ) produce_pos( 2 )
   = __float2half(T11_s[ iblockIdx.x39{( ceilDiv(i0, 256) )}, iblockIdx.y41{( ceilDiv(i4, 64) )}, ithreadIdx.z270{( ceilDiv(( ( ceilDiv(256, 4) ) * 4 ), 64) )}, ithreadIdx.y272{( ceilDiv(( ( ( ceilDiv(( ceilDiv(64, 8) ), 4) ) * 4 ) * 8 ), 64) )}, iS274{( ceilDiv(64, 16) )}, iS276{( ceilDiv(64, 16) )}, ithreadIdx.x287{( ( ( ceilDiv(( ceilDiv(16, 8) ), 2) ) * 8 ) * ( ceilDiv(8, 2) ) )}, iS283{( ceilDiv(16, 8) )}, iS281{2}, iV286{2} ] ca_pos( 2 ) produce_pos( 2 ));
T2_g[ iblockIdx.x83{( ceilDiv(i0, 256) )}, iblockIdx.y85{( ceilDiv(i4, 64) )}, iS295{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, ithreadIdx.z296{4}, ithreadIdx.y294{1}, ithreadIdx.x292{32}, iV290{8} ] ca_pos( 6 ) produce_pos( 6 )
   = Set( T5_l[ iS79{( ceilDiv(i0, 256) )}, iS81{( ceilDiv(i4, 64) )}, iS304{( ceilDiv(( ceilDiv(( ceilDiv(( ceilDiv(( 256 * 64 ), 8) ), 32) ), 1) ), 4) )}, iS305{4}, iS303{1}, iS301{32}, iS299{8}, rS6{i1} ] ca_pos( 6 ) produce_pos( 2 ), cache_op=Streaming )
}

I can't yet see why we have this i166 loop.

Metadata

Metadata

Assignees

No one assigned

    Labels

    MatmulsquestionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions