From f558a2f7406072b9e48f002514c7c4759a8b3724 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 11:54:06 -0800 Subject: [PATCH 1/8] Fix reorder. --- include/tvm/meta_schedule/schedule/cuda/thread_bind.h | 4 +++- src/meta_schedule/schedule/cuda/thread_bind.cc | 4 ++-- src/meta_schedule/schedule/cuda/winograd.cc | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index ae6d492bfe12..f5d7b4695084 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -45,10 +45,12 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param max_threadblocks The maximum number of threadblocks allowed. * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. + * \param allow_reorder Whether to allow reorder. */ Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr); + std::function get_factor = nullptr, + bool allow_reorder = true); /*! * \brief Bind the given block if it is not bound to blockIdx or threadIdx. diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index e5dd5068783d..68d4ed075339 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -55,14 +55,14 @@ std::function MakeFactorSampler(Schedule sch, Array th Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor) { + std::function get_factor, bool allow_reorder) { int64_t extent = -1; if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { extent = *e; } else { extent = std::numeric_limits::max(); } - if (extent <= max_threadblocks * max_threads_per_block) { + if (extent <= max_threadblocks * max_threads_per_block || !allow_reorder) { if (!get_factor) { get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); } diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index 5334c4df2ac9..c7ea9a65bda3 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -117,7 +117,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") sch->Unroll(loops[4]); sch->Unroll(loops[5]); outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, - max_threads_per_block)[1]; + max_threads_per_block, /*get_factor=*/nullptr, + /*allow_reorder=*/false)[1]; } { BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); From a4327c60037bae9c26b3b25728ba252f3d1dde0d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 12:07:23 -0800 Subject: [PATCH 2/8] Add test. --- .../unittest/test_meta_schedule_tune_tir.py | 227 ++++++++++++++++++ 1 file changed, 227 insertions(+) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index aa45120c2316..e341c0dd73a6 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -33,6 +33,221 @@ logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) +@tvm.script.ir_module +class WinogradConv2d: + @T.prim_func + def main( + p0: T.Buffer[(2, 2048, 50, 75), "float32"], + p1: T.Buffer[(4, 4, 2048, 2048), "float32"], + p2: T.Buffer[(1, 2048, 1, 1), "float32"], + T_relu: T.Buffer[(2, 2048, 50, 75), "float32"], + ): + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([2, 2048, 52, 77], dtype="float32") + input_tile = T.alloc_buffer([2048, 1900, 4, 4], dtype="float32") + B = T.alloc_buffer([4, 4], dtype="float32") + data_pack = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + bgemm = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + A = T.alloc_buffer([4, 2], dtype="float32") + inverse = T.alloc_buffer([2048, 1900, 2, 2], dtype="float32") + conv2d_winograd = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") + T_add = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") + for i0, i1, i2, i3 in T.grid(2, 2048, 52, 77): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 1 <= i2_1 and i2_1 < 51 and 1 <= i3_1 and i3_1 < 76, + p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], + T.float32(0), + dtype="float32", + ) + for i0, i1, i2, i3 in T.grid(2048, 1900, 4, 4): + with T.block("input_tile"): + ci, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu]) + T.writes(input_tile[ci, p, eps, nu]) + T.block_attr({"schedule_rule": "None"}) + input_tile[ci, p, eps, nu] = data_pad[ + p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu + ] + for i0, i1 in T.grid(4, 4): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"schedule_rule": "None"}) + B[i, j] = T.Select( + i % 4 == 3 and j % 4 == 3, + T.float32(1), + T.Select( + i % 4 == 3 and j % 4 == 2, + T.float32(0), + T.Select( + i % 4 == 3 and j % 4 == 1, + T.float32(0), + T.Select( + i % 4 == 3 and j % 4 == 0, + T.float32(0), + T.Select( + i % 4 == 2 and j % 4 == 3, + T.float32(0), + T.Select( + i % 4 == 2 and j % 4 == 2, + T.float32(1), + T.Select( + i % 4 == 2 and j % 4 == 1, + T.float32(1), + T.Select( + i % 4 == 2 and j % 4 == 0, + T.float32(-1), + T.Select( + i % 4 == 1 and j % 4 == 3, + T.float32(-1), + T.Select( + i % 4 == 1 and j % 4 == 2, + T.float32(1), + T.Select( + i % 4 == 1 and j % 4 == 1, + T.float32(-1), + T.Select( + i % 4 == 1 and j % 4 == 0, + T.float32(0), + T.Select( + i % 4 == 0 and j % 4 == 3, + T.float32(0), + T.Select( + i % 4 == 0 and j % 4 == 2, + T.float32(0), + T.Select( + i % 4 == 0 + and j % 4 == 1, + T.float32(0), + T.Select( + i % 4 == 0 + and j % 4 == 0, + T.float32(1), + T.float32(0), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 2048, 1900, 4, 4): + with T.block("data_pack"): + eps, nu, ci, p, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + input_tile[ci, p, r_a, r_b], + B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1], + ) + T.writes(data_pack[eps, nu, ci, p]) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) + with T.init(): + data_pack[eps, nu, ci, p] = T.float32(0) + data_pack[eps, nu, ci, p] = ( + data_pack[eps, nu, ci, p] + + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu] + ) + for i0, i1, i2, i3, i4 in T.grid(4, 4, 2048, 1900, 2048): + with T.block("bgemm"): + eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, ci, p], p1[eps, nu, ci, co]) + T.writes(bgemm[eps, nu, co, p]) + with T.init(): + bgemm[eps, nu, co, p] = T.float32(0) + bgemm[eps, nu, co, p] = ( + bgemm[eps, nu, co, p] + data_pack[eps, nu, ci, p] * p1[eps, nu, ci, co] + ) + for i0, i1 in T.grid(4, 2): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"schedule_rule": "None"}) + A[i, j] = T.Select( + i % 4 == 3 and j % 2 == 1, + T.float32(1), + T.Select( + i % 4 == 3 and j % 2 == 0, + T.float32(0), + T.Select( + i % 4 == 2 and j % 2 == 1, + T.float32(1), + T.Select( + i % 4 == 2 and j % 2 == 0, + T.float32(1), + T.Select( + i % 4 == 1 and j % 2 == 1, + T.float32(-1), + T.Select( + i % 4 == 1 and j % 2 == 0, + T.float32(1), + T.Select( + i % 4 == 0 and j % 2 == 1, + T.float32(0), + T.Select( + i % 4 == 0 and j % 2 == 0, + T.float32(1), + T.float32(0), + ), + ), + ), + ), + ), + ), + ), + ) + for i0, i1, i2, i3, i4, i5 in T.grid(2048, 1900, 2, 2, 4, 4): + with T.block("inverse"): + co, p, vh, vw, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads( + bgemm[r_a, r_b, co, p], + A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1], + ) + T.writes(inverse[co, p, vh, vw]) + T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) + with T.init(): + inverse[co, p, vh, vw] = T.float32(0) + inverse[co, p, vh, vw] = ( + inverse[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw] + ) + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("conv2d_winograd"): + n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2]) + T.writes(conv2d_winograd[n, co, h, w]) + conv2d_winograd[n, co, h, w] = inverse[ + co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2 + ] + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[0, ax1, 0, 0]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[0, ax1, 0, 0] + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add[ax0, ax1, ax2, ax3], T.float32(0)) + + @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) @@ -178,8 +393,20 @@ def clone(self) -> "RemoveBlock": sch.trace.show() +@pytest.skip("Slow test and requires rtx-3070") +def test_tune_winograd_conv2d_cuda(): + mod = WinogradConv2d + with tempfile.TemporaryDirectory() as work_dir: + database = ms.tune_tir( + mod, target="nvidia/geforce-rtx-3070", max_trials_global=10, work_dir=work_dir + ) + records = database.get_top_k(database.commit_workload(mod), 1) + assert len(records) == 1, "No valid schedule found!" + + if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda() test_tune_run_module_via_rpc() test_tune_block_cpu() + test_tune_winograd_conv2d_cuda() From 312b90fe5cdb00109b58b729caaf41677ce21d7a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 12:11:53 -0800 Subject: [PATCH 3/8] Fix format. --- .../unittest/test_meta_schedule_tune_tir.py | 162 ++---------------- 1 file changed, 19 insertions(+), 143 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index e341c0dd73a6..9d572b1ca551 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -32,18 +32,13 @@ logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) - +# fmt: off @tvm.script.ir_module class WinogradConv2d: @T.prim_func - def main( - p0: T.Buffer[(2, 2048, 50, 75), "float32"], - p1: T.Buffer[(4, 4, 2048, 2048), "float32"], - p2: T.Buffer[(1, 2048, 1, 1), "float32"], - T_relu: T.Buffer[(2, 2048, 50, 75), "float32"], - ): + def main(p0: T.Buffer[(2, 2048, 50, 75), "float32"], p1: T.Buffer[(4, 4, 2048, 2048), "float32"], p2: T.Buffer[(1, 2048, 1, 1), "float32"], T_relu: T.Buffer[(2, 2048, 50, 75), "float32"]): # function attr dict - T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) # body # with T.block("root") data_pad = T.alloc_buffer([2, 2048, 52, 77], dtype="float32") @@ -60,109 +55,30 @@ def main( i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1]) T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) - data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 1 <= i2_1 and i2_1 < 51 and 1 <= i3_1 and i3_1 < 76, - p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], - T.float32(0), - dtype="float32", - ) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i2_1 and i2_1 < 51 and 1 <= i3_1 and i3_1 < 76, p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") for i0, i1, i2, i3 in T.grid(2048, 1900, 4, 4): with T.block("input_tile"): ci, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu]) T.writes(input_tile[ci, p, eps, nu]) - T.block_attr({"schedule_rule": "None"}) - input_tile[ci, p, eps, nu] = data_pad[ - p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu - ] + T.block_attr({"schedule_rule":"None"}) + input_tile[ci, p, eps, nu] = data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu] for i0, i1 in T.grid(4, 4): with T.block("B"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(B[i, j]) - T.block_attr({"schedule_rule": "None"}) - B[i, j] = T.Select( - i % 4 == 3 and j % 4 == 3, - T.float32(1), - T.Select( - i % 4 == 3 and j % 4 == 2, - T.float32(0), - T.Select( - i % 4 == 3 and j % 4 == 1, - T.float32(0), - T.Select( - i % 4 == 3 and j % 4 == 0, - T.float32(0), - T.Select( - i % 4 == 2 and j % 4 == 3, - T.float32(0), - T.Select( - i % 4 == 2 and j % 4 == 2, - T.float32(1), - T.Select( - i % 4 == 2 and j % 4 == 1, - T.float32(1), - T.Select( - i % 4 == 2 and j % 4 == 0, - T.float32(-1), - T.Select( - i % 4 == 1 and j % 4 == 3, - T.float32(-1), - T.Select( - i % 4 == 1 and j % 4 == 2, - T.float32(1), - T.Select( - i % 4 == 1 and j % 4 == 1, - T.float32(-1), - T.Select( - i % 4 == 1 and j % 4 == 0, - T.float32(0), - T.Select( - i % 4 == 0 and j % 4 == 3, - T.float32(0), - T.Select( - i % 4 == 0 and j % 4 == 2, - T.float32(0), - T.Select( - i % 4 == 0 - and j % 4 == 1, - T.float32(0), - T.Select( - i % 4 == 0 - and j % 4 == 0, - T.float32(1), - T.float32(0), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ), - ) + T.block_attr({"schedule_rule":"None"}) + B[i, j] = T.Select(i % 4 == 3 and j % 4 == 3, T.float32(1), T.Select(i % 4 == 3 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 0, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 2048, 1900, 4, 4): with T.block("data_pack"): eps, nu, ci, p, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads( - input_tile[ci, p, r_a, r_b], - B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1], - ) + T.reads(input_tile[ci, p, r_a, r_b], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) T.writes(data_pack[eps, nu, ci, p]) - T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) with T.init(): data_pack[eps, nu, ci, p] = T.float32(0) - data_pack[eps, nu, ci, p] = ( - data_pack[eps, nu, ci, p] - + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu] - ) + data_pack[eps, nu, ci, p] = data_pack[eps, nu, ci, p] + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu] for i0, i1, i2, i3, i4 in T.grid(4, 4, 2048, 1900, 2048): with T.block("bgemm"): eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) @@ -170,70 +86,29 @@ def main( T.writes(bgemm[eps, nu, co, p]) with T.init(): bgemm[eps, nu, co, p] = T.float32(0) - bgemm[eps, nu, co, p] = ( - bgemm[eps, nu, co, p] + data_pack[eps, nu, ci, p] * p1[eps, nu, ci, co] - ) + bgemm[eps, nu, co, p] = bgemm[eps, nu, co, p] + data_pack[eps, nu, ci, p] * p1[eps, nu, ci, co] for i0, i1 in T.grid(4, 2): with T.block("A"): i, j = T.axis.remap("SS", [i0, i1]) T.reads() T.writes(A[i, j]) - T.block_attr({"schedule_rule": "None"}) - A[i, j] = T.Select( - i % 4 == 3 and j % 2 == 1, - T.float32(1), - T.Select( - i % 4 == 3 and j % 2 == 0, - T.float32(0), - T.Select( - i % 4 == 2 and j % 2 == 1, - T.float32(1), - T.Select( - i % 4 == 2 and j % 2 == 0, - T.float32(1), - T.Select( - i % 4 == 1 and j % 2 == 1, - T.float32(-1), - T.Select( - i % 4 == 1 and j % 2 == 0, - T.float32(1), - T.Select( - i % 4 == 0 and j % 2 == 1, - T.float32(0), - T.Select( - i % 4 == 0 and j % 2 == 0, - T.float32(1), - T.float32(0), - ), - ), - ), - ), - ), - ), - ), - ) + T.block_attr({"schedule_rule":"None"}) + A[i, j] = T.Select(i % 4 == 3 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 3 and j % 2 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 1 and j % 2 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 0 and j % 2 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 2 == 0, T.float32(1), T.float32(0))))))))) for i0, i1, i2, i3, i4, i5 in T.grid(2048, 1900, 2, 2, 4, 4): with T.block("inverse"): co, p, vh, vw, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads( - bgemm[r_a, r_b, co, p], - A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1], - ) + T.reads(bgemm[r_a, r_b, co, p], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) T.writes(inverse[co, p, vh, vw]) - T.block_attr({"schedule_rule": "conv2d_nchw_winograd_inverse"}) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) with T.init(): inverse[co, p, vh, vw] = T.float32(0) - inverse[co, p, vh, vw] = ( - inverse[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw] - ) + inverse[co, p, vh, vw] = inverse[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw] for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): with T.block("conv2d_winograd"): n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2]) T.writes(conv2d_winograd[n, co, h, w]) - conv2d_winograd[n, co, h, w] = inverse[ - co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2 - ] + conv2d_winograd[n, co, h, w] = inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2] for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): with T.block("T_add"): ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) @@ -274,6 +149,7 @@ def two_step(a: T.handle, c: T.handle) -> None: with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 3.0 +# fmt: on @tvm.testing.requires_llvm From 894d59ccff7f323447e3adc40c811efb2e700bbc Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 17:04:46 -0800 Subject: [PATCH 4/8] Fix test skip. --- tests/python/unittest/test_meta_schedule_tune_tir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 9d572b1ca551..bd5500779611 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -269,7 +269,7 @@ def clone(self) -> "RemoveBlock": sch.trace.show() -@pytest.skip("Slow test and requires rtx-3070") +@pytest.mark.skip(reason="slow test and requires rtx-3070") def test_tune_winograd_conv2d_cuda(): mod = WinogradConv2d with tempfile.TemporaryDirectory() as work_dir: From 51efc936341f67c4a8760537cf299d6f2413e622 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 17:43:50 -0800 Subject: [PATCH 5/8] Change loop to fuse. --- include/tvm/meta_schedule/schedule/cuda/thread_bind.h | 4 ++-- src/meta_schedule/schedule/cuda/thread_bind.cc | 6 +++--- src/meta_schedule/schedule/cuda/winograd.cc | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index f5d7b4695084..5cd41b73bca2 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -46,11 +46,11 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. * \param allow_reorder Whether to allow reorder. + * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest. */ Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor = nullptr, - bool allow_reorder = true); + std::function get_factor = nullptr); /*! * \brief Bind the given block if it is not bound to blockIdx or threadIdx. diff --git a/src/meta_schedule/schedule/cuda/thread_bind.cc b/src/meta_schedule/schedule/cuda/thread_bind.cc index 68d4ed075339..b651b1f401cb 100644 --- a/src/meta_schedule/schedule/cuda/thread_bind.cc +++ b/src/meta_schedule/schedule/cuda/thread_bind.cc @@ -55,14 +55,14 @@ std::function MakeFactorSampler(Schedule sch, Array th Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblocks, int64_t max_threads_per_block, - std::function get_factor, bool allow_reorder) { + std::function get_factor) { int64_t extent = -1; if (const int64_t* e = as_const_int(sch->Get(loop)->extent)) { extent = *e; } else { extent = std::numeric_limits::max(); } - if (extent <= max_threadblocks * max_threads_per_block || !allow_reorder) { + if (extent <= max_threadblocks * max_threads_per_block) { if (!get_factor) { get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); } @@ -80,7 +80,7 @@ Array BindSpatialLoop(Schedule sch, LoopRV loop, int64_t max_threadblock sch->Reorder({splits[1], splits[2], splits[0]}); sch->Bind(splits[1], "blockIdx.x"); sch->Bind(splits[2], "threadIdx.x"); - return {splits[1], splits[2]}; + return {splits[1], splits[2], splits[0]}; } } diff --git a/src/meta_schedule/schedule/cuda/winograd.cc b/src/meta_schedule/schedule/cuda/winograd.cc index c7ea9a65bda3..59ed7bdc009a 100644 --- a/src/meta_schedule/schedule/cuda/winograd.cc +++ b/src/meta_schedule/schedule/cuda/winograd.cc @@ -117,8 +117,8 @@ TVM_REGISTER_GLOBAL("meta_schedule.cuda.conv2d_nchw_winograd_data_pack") sch->Unroll(loops[4]); sch->Unroll(loops[5]); outer = BindSpatialLoop(sch, sch->Fuse({loops[2], loops[3]}), max_threadblocks, - max_threads_per_block, /*get_factor=*/nullptr, - /*allow_reorder=*/false)[1]; + max_threads_per_block, /*get_factor=*/nullptr) + .back(); } { BlockRV data_pack_local = sch->CacheWrite(data_pack, 0, "local"); From 7f12dc545fa79887df473b789069119ce37943bf Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 17:44:27 -0800 Subject: [PATCH 6/8] Fix comments. --- include/tvm/meta_schedule/schedule/cuda/thread_bind.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h index 5cd41b73bca2..125d6dc11fc8 100644 --- a/include/tvm/meta_schedule/schedule/cuda/thread_bind.h +++ b/include/tvm/meta_schedule/schedule/cuda/thread_bind.h @@ -45,7 +45,6 @@ std::function MakeFactorSampler(tir::Schedule sch, * \param max_threadblocks The maximum number of threadblocks allowed. * \param max_threads_per_block The maximum number of threads allowed. * \param get_factor A function that returns the tiling factor. - * \param allow_reorder Whether to allow reorder. * \return The binded loops in the order of blockIdx.x, threadIdx.x, and the rest. */ Array BindSpatialLoop(tir::Schedule sch, tir::LoopRV loop, // From f8cd7d00d734f6889fa7ec3abfaa9e8e6b4749d9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 20:35:48 -0800 Subject: [PATCH 7/8] Move test case. --- .../test_meta_schedule_space_cuda_winograd.py | 241 ++++++++++++++++++ .../unittest/test_meta_schedule_tune_tir.py | 89 ------- 2 files changed, 241 insertions(+), 89 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py index 16f9e64252ad..53a153b90522 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -350,6 +350,247 @@ def cuda_nchw_0(data: T.Buffer[(1, 64, 56, 56), "float32"], weight: T.Buffer[(6, ) +def test_cuda_nchw_add_relu(): + # fmt: off + @T.prim_func + def nchw_add_relu(p0: T.Buffer[(2, 2048, 50, 75), "float32"], p1: T.Buffer[(4, 4, 2048, 2048), "float32"], p2: T.Buffer[(1, 2048, 1, 1), "float32"], T_relu: T.Buffer[(2, 2048, 50, 75), "float32"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([2, 2048, 52, 77], dtype="float32") + input_tile = T.alloc_buffer([2048, 1900, 4, 4], dtype="float32") + B = T.alloc_buffer([4, 4], dtype="float32") + data_pack = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + bgemm = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + A = T.alloc_buffer([4, 2], dtype="float32") + inverse = T.alloc_buffer([2048, 1900, 2, 2], dtype="float32") + conv2d_winograd = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") + T_add = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") + for i0, i1, i2, i3 in T.grid(2, 2048, 52, 77): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i2_1 and i2_1 < 51 and 1 <= i3_1 and i3_1 < 76, p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") + for i0, i1, i2, i3 in T.grid(2048, 1900, 4, 4): + with T.block("input_tile"): + ci, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu]) + T.writes(input_tile[ci, p, eps, nu]) + T.block_attr({"schedule_rule":"None"}) + input_tile[ci, p, eps, nu] = data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu] + for i0, i1 in T.grid(4, 4): + with T.block("B"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(B[i, j]) + T.block_attr({"schedule_rule":"None"}) + B[i, j] = T.Select(i % 4 == 3 and j % 4 == 3, T.float32(1), T.Select(i % 4 == 3 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 0, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 2048, 1900, 4, 4): + with T.block("data_pack"): + eps, nu, ci, p, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(input_tile[ci, p, r_a, r_b], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) + T.writes(data_pack[eps, nu, ci, p]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) + with T.init(): + data_pack[eps, nu, ci, p] = T.float32(0) + data_pack[eps, nu, ci, p] = data_pack[eps, nu, ci, p] + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu] + for i0, i1, i2, i3, i4 in T.grid(4, 4, 2048, 1900, 2048): + with T.block("bgemm"): + eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) + T.reads(data_pack[eps, nu, ci, p], p1[eps, nu, ci, co]) + T.writes(bgemm[eps, nu, co, p]) + with T.init(): + bgemm[eps, nu, co, p] = T.float32(0) + bgemm[eps, nu, co, p] = bgemm[eps, nu, co, p] + data_pack[eps, nu, ci, p] * p1[eps, nu, ci, co] + for i0, i1 in T.grid(4, 2): + with T.block("A"): + i, j = T.axis.remap("SS", [i0, i1]) + T.reads() + T.writes(A[i, j]) + T.block_attr({"schedule_rule":"None"}) + A[i, j] = T.Select(i % 4 == 3 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 3 and j % 2 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 1 and j % 2 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 0 and j % 2 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 2 == 0, T.float32(1), T.float32(0))))))))) + for i0, i1, i2, i3, i4, i5 in T.grid(2048, 1900, 2, 2, 4, 4): + with T.block("inverse"): + co, p, vh, vw, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) + T.reads(bgemm[r_a, r_b, co, p], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) + T.writes(inverse[co, p, vh, vw]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) + with T.init(): + inverse[co, p, vh, vw] = T.float32(0) + inverse[co, p, vh, vw] = inverse[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw] + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("conv2d_winograd"): + n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2]) + T.writes(conv2d_winograd[n, co, h, w]) + conv2d_winograd[n, co, h, w] = inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2] + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("T_add"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[0, ax1, 0, 0]) + T.writes(T_add[ax0, ax1, ax2, ax3]) + T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[0, ax1, 0, 0] + for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): + with T.block("T_relu"): + ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(T_add[ax0, ax1, ax2, ax3]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(T_add[ax0, ax1, ax2, ax3], T.float32(0)) + + @T.prim_func + def nchw_add_relu_scheduled(p0: T.Buffer[(2, 2048, 50, 75), "float32"], p1: T.Buffer[(4, 4, 2048, 2048), "float32"], p2: T.Buffer[(1, 2048, 1, 1), "float32"], T_relu: T.Buffer[(2, 2048, 50, 75), "float32"]): + # function attr dict + T.func_attr({"layout_free_buffers": [1], "tir.noalias": True, "global_symbol": "main"}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr({"meta_schedule.unroll_explicit":1024}) + input_tile_local = T.alloc_buffer([2048, 1900, 4, 4], dtype="float32", scope="local") + data_pack = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + bgemm = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") + inverse_local = T.alloc_buffer([2048, 1900, 2, 2], dtype="float32", scope="local") + data_pack_local = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="local") + bgemm_local = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="local") + data_pack_shared = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32", scope="shared") + p1_shared = T.alloc_buffer([4, 4, 2048, 2048], dtype="float32", scope="shared") + for i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): + for i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): + for i2_i3_fused_0 in T.serial(15): + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 4, 4): + with T.block("input_tile"): + T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) + ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax0) + p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax1) + eps, nu = T.axis.remap("SS", [ax2, ax3]) + T.reads(p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1]) + T.writes(input_tile_local[ci, p, eps, nu]) + T.block_attr({"schedule_rule":"None"}) + input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 950 // 38 * 2 + eps and p % 950 // 38 * 2 + eps < 51 and 1 <= p % 38 * 2 + nu and p % 38 * 2 + nu < 76, p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1], T.float32(0), dtype="float32") + for i0 in T.unroll(4): + for i1 in T.unroll(4): + for i4 in T.unroll(4): + for i5 in T.unroll(4): + with T.block("data_pack"): + T.where((i2_i3_fused_0 * 256 + i2_i3_fused_1) * 1024 + i2_i3_fused_2 < 3891200) + eps, nu = T.axis.remap("SS", [i0, i1]) + ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900) + p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900) + r_a, r_b = T.axis.remap("RR", [i4, i5]) + T.reads(input_tile_local[ci, p, r_a, r_b]) + T.writes(data_pack_local[eps, nu, ci, p]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) + with T.init(): + data_pack_local[eps, nu, ci, p] = T.float32(0) + data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + input_tile_local[ci, p, r_a, r_b] * T.Select(r_a % 4 == 3 and eps % 4 == 3, T.float32(1), T.Select(r_a % 4 == 3 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 1, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 0, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 3, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 1 and eps % 4 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) * T.Select(r_b % 4 == 3 and nu % 4 == 3, T.float32(1), T.Select(r_b % 4 == 3 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 1, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 0, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 3, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 1 and nu % 4 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) + for ax0, ax1, ax2, ax3 in T.grid(4, 4, 1, 1): + with T.block("data_pack_local"): + T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax2) + v3 = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax3) + T.reads(data_pack_local[v0, v1, v2, v3]) + T.writes(data_pack[v0, v1, v2, v3]) + data_pack[v0, v1, v2, v3] = data_pack_local[v0, v1, v2, v3] + for i0_0_i1_0_i2_0_i3_0_fused in T.thread_binding(24320, thread="blockIdx.x"): + for i0_1_i1_1_i2_1_i3_1_fused in T.thread_binding(2, thread="vthread.x"): + for i0_2_i1_2_i2_2_i3_2_fused in T.thread_binding(64, thread="threadIdx.x"): + for i4_0 in T.serial(256): + for ax0_ax1_ax2_ax3_fused in T.serial(640): + with T.block("data_pack_shared"): + v0 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + ax0_ax1_ax2_ax3_fused // 320) + v1 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + ax0_ax1_ax2_ax3_fused % 320 // 160) + v2 = T.axis.spatial(2048, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 160 // 20) + v3 = T.axis.spatial(1900, i0_0_i1_0_i2_0_i3_0_fused % 95 * 20 + ax0_ax1_ax2_ax3_fused % 20) + T.reads(data_pack[v0, v1, v2, v3]) + T.writes(data_pack_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":1}) + data_pack_shared[v0, v1, v2, v3] = data_pack[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(1024): + with T.block("p1_shared"): + v0 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + ax0_ax1_ax2_ax3_fused // 512) + v1 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + ax0_ax1_ax2_ax3_fused % 512 // 256) + v2 = T.axis.spatial(2048, i4_0 * 8 + ax0_ax1_ax2_ax3_fused % 256 // 32) + v3 = T.axis.spatial(2048, i0_0_i1_0_i2_0_i3_0_fused % 6080 // 95 * 32 + ax0_ax1_ax2_ax3_fused % 32) + T.reads(p1[v0, v1, v2, v3]) + T.writes(p1_shared[v0, v1, v2, v3]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + p1_shared[v0, v1, v2, v3] = p1[v0, v1, v2, v3] + for i4_1, i0_3, i1_3, i2_3, i3_3, i4_2, i0_4, i1_4, i2_4, i3_4 in T.grid(1, 1, 2, 1, 1, 8, 1, 1, 2, 5): + with T.block("bgemm"): + eps = T.axis.spatial(4, i0_4 + i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 32 + i0_3) + nu = T.axis.spatial(4, i1_4 + i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + i1_3) + co = T.axis.spatial(2048, i0_0_i1_0_i2_0_i3_0_fused % 6080 // 95 * 32 + i0_1_i1_1_i2_1_i3_1_fused * 16 + i0_2_i1_2_i2_2_i3_2_fused % 32 // 4 * 2 + i2_3 * 2 + i2_4) + p = T.axis.spatial(1900, i0_0_i1_0_i2_0_i3_0_fused % 95 * 20 + i0_2_i1_2_i2_2_i3_2_fused % 4 * 5 + i3_3 * 5 + i3_4) + ci = T.axis.reduce(2048, i4_0 * 8 + i4_1 * 8 + i4_2) + T.reads(data_pack_shared[eps, nu, ci, p], p1_shared[eps, nu, ci, co]) + T.writes(bgemm_local[eps, nu, co, p]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"}) + with T.init(): + bgemm_local[eps, nu, co, p] = T.float32(0) + bgemm_local[eps, nu, co, p] = bgemm_local[eps, nu, co, p] + data_pack_shared[eps, nu, ci, p] * p1_shared[eps, nu, ci, co] + for ax0, ax1, ax2, ax3 in T.grid(1, 2, 2, 5): + with T.block("bgemm_local"): + v0 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused // 12160 * 2 + i0_2_i1_2_i2_2_i3_2_fused // 32 + ax0) + v1 = T.axis.spatial(4, i0_0_i1_0_i2_0_i3_0_fused % 12160 // 6080 * 2 + ax1) + v2 = T.axis.spatial(2048, i0_0_i1_0_i2_0_i3_0_fused % 6080 // 95 * 32 + i0_1_i1_1_i2_1_i3_1_fused * 16 + i0_2_i1_2_i2_2_i3_2_fused % 32 // 4 * 2 + ax2) + v3 = T.axis.spatial(1900, i0_0_i1_0_i2_0_i3_0_fused % 95 * 20 + i0_2_i1_2_i2_2_i3_2_fused % 4 * 5 + ax3) + T.reads(bgemm_local[v0, v1, v2, v3]) + T.writes(bgemm[v0, v1, v2, v3]) + bgemm[v0, v1, v2, v3] = bgemm_local[v0, v1, v2, v3] + for i0, i1, i2_0, i3_0, ax0, ax1 in T.grid(2, 2048, 25, 38, 1, 1): + for ax2 in T.unroll(2): + for ax3 in T.unroll(2): + for ax4 in T.unroll(4): + for ax5 in T.unroll(4): + with T.block("inverse"): + co = T.axis.spatial(2048, i1 + ax0) + p = T.axis.spatial(1900, i0 * 950 + i2_0 * 38 + i3_0 + ax1) + vh, vw, r_a, r_b = T.axis.remap("SSRR", [ax2, ax3, ax4, ax5]) + T.reads(bgemm[r_a, r_b, co, p]) + T.writes(inverse_local[co, p, vh, vw]) + T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) + with T.init(): + inverse_local[co, p, vh, vw] = T.float32(0) + inverse_local[co, p, vh, vw] = inverse_local[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * T.Select(r_a % 4 == 3 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 3 and vh % 2 == 0, T.float32(0), T.Select(r_a % 4 == 2 and vh % 2 == 1, T.float32(1), T.Select(r_a % 4 == 2 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 1 and vh % 2 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and vh % 2 == 0, T.float32(1), T.Select(r_a % 4 == 0 and vh % 2 == 1, T.float32(0), T.Select(r_a % 4 == 0 and vh % 2 == 0, T.float32(1), T.float32(0))))))))) * T.Select(r_b % 4 == 3 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 3 and vw % 2 == 0, T.float32(0), T.Select(r_b % 4 == 2 and vw % 2 == 1, T.float32(1), T.Select(r_b % 4 == 2 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 1 and vw % 2 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and vw % 2 == 0, T.float32(1), T.Select(r_b % 4 == 0 and vw % 2 == 1, T.float32(0), T.Select(r_b % 4 == 0 and vw % 2 == 0, T.float32(1), T.float32(0))))))))) + for i0_i1_i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): + for i0_i1_i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): + for i0_i1_i2_i3_fused_0 in T.serial(59): + with T.block("T_add"): + T.where((i0_i1_i2_i3_fused_0 * 256 + i0_i1_i2_i3_fused_1) * 1024 + i0_i1_i2_i3_fused_2 < 15360000) + ax0 = T.axis.spatial(2, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) // 7680000) + ax1 = T.axis.spatial(2048, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 7680000 // 3750) + ax2 = T.axis.spatial(50, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 3750 // 75) + ax3 = T.axis.spatial(75, (i0_i1_i2_i3_fused_0 * 262144 + i0_i1_i2_i3_fused_1 * 1024 + i0_i1_i2_i3_fused_2) % 75) + T.reads(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2], p2[0, ax1, 0, 0]) + T.writes(T_relu[ax0, ax1, ax2, ax3]) + T_relu[ax0, ax1, ax2, ax3] = T.max(inverse_local[ax1, ax0 * 950 + ax2 // 2 * 38 + ax3 // 2, ax2 % 2, ax3 % 2] + p2[0, ax1, 0, 0], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [2, 1, 2, 1, 1]), + ("SamplePerfectTile", [2, 1, 1, 2, 1]), + ("SamplePerfectTile", [64, 2, 8, 1, 2]), + ("SamplePerfectTile", [95, 1, 4, 1, 5]), + ("SamplePerfectTile", [256, 1, 8]), + ("SampleCategorical", 0), + ("SampleCategorical", 3), + ("SampleCategorical", 4), + ] + with _target(): + mod = nchw_add_relu + actual = _design_space(mod) + check_sketches( + mod, + sketches=actual, + expected_mods=[nchw_add_relu_scheduled], + expected_decisions=[decision_0], + debug_mask=0, + ) + + if __name__ == "__main__": test_cuda_nhwc() test_cuda_nchw() + test_cuda_nchw_add_relu() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index bd5500779611..37eb43058487 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -33,94 +33,6 @@ logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) # fmt: off -@tvm.script.ir_module -class WinogradConv2d: - @T.prim_func - def main(p0: T.Buffer[(2, 2048, 50, 75), "float32"], p1: T.Buffer[(4, 4, 2048, 2048), "float32"], p2: T.Buffer[(1, 2048, 1, 1), "float32"], T_relu: T.Buffer[(2, 2048, 50, 75), "float32"]): - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True, "layout_free_buffers": [1]}) - # body - # with T.block("root") - data_pad = T.alloc_buffer([2, 2048, 52, 77], dtype="float32") - input_tile = T.alloc_buffer([2048, 1900, 4, 4], dtype="float32") - B = T.alloc_buffer([4, 4], dtype="float32") - data_pack = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") - bgemm = T.alloc_buffer([4, 4, 2048, 1900], dtype="float32") - A = T.alloc_buffer([4, 2], dtype="float32") - inverse = T.alloc_buffer([2048, 1900, 2, 2], dtype="float32") - conv2d_winograd = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") - T_add = T.alloc_buffer([2, 2048, 50, 75], dtype="float32") - for i0, i1, i2, i3 in T.grid(2, 2048, 52, 77): - with T.block("data_pad"): - i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1]) - T.writes(data_pad[i0_1, i1_1, i2_1, i3_1]) - data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i2_1 and i2_1 < 51 and 1 <= i3_1 and i3_1 < 76, p0[i0_1, i1_1, i2_1 - 1, i3_1 - 1], T.float32(0), dtype="float32") - for i0, i1, i2, i3 in T.grid(2048, 1900, 4, 4): - with T.block("input_tile"): - ci, p, eps, nu = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu]) - T.writes(input_tile[ci, p, eps, nu]) - T.block_attr({"schedule_rule":"None"}) - input_tile[ci, p, eps, nu] = data_pad[p // 950, ci, p % 950 // 38 * 2 + eps, p % 38 * 2 + nu] - for i0, i1 in T.grid(4, 4): - with T.block("B"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads() - T.writes(B[i, j]) - T.block_attr({"schedule_rule":"None"}) - B[i, j] = T.Select(i % 4 == 3 and j % 4 == 3, T.float32(1), T.Select(i % 4 == 3 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 3 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 2 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 4 == 0, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 3, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 2, T.float32(1), T.Select(i % 4 == 1 and j % 4 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 4 == 0, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 3, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 2, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) - for i0, i1, i2, i3, i4, i5 in T.grid(4, 4, 2048, 1900, 4, 4): - with T.block("data_pack"): - eps, nu, ci, p, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(input_tile[ci, p, r_a, r_b], B[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(eps, nu) : T.max(eps, nu) + 1]) - T.writes(data_pack[eps, nu, ci, p]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_data_pack"}) - with T.init(): - data_pack[eps, nu, ci, p] = T.float32(0) - data_pack[eps, nu, ci, p] = data_pack[eps, nu, ci, p] + input_tile[ci, p, r_a, r_b] * B[r_a, eps] * B[r_b, nu] - for i0, i1, i2, i3, i4 in T.grid(4, 4, 2048, 1900, 2048): - with T.block("bgemm"): - eps, nu, co, p, ci = T.axis.remap("SSSSR", [i0, i1, i2, i3, i4]) - T.reads(data_pack[eps, nu, ci, p], p1[eps, nu, ci, co]) - T.writes(bgemm[eps, nu, co, p]) - with T.init(): - bgemm[eps, nu, co, p] = T.float32(0) - bgemm[eps, nu, co, p] = bgemm[eps, nu, co, p] + data_pack[eps, nu, ci, p] * p1[eps, nu, ci, co] - for i0, i1 in T.grid(4, 2): - with T.block("A"): - i, j = T.axis.remap("SS", [i0, i1]) - T.reads() - T.writes(A[i, j]) - T.block_attr({"schedule_rule":"None"}) - A[i, j] = T.Select(i % 4 == 3 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 3 and j % 2 == 0, T.float32(0), T.Select(i % 4 == 2 and j % 2 == 1, T.float32(1), T.Select(i % 4 == 2 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 1 and j % 2 == 1, T.float32(-1), T.Select(i % 4 == 1 and j % 2 == 0, T.float32(1), T.Select(i % 4 == 0 and j % 2 == 1, T.float32(0), T.Select(i % 4 == 0 and j % 2 == 0, T.float32(1), T.float32(0))))))))) - for i0, i1, i2, i3, i4, i5 in T.grid(2048, 1900, 2, 2, 4, 4): - with T.block("inverse"): - co, p, vh, vw, r_a, r_b = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5]) - T.reads(bgemm[r_a, r_b, co, p], A[T.min(r_a, r_b) : T.max(r_a, r_b) + 1, T.min(vh, vw) : T.max(vh, vw) + 1]) - T.writes(inverse[co, p, vh, vw]) - T.block_attr({"schedule_rule":"conv2d_nchw_winograd_inverse"}) - with T.init(): - inverse[co, p, vh, vw] = T.float32(0) - inverse[co, p, vh, vw] = inverse[co, p, vh, vw] + bgemm[r_a, r_b, co, p] * A[r_a, vh] * A[r_b, vw] - for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): - with T.block("conv2d_winograd"): - n, co, h, w = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2]) - T.writes(conv2d_winograd[n, co, h, w]) - conv2d_winograd[n, co, h, w] = inverse[co, n * 950 + h // 2 * 38 + w // 2, h % 2, w % 2] - for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): - with T.block("T_add"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(conv2d_winograd[ax0, ax1, ax2, ax3], p2[0, ax1, 0, 0]) - T.writes(T_add[ax0, ax1, ax2, ax3]) - T_add[ax0, ax1, ax2, ax3] = conv2d_winograd[ax0, ax1, ax2, ax3] + p2[0, ax1, 0, 0] - for i0, i1, i2, i3 in T.grid(2, 2048, 50, 75): - with T.block("T_relu"): - ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) - T.reads(T_add[ax0, ax1, ax2, ax3]) - T.writes(T_relu[ax0, ax1, ax2, ax3]) - T_relu[ax0, ax1, ax2, ax3] = T.max(T_add[ax0, ax1, ax2, ax3], T.float32(0)) @T.prim_func @@ -285,4 +197,3 @@ def test_tune_winograd_conv2d_cuda(): test_tune_matmul_cuda() test_tune_run_module_via_rpc() test_tune_block_cpu() - test_tune_winograd_conv2d_cuda() From 9ff9f0e96b65950d3a98e36445383675ed92cee1 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 17 Nov 2022 20:37:47 -0800 Subject: [PATCH 8/8] Remove change on test tir. --- .../python/unittest/test_meta_schedule_tune_tir.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 37eb43058487..aa45120c2316 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -32,8 +32,6 @@ logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -# fmt: off - @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -61,7 +59,6 @@ def two_step(a: T.handle, c: T.handle) -> None: with T.block("B"): vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 3.0 -# fmt: on @tvm.testing.requires_llvm @@ -181,17 +178,6 @@ def clone(self) -> "RemoveBlock": sch.trace.show() -@pytest.mark.skip(reason="slow test and requires rtx-3070") -def test_tune_winograd_conv2d_cuda(): - mod = WinogradConv2d - with tempfile.TemporaryDirectory() as work_dir: - database = ms.tune_tir( - mod, target="nvidia/geforce-rtx-3070", max_trials_global=10, work_dir=work_dir - ) - records = database.get_top_k(database.commit_workload(mod), 1) - assert len(records) == 1, "No valid schedule found!" - - if __name__ == """__main__""": test_tune_matmul_cpu() test_tune_matmul_cuda()