Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/conv2d_winograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_
kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
),
name="kernel_pack",
attrs={"schedule_rule": "meta_schedule.winograd_kernel_pack.nchw.cuda"},
)
else:
kernel_pack = kernel
Expand Down
29 changes: 28 additions & 1 deletion src/meta_schedule/schedule_rule/winograd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,32 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse.nchw.cuda")
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_kernel_pack.nchw.cuda")
.set_body_typed([](Schedule sch, BlockRV kernel_pack) -> Array<Schedule> {
Array<LoopRV> loops = sch->GetLoops(kernel_pack);
ICHECK_EQ(loops.size(), 6);
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[0]))) {
if (*i <= 16) {
sch->Unroll(loops[0]);
}
}
if (const int64_t* i = tir::GetLoopIntExtent(sch->GetSRef(loops[1]))) {
if (*i <= 16) {
sch->Unroll(loops[1]);
}
}
sch->Unroll(loops[4]);
sch->Unroll(loops[5]);

LoopRV fused = sch->Fuse({loops[2], loops[3]});

int64_t max_threadblocks = 256;
int64_t max_threads_per_block = 1024;
auto get_factor = MakeFactorSampler(sch, {32, 64, 128, 256, 512, 1024});
BindBlockThreadIdx(sch, kernel_pack, max_threadblocks, max_threads_per_block, get_factor);
return {sch};
});

TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda")
.set_body_typed([](Schedule sch, BlockRV data_pack) -> Array<Schedule> {
BlockRV input_tile = GetOnlyProducer(sch, data_pack);
Expand All @@ -206,9 +232,10 @@ TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.nchw.cuda")
BlockRV data_pad = GetOnlyProducer(sch, input_tile);

BlockRV data_l = sch->CacheWrite(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local");
BlockRV d = sch->CacheRead(data_pack, /*buffer_index=*/0, /*storage_scope=*/"local");
LoopRV loop = ScheduleDataPackNCHW(sch, data_pack);
sch->ReverseComputeAt(data_l, loop, /*preserve_unit_loops=*/true);
sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true);
sch->ComputeAt(d, /*loop_rv=*/loop, /*preserve_unit_loops=*/true);
sch->ComputeInline(data_pad);

int64_t max_threadblocks = 256;
Expand Down
13 changes: 11 additions & 2 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,11 +1338,20 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T
bgemm = T.alloc_buffer([6, 6, 64, 3136], dtype="float32")
inverse_local = T.alloc_buffer([64, 3136, 4, 4], dtype="float32", scope="local")
data_pack_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local")
d_local = T.alloc_buffer([64, 3136, 6, 6], dtype="float32", scope="local")
bgemm_local = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="local")
kernel_shared = T.alloc_buffer([6, 6, 64, 64], dtype="float32", scope="shared")
data_pack_shared = T.alloc_buffer([6, 6, 64, 3136], dtype="float32", scope="shared")
for i2_i3_0_fused_i3_1_fused_0 in T.thread_binding(3136, thread="blockIdx.x"):
for i2_i3_0_fused_i3_1_fused_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6):
with T.block("d_local"):
v0 = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136 + ax0)
v1 = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7 + ax1)
v2, v3 = T.axis.remap("SS", [ax2, ax3])
T.reads(data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1])
T.writes(d_local[v0, v1, v2, v3])
d_local[v0, v1, v2, v3] = T.if_then_else(1 <= v1 % 3136 // 56 * 4 + v2 and v1 % 3136 // 56 * 4 + v2 < 225 and 1 <= v1 % 56 * 4 + v3 and v1 % 56 * 4 + v3 < 225, data[v1 // 3136, v0, v1 % 3136 // 56 * 4 + v2 - 1, v1 % 56 * 4 + v3 - 1], T.float32(0), dtype="float32")
for i0 in T.unroll(6):
for i1 in T.unroll(6):
for i4 in T.unroll(6):
Expand All @@ -1352,12 +1361,12 @@ def winograd_nchw_conv2d(data: T.Buffer[(1, 64, 224, 224), "float32"], kernel: T
ci = T.axis.spatial(64, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) // 3136)
p = T.axis.spatial(3136, (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 3136 // 7 * 7 + (i2_i3_0_fused_i3_1_fused_0 * 64 + i2_i3_0_fused_i3_1_fused_1) % 7)
r_a, r_a_1 = T.axis.remap("RR", [i4, i5])
T.reads(data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1])
T.reads(d_local[ci, p, r_a, r_a_1])
T.writes(data_pack_local[eps, nu, ci, p])
T.block_attr({"schedule_rule":"meta_schedule.winograd_data_pack.nchw.cuda"})
with T.init():
data_pack_local[eps, nu, ci, p] = T.float32(0)
data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + T.if_then_else(1 <= p % 3136 // 56 * 4 + r_a and p % 3136 // 56 * 4 + r_a < 225 and 1 <= p % 56 * 4 + r_a_1 and p % 56 * 4 + r_a_1 < 225, data[p // 3136, ci, p % 3136 // 56 * 4 + r_a - 1, p % 56 * 4 + r_a_1 - 1], T.float32(0), dtype="float32") * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))
data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + d_local[ci, p, r_a, r_a_1] * T.Select(r_a % 6 == 5 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 5 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 5 and eps % 6 == 0, T.float32(0), T.Select(r_a % 6 == 4 and eps % 6 == 5, T.float32(1.5), T.Select(r_a % 6 == 4 and eps % 6 == 4, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 3, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 2, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 4 and eps % 6 == 0, T.float32(1), T.Select(r_a % 6 == 3 and eps % 6 == 5, T.float32(-2), T.Select(r_a % 6 == 3 and eps % 6 == 4, T.float32(-0.5), T.Select(r_a % 6 == 3 and eps % 6 == 3, T.float32(2), T.Select(r_a % 6 == 3 and eps % 6 == 2, T.float32(2.5), T.Select(r_a % 6 == 3 and eps % 6 == 1, T.float32(0.5), T.Select(r_a % 6 == 3 and eps % 6 == 0, T.float32(1.5), T.Select(r_a % 6 == 2 and eps % 6 == 5, T.float32(-1.5), T.Select(r_a % 6 == 2 and eps % 6 == 4, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 3, T.float32(-1), T.Select(r_a % 6 == 2 and eps % 6 == 2, T.float32(0.5), T.Select(r_a % 6 == 2 and eps % 6 == 1, T.float32(-2.5), T.Select(r_a % 6 == 2 and eps % 6 == 0, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 5, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 4, T.float32(0.5), T.Select(r_a % 6 == 1 and eps % 6 == 3, T.float32(-2), T.Select(r_a % 6 == 1 and eps % 6 == 2, T.float32(-1), T.Select(r_a % 6 == 1 and eps % 6 == 1, T.float32(1), T.Select(r_a % 6 == 1 and eps % 6 == 0, T.float32(-1.5), T.Select(r_a % 6 == 0 and eps % 6 == 5, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 4, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 3, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 2, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 1, T.float32(0), T.Select(r_a % 6 == 0 and eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(r_a_1 % 6 == 5 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 5 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 5 and nu % 6 == 0, T.float32(0), T.Select(r_a_1 % 6 == 4 and nu % 6 == 5, T.float32(1.5), T.Select(r_a_1 % 6 == 4 and nu % 6 == 4, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 3, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 2, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 4 and nu % 6 == 0, T.float32(1), T.Select(r_a_1 % 6 == 3 and nu % 6 == 5, T.float32(-2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 4, T.float32(-0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 3, T.float32(2), T.Select(r_a_1 % 6 == 3 and nu % 6 == 2, T.float32(2.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 1, T.float32(0.5), T.Select(r_a_1 % 6 == 3 and nu % 6 == 0, T.float32(1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 5, T.float32(-1.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 4, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 3, T.float32(-1), T.Select(r_a_1 % 6 == 2 and nu % 6 == 2, T.float32(0.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 1, T.float32(-2.5), T.Select(r_a_1 % 6 == 2 and nu % 6 == 0, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 5, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 4, T.float32(0.5), T.Select(r_a_1 % 6 == 1 and nu % 6 == 3, T.float32(-2), T.Select(r_a_1 % 6 == 1 and nu % 6 == 2, T.float32(-1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 1, T.float32(1), T.Select(r_a_1 % 6 == 1 and nu % 6 == 0, T.float32(-1.5), T.Select(r_a_1 % 6 == 0 and nu % 6 == 5, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 4, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 3, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 2, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 1, T.float32(0), T.Select(r_a_1 % 6 == 0 and nu % 6 == 0, T.float32(1), T.float32(0)))))))))))))))))))))))))))))))))))))
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1):
with T.block("data_pack_local"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
Expand Down