Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*! \return The GlobalVar of the func that the schedule is currently working on */
virtual Optional<GlobalVar> func_working_on() const = 0;
/*!
* \brief Instruct the schedule to work on a function in the IRModule.
*
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.ir import GlobalVar, IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc

Expand Down Expand Up @@ -207,6 +207,11 @@ def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

@property
def func_working_on(self) -> Optional[GlobalVar]:
"""Returns the GlobalVar of the func that the schedule is currently working on"""
return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore # pylint: disable=no-member

def work_on(self, func_name: str) -> None:
"""Instruct the schedule to work on a function in the IRModule.

Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
Optional<GlobalVar> func_working_on() const final { return func_working_on_; }
void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") //
.set_body_method<Schedule>(&ScheduleNode::func_working_on);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
.set_body_method<Schedule>(&ScheduleNode::Copy);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
sch.get_block(name="init")
sch.work_on(func_name="vector_add")
sch.get_block(name="init")
assert sch.func_working_on == sch.mod.get_global_var("vector_add")


def test_tir_schedule_get_loops(use_block_name):
Expand Down