diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 3ab09f01dd01..2bf898e66b08 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -548,6 +548,56 @@ def single_reduction_loop_with_tensorize( ) +@T.prim_func +def nested_reduction_loop_with_inner_match_buffers( + in0: T.Buffer[(4, 16), "int8"], + in1: T.Buffer[(4, 16), "int8"], + out: T.Buffer[(4, 4), "int32"], +) -> None: + # body + # with T.block("root") + for y in T.serial(4): + with T.block("C"): + yi = T.axis.spatial(4, y) + T.reads(in0[yi, 0:16], in1[yi, 0:16]) + T.writes(out[yi, 0:4]) + for x in T.serial(4): + with T.block("C"): + xr = T.axis.reduce(4, x) + with T.init(): + for i in T.serial(4): + with T.block("C_init"): + ii = T.axis.spatial(4, i) + T.reads() + T.writes(out[yi, ii]) + out[yi, ii] = 0 + with T.block("C"): + T.reads( + out[yi, xr], + in0[yi, yi * 4 + xr : yi * 4 + xr + 4], + in1[yi, yi * 4 + xr : yi * 4 + xr + 4], + ) + T.writes(out[yi, xr]) + A = T.match_buffer( + in0[yi, yi * 4 + xr : yi * 4 + xr + 4], + [4], + dtype="int8", + offset_factor=1, + ) + B = T.match_buffer( + in1[yi, yi * 4 + xr : yi * 4 + xr + 4], + [4], + dtype="int8", + offset_factor=1, + ) + C = T.match_buffer(out[yi, xr], [1], dtype="int32", offset_factor=1) + A_i8x4: T.int8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32") + B_i8x4: T.int8x4 = B[0:4] + B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32") + C[0] = A_i32 + B_i32 + C[0] + + @T.prim_func def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") @@ -1247,6 +1297,13 @@ def test_single_reduction_loop_with_tensorize(): ) +def test_nested_reduction_loop_with_inner_match_buffers(): + _check( + nested_reduction_loop_with_inner_match_buffers, + nested_reduction_loop_with_inner_match_buffers, + ) + + def test_reducer_max(): _check(reducer_max, lowered_reducer_max)