@@ -582,6 +582,40 @@ def tiled_after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
582582 C [vi , vj ] = B [vi , vj ] + 1.0
583583
584584
585+ @T .prim_func
586+ def tiled_trivial_binding (a : T .handle , c : T .handle ) -> None :
587+ A = T .match_buffer (a , [1 , 128 , 128 ], "float32" )
588+ B = T .alloc_buffer ([1 , 128 , 128 ], "float32" )
589+ C = T .match_buffer (c , [1 , 128 , 128 ], "float32" )
590+ for i_0 , j_0 , i_1 , j_1 in T .grid (8 , 8 , 16 , 16 ):
591+ with T .block ("B" ):
592+ vi = T .axis .S (128 , i_0 * 16 + i_1 )
593+ vj = T .axis .S (128 , j_0 * 16 + j_1 )
594+ B [0 , vi , vj ] = A [0 , vi , vj ] * 2.0
595+ for i , j in T .grid (128 , 128 ):
596+ with T .block ("C" ):
597+ vi , vj = T .axis .remap ("SS" , [i , j ])
598+ C [0 , vi , vj ] = B [0 , vi , vj ] + 1.0
599+
600+
601+ @T .prim_func
602+ def tiled_trivial_binding_after_reverse_compute_at (a : T .handle , c : T .handle ) -> None :
603+ A = T .match_buffer (a , [1 , 128 , 128 ], "float32" )
604+ B = T .alloc_buffer ([1 , 128 , 128 ], "float32" )
605+ C = T .match_buffer (c , [1 , 128 , 128 ], "float32" )
606+ for i_0 , j_0 , i_1 in T .grid (8 , 8 , 16 ):
607+ for j_1 in T .serial (0 , 16 ):
608+ with T .block ("B" ):
609+ vi = T .axis .S (128 , i_0 * 16 + i_1 )
610+ vj = T .axis .S (128 , j_0 * 16 + j_1 )
611+ B [0 , vi , vj ] = A [0 , vi , vj ] * 2.0
612+ for j_1 in T .serial (0 , 16 ):
613+ with T .block ("C" ):
614+ vi = T .axis .S (128 , i_0 * 16 + i_1 )
615+ vj = T .axis .S (128 , j_0 * 16 + j_1 )
616+ C [0 , vi , vj ] = B [0 , vi , vj ] + 1.0
617+
618+
585619@T .prim_func
586620def factorized (a : T .handle , b : T .handle ) -> None :
587621 A = T .match_buffer (a , [16 , 16 , 16 ], "float32" )
@@ -1149,6 +1183,15 @@ def test_reverse_compute_at_tiled():
11491183 verify_trace_roundtrip (sch = sch , mod = tiled )
11501184
11511185
1186+ def test_reverse_compute_at_tiled_trivial_binding ():
1187+ sch = tir .Schedule (tiled_trivial_binding , debug_mask = "all" )
1188+ block = sch .get_block ("C" )
1189+ _ , _ , loop , _ = sch .get_loops (sch .get_block ("B" ))
1190+ sch .reverse_compute_at (block , loop , preserve_unit_loops = False )
1191+ tvm .ir .assert_structural_equal (tiled_trivial_binding_after_reverse_compute_at , sch .mod ["main" ])
1192+ verify_trace_roundtrip (sch = sch , mod = tiled_trivial_binding )
1193+
1194+
11521195def test_reverse_compute_at_blockized_2 ():
11531196 sch = tir .Schedule (blockized_2 , debug_mask = "all" )
11541197 block = sch .get_block ("C" )
0 commit comments