diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index d05e0a6e9216..45e8eb0f68c6 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -24,8 +24,6 @@ from tvm.script import tir as T from numpy.random import default_rng -from tvm.tir.function import TensorIntrin - VRMPY_SIZE_B = 128 VRMPY_SIZE_INT32 = 32 @@ -72,9 +70,23 @@ def operator(a: T.handle, b: T.handle, c: T.handle) -> None: return tvm.tir.Schedule(operator) -def evaluate(hexagon_session, sch, a, b, size_a, expected_output, use_async_copy=0): +def evaluate( + hexagon_session, + sch, + a, + b, + c, + expected_output=None, + use_async_copy=0, + merge_async_commit_queue_scope=False, +): target_hexagon = tvm.target.hexagon("v68", link_params=True) - with tvm.transform.PassContext(config={"tir.use_async_copy": use_async_copy}): + with tvm.transform.PassContext( + config={ + "tir.use_async_copy": use_async_copy, + "tir.merge_async_commit_queue_scope": merge_async_commit_queue_scope, + } + ): func_tir = tvm.build( sch.mod["main"], target=tvm.target.Target(target_hexagon, host=target_hexagon) ) @@ -82,9 +94,7 @@ def evaluate(hexagon_session, sch, a, b, size_a, expected_output, use_async_copy a_hexagon = tvm.runtime.ndarray.array(a, device=hexagon_session.device) b_hexagon = tvm.runtime.ndarray.array(b, device=hexagon_session.device) - c_hexagon = tvm.runtime.ndarray.array( - np.zeros((size_a, VRMPY_SIZE_INT32), dtype="int32"), device=hexagon_session.device - ) + c_hexagon = tvm.runtime.ndarray.array(c, device=hexagon_session.device) if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI @@ -93,7 +103,8 @@ def evaluate(hexagon_session, sch, a, b, size_a, expected_output, use_async_copy timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10) time = timer(a_hexagon, b_hexagon, c_hexagon) - tvm.testing.assert_allclose(c_hexagon.asnumpy(), expected_output) + if expected_output is not None: + tvm.testing.assert_allclose(c_hexagon.asnumpy(), expected_output) return round(time.mean * 1000, 4) @@ -252,9 +263,32 @@ def get_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): sch.compute_at(cache_read_block_a, no) sch.fuse(*sch.get_loops(cache_read_block_a)[1:]) - cache_read_block_c = sch.cache_write(compute_block, 0, "global.vtcm") - sch.reverse_compute_at(cache_read_block_c, no) - sch.fuse(*sch.get_loops(cache_read_block_c)[1:]) + cache_write_block_c = sch.cache_write(compute_block, 0, "global.vtcm") + sch.reverse_compute_at(cache_write_block_c, no) + sch.fuse(*sch.get_loops(cache_write_block_c)[1:]) + + return sch + + +def get_multi_input_fake_conv_vtcm_schedule(size_a, size_w, blocks=2): + sch = conv_approximation(size_a, size_w) + + compute_block = sch.get_block("C") + + n = sch.get_loops(compute_block)[0] + no, _ = sch.split(n, [blocks, None]) + + cache_read_block_a = sch.cache_read(compute_block, 0, "global.vtcm") + sch.compute_at(cache_read_block_a, no) + sch.fuse(*sch.get_loops(cache_read_block_a)[1:]) + + cache_read_block_b = sch.cache_read(compute_block, 1, "global.vtcm") + sch.compute_at(cache_read_block_b, no) + sch.fuse(*sch.get_loops(cache_read_block_b)[1:]) + + cache_write_block_c = sch.cache_write(compute_block, 0, "global.vtcm") + sch.reverse_compute_at(cache_write_block_c, no) + sch.fuse(*sch.get_loops(cache_write_block_c)[1:]) return sch @@ -271,13 +305,12 @@ class TestAsyncDMAPipeline: size_a = tvm.testing.parameter( 1024, 64 * 64, - 128 * 128, + 128 * 64, ) size_w = tvm.testing.parameter( 1 * 1, 3 * 3, - 7 * 7, 9 * 9, ) @@ -296,11 +329,24 @@ def test_loading_vtcm_for_vrmpy( pytest.skip("Skipping test since it takes too long in CI.") sch = conv_approximation(size_a, size_w) - base_runtime = evaluate(hexagon_session, sch, input_a, input_w, size_a, expected_output) + base_runtime = evaluate( + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) base_vtcm_runtime = evaluate( - hexagon_session, sch, input_a, input_w, size_a, expected_output, use_async_copy=1 + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) @@ -309,7 +355,13 @@ def test_loading_vtcm_for_vrmpy( sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [0]) async_input_runtime = evaluate( - hexagon_session, sch, input_a, input_w, size_a, expected_output, use_async_copy=1 + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) @@ -318,7 +370,44 @@ def test_loading_vtcm_for_vrmpy( sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [0, 2]) async_input_output_runtime = evaluate( - hexagon_session, sch, input_a, input_w, size_a, expected_output, use_async_copy=1 + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, + ) + + sch = get_fake_conv_vtcm_schedule(size_a, size_w) + n = sch.get_loops(sch.get_block("C"))[0] + sch.annotate(n, "software_pipeline_stage", [0, 3, 6]) + sch.annotate(n, "software_pipeline_order", [0, 1, 2]) + sch.annotate(n, "software_pipeline_async_stages", [0, 6]) + async_input_output_runtime_larger_buffers = evaluate( + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, + ) + + sch = get_multi_input_fake_conv_vtcm_schedule(size_a, size_w) + n = sch.get_loops(sch.get_block("C"))[0] + sch.annotate(n, "software_pipeline_stage", [0, 0, 1, 2]) + sch.annotate(n, "software_pipeline_order", [0, 1, 2, 3]) + sch.annotate(n, "software_pipeline_async_stages", [0, 2]) + async_multi_input_output_runtime = evaluate( + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, + merge_async_commit_queue_scope=False, ) sch = get_fake_conv_vtcm_schedule(size_a, size_w) @@ -327,12 +416,23 @@ def test_loading_vtcm_for_vrmpy( sch.annotate(n, "software_pipeline_order", [0, 1, 2]) sch.annotate(n, "software_pipeline_async_stages", [2]) async_output_runtime = evaluate( - hexagon_session, sch, input_a, input_w, size_a, expected_output, use_async_copy=1 + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, + use_async_copy=1, ) sch = get_single_dma_schedule(size_a, size_w) single_dma_runtime = evaluate( - hexagon_session, sch, input_a, input_w, size_a, expected_output + hexagon_session, + sch, + input_a, + input_w, + np.zeros(expected_output.shape, "int32"), + expected_output, ) # Total transfer size is equal to the size of A + W + C which is equal to 2 * size_a * 128 + size_w * 128 @@ -349,5 +449,313 @@ def test_loading_vtcm_for_vrmpy( "async_dma_input": async_input_runtime, "async_dma_output": async_output_runtime, "async_dma_input_output": async_input_output_runtime, + "async_dma_multi_input_output": async_multi_input_output_runtime, + "async_input_output_runtime_larger_buffers": async_input_output_runtime_larger_buffers, }, ) + + +# from tvm.script import tir as T +@tvm.script.ir_module +class ModulePipelined: + @T.prim_func + def main( + p0: T.Buffer[(1, 1, 230, 230, 4), "uint8"], + p1: T.Buffer[(2, 1, 7, 7, 1, 32, 4), "int8"], + T_cast: T.Buffer[(1, 2, 112, 112, 32), "int32"], + ) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], dtype="int32", scope="global.vtcm") + p0_global_vtcm = T.alloc_buffer([1, 1, 230, 230, 4], dtype="uint8", scope="global.vtcm") + p1_global_vtcm = T.alloc_buffer([2, 1, 7, 7, 1, 32, 4], dtype="int8", scope="global.vtcm") + for ax0, ax1, ax2, ax3, ax4, ax5, ax6 in T.grid(2, 1, 7, 7, 1, 32, 4): + with T.block("p1_global.vtcm"): + v0, v1, v2, v3, v4, v5, v6 = T.axis.remap( + "SSSSSSS", [ax0, ax1, ax2, ax3, ax4, ax5, ax6] + ) + T.reads(p1[v0, v1, v2, v3, v4, v5, v6]) + T.writes(p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) + p1_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = p1[v0, v1, v2, v3, v4, v5, v6] + for po in T.serial(4): + for i in T.serial(55876): + with T.block("p0_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(1, 0) + v2 = T.axis.spatial(230, po * 56 + i // 916) + v3 = T.axis.spatial(230, i % 916 // 4) + v4 = T.axis.spatial(4, i % 4) + T.reads(p0[v0, v1, v2, v3, v4]) + T.writes(p0_global_vtcm[v0, v1, v2, v3, v4]) + p0_global_vtcm[v0, v1, v2, v3, v4] = p0[v0, v1, v2, v3, v4] + for i in T.parallel(28): + for ii, iii, iiii in T.grid(2, 14, 8): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(2, ii) + oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + iii) + ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + iiii) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32]) + for i4_1 in T.vectorized(32): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(32, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i1_1, i5_1, i6_1, i2_2, i3_2 in T.grid(2, 7, 7, 14, 8): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(2, i1_1) + oh = T.axis.spatial(112, (po * 28 + i) // 14 * 14 + i2_2) + ow = T.axis.spatial(112, (po * 28 + i) % 14 * 8 + i3_2) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(7, i5_1) + kw = T.axis.reduce(7, i6_1) + ic_outer = T.axis.reduce(1, 0) + ic_f_inner = T.axis.reduce(1, 0) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32], + p0_global_vtcm[ + n, + ic_outer, + oh * 2 + kh, + ow * 2 + kw, + ic_f_inner * 4 : ic_f_inner * 4 + 4, + ], + p1_global_vtcm[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32]) + A = T.match_buffer( + p0_global_vtcm[ + n, + ic_outer, + oh * 2 + kh, + ow * 2 + kw, + ic_f_inner * 4 : ic_f_inner * 4 + 4, + ], + [4], + dtype="uint8", + offset_factor=1, + scope="global.vtcm", + ) + B = T.match_buffer( + p1_global_vtcm[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4], + [32, 4], + dtype="int8", + offset_factor=1, + scope="global.vtcm", + ) + C = T.match_buffer( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32], + [32], + dtype="int32", + offset_factor=1, + scope="global.vtcm", + ) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x128 = B[0, 0:128] + B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + C[0:32] = T.call_llvm_pure_intrin( + 4217, + T.uint32(3), + C[0:32], + T.broadcast(A_i32, 32), + B_i32x32, + dtype="int32x32", + ) + for i in T.serial(200704): + with T.block("conv2d_NCHWc_int8.vtcm"): + ax0_1 = T.axis.spatial(1, 0) + ax1_1 = T.axis.spatial(2, i % 7168 // 3584) + ax2_1 = T.axis.spatial(112, (po * 28 + i // 7168) // 14 * 14 + i % 3584 // 256) + ax3_1 = T.axis.spatial(112, (po * 28 + i // 7168) % 14 * 8 + i % 256 // 32) + ax4 = T.axis.spatial(32, i % 32) + T.reads(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T.writes(T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = conv2d_NCHWc_int8[ + ax0_1, ax1_1, ax2_1, ax3_1, ax4 + ] + + +# from tvm.script import tir as T +@tvm.script.ir_module +class ModuleBase: + @T.prim_func + def main( + p0: T.Buffer[(1, 1, 230, 230, 4), "uint8"], + p1: T.Buffer[(2, 1, 7, 7, 1, 32, 4), "int8"], + T_cast: T.Buffer[(1, 2, 112, 112, 32), "int32"], + ) -> None: + # function attr dict + T.func_attr({"tir.noalias": True, "global_symbol": "main"}) + # buffer definition + # body + # with T.block("root") + conv2d_NCHWc_int8 = T.alloc_buffer([1, 2, 112, 112, 32], dtype="int32") + for i0_0_i1_0_i2_0_i3_0_fused in T.parallel( + 112, annotations={"pragma_auto_unroll_max_step": 64, "pragma_unroll_explicit": 1} + ): + for i4_0_0 in T.serial(1): + for i1_1_init, i2_1_init, i3_1_init, i1_2_init, i2_2_init, i3_2_init in T.grid( + 2, 1, 1, 1, 14, 8 + ): + with T.block("conv2d_NCHWc_int8_o_init"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(2, i1_1_init + i1_2_init) + oh = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + i2_1_init * 14 + i2_2_init + ) + ow = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + i3_1_init * 8 + i3_2_init + ) + oc_block_o = T.axis.spatial(1, 0) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32]) + for i4_1 in T.vectorized(32): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(32, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i5_0, i6_0, i7_0, i8_0, i9_0_0 in T.grid(1, 1, 1, 1, 1): + for ( + i0_1, + i1_1, + i2_1, + i3_1, + i4_0_1, + i5_1, + i6_1, + i7_1, + i8_1, + i9_0_1, + i0_2, + i1_2, + i2_2, + i3_2, + i4_0_2, + ) in T.grid(1, 2, 1, 1, 1, 7, 7, 1, 1, 1, 1, 1, 14, 8, 1): + with T.block("conv2d_NCHWc_int8_o_update"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(2, i1_1 + i1_2) + oh = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + i2_1 * 14 + i2_2 + ) + ow = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + i3_1 * 8 + i3_2 + ) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(7, i5_0 * 7 + i5_1) + kw = T.axis.reduce(7, i6_0 * 7 + i6_1) + ic_outer = T.axis.reduce(1, 0) + ic_f_inner = T.axis.reduce(1, 0) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32], + p0[ + n, + ic_outer, + oh * 2 + kh, + ow * 2 + kw, + ic_f_inner * 4 : ic_f_inner * 4 + 4, + ], + p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32]) + A = T.match_buffer( + p0[ + n, + ic_outer, + oh * 2 + kh, + ow * 2 + kw, + ic_f_inner * 4 : ic_f_inner * 4 + 4, + ], + [4], + dtype="uint8", + offset_factor=1, + ) + B = T.match_buffer( + p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0:32, 0:4], + [32, 4], + dtype="int8", + offset_factor=1, + ) + C = T.match_buffer( + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0:32], + [32], + dtype="int32", + offset_factor=1, + ) + A_u8x4: T.uint8x4 = A[0:4] + A_i32: T.int32 = T.reinterpret(A_u8x4, dtype="int32") + B_i8x128 = B[0, 0:128] + B_i32x32: T.int32x32 = T.reinterpret(B_i8x128, dtype="int32x32") + C[0:32] = T.call_llvm_pure_intrin( + 4217, + T.uint32(3), + C[0:32], + T.broadcast(A_i32, 32), + B_i32x32, + dtype="int32x32", + ) + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 14, 8): + for ax4_fused in T.vectorized(32): + with T.block("T_cast_2"): + ax0_1, ax1_1 = T.axis.remap("SS", [ax0, ax1]) + ax2_1 = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused // 14 * 14 + ax2 + ) + ax3_1 = T.axis.spatial( + 112, i0_0_i1_0_i2_0_i3_0_fused % 14 * 8 + ax3 + ) + ax4 = T.axis.spatial(32, ax4_fused) + T.reads(conv2d_NCHWc_int8[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T.writes(T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4]) + T_cast[ax0_1, ax1_1, ax2_1, ax3_1, ax4] = conv2d_NCHWc_int8[ + ax0_1, ax1_1, ax2_1, ax3_1, ax4 + ] + + +@tvm.testing.requires_hexagon +def test_meta(hexagon_session): + if tvm.testing.utils.IS_IN_CI: + pytest.skip("Skipping test since it takes too long in CI.") + + a = default_rng().integers(1, 8, (1, 1, 230, 230, 4), dtype="uint8") + w = default_rng().integers(1, 8, (2, 1, 7, 7, 1, 32, 4), dtype="int8") + c = np.zeros((1, 2, 112, 112, 32), dtype="int32") + + sch = tvm.tir.Schedule(ModuleBase) + base_runtime = evaluate(hexagon_session, sch, a, w, c) + + sch = tvm.tir.Schedule(ModulePipelined) + compute_block = sch.get_block("conv2d_NCHWc_int8_o_update") + o = sch.get_loops(compute_block)[0] + + unscheduled_vtcm_runtime = evaluate(hexagon_session, sch, a, w, c, use_async_copy=1) + + sch = tvm.tir.Schedule(ModulePipelined) + compute_block = sch.get_block("conv2d_NCHWc_int8_o_update") + o = sch.get_loops(compute_block)[0] + + sch.annotate(o, "software_pipeline_stage", [0, 1, 2]) + sch.annotate(o, "software_pipeline_order", [0, 1, 2]) + sch.annotate(o, "software_pipeline_async_stages", [0, 2]) + + pipeline_runtime = evaluate(hexagon_session, sch, a, w, c, use_async_copy=1) + + transfer_mb = round((a.size + w.size + c.size) / 1e6, 2) + print_results( + f"Test with A.size: {a.size}, W.size: {w.size}, and total memory transfer of {transfer_mb} MB...", + { + "without_vtcm": base_runtime, + "unscheduled_vtcm_runtime": unscheduled_vtcm_runtime, + "pipeline_runtime": pipeline_runtime, + }, + )