From 3d02af91c61e7f7d2d12defa21e9fe4278bea0ae Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Mon, 14 Nov 2022 09:12:01 -0800 Subject: [PATCH 1/2] [TIR] Add test to cover specific case of reducer match buffer checking --- ..._transform_lower_cross_thread_reduction.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index 3ab09f01dd01..de1bab7d7d53 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -548,6 +548,49 @@ def single_reduction_loop_with_tensorize( ) +@T.prim_func +def nested_reduction_loop_with_inner_match_buffers( + in0: T.Buffer[(4, 16), "int8"], + in1: T.Buffer[(4, 16), "int8"], + out: T.Buffer[(4, 4), "int32"], +) -> None: + # body + # with T.block("root") + for y in T.serial(4): + with T.block("C"): + yi = T.axis.spatial(4, y) + T.reads(in0[yi, 0:16], in1[yi, 0:16]) + T.writes(out[yi, 0:4]) + for x in T.serial(4): + with T.block("C"): + xr = T.axis.reduce(4, x) + with T.init(): + for i in T.serial(4): + with T.block("C_init"): + ii = T.axis.spatial(4, i) + T.reads() + T.writes(out[yi, ii]) + out[yi, ii] = 0 + with T.block("C"): + T.reads( + out[yi, xr], + in0[yi, yi * 4 + xr : yi * 4 + xr + 4], + in1[yi, yi * 4 + xr : yi * 4 + xr + 4], + ) + T.writes(out[yi, xr]) + A = T.match_buffer( + in0[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1 + ) + B = T.match_buffer( + in1[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1 + ) + C = T.match_buffer(out[yi, xr], [1], dtype="int32", offset_factor=1) + A_i8x4: T.int8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32") + B_i8x4: T.int8x4 = B[0:4] + B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32") + C[0] = A_i32 + B_i32 + C[0] + @T.prim_func def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32") @@ -1247,6 +1290,13 @@ def test_single_reduction_loop_with_tensorize(): ) +def test_nested_reduction_loop_with_inner_match_buffers(): + _check( + nested_reduction_loop_with_inner_match_buffers, + nested_reduction_loop_with_inner_match_buffers, + ) + + def test_reducer_max(): _check(reducer_max, lowered_reducer_max) From 55b9b404d5bc108b957066472bf045505ac9a552 Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Tue, 15 Nov 2022 13:55:31 -0800 Subject: [PATCH 2/2] lint --- ...test_tir_transform_lower_cross_thread_reduction.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py index de1bab7d7d53..2bf898e66b08 100644 --- a/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py +++ b/tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py @@ -579,10 +579,16 @@ def nested_reduction_loop_with_inner_match_buffers( ) T.writes(out[yi, xr]) A = T.match_buffer( - in0[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1 + in0[yi, yi * 4 + xr : yi * 4 + xr + 4], + [4], + dtype="int8", + offset_factor=1, ) B = T.match_buffer( - in1[yi, yi * 4 + xr : yi * 4 + xr + 4], [4], dtype="int8", offset_factor=1 + in1[yi, yi * 4 + xr : yi * 4 + xr + 4], + [4], + dtype="int8", + offset_factor=1, ) C = T.match_buffer(out[yi, xr], [1], dtype="int32", offset_factor=1) A_i8x4: T.int8x4 = A[0:4] @@ -591,6 +597,7 @@ def nested_reduction_loop_with_inner_match_buffers( B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32") C[0] = A_i32 + B_i32 + C[0] + @T.prim_func def reducer_max(a: T.handle, b: T.handle) -> None: A = T.match_buffer(a, [128, 128], dtype="float32")