diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index 6069f4289cf3..4fea3ccaa736 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -271,11 +271,19 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, Var old_loop_var = old_loop->loop_var; Var new_loop_var = old_loop_var.copy_with_suffix("_init"); loop_var_map[old_loop_var] = new_loop_var; + Optional opt_thread_binding = old_loop->thread_binding; + if (opt_thread_binding) { + auto thread_binding = opt_thread_binding.value(); + auto new_var = thread_binding->var.copy_with_suffix(""); + thread_binding.CopyOnWrite()->var = new_var; + opt_thread_binding = thread_binding; + } body = For(/*loop_var=*/new_loop_var, /*min=*/old_loop->min, /*extent=*/old_loop->extent, /*kind=*/old_loop->kind, - /*body=*/body); + /*body=*/body, + /*thread_binding=*/opt_thread_binding); } body = Substitute(body, loop_var_map); // Step 6. Mutate IR diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py index b24ecee3762a..3f54c7f9ed5b 100644 --- a/tests/python/unittest/test_tir_schedule_reduction.py +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -350,5 +350,38 @@ def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), " verify_trace_roundtrip(sch, mod=nested_block) +class TestDecomposeReductionWithThreadBinding(tvm.testing.CompareBeforeAfter): + def transform(self): + def func(mod): + sch = tir.Schedule(mod) + t, _ = sch.get_loops("B") + sch.decompose_reduction("B", t) + return sch.mod + + return func + + @T.prim_func + def before(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): + for t in T.thread_binding(0, 32, thread="threadIdx.x"): + for r in T.serial(16): + with T.block("B"): + vi, vr = T.axis.remap("SR", [t, r]) + with T.init(): + B[vi] = T.float32(0) + B[vi] += A[vi, vr] + + @T.prim_func + def expected(A: T.Buffer((32, 16), "float32"), B: T.Buffer((32,), "float32")): + for t_init in T.thread_binding(0, 32, thread="threadIdx.x"): + with T.block("B_init"): + vi = T.axis.remap("S", [t_init]) + B[vi] = T.float32(0) + for t in T.thread_binding(0, 32, thread="threadIdx.x"): + for r in T.serial(16): + with T.block("B"): + vi, vr = T.axis.remap("SR", [t, r]) + B[vi] += A[vi, vr] + + if __name__ == "__main__": tvm.testing.main()