diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 77afcf8266f8..343fb7617886 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -555,9 +555,7 @@ void ConcreteScheduleNode::Bind(const LoopRV& loop_rv, const String& thread_axis "`vthread.x`, `vthread.y` and `vthread.z` instead"; } TVM_TIR_SCHEDULE_BEGIN(); - tir::Bind(state_, this->GetSRef(loop_rv), - IterVar(/*dom=*/Range(nullptr), /*var=*/Var(thread_axis), /*iter_type=*/kThreadIndex, - /*thread_tag=*/thread_axis)); + tir::Bind(state_, this->GetSRef(loop_rv), thread_axis); this->state_->DebugVerify(); TVM_TIR_SCHEDULE_END("bind", this->error_render_level_); } diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index fe6280e1c4b1..02fb982f5ed9 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -303,7 +303,7 @@ TVM_DLL void Vectorize(ScheduleState self, const StmtSRef& loop_sref); * \param loop_sref The sref of the loop to be bound to the thread axis * \param thread_axis The thread axis to be bound to the loop */ -TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis); +TVM_DLL void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis); /*! * \brief Unroll the input loop. It requires nothing * \param self The state of the schedule diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 02d8866e8e9d..9690cd78c82f 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -144,7 +144,7 @@ void CheckParallelizability(const ScheduleState& self, const For& loop, ForKind * `for_kind` is `kThreadBinding` */ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref, ForKind for_kind, - Optional thread_axis) { + Optional thread_axis) { const ForNode* loop = TVM_SREF_TO_FOR(loop_sref); /* @@ -164,14 +164,21 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref // Step 2. Check whether the loop can be parallelized/vectorized/bound with regard to each // underlying block. CheckParallelizability(self, GetRef(loop), for_kind, - thread_axis.defined() - ? runtime::ThreadScope::Create(thread_axis.value()->thread_tag) - : runtime::ThreadScope{-1, -1}); + thread_axis.defined() ? runtime::ThreadScope::Create(thread_axis.value()) + : runtime::ThreadScope{-1, -1}); // Step 3. Loop update and IR replacement ObjectPtr new_loop = make_object(*loop); new_loop->kind = for_kind; - new_loop->thread_binding = std::move(thread_axis); + if (thread_axis.defined()) { + const String& thread_tag = thread_axis.value(); + new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), // + /*var=*/Var(thread_axis.value(), loop->loop_var.dtype()), // + /*iter_type=*/kThreadIndex, // + /*thread_tag=*/thread_axis.value()); + } else { + new_loop->thread_binding = NullOpt; + } self->Replace(loop_sref, For(new_loop), {}); } @@ -183,7 +190,7 @@ void Vectorize(ScheduleState self, const StmtSRef& loop_sref) { ParallelizeComputation(self, loop_sref, ForKind::kVectorized, NullOpt); } -void Bind(ScheduleState self, const StmtSRef& loop_sref, const IterVar& thread_axis) { +void Bind(ScheduleState self, const StmtSRef& loop_sref, const String& thread_axis) { ParallelizeComputation(self, loop_sref, ForKind::kThreadBinding, thread_axis); } diff --git a/tests/python/unittest/test_tir_schedule_for_kind.py b/tests/python/unittest/test_tir_schedule_for_kind.py index fb0939f99086..7ae406445530 100644 --- a/tests/python/unittest/test_tir_schedule_for_kind.py +++ b/tests/python/unittest/test_tir_schedule_for_kind.py @@ -668,5 +668,34 @@ def test_scatter_parallelize(): verify_trace_roundtrip(s, mod=scatter_compute) +def test_bind_thread_iter_var_dtype(): + @T.prim_func(private=True) + def before( + A: T.Buffer((T.int64(128), T.int64(128))), + B: T.Buffer((T.int64(128), T.int64(128))), + ) -> None: + for i, j in T.grid(T.int64(128), T.int64(128)): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + @T.prim_func(private=True) + def expected( + A: T.Buffer((T.int64(128), T.int64(128))), + B: T.Buffer((T.int64(128), T.int64(128))), + ) -> None: + for i0 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for i1 in range(T.int64(128)): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i0, i1]) + B[vi, vj] = A[vi, vj] * 2.0 + + s = tir.Schedule(before, debug_mask="all") + i, _ = s.get_loops(s.get_block("B")) + s.bind(i, "threadIdx.x") + assert_structural_equal_ignore_global_symbol(s.mod["main"], expected) + verify_trace_roundtrip(s, mod=before) + + if __name__ == "__main__": tvm.testing.main()