From fda7b035c76d5806e5ced0fb34901f01eec9f7ed Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 30 Aug 2022 05:09:26 -0700 Subject: [PATCH 1/2] Complete winograd scheduling. --- python/tvm/topi/cuda/conv2d_winograd.py | 1 + src/meta_schedule/schedule_rule/winograd.cc | 29 ++++++++++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index f5e6cd88a5e3..239d05844b40 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -104,6 +104,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] ), name="kernel_pack", + attrs={"schedule_rule": "meta_schedule.winograd_kernel_pack.nchw.cuda"}, ) else: kernel_pack = kernel diff --git a/src/meta_schedule/schedule_rule/winograd.cc b/src/meta_schedule/schedule_rule/winograd.cc index 8ae8118731dd..22e2300d63b6 100644 --- a/src/meta_schedule/schedule_rule/winograd.cc +++ b/src/meta_schedule/schedule_rule/winograd.cc @@ -185,6 +185,32 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.nchw.cuda") return {sch}; }); +TVM_REGISTER_GLOBAL("meta_schedule.winograd_kernel_pack.nchw.cuda") + .set_body_typed([](Schedule sch, BlockRV kernel_pack) -> Array { + Array loops = sch->GetLoops(kernel_pack); + ICHECK_EQ(loops.size(), 6); + if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) { + if (*i <= 16) { + sch->Unroll(loops[0]); + } + } + if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) { + if (*i <= 16) { + sch->Unroll(loops[1]); + } + } + sch->Unroll(loops[4]); + sch->Unroll(loops[5]); + + LoopRV fused = sch->Fuse({loops[2], loops[3]}); + + int64_t max_threadblocks = 256; + int64_t max_threads_per_block = 1024; + auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024}); + BindBlockThreadIdx(sch, kernel_pack, max_threadblocks, max_threads_per_block, get_factor); + return {sch}; + }); + TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda") .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { BlockRV input_tile = GetOnlyProducer(sch, data_pack); @@ -206,9 +232,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.nchw.cuda") BlockRV data_pad = GetOnlyProducer(sch, input_tile); BlockRV data_l = sch->CacheWrite(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local"); + BlockRV d = sch->CacheRead(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local"); LoopRV loop = ScheduleDataPackNCHW(sch, data_pack); sch->ReverseComputeAt(data_l, loop, /*preserve_unit_loops=*/true); - sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); + sch->ComputeAt(d, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); sch->ComputeInline(data_pad); int64_t max_threadblocks = 256; From bbe5c05d8073aa491e9db73628382713d5bb1700 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Tue, 30 Aug 2022 15:44:34 -0700 Subject: [PATCH 2/2] Fix test. --- .../unittest/test_meta_schedule_space_cuda.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda.py b/tests/python/unittest/test_meta_schedule_space_cuda.py index ce333887ec83..ffa2b57ba8ec 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda.py @@ -1338,11 +1338,20 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T bgemm = T.alloc_buffer([6, 6, 64, 3136], dtype="float32") inverse_local = T.alloc_buffer([64, 3136, 4, 4], dtype="float32", scope="local") data_pack_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") + d_local = T.alloc_buffer([64, 3136, 6, 6], dtype="float32", scope="local") bgemm_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local") kernel_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared") data_pack_shared = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="shared") for i2_i3_0_fused_i3_1_fused_0 in T.thread_binding(3136, thread="blockIdx.x"): for i2_i3_0_fused_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6): + with T.block("d_local"): + v0 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax0) + v1 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7 + ax1) + v2, v3 = T.axis.remap("SS", [ax2, ax3]) + T.reads(data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1]) + T.writes(d_local[v0, v1, v2, v3]) + d_local[v0, v1, v2, v3] = T.if_then_else(1 <= v1 % 3136 // 56 * 4 + v2 and v1 % 3136 // 56 * 4 + v2 < 225 and 1 <= v1 % 56 * 4 + v3 and v1 % 56 * 4 + v3 < 225, data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1], T.float32(0), dtype="float32") for i0 in T.unroll(6): for i1 in T.unroll(6): for i4 in T.unroll(6): @@ -1352,12 +1361,12 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T ci = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136) p = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7) r_a, r_a_1 = T.axis.remap("RR", [i4, i5]) - T.reads(data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1]) + T.reads(d_local[ci, p, r_a, r_a_1]) T.writes(data_pack_local[eps, nu, ci, p]) T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"}) with T.init(): data_pack_local[eps, nu, ci, p] = T.float32(0) - data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + T.if_then_else(1 <= p % 3136 // 56 * 4 + r_a and p % 3136 // 56 * 4 + r_a < 225 and 1 <= p % 56 * 4 + r_a_1 and p % 56 * 4 + r_a_1 < 225, data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1], T.float32(0), dtype="float32") * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + d_local[ci, p, r_a, r_a_1] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("data_pack_local"): v0, v1 = T.axis.remap("SS", [ax0, ax1])