From 918134d8067e391873b2e413a31090b3f76225bd Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Tue, 25 Apr 2023 17:29:45 +0300 Subject: [PATCH 1/2] [MetaSchedule] skip custom_rule == None attribute Signed-off-by: Alexander Peskov --- .../space_generator/post_order_apply.cc | 3 +- ...chedule_schedule_rule_apply_custom_rule.py | 34 +++++++++++++++---- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 491af6e28f77..1495fc604e8a 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -142,7 +142,8 @@ class PostOrderApplyNode : public SpaceGeneratorNode { continue; } if (!ScheduleRule::IsApplyCustomRule(sch_rule)) { - if (tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule").defined()) { + auto sch_rule = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule"); + if (sch_rule.defined() && sch_rule.value() != "None") { stack.emplace_back(sch, blocks); continue; } diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py index 2bfa3070d1b4..811dfadad714 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_apply_custom_rule.py @@ -14,33 +14,35 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +# pylint: disable=invalid-name,missing-module-docstring,missing-function-docstring,missing-class-docstring from typing import List import tempfile import pytest import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule.schedule_rule import ApplyCustomRule +from tvm.meta_schedule.schedule_rule import ApplyCustomRule, MultiLevelTiling +from tvm.meta_schedule.testing.space_generation import generate_design_space from tvm.script import tir as T -@tvm.script.ir_module -class Matmul: +def create_matmul(rule_name: str): @T.prim_func - def main(a: T.handle, b: T.handle, c: T.handle) -> None: + def func(a: T.handle, b: T.handle, c: T.handle) -> None: T.func_attr({"global_symbol": "main"}) A = T.match_buffer(a, (1024, 1024), "float32") B = T.match_buffer(b, (1024, 1024), "float32") C = T.match_buffer(c, (1024, 1024), "float32") for i, j, k in T.grid(1024, 1024, 1024): with T.block("matmul"): - T.block_attr({"schedule_rule": "test_apply_custom_rule"}) + T.block_attr({"schedule_rule": rule_name}) vi, vj, vk = T.axis.remap("SSR", [i, j, k]) with T.init(): C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + return func + @tvm.register_func("meta_schedule.cpu.test_apply_custom_rule") def sch_fn(sch: tvm.tir.Schedule, block: tvm.tir.Block) -> List[tvm.tir.Schedule]: @@ -53,7 +55,7 @@ def test_custom_rule(): sch_rules = [ApplyCustomRule()] space_gen = ms.space_generator.PostOrderApply(sch_rules=sch_rules) ms.tune_tir( - mod=Matmul, + mod=create_matmul(rule_name="test_apply_custom_rule"), target="llvm -num-cores=1", work_dir=tmpdir, max_trials_global=10, @@ -62,5 +64,23 @@ def test_custom_rule(): assert "ValueError: Intended for meta_schedule.cpu.test_apply_custom_rule" in str(e_info.value) +def test_custom_rule_with_none(): + """Should ignore custom_rule and apply MultiLevelTiling""" + schs = generate_design_space( + "llvm", + mod=create_matmul(rule_name="None"), + target=tvm.target.Target("llvm -num-cores=1"), + types=None, + sch_rules=[ApplyCustomRule(), MultiLevelTiling("SSR")], + ) + assert len(schs) == 1 + tiling_annotations = [ + inst + for inst in schs[0].trace.insts + if inst.kind.name == "Annotate" and inst.attrs[0] == "meta_schedule.tiling_structure" + ] + assert len(tiling_annotations) == 1, "Tiling rule was not applied" + + if __name__ == "__main__": test_custom_rule() From dadc794cdf7ef450f30e15c3584cb1e7f2f8d5fa Mon Sep 17 00:00:00 2001 From: Alexander Peskov Date: Fri, 28 Apr 2023 14:57:18 +0300 Subject: [PATCH 2/2] Fix tests with hardcoded prim funcs Signed-off-by: Alexander Peskov --- .../test_meta_schedule_space_cuda_winograd.py | 45 +++---------------- 1 file changed, 6 insertions(+), 39 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py index 27fe47ab8699..8a29820f2d05 100644 --- a/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py +++ b/tests/python/unittest/test_meta_schedule_space_cuda_winograd.py @@ -48,7 +48,6 @@ def cuda_nhwc_0(data: T.Buffer((1, 14, 14, 128), "float32"), weight: T.Buffer((6 T.reads() T.writes() T.block_attr({"meta_schedule.unroll_explicit": 16}) - input_tile_local = T.alloc_buffer((6, 6, 9, 128), scope="local") data_pack = T.alloc_buffer((6, 6, 9, 128)) bgemm = T.alloc_buffer((6, 6, 9, 128)) inverse = T.alloc_buffer((4, 4, 9, 128)) @@ -58,16 +57,6 @@ def cuda_nhwc_0(data: T.Buffer((1, 14, 14, 128), "float32"), weight: T.Buffer((6 weight_shared = T.alloc_buffer((6, 6, 128, 128), scope="shared") for p_0_ci_0_p_1_ci_1_fused_0 in T.thread_binding(2, thread="blockIdx.x"): for p_0_ci_0_p_1_ci_1_fused_1 in T.thread_binding(1024, thread="threadIdx.x"): - for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): - with T.block("input_tile"): - v_eps, v_nu = T.axis.remap("SS", [ax0, ax1]) - v_p = T.axis.spatial(9, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) // 384 * 3 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 24 // 8 + ax2) - v_ci = T.axis.spatial(128, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 384 // 24 * 8 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 8 + ax3) - T.where(p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1 < 1152) - T.reads(data[v_p // 9, v_p % 9 // 3 * 4 + v_eps, v_p % 3 * 4 + v_nu, v_ci]) - T.writes(input_tile_local[v_eps, v_nu, v_p, v_ci]) - T.block_attr({"schedule_rule": "None"}) - input_tile_local[v_eps, v_nu, v_p, v_ci] = T.if_then_else(0 <= v_p % 9 // 3 * 4 + v_eps and v_p % 9 // 3 * 4 + v_eps < 14 and 0 <= v_p % 3 * 4 + v_nu and v_p % 3 * 4 + v_nu < 14, data[v_p // 9, v_p % 9 // 3 * 4 + v_eps, v_p % 3 * 4 + v_nu, v_ci], T.float32(0)) for eps in T.unroll(6): for nu in T.unroll(6): for r_a in T.unroll(6): @@ -78,12 +67,12 @@ def cuda_nhwc_0(data: T.Buffer((1, 14, 14, 128), "float32"), weight: T.Buffer((6 v_ci = T.axis.spatial(128, (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 384 // 24 * 8 + (p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1) % 8) v_r_a, v_r_b = T.axis.remap("RR", [r_a, r_b]) T.where(p_0_ci_0_p_1_ci_1_fused_0 * 1024 + p_0_ci_0_p_1_ci_1_fused_1 < 1152) - T.reads(input_tile_local[v_r_a, v_r_b, v_p, v_ci]) + T.reads(data[v_p // 9, v_p % 9 // 3 * 4 + v_r_a, v_p % 3 * 4 + v_r_b, v_ci]) T.writes(data_pack_local[v_eps, v_nu, v_p, v_ci]) T.block_attr({"schedule_rule": "conv2d_nhwc_winograd_data_pack"}) with T.init(): data_pack_local[v_eps, v_nu, v_p, v_ci] = T.float32(0) - data_pack_local[v_eps, v_nu, v_p, v_ci] = data_pack_local[v_eps, v_nu, v_p, v_ci] + input_tile_local[v_r_a, v_r_b, v_p, v_ci] * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + data_pack_local[v_eps, v_nu, v_p, v_ci] = data_pack_local[v_eps, v_nu, v_p, v_ci] + T.if_then_else(0 <= v_p % 9 // 3 * 4 + v_r_a and v_p % 9 // 3 * 4 + v_r_a < 14 and 0 <= v_p % 3 * 4 + v_r_b and v_p % 3 * 4 + v_r_b < 14, data[v_p // 9, v_p % 9 // 3 * 4 + v_r_a, v_p % 3 * 4 + v_r_b, v_ci], T.float32(0)) * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("data_pack_local"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) @@ -204,7 +193,6 @@ def cuda_nchw_0(data: T.Buffer((1, 64, 56, 56), "float32"), weight: T.Buffer((6, T.reads() T.writes() T.block_attr({"meta_schedule.unroll_explicit": 16}) - input_tile_local = T.alloc_buffer((64, 196, 6, 6), scope="local") data_pack = T.alloc_buffer((6, 6, 64, 196)) bgemm = T.alloc_buffer((6, 6, 64, 196)) inverse_local = T.alloc_buffer((64, 196, 4, 4), scope="local") @@ -214,16 +202,6 @@ def cuda_nchw_0(data: T.Buffer((1, 64, 56, 56), "float32"), weight: T.Buffer((6, weight_shared = T.alloc_buffer((6, 6, 64, 64), scope="shared") for ci_p_fused_0 in T.thread_binding(25, thread="blockIdx.x"): for ci_p_fused_1 in T.thread_binding(512, thread="threadIdx.x"): - for ax0, ax1, ax2, ax3 in T.grid(1, 1, 6, 6): - with T.block("input_tile"): - v_ci = T.axis.spatial(64, (ci_p_fused_0 * 512 + ci_p_fused_1) // 196 + ax0) - v_p = T.axis.spatial(196, (ci_p_fused_0 * 120 + ci_p_fused_1) % 196 + ax1) - v_eps, v_nu = T.axis.remap("SS", [ax2, ax3]) - T.where(ci_p_fused_0 * 512 + ci_p_fused_1 < 12544) - T.reads(data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_eps - 1, v_p % 14 * 4 + v_nu - 1]) - T.writes(input_tile_local[v_ci, v_p, v_eps, v_nu]) - T.block_attr({"schedule_rule": "None"}) - input_tile_local[v_ci, v_p, v_eps, v_nu] = T.if_then_else(1 <= v_p % 196 // 14 * 4 + v_eps and v_p % 196 // 14 * 4 + v_eps < 57 and 1 <= v_p % 14 * 4 + v_nu and v_p % 14 * 4 + v_nu < 57, data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_eps - 1, v_p % 14 * 4 + v_nu - 1], T.float32(0)) for eps in T.unroll(6): for nu in T.unroll(6): for r_a in T.unroll(6): @@ -234,12 +212,12 @@ def cuda_nchw_0(data: T.Buffer((1, 64, 56, 56), "float32"), weight: T.Buffer((6, v_p = T.axis.spatial(196, (ci_p_fused_0 * 512 + ci_p_fused_1) % 196) v_r_a, v_r_b = T.axis.remap("RR", [r_a, r_b]) T.where(ci_p_fused_0 * 512 + ci_p_fused_1 < 12544) - T.reads(input_tile_local[v_ci, v_p, v_r_a, v_r_b]) + T.reads(data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_r_a - 1, v_p % 14 * 4 + v_r_b - 1]) T.writes(data_pack_local[v_eps, v_nu, v_ci, v_p]) T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) with T.init(): data_pack_local[v_eps, v_nu, v_ci, v_p] = T.float32(0) - data_pack_local[v_eps, v_nu, v_ci, v_p] = data_pack_local[v_eps, v_nu, v_ci, v_p] + input_tile_local[v_ci, v_p, v_r_a, v_r_b] * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + data_pack_local[v_eps, v_nu, v_ci, v_p] = data_pack_local[v_eps, v_nu, v_ci, v_p] + T.if_then_else(1 <= v_p % 196 // 14 * 4 + v_r_a and v_p % 196 // 14 * 4 + v_r_a < 57 and 1 <= v_p % 14 * 4 + v_r_b and v_p % 14 * 4 + v_r_b < 57, data[v_p // 196, v_ci, v_p % 196 // 14 * 4 + v_r_a - 1, v_p % 14 * 4 + v_r_b - 1], T.float32(0)) * T.Select(v_r_a % 6 == 5 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 5 and v_eps % 6 == 0, T.float32(0), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 5, T.float32(1.5), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 4, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 3, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 2, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 4 and v_eps % 6 == 0, T.float32(1), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 5, T.float32(-2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 4, T.float32(-0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 3, T.float32(2), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 2, T.float32(2.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 1, T.float32(0.5), T.Select(v_r_a % 6 == 3 and v_eps % 6 == 0, T.float32(1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 5, T.float32(-1.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 4, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 3, T.float32(-1), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 2, T.float32(0.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 1, T.float32(-2.5), T.Select(v_r_a % 6 == 2 and v_eps % 6 == 0, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 5, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 4, T.float32(0.5), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 3, T.float32(-2), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 2, T.float32(-1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 1, T.float32(1), T.Select(v_r_a % 6 == 1 and v_eps % 6 == 0, T.float32(-1.5), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 5, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 4, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 3, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 2, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 1, T.float32(0), T.Select(v_r_a % 6 == 0 and v_eps % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) * T.Select(v_r_b % 6 == 5 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 5 and v_nu % 6 == 0, T.float32(0), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 5, T.float32(1.5), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 4, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 3, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 2, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 4 and v_nu % 6 == 0, T.float32(1), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 5, T.float32(-2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 4, T.float32(-0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 3, T.float32(2), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 2, T.float32(2.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 1, T.float32(0.5), T.Select(v_r_b % 6 == 3 and v_nu % 6 == 0, T.float32(1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 5, T.float32(-1.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 4, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 3, T.float32(-1), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 2, T.float32(0.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 1, T.float32(-2.5), T.Select(v_r_b % 6 == 2 and v_nu % 6 == 0, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 5, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 4, T.float32(0.5), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 3, T.float32(-2), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 2, T.float32(-1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 1, T.float32(1), T.Select(v_r_b % 6 == 1 and v_nu % 6 == 0, T.float32(-1.5), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 5, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 4, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 3, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 2, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 1, T.float32(0), T.Select(v_r_b % 6 == 0 and v_nu % 6 == 0, T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(6, 6, 1, 1): with T.block("data_pack_local"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) @@ -442,7 +420,6 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu T.reads() T.writes() T.block_attr({"meta_schedule.unroll_explicit": 1024}) - input_tile_local = T.alloc_buffer((2048, 1900, 4, 4), scope="local") data_pack = T.alloc_buffer((4, 4, 2048, 1900)) bgemm = T.alloc_buffer((4, 4, 2048, 1900)) inverse_local = T.alloc_buffer((2048, 1900, 2, 2), scope="local") @@ -453,16 +430,6 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu for i2_i3_fused_1 in T.thread_binding(256, thread="blockIdx.x"): for i2_i3_fused_2 in T.thread_binding(1024, thread="threadIdx.x"): for i2_i3_fused_0 in range(15): - for ax0, ax1, ax2, ax3 in T.grid(1, 1, 4, 4): - with T.block("input_tile"): - ci = T.axis.spatial(2048, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) // 1900 + ax0) - p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900 + ax1) - eps, nu = T.axis.remap("SS", [ax2, ax3]) - T.where(i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2 < 3891200) - T.reads(p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1]) - T.writes(input_tile_local[ci, p, eps, nu]) - T.block_attr({"schedule_rule": "None"}) - input_tile_local[ci, p, eps, nu] = T.if_then_else(1 <= p % 950 // 38 * 2 + eps and p % 950 // 38 * 2 + eps < 51 and 1 <= p % 38 * 2 + nu and p % 38 * 2 + nu < 76, p0[p // 950, ci, p % 950 // 38 * 2 + eps - 1, p % 38 * 2 + nu - 1], T.float32(0)) for i0 in T.unroll(4): for i1 in T.unroll(4): for i4 in T.unroll(4): @@ -473,12 +440,12 @@ def nchw_add_relu_scheduled(p0: T.Buffer((2, 2048, 50, 75), "float32"), p1: T.Bu p = T.axis.spatial(1900, (i2_i3_fused_0 * 262144 + i2_i3_fused_1 * 1024 + i2_i3_fused_2) % 1900) r_a, r_b = T.axis.remap("RR", [i4, i5]) T.where((i2_i3_fused_0 * 256 + i2_i3_fused_1) * 1024 + i2_i3_fused_2 < 3891200) - T.reads(input_tile_local[ci, p, r_a, r_b]) + T.reads(p0[p // 950, ci, p % 950 // 38 * 2 + r_a - 1, p % 38 * 2 + r_b - 1]) T.writes(data_pack_local[eps, nu, ci, p]) T.block_attr({"schedule_rule": "conv2d_nchw_winograd_data_pack"}) with T.init(): data_pack_local[eps, nu, ci, p] = T.float32(0) - data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + input_tile_local[ci, p, r_a, r_b] * T.Select(r_a % 4 == 3 and eps % 4 == 3, T.float32(1), T.Select(r_a % 4 == 3 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 1, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 0, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 3, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 1 and eps % 4 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) * T.Select(r_b % 4 == 3 and nu % 4 == 3, T.float32(1), T.Select(r_b % 4 == 3 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 1, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 0, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 3, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 1 and nu % 4 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) + data_pack_local[eps, nu, ci, p] = data_pack_local[eps, nu, ci, p] + T.if_then_else(1 <= p % 950 // 38 * 2 + r_a and p % 950 // 38 * 2 + r_a < 51 and 1 <= p % 38 * 2 + r_b and p % 38 * 2 + r_b < 76, p0[p // 950, ci, p % 950 // 38 * 2 + r_a - 1, p % 38 * 2 + r_b - 1], T.float32(0)) * T.Select(r_a % 4 == 3 and eps % 4 == 3, T.float32(1), T.Select(r_a % 4 == 3 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 3 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 2 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 1, T.float32(1), T.Select(r_a % 4 == 2 and eps % 4 == 0, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 3, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 2, T.float32(1), T.Select(r_a % 4 == 1 and eps % 4 == 1, T.float32(-1), T.Select(r_a % 4 == 1 and eps % 4 == 0, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 3, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 2, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 1, T.float32(0), T.Select(r_a % 4 == 0 and eps % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) * T.Select(r_b % 4 == 3 and nu % 4 == 3, T.float32(1), T.Select(r_b % 4 == 3 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 3 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 2 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 1, T.float32(1), T.Select(r_b % 4 == 2 and nu % 4 == 0, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 3, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 2, T.float32(1), T.Select(r_b % 4 == 1 and nu % 4 == 1, T.float32(-1), T.Select(r_b % 4 == 1 and nu % 4 == 0, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 3, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 2, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 1, T.float32(0), T.Select(r_b % 4 == 0 and nu % 4 == 0, T.float32(1), T.float32(0))))))))))))))))) for ax0, ax1, ax2, ax3 in T.grid(4, 4, 1, 1): with T.block("data_pack_local"): v0, v1 = T.axis.remap("SS", [ax0, ax1])