From fba40a3a1a5d525a13aef63db5268004a54b6f45 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 2 Jul 2022 17:35:34 -0700 Subject: [PATCH 1/2] [TIR] Add sugar method `Schedule.work_on` This PR introduces `Schedule.work_on`, which instructs `Schedule.get_block` to find the correct PrimFunc to retrieve from without having to specify `func_name` in every time if the PrimFunc's name is not `main`. --- include/tvm/tir/schedule/schedule.h | 24 ++++++++++- python/tvm/tir/schedule/schedule.py | 25 ++++++++++- src/meta_schedule/arg_info.cc | 41 ++++++++++++++++++ src/meta_schedule/mutator/mutate_parallel.cc | 3 +- src/meta_schedule/utils.h | 42 ------------------- src/tir/schedule/analysis.h | 9 ++++ src/tir/schedule/analysis/analysis.cc | 41 ++++++++++++++++++ src/tir/schedule/concrete_schedule.cc | 25 ++++++++++- src/tir/schedule/concrete_schedule.h | 8 +++- src/tir/schedule/primitive.h | 4 +- src/tir/schedule/primitive/get_block_loop.cc | 4 +- src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 21 +++++++++- src/tir/schedule/traced_schedule.h | 2 +- .../unittest/test_tir_schedule_utilities.py | 32 +++++++++++++- 15 files changed, 225 insertions(+), 58 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index d95a9d4e7e5e..8e160c61328c 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,6 +115,21 @@ class ScheduleNode : public runtime::Object { virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ virtual Optional trace() const = 0; + /*! + * \brief Instruct the schedule to work on a function in the IRModule. + * + * By default, the schedule works on the function with the name "main", or the only function in + * the IRModule if there is only one. If there is multiple functions in the IRModule, and none of + * their names are "main", users will have to call this method to explicitly specify which + * function to work on. + * + * This sugar function will guide the `GetBlock` method if its `func_name` is not specified. + * + * \param func_name The name of the function to be working on + * + * \sa GetBlock + */ + virtual void WorkOn(const String& func_name) = 0; /*! * \brief Returns a copy of the schedule, including both its state and its symbol table, * guaranteeing that @@ -231,12 +246,19 @@ class ScheduleNode : public runtime::Object { /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name + * + * By default, if `func_name` is not specified, the schedule will search for the block in the + * function that is currently being "worked on". To switch the function to be worked on, use + * `WorkOn` before calling this method. + * * \param name The name of the block to be retrieved * \param func_name The name of the function * \return The block retrieved * \note Indexing error is raised if 0 or multiple blocks exist with the specific name + * + * \sa WorkOn */ - virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0; + virtual BlockRV GetBlock(const String& name, const Optional& func_name = NullOpt) = 0; /*! * \brief Get the parent loops of the block in its scope, from outer to inner * \param block_rv The query block diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 7a1e244604b7..b3fff7622df9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -186,6 +186,23 @@ def trace(self) -> Optional[Trace]: """Returns the internally maintained trace of scheduling program execution""" return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member + def work_on(self, func_name: str) -> None: + """Instruct the schedule to work on a function in the IRModule. + + By default, the schedule works on the function with the name "main", or the only function in + the IRModule if there is only one. If there is multiple functions in the IRModule, and none + of their names are "main", users will have to call this method to explicitly specify which + function to work on. + + This sugar function will guide the `GetBlock` method if its `func_name` is not specified. + + Parameters + ---------- + func_name : str + The name of the function to work on. + """ + _ffi_api.ScheduleWorkOn(self, func_name) + def copy(self) -> "Schedule": """Returns a copy of the schedule, including both the state and the symbol table, * guaranteeing that @@ -403,15 +420,19 @@ def sample_compute_location( def get_block( self, name: str, - func_name: str = "main", + func_name: Optional[str] = None, ) -> BlockRV: """Retrieve a block in a specific function with its name + By default, if `func_name` is not specified, the schedule will search for the block in the + function that is currently being "worked on". To switch the function to be worked on, use + `work_on` before calling this method. + Parameters ---------- name : str The name of the block - func_name : str = "main" + func_name : Optional[str] = None The name of the function Returns diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 672df86deb9d..21de9d719d00 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -21,6 +21,47 @@ namespace tvm { namespace meta_schedule { +/*! + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * \param mod The IRModule to find the entry function. + * \return The entry function. + */ +inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + return GetRef(func); + } + if (gv->name_hint == "main") { + main_func = func; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + return GetRef(main_func); + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 0) { + LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " + << tir::AsTVMScript(mod); + } + if (num_prim_func > 1) { + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " + "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + << tir::AsTVMScript(mod); + } + return GetRef(last_func); +} /******** ArgInfo ********/ ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) { diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 7c973879f2cc..5b7fe7f5148d 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -79,7 +79,8 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) { std::vector> AnalyzeParallel(const ScheduleState& self, const String& block_name, const String& func_name, int64_t limit) { - Array block_srefs = tir::GetBlocks(self, block_name, func_name); + Array block_srefs = + tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name)); ICHECK_EQ(block_srefs.size(), 1); const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]); ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef(block)); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ca696da71e00..b5cb73c26e00 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -174,48 +174,6 @@ inline String SHash2Hex(const ObjectRef& obj) { return os.str(); } -/*! - * \brief Find the entry function of the given IRModule, i.e, functions marked by - * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. - * \param mod The IRModule to find the entry function. - * \return The entry function. - */ -inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { - // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` - int num_prim_func = 0; - const tir::PrimFuncNode* main_func = nullptr; - const tir::PrimFuncNode* last_func = nullptr; - for (const auto& kv : mod->functions) { - GlobalVar gv = kv.first; - BaseFunc base_func = kv.second; - if (const auto* func = base_func.as()) { - last_func = func; - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - return GetRef(func); - } - if (gv->name_hint == "main") { - main_func = func; - } - ++num_prim_func; - } - } - // Priority 2: PrimFunc whose name is `main` - if (main_func != nullptr) { - return GetRef(main_func); - } - // Priority 3: The only PrimFunc in the IRModule - if (num_prim_func == 0) { - LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " - << tir::AsTVMScript(mod); - } - if (num_prim_func > 1) { - LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " - "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" - << tir::AsTVMScript(mod); - } - return GetRef(last_func); -} - /*! * \brief Fork a random state into another, i.e. PRNG splitting. * The given random state is also mutated. diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b30cef829f1e..317b3625f0b6 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -71,6 +71,15 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl */ StmtSRef GetSRefTreeRoot(const StmtSRef& sref); +/*! + * \brief Find the entry function of the given IRModule, i.e, functions marked by + * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * \param mod The IRModule to find the entry function. + * \param result_g_var The result GlobalVar of the entry function. + * \return The entry function. + */ +const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var); + /******** Scope ********/ /*! * \brief Checks if scope the specified sref is in is a stage-pipeline and return it diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3ee1ed28b857..ac73ac3ce2c1 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -49,6 +49,47 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl throw; } +const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) { + GlobalVar result = NullValue(); + // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + int num_prim_func = 0; + const tir::PrimFuncNode* main_func = nullptr; + const tir::PrimFuncNode* last_func = nullptr; + for (const auto& kv : mod->functions) { + GlobalVar gv = kv.first; + BaseFunc base_func = kv.second; + if (const auto* func = base_func.as()) { + last_func = func; + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + if (result_g_var != nullptr) { + *result_g_var = gv; + } + return func; + } + if (gv->name_hint == "main") { + main_func = func; + result = gv; + } + ++num_prim_func; + } + } + // Priority 2: PrimFunc whose name is `main` + if (main_func != nullptr) { + if (result_g_var != nullptr) { + *result_g_var = result; + } + return main_func; + } + // Priority 3: The only PrimFunc in the IRModule + if (num_prim_func == 1) { + if (result_g_var != nullptr) { + *result_g_var = result; + } + return last_func; + } + return nullptr; +} + /******** Scope ********/ StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index b2f48753b555..c19735025ddc 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -31,6 +31,12 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->Seed(seed); + GlobalVar gv = NullValue(); + if (FindEntryFunc(mod, &gv) != nullptr) { + n->func_working_on_ = gv; + } else { + n->func_working_on_ = NullOpt; + } return Schedule(std::move(n)); } @@ -177,6 +183,10 @@ class ScheduleCopier { std::unordered_map old2new_; }; +void ConcreteScheduleNode::WorkOn(const String& func_name) { + this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name); +} + void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const { ScheduleCopier::Copy(this, new_state, new_symbol_table); new_state->get()->DebugVerify(); @@ -184,6 +194,7 @@ void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symb Schedule ConcreteScheduleNode::Copy() { ObjectPtr n = make_object(); + n->func_working_on_ = this->func_working_on_; n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful @@ -251,7 +262,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { +BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional& func_name) { class NotSingleResult : public ScheduleError { public: explicit NotSingleResult(String name, IRModule mod, const Array& blocks) @@ -286,7 +297,17 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_na IRModule mod_; Array blocks_; }; - Array blocks = tir::GetBlocks(this->state_, name, func_name); + GlobalVar gv = NullValue(); + if (func_name.defined()) { + gv = state_->mod->GetGlobalVar(func_name.value()); + } else if (func_working_on_.defined()) { + gv = this->func_working_on_.value(); + } else { + LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_block`."; + } + Array blocks = tir::GetBlocks(this->state_, name, gv); if (blocks.size() != 1) { TVM_TIR_SCHEDULE_BEGIN(); throw NotSingleResult(name, this->state_->mod, blocks); diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index dfbacb530a36..feea310bd7af 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -38,6 +38,8 @@ class ConcreteScheduleNode : public ScheduleNode { protected: /*! \brief The internal state of scheduling */ ScheduleState state_; + /*! \brief The function to be worked on. */ + Optional func_working_on_; /*! \brief The level of error rendering */ ScheduleErrorRenderLevel error_render_level_; /*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */ @@ -50,10 +52,11 @@ class ConcreteScheduleNode : public ScheduleNode { public: void VisitAttrs(tvm::AttrVisitor* v) { // `state_` is not visited + // `func_working_on_` is not visited // `error_render_level_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visited - // `rand_state_` is not visited + // `rgnd_state_` is not visited } virtual ~ConcreteScheduleNode() = default; @@ -61,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } + void WorkOn(const String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; support::LinearCongruentialEngine::TRandState ForkSeed() final; @@ -89,7 +93,7 @@ class ConcreteScheduleNode : public ScheduleNode { LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const String& func_name = "main") override; + BlockRV GetBlock(const String& name, const Optional& func_name) override; Array GetLoops(const BlockRV& block_rv) override; Array GetChildBlocks(const BlockRV& block_rv) override; Array GetChildBlocks(const LoopRV& loop_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 212571df1027..608368fbb31f 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -116,10 +116,10 @@ TVM_DLL tir::StmtSRef SampleComputeLocation( * \brief Retrieves blocks in a specific function with its name * \param self The schedule state * \param name The name of the blocks to be retrieved - * \param func_name The name of the function + * \param gvar The function to be retrieved * \return A list of blocks with the specific name */ -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name); +Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv); /*! * \brief Gets the parent loops of the block in its scope, from outer to inner * \param self The schedule state diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index a13e52515708..746918ac4e34 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -21,7 +21,7 @@ namespace tvm { namespace tir { -Array GetBlocks(const ScheduleState& self, const String& name, const String& func_name) { +Array GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) { struct Finder : public StmtVisitor { explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {} @@ -39,7 +39,7 @@ Array GetBlocks(const ScheduleState& self, const String& name, const S Array results_; }; - BaseFunc func = self->mod->Lookup(func_name); + BaseFunc func = self->mod->Lookup(gv); const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode); Finder finder(self, name); finder(prim_func->body); diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 372d94a15025..e386061ebfbd 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -56,6 +56,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // .set_body_method(&ScheduleNode::ForkSeed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") // + .set_body_method(&ScheduleNode::WorkOn); /**************** (FFI) Constructor ****************/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 733b5d872f93..93e4c984a41b 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -30,6 +30,12 @@ Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRand n->analyzer_ = std::make_unique(); n->trace_ = Trace(); n->Seed(seed); + GlobalVar gv = NullValue(); + if (FindEntryFunc(mod, &gv) != nullptr) { + n->func_working_on_ = gv; + } else { + n->func_working_on_ = NullOpt; + } return Schedule(std::move(n)); } @@ -37,6 +43,7 @@ Schedule TracedScheduleNode::Copy() { ObjectPtr n = make_object(); n->error_render_level_ = this->error_render_level_; ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_); + n->func_working_on_ = this->func_working_on_; n->analyzer_ = std::make_unique(); // new analyzer needed because it is stateful n->rand_state_ = ForkSeed(); n->trace_ = Trace(this->trace_->insts, this->trace_->decisions); @@ -90,13 +97,23 @@ LoopRV TracedScheduleNode::SampleComputeLocation(const BlockRV& block_rv, /******** Schedule: Get blocks & loops ********/ -BlockRV TracedScheduleNode::GetBlock(const String& name, const String& func_name) { +BlockRV TracedScheduleNode::GetBlock(const String& name, const Optional& func_name) { + GlobalVar gv = NullValue(); + if (func_name.defined()) { + gv = state_->mod->GetGlobalVar(func_name.value()); + } else if (func_working_on_.defined()) { + gv = this->func_working_on_.value(); + } else { + LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_block`."; + } BlockRV result = ConcreteScheduleNode::GetBlock(name, func_name); static const InstructionKind& kind = InstructionKind::Get("GetBlock"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, - /*attrs=*/{name, func_name}, + /*attrs=*/{name, gv->name_hint}, /*outputs=*/{result})); return result; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 178026d9eaf8..f6405d77a195 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -53,7 +53,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Optional> decision = NullOpt) final; LoopRV SampleComputeLocation(const BlockRV& block_rv, Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ - BlockRV GetBlock(const String& name, const String& func_name = "main") final; + BlockRV GetBlock(const String& name, const Optional& func_name) final; Array GetLoops(const BlockRV& block_rv) final; Array GetChildBlocks(const BlockRV& block_rv) final; Array GetChildBlocks(const LoopRV& loop_rv) final; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index b7517aab7cd3..c479555590d2 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -20,7 +20,6 @@ import pytest import tvm import tvm.testing - from tvm import tir from tvm.ir import IRModule from tvm.script import tir as T @@ -102,6 +101,29 @@ def matmul_relu_ann2(a: T.handle, b: T.handle, d: T.handle) -> None: D[vi, vj] = T.max(C[vi, vj], 0.0) +@tvm.script.ir_module +class ModuleWithMultipleFuncs: + @T.prim_func + def vector_add( + A: T.Buffer[128, "float32"], + B: T.Buffer[128, "float32"], + ) -> None: + for i in range(128): + with T.block("init"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + @T.prim_func + def vector_add_2( + A: T.Buffer[128, "float32"], + B: T.Buffer[128, "float32"], + ) -> None: + for i in range(128): + with T.block("init"): + vi = T.axis.remap("S", [i]) + B[vi] = A[vi] + + # pylint: enable=no-member,invalid-name,unused-variable use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) @@ -133,6 +155,14 @@ def test_tir_schedule_get_block(): assert block.same_as(matmul.body.block.body.body.body[1].body.block) +def test_tir_schedule_work_on(): + sch = tir.Schedule(ModuleWithMultipleFuncs, debug_mask="all") + with pytest.raises(ValueError, match="does not know which function to be working on"): + sch.get_block(name="init") + sch.work_on(func_name="vector_add") + sch.get_block(name="init") + + def test_tir_schedule_get_loops(use_block_name): # Tests: # - Schedule.get_loops From 1a845c055a0a8b3764fd8bcac5c032422fabc15f Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 2 Jul 2022 18:56:23 -0700 Subject: [PATCH 2/2] fix lint --- python/tvm/tir/schedule/schedule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index b3fff7622df9..28bdf63872d9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -201,7 +201,7 @@ def work_on(self, func_name: str) -> None: func_name : str The name of the function to work on. """ - _ffi_api.ScheduleWorkOn(self, func_name) + _ffi_api.ScheduleWorkOn(self, func_name) # type: ignore # pylint: disable=no-member def copy(self) -> "Schedule": """Returns a copy of the schedule, including both the state and the symbol table,