@@ -388,8 +388,12 @@ def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]
388388 ]
389389 elif (dtype == 'fp16' or dtype == 'bf16' ) and tr_load == 't' :
390390 return [
391+ FmhaBwdDQDKDVTileSize ( 32 , 128 , 64 , 32 , 64 , 32 , 32 , 64 , 64 , 1 , 4 , 1 , 4 , 1 , 1 , 1 , 4 , 1 , 16 , 16 , 32 , 16 , 16 , 32 , 1 ),
391392 FmhaBwdDQDKDVTileSize ( 32 , 128 , 128 , 32 , 128 , 32 , 32 , 128 , 128 , 1 , 4 , 1 , 4 , 1 , 1 , 1 , 4 , 1 , 16 , 16 , 32 , 16 , 16 , 32 , 1 ),
392393 FmhaBwdDQDKDVTileSize ( 16 , 192 , 128 , 16 , 128 , 16 , 32 , 128 , 128 , 1 , 4 , 1 , 4 , 1 , 1 , 1 , 4 , 1 , 16 , 16 , 32 , 16 , 16 , 16 , 1 ),
394+
395+ # FmhaBwdDQDKDVTileSize( 32, 32, 64, 32, 64, 32, 32, 64, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, 1, 32),
396+ FmhaBwdDQDKDVTileSize ( 32 , 16 , 64 , 32 , 64 , 32 , 16 , 64 , 64 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 16 , 16 , 32 , 16 , 16 , 16 , 2 , 32 ),
393397 # FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
394398 FmhaBwdDQDKDVTileSize ( 16 , 16 , 128 , 16 , 128 , 16 , 16 , 128 , 128 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 16 , 16 , 32 , 16 , 16 , 16 , 2 , 16 ),
395399 ]
@@ -812,7 +816,9 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
812816 if ("wg32" in dropout ):
813817 continue
814818 if tr_load == "t" :
815- continue # tr_load cannot work with dpad or dvpad
819+ # tr_load can only work with 8 pad
820+ if dpad != dvpad or dpad == 1 :
821+ continue
816822 else : # tr_load == "f"
817823 # do not generate instance with only 1 of dpad/dvpad being 8
818824 if dpad != dvpad and dpad == 8 :
0 commit comments