diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 570560c62d8c..e7b7e1f45340 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object { virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ virtual Optional trace() const = 0; + /*! \return The GlobalVar of the func that the schedule is currently working on */ + virtual Optional func_working_on() const = 0; /*! * \brief Instruct the schedule to work on a function in the IRModule. * diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 68f0b9454cb1..7221fa48b0b9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,7 +19,7 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error -from tvm.ir import IRModule, PrimExpr +from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object, String from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc @@ -207,6 +207,11 @@ def trace(self) -> Optional[Trace]: """Returns the internally maintained trace of scheduling program execution""" return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member + @property + def func_working_on(self) -> Optional[GlobalVar]: + """Returns the GlobalVar of the func that the schedule is currently working on""" + return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore # pylint: disable=no-member + def work_on(self, func_name: str) -> None: """Instruct the schedule to work on a function in the IRModule. diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 955381b740c8..753974571a17 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -566,16 +566,26 @@ class BlockCollector : public tir::StmtVisitor { /*! \brief Entry point */ Array Run() { std::vector results; - for (const auto& [gv, base_func] : sch_->mod()->functions) { - // `gv->name_hint` is the name of the function - // `base_func` can be PrimFunc or relay::Function - if (const auto* func = base_func.as()) { - func_name_ = gv->name_hint; - block_names_.clear(); - blocks_to_collect_.clear(); - VisitStmt(func->body); - for (const String& name : blocks_to_collect_) { - results.push_back(sch_->GetBlock(name, func_name_)); + auto f_collect = [this, &results](tir::PrimFunc func, String func_name) { + func_name_ = func_name; + block_names_.clear(); + blocks_to_collect_.clear(); + VisitStmt(func->body); + for (const String& name : blocks_to_collect_) { + results.push_back(sch_->GetBlock(name, func_name_)); + } + }; + + if (sch_->func_working_on().defined()) { + GlobalVar gv = sch_->func_working_on().value(); + tir::PrimFunc func = Downcast(sch_->mod()->functions[gv]); + f_collect(func, gv->name_hint); + } else { + for (const auto& [gv, base_func] : sch_->mod()->functions) { + // `gv->name_hint` is the name of the function + // `base_func` can be PrimFunc or relay::Function + if (const auto* func = base_func.as()) { + f_collect(GetRef(func), gv->name_hint); } } } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 227288b232d9..d68683c45fd8 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } + Optional func_working_on() const final { return func_working_on_; } void WorkOn(const String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a0e39b74d31b..ce28c39a81f1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // + .set_body_method(&ScheduleNode::func_working_on); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 53ee6a58cd9a..0ce2f0ea914d 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -193,6 +193,7 @@ def test_tir_schedule_work_on(): sch.get_block(name="init") sch.work_on(func_name="vector_add") sch.get_block(name="init") + assert sch.func_working_on == sch.mod.get_global_var("vector_add") def test_tir_schedule_get_loops(use_block_name):