diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index c294d0ae8762..69f05201177b 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object { virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ virtual Optional trace() const = 0; + /*! \return The GlobalVar of the func that the schedule is currently working on */ + virtual Optional func_working_on() const = 0; /*! * \brief Instruct the schedule to work on a function in the IRModule. * diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b19e30848f69..34fd649a5d13 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,7 +19,7 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error -from tvm.ir import IRModule, PrimExpr +from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object, String from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc @@ -207,6 +207,11 @@ def trace(self) -> Optional[Trace]: """Returns the internally maintained trace of scheduling program execution""" return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member + @property + def func_working_on(self) -> Optional[GlobalVar]: + """Returns the GlobalVar of the func that the schedule is currently working on""" + return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore # pylint: disable=no-member + def work_on(self, func_name: str) -> None: """Instruct the schedule to work on a function in the IRModule. diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index eb7c38753c94..16065df3cd93 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } + Optional func_working_on() const final { return func_working_on_; } void WorkOn(const String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index fdf473ff7972..6c5998f52b4d 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // + .set_body_method(&ScheduleNode::func_working_on); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index ba2c134def7c..a8be97488b25 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -193,6 +193,7 @@ def test_tir_schedule_work_on(): sch.get_block(name="init") sch.work_on(func_name="vector_add") sch.get_block(name="init") + assert sch.func_working_on == sch.mod.get_global_var("vector_add") def test_tir_schedule_get_loops(use_block_name):