From 2cfe81d595c8ce7144be4a793824565317490658 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 2 Aug 2023 12:57:29 -0700 Subject: [PATCH 1/2] [TIR, Schedule] Fix decompose reduction with thread binding loops --- src/tir/schedule/primitive/reduction.cc | 8 ++++- .../unittest/test_tir_schedule_reduction.py | 33 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 6069f4289cf3..aec6d5b8cb7d 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -271,11 +271,17 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); loop_var_map[old_loop_var] = new_loop_var; + Optional opt_thread_binding = old_loop->thread_binding; + if (opt_thread_binding) { + auto thread_binding = opt_thread_binding.value(); + thread_binding.CopyOnWrite()->var = new_loop_var; + } body = For(/*loop_var=*/new_loop_var, /*min=*/old_loop->min, /*extent=*/old_loop->extent, /*kind=*/old_loop->kind, - /*body=*/body); + /*body=*/body, + /*thread_binding*/opt_thread_binding); } body = Substitute(body, loop_var_map); // Step 6. Mutate IR diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index b24ecee3762a..3f54c7f9ed5b 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -350,5 +350,38 @@ def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), " verify_trace_roundtrip(sch, mod=nested_block) +class TestDecomposeReductionWithThreadBinding(tvm.testing.CompareBeforeAfter): + def transform(self): + def func(mod): + sch = tir.Schedule(mod) + t, _ = sch.get_loops("B") + sch.decompose_reduction("B", t) + return sch.mod + + return func + + @T.prim_func + def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): + for t in T.thread_binding(0, 32, thread="threadIdx.x"): + for r in T.serial(16): + with T.block("B"): + vi, vr = T.axis.remap("SR", [t, r]) + with T.init(): + B[vi] = T.float32(0) + B[vi] += A[vi, vr] + + @T.prim_func + def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): + for t_init in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B_init"): + vi = T.axis.remap("S", [t_init]) + B[vi] = T.float32(0) + for t in T.thread_binding(0, 32, thread="threadIdx.x"): + for r in T.serial(16): + with T.block("B"): + vi, vr = T.axis.remap("SR", [t, r]) + B[vi] += A[vi, vr] + + if __name__ == "__main__": tvm.testing.main() From 6ae6096e3bab01ac4148a565c0e00d5a9399cb9a Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 2 Aug 2023 13:14:29 -0700 Subject: [PATCH 2/2] fix --- src/tir/schedule/primitive/reduction.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index aec6d5b8cb7d..4fea3ccaa736 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -274,14 +274,16 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, Optional opt_thread_binding = old_loop->thread_binding; if (opt_thread_binding) { auto thread_binding = opt_thread_binding.value(); - thread_binding.CopyOnWrite()->var = new_loop_var; + auto new_var = thread_binding->var.copy_with_suffix(""); + thread_binding.CopyOnWrite()->var = new_var; + opt_thread_binding = thread_binding; } body = For(/*loop_var=*/new_loop_var, /*min=*/old_loop->min, /*extent=*/old_loop->extent, /*kind=*/old_loop->kind, /*body=*/body, - /*thread_binding*/opt_thread_binding); + /*thread_binding=*/opt_thread_binding); } body = Substitute(body, loop_var_map); // Step 6. Mutate IR