From 8e8c46b933bc6543ffeeaa680a3f2e2f53d179e3 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 1 Aug 2022 11:46:21 -0700 Subject: [PATCH] [MetaSchedule][Test] Add unittests for TBG --- .../unittest/test_meta_schedule_space_cpu.py | 185 ++++++++++++++++++ .../unittest/test_meta_schedule_space_cuda.py | 85 ++++++++ 2 files changed, 270 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_space_cpu.py b/tests/python/unittest/test_meta_schedule_space_cpu.py index e0d7b29c8915..25dc14fd5cb7 100644 --- a/tests/python/unittest/test_meta_schedule_space_cpu.py +++ b/tests/python/unittest/test_meta_schedule_space_cpu.py @@ -2418,6 +2418,190 @@ def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 ) +def test_cpu_tbg(): + # fmt: off + @T.prim_func + def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) + query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32") + value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") + C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32") + for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1 in T.grid(1, 1, 1, 2, 1, 6, 2): + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): + with T.block("value_T"): + b = T.axis.spatial(1, ax0) + h = T.axis.spatial(12, i1_1 * 2 + ax1) + d = T.axis.spatial(64, ax2) + l = T.axis.spatial(128, i3_0 * 64 + ax3) + T.reads(value[b, l, h, d]) + T.writes(value_T[b, h, d, l]) + value_T[b, h, d, l] = value[b, l, h, d] + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 64): + with T.block("query_T"): + b = T.axis.spatial(1, ax0) + h = T.axis.spatial(12, i1_1 * 2 + ax1) + l = T.axis.spatial(128, i2_1 * 64 + ax2) + d = T.axis.spatial(64, ax3) + T.reads(query[b, l, h, d]) + T.writes(query_T[b, h, l, d]) + query_T[b, h, l, d] = query[b, l, h, d] + for i3_1 in T.serial(8): + for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): + with T.block("C"): + b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) + h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) + i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) + j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) + k = T.axis.reduce(64, i4_0 * 64 + i4_1) + T.reads(query_T[b, h, i, k], value_T[b, h, k, j]) + T.writes(C_global[b, h, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + C_global[b, h, i, j] = T.float32(0) + C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j] + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): + with T.block("C_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(12, i1_1 * 2 + ax1) + v2 = T.axis.spatial(128, i2_1 * 64 + ax2) + v3 = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3) + T.reads(C_global[v0, v1, v2, v3]) + T.writes(C[v0, v1, v2, v3]) + C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] + @T.prim_func + def tbg_1(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":64, "meta_schedule.vectorize":64}) + query_T = T.alloc_buffer([1, 12, 128, 64], dtype="float32") + value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") + C_global = T.alloc_buffer([1, 12, 128, 128], dtype="float32") + for i0, i1, i2, i3 in T.grid(1, 12, 128, 64): + with T.block("query_T"): + b, h, l, d = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(query[b, l, h, d]) + T.writes(query_T[b, h, l, d]) + query_T[b, h, l, d] = query[b, l, h, d] + for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 1, 1, 2): + for i0_1, i1_1, i2_1, i3_1, i4_0, i0_2, i1_2, i2_2, i3_2, i4_1 in T.grid(1, 6, 2, 8, 1, 1, 2, 2, 4, 64): + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 1, 2): + with T.block("value_T"): + b = T.axis.spatial(1, ax0) + h = T.axis.spatial(12, i1_1 * 2 + i1_2 + ax1) + d = T.axis.spatial(64, i4_1 + ax2) + l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + ax3) + T.reads(value[b, l, h, d]) + T.writes(value_T[b, h, d, l]) + value_T[b, h, d, l] = value[b, l, h, d] + for i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 32, 2): + with T.block("C"): + b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) + h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) + i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) + j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) + k = T.axis.reduce(64, i4_0 * 64 + i4_1) + T.reads(query_T[b, h, i, k], value_T[b, h, k, j]) + T.writes(C_global[b, h, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + C_global[b, h, i, j] = T.float32(0) + C_global[b, h, i, j] = C_global[b, h, i, j] + query_T[b, h, i, k] * value_T[b, h, k, j] + for ax0, ax1, ax2, ax3 in T.grid(1, 12, 128, 64): + with T.block("C_global"): + v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + v3 = T.axis.spatial(128, i3_0 * 64 + ax3) + T.reads(C_global[v0, v1, v2, v3]) + T.writes(C[v0, v1, v2, v3]) + C[v0, v1, v2, v3] = C_global[v0, v1, v2, v3] + @T.prim_func + def tbg_2(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":512, "meta_schedule.vectorize":64}) + value_T = T.alloc_buffer([1, 12, 64, 128], dtype="float32") + for i0_0, i1_0, i2_0, i3_0, i0_1, i1_1, i2_1, i3_1 in T.grid(1, 1, 1, 2, 1, 6, 2, 8): + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 64, 8): + with T.block("value_T"): + b = T.axis.spatial(1, ax0) + h = T.axis.spatial(12, i1_1 * 2 + ax1) + d = T.axis.spatial(64, ax2) + l = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + ax3) + T.reads(value[b, l, h, d]) + T.writes(value_T[b, h, d, l]) + value_T[b, h, d, l] = value[b, l, h, d] + for i4_0, i0_2, i1_2, i2_2, i3_2, i4_1, i0_3, i1_3, i2_3, i3_3 in T.grid(1, 1, 2, 2, 4, 64, 1, 1, 32, 2): + with T.block("C"): + b = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0) + h = T.axis.spatial(12, i1_0 * 12 + i1_1 * 2 + i1_2 + i1_3) + i = T.axis.spatial(128, i2_0 * 128 + i2_1 * 64 + i2_2 * 32 + i2_3) + j = T.axis.spatial(128, i3_0 * 64 + i3_1 * 8 + i3_2 * 2 + i3_3) + k = T.axis.reduce(64, i4_0 * 64 + i4_1) + T.reads(query[b, i, h, k], value_T[b, h, k, j]) + T.writes(C[b, h, i, j]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + with T.init(): + C[b, h, i, j] = T.float32(0) + C[b, h, i, j] = C[b, h, i, j] + query[b, i, h, k] * value_T[b, h, k, j] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 6, 2, 1]), + ("SamplePerfectTile", [1, 2, 2, 32]), + ("SamplePerfectTile", [2, 8, 4, 2]), + ("SamplePerfectTile", [1, 64]), + ("SampleCategorical", 2), + ("SampleComputeLocation", 6), + ("SampleComputeLocation", 6), + ] + decision_1 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 6, 2, 1]), + ("SamplePerfectTile", [1, 2, 2, 32]), + ("SamplePerfectTile", [2, 8, 4, 2]), + ("SamplePerfectTile", [1, 64]), + ("SampleCategorical", 2), + ("SampleComputeLocation", 13), + ("SampleComputeLocation", -1), + ] + decision_2 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 6, 2, 1]), + ("SamplePerfectTile", [1, 2, 2, 32]), + ("SamplePerfectTile", [2, 8, 4, 2]), + ("SamplePerfectTile", [1, 64]), + ("SampleCategorical", 3), + ("SampleComputeLocation", 7), + ("SampleComputeLocation", -2), + ] + mod = create_te_workload("TBG", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[tbg_0, tbg_1, tbg_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + if __name__ == "__main__": test_cpu_c1d() test_cpu_c2d() @@ -2431,3 +2615,4 @@ def cbr_2(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 test_cpu_nrm() test_cpu_sfm() test_cpu_cbr() + test_cpu_tbg() diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index ae4737a362a3..d617742d9457 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -1217,6 +1217,90 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 ) +def test_cuda_tbg(): + # fmt: off + @T.prim_func + def tbg_0(query: T.Buffer[(1, 128, 12, 64), "float32"], value: T.Buffer[(1, 128, 12, 64), "float32"], C: T.Buffer[(1, 12, 128, 128), "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + C_local = T.alloc_buffer([1, 12, 128, 128], dtype="float32", scope="local") + query_T_shared = T.alloc_buffer([1, 12, 128, 64], dtype="float32", scope="shared") + value_T_shared = T.alloc_buffer([1, 12, 64, 128], dtype="float32", scope="shared") + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(4, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(192, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(32, thread="threadIdx.x"): + for i4_0 in T.serial(8): + for ax0_ax1_ax2_ax3_fused in T.serial(12288): + with T.block("query_T_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 1024) + v2 = T.axis.spatial(128, ax0_ax1_ax2_ax3_fused % 1024 // 8) + v3 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 8) + T.reads(query[v0, v2, v1, v3]) + T.writes(query_T_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + query_T_shared[v0, v1, v2, v3] = query[v0, v2, v1, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(3072): + with T.block("value_T_shared"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(12, ax0_ax1_ax2_ax3_fused // 256) + v2 = T.axis.spatial(64, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32) + v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(value[v0, v3, v1, v2]) + T.writes(value_T_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + value_T_shared[v0, v1, v2, v3] = value[v0, v3, v1, v2] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(4, 1, 2, 1, 1, 2, 1, 1, 4, 1): + with T.block("C"): + b = T.axis.spatial(1, i0_4 + i0_3) + h = T.axis.spatial(12, i1_4 + i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + i1_3) + i = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + i2_3 * 4 + i2_4) + j = T.axis.spatial(128, i3_4 + i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + i3_3) + k = T.axis.reduce(64, i4_0 * 8 + i4_1 * 2 + i4_2) + T.reads(query_T_shared[b, h, i, k], value_T_shared[b, h, k, j]) + T.writes(C_local[b, h, i, j]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + with T.init(): + C_local[b, h, i, j] = T.float32(0) + C_local[b, h, i, j] = C_local[b, h, i, j] + query_T_shared[b, h, i, k] * value_T_shared[b, h, k, j] + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 4, 1): + with T.block("C_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(12, i0_1_i1_1_i2_1_i3_1_fused // 32 * 2 + ax1) + v2 = T.axis.spatial(128, i0_1_i1_1_i2_1_i3_1_fused % 32 // 8 * 32 + i0_2_i1_2_i2_2_i3_2_fused // 4 * 4 + ax2) + v3 = T.axis.spatial(128, i0_0_i1_0_i2_0_i3_0_fused * 32 + i0_1_i1_1_i2_1_i3_1_fused % 8 * 4 + i0_2_i1_2_i2_2_i3_2_fused % 4 + ax3) + T.reads(C_local[v0, v1, v2, v3]) + T.writes(C[v0, v1, v2, v3]) + C[v0, v1, v2, v3] = C_local[v0, v1, v2, v3] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 6, 1, 2, 1]), + ("SamplePerfectTile", [1, 4, 8, 1, 4]), + ("SamplePerfectTile", [4, 8, 4, 1, 1]), + ("SamplePerfectTile", [8, 4, 2]), + ("SampleCategorical", 2), + ("SampleCategorical", 3), + ("SampleCategorical", 4), + ] + mod = create_te_workload("TBG", 0) + actual = ms.TuneContext( + mod=mod, + target=_target(), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules="default", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[tbg_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": test_cuda_c1d() test_cuda_c2d() @@ -1230,3 +1314,4 @@ def cbr_0(data: T.Buffer[(1, 224, 224, 3), "float32"], kernel: T.Buffer[(7, 7, 3 test_cuda_nrm() test_cuda_sfm() test_cuda_cbr() + test_cuda_tbg()