From 087dd91d5bf2d3c1c689f4a7b3916f967c4630b9 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Tue, 4 Apr 2023 18:22:21 +0530 Subject: [PATCH 1/4] [TIR] [Schedule] Add get_output_blocks primitive When scheduling fused ops, its really useful to be able to get all the output blocks and schedule other blocks in the fused op with respect to output blocks. This helps avoid hardcoding block names in manually written schedules in many cases --- include/tvm/tir/schedule/schedule.h | 8 +++ python/tvm/tir/schedule/schedule.py | 17 +++++++ src/tir/schedule/analysis.h | 9 ++++ src/tir/schedule/analysis/analysis.cc | 26 ++++++++++ src/tir/schedule/concrete_schedule.cc | 17 +++++++ src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/primitive.h | 8 +++ src/tir/schedule/primitive/get_block_loop.cc | 30 +++++++++++ src/tir/schedule/schedule.cc | 2 + src/tir/schedule/traced_schedule.cc | 11 ++++ src/tir/schedule/traced_schedule.h | 1 + .../unittest/test_tir_schedule_utilities.py | 51 +++++++++++++++++++ 12 files changed, 181 insertions(+) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 257c9c5d04eb..bfaf38794f73 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -291,6 +291,14 @@ class ScheduleNode : public runtime::Object { * block */ virtual Array GetConsumers(const BlockRV& block_rv) = 0; + /*! + * \brief Get the list of output blocks + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \return A list of all blocks that write to some output buffer + * block + */ + virtual Array GetOutputBlocks(const Optional& func_name = NullOpt) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 8113097003fa..bdb330e14510 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -540,6 +540,23 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member + @type_checked + def get_output_blocks( + self, + func_name: Optional[str] = None, + ) -> List[BlockRV]: + """Get the list of output blocks + An output block is a block which has atleast one buffer being written + to, but is not allocated within the PrimFunc + + Returns + ------- + output_blocks : List[BlockRV] + A list of all blocks that write to some output buffer + + """ + return list(_ffi_api.ScheduleGetOutputBlocks(self, func_name)) # type: ignore # pylint: disable=no-member + ########## Schedule: Transform loops ########## @type_checked def merge( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index bc505a0104be..e8bd4e4d3a2a 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -385,6 +385,15 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope */ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); +/*! + * \brief Get the list of output blocks + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \return A list of all blocks that write to some output buffer + * block + */ +Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* func); + /*! * \brief A solution to split a ordered list of subtrees into two parts, * where producers are on the LHS and consumers are on the RHS. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 674abe28a3e0..53c4634952d9 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1043,6 +1043,32 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope return results; } +Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* func) { + struct OutputBlockCollector : public StmtVisitor { + explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} + + void VisitStmt_(const BlockNode* block) override { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + auto block_sref = it->second; + if (block_sref->parent != nullptr) { + StmtSRef scope_root_sref = GetScopeRoot(self_, block_sref, /*require_stage_pipeline=*/false); + if (IsOutputBlock(self_, block_sref, scope_root_sref)) { + results_.push_back(block_sref); + } + } + StmtVisitor::VisitStmt_(block); + } + + const ScheduleState& self_; + Array results_; + }; + OutputBlockCollector collector(self); + collector(func->body); + auto results = collector.results_; + return results; +} + ProducerConsumerSplit ProducerConsumerSplit::Find( const ScheduleState& self, const Array& subtrees, const Array& producer_block_srefs, const Array& consumer_block_srefs, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8e5eefdb6ac4..ddc1de019b98 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -354,6 +354,23 @@ Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { throw; } +Array ConcreteScheduleNode::GetOutputBlocks(const Optional& func_name) { + TVM_TIR_SCHEDULE_BEGIN(); + GlobalVar gv = NullValue(); + if (func_name.defined()) { + gv = state_->mod->GetGlobalVar(func_name.value()); + } else if (func_working_on_.defined()) { + gv = this->func_working_on_.value(); + } else { + LOG(FATAL) << "ValueError: `get_output_blocks` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_output_blocks`."; + } + return CreateRV(tir::GetOutputBlocks(state_, gv)); + TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); + throw; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 903658e33949..d0fb6b465bb2 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -99,6 +99,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) override; Array GetProducers(const BlockRV& block_rv) override; Array GetConsumers(const BlockRV& block_rv) override; + Array GetOutputBlocks(const Optional& func_name) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const Array& loop_rvs) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0050789243f8..b011a3b23d66 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -148,6 +148,14 @@ Array GetProducers(const ScheduleState& self, const StmtSRef& block_sr * \return A list of blocks, the consumers of the given block */ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the list of output blocks + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \return A list of all blocks that write to some output buffer + * block + */ +Array GetOutputBlocks(const ScheduleState& self, const GlobalVar& gv); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 72f43a8d4929..85ad5b4d7c48 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -88,6 +88,12 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); } +Array GetOutputBlocks(const ScheduleState& self, const GlobalVar& gv) { + BaseFunc func = self->mod->Lookup(gv); + const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode); + return tir::GetOutputBlocks(self, prim_func); +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -218,11 +224,35 @@ struct GetConsumersTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetOutputBlocksTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetOutputBlocks"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch) { + return sch->GetOutputBlocks(); + } + + static String UnpackedAsPython(Array outputs) { + PythonAPICall py("get_output_blocks"); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetOutputBlocksTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index aafb29800bc1..fdf473ff7972 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -152,6 +152,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") .set_body_method(&ScheduleNode::GetProducers); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") + .set_body_method(&ScheduleNode::GetOutputBlocks); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index fe48c52e3103..ac72f7aa7392 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -174,6 +174,17 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } +Array TracedScheduleNode::GetOutputBlocks(const Optional& func_name) { + Array results = ConcreteScheduleNode::GetOutputBlocks(func_name); + + static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 9630d1513e8d..70a63b22be7f 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -59,6 +59,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) final; Array GetProducers(const BlockRV& block_rv) final; Array GetConsumers(const BlockRV& block_rv) final; + Array GetOutputBlocks(const Optional& func_name) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const Array& loop_rvs) final; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 53ee6a58cd9a..d087b90e6984 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -358,5 +358,56 @@ def test_annotate_unannotate_block(): verify_trace_roundtrip(sch=sch, mod=matmul_relu) +def test_get_output_blocks_single_output(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + output_blocks = sch.get_output_blocks() + assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected" + block = sch.get(output_blocks[0]) + assert block.name_hint == "relu" + relu_block = sch.get_block("relu") + assert sch.get(relu_block).same_as(block) + + +def test_get_output_blocks_multiple_outputs(): + sch = tir.Schedule(mod=matmul, debug_mask="all") + output_blocks = sch.get_output_blocks() + assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" + block_1 = sch.get(output_blocks[0]) + assert block_1.name_hint == "init" + block_2 = sch.get(output_blocks[1]) + assert block_2.name_hint == "update" + init_block = sch.get_block("init") + assert sch.get(init_block).same_as(block_1) + update_block = sch.get_block("update") + assert sch.get(update_block).same_as(block_2) + + +def test_get_output_blocks_nested(): + @T.prim_func + def blockized( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + ) -> None: + with T.block("blockized_B"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + sch = tir.Schedule(mod=blockized, debug_mask="all") + output_blocks = sch.get_output_blocks() + assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" + block_1 = sch.get(output_blocks[0]) + assert block_1.name_hint == "blockized_B" + block_2 = sch.get(output_blocks[1]) + assert block_2.name_hint == "B" + blockized_block = sch.get_block("blockized_B") + assert sch.get(blockized_block).same_as(block_1) + b_block = sch.get_block("B") + assert sch.get(b_block).same_as(block_2) + + if __name__ == "__main__": tvm.testing.main() From ebd64148f58d418ed76edde480788c4bbd9dca1b Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Tue, 4 Apr 2023 22:04:25 +0530 Subject: [PATCH 2/4] Fix cpplint errors --- src/tir/schedule/analysis/analysis.cc | 3 ++- src/tir/schedule/concrete_schedule.cc | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 53c4634952d9..e2c5f1652cdf 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1052,7 +1052,8 @@ Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* f ICHECK(it != self_->stmt2ref.end()); auto block_sref = it->second; if (block_sref->parent != nullptr) { - StmtSRef scope_root_sref = GetScopeRoot(self_, block_sref, /*require_stage_pipeline=*/false); + StmtSRef scope_root_sref = + GetScopeRoot(self_, block_sref, /*require_stage_pipeline=*/false); if (IsOutputBlock(self_, block_sref, scope_root_sref)) { results_.push_back(block_sref); } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index ddc1de019b98..892dcc4564d1 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -362,9 +362,10 @@ Array ConcreteScheduleNode::GetOutputBlocks(const Optional& fun } else if (func_working_on_.defined()) { gv = this->func_working_on_.value(); } else { - LOG(FATAL) << "ValueError: `get_output_blocks` does not know which function to be working on. Please " - "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_output_blocks`."; + LOG(FATAL) + << "ValueError: `get_output_blocks` does not know which function to be working on. Please " + "specify the function name explicitly, or call `work_on` to specify the function " + "before using `get_output_blocks`."; } return CreateRV(tir::GetOutputBlocks(state_, gv)); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); From 599a4b804b7043a1fa13ea7346389e89dffba4b4 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Tue, 4 Apr 2023 23:11:34 +0530 Subject: [PATCH 3/4] Another lint --- src/tir/schedule/primitive/get_block_loop.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 85ad5b4d7c48..5b0801f695b3 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -233,9 +233,7 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch) { - return sch->GetOutputBlocks(); - } + static Array UnpackedApplyToSchedule(Schedule sch) { return sch->GetOutputBlocks(); } static String UnpackedAsPython(Array outputs) { PythonAPICall py("get_output_blocks"); From 0b41504bf375c7927915daf2c627f60165e6a344 Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Wed, 5 Apr 2023 16:02:54 +0530 Subject: [PATCH 4/4] Update primitive to take scope block as argument This commit updates the primitive to take scope block as argument and return all output blocks under that scope. In order to get all output blocks of a PrimFunc, the root block can be passed --- include/tvm/tir/schedule/schedule.h | 5 +++-- python/tvm/tir/schedule/schedule.py | 12 +++++++++--- src/tir/schedule/analysis.h | 5 +++-- src/tir/schedule/analysis/analysis.cc | 4 ++-- src/tir/schedule/concrete_schedule.cc | 15 ++------------- src/tir/schedule/concrete_schedule.h | 2 +- src/tir/schedule/primitive.h | 5 +++-- src/tir/schedule/primitive/get_block_loop.cc | 16 +++++++++------- src/tir/schedule/traced_schedule.cc | 6 +++--- src/tir/schedule/traced_schedule.h | 2 +- .../unittest/test_tir_schedule_utilities.py | 14 +++++++++++--- 11 files changed, 47 insertions(+), 39 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index bfaf38794f73..c294d0ae8762 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -292,13 +292,14 @@ class ScheduleNode : public runtime::Object { */ virtual Array GetConsumers(const BlockRV& block_rv) = 0; /*! - * \brief Get the list of output blocks + * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected * \return A list of all blocks that write to some output buffer * block */ - virtual Array GetOutputBlocks(const Optional& func_name = NullOpt) = 0; + virtual Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index bdb330e14510..b19e30848f69 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -543,19 +543,25 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: @type_checked def get_output_blocks( self, - func_name: Optional[str] = None, + scope_block: Union[BlockRV, str], ) -> List[BlockRV]: - """Get the list of output blocks + """Get the list of output blocks within the given scope An output block is a block which has atleast one buffer being written to, but is not allocated within the PrimFunc + Parameters + ---------- + scope_block : Union[BlockRV, str], + The scope block from which output blocks are collected + Returns ------- output_blocks : List[BlockRV] A list of all blocks that write to some output buffer """ - return list(_ffi_api.ScheduleGetOutputBlocks(self, func_name)) # type: ignore # pylint: disable=no-member + scope_block = self._normalize_block_arg(scope_block) + return list(_ffi_api.ScheduleGetOutputBlocks(self, scope_block)) # type: ignore # pylint: disable=no-member ########## Schedule: Transform loops ########## @type_checked diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index e8bd4e4d3a2a..7b9c2a9a2448 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -386,13 +386,14 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); /*! - * \brief Get the list of output blocks + * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* func); +Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); /*! * \brief A solution to split a ordered list of subtrees into two parts, diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index e2c5f1652cdf..2c4da4aaf731 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1043,7 +1043,7 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope return results; } -Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* func) { +Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { struct OutputBlockCollector : public StmtVisitor { explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} @@ -1065,7 +1065,7 @@ Array GetOutputBlocks(const ScheduleState& self, const PrimFuncNode* f Array results_; }; OutputBlockCollector collector(self); - collector(func->body); + collector(scope_block->body); auto results = collector.results_; return results; } diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 892dcc4564d1..7192a4809994 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -354,20 +354,9 @@ Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { throw; } -Array ConcreteScheduleNode::GetOutputBlocks(const Optional& func_name) { +Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { TVM_TIR_SCHEDULE_BEGIN(); - GlobalVar gv = NullValue(); - if (func_name.defined()) { - gv = state_->mod->GetGlobalVar(func_name.value()); - } else if (func_working_on_.defined()) { - gv = this->func_working_on_.value(); - } else { - LOG(FATAL) - << "ValueError: `get_output_blocks` does not know which function to be working on. Please " - "specify the function name explicitly, or call `work_on` to specify the function " - "before using `get_output_blocks`."; - } - return CreateRV(tir::GetOutputBlocks(state_, gv)); + return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); throw; } diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index d0fb6b465bb2..eb7c38753c94 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -99,7 +99,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) override; Array GetProducers(const BlockRV& block_rv) override; Array GetConsumers(const BlockRV& block_rv) override; - Array GetOutputBlocks(const Optional& func_name) override; + Array GetOutputBlocks(const BlockRV& scope_block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const Array& loop_rvs) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index b011a3b23d66..5f5591ac45a0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -149,13 +149,14 @@ Array GetProducers(const ScheduleState& self, const StmtSRef& block_sr */ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); /*! - * \brief Get the list of output blocks + * \brief Get the list of output blocks within the given scope * An output block is a block which has atleast one buffer being written * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected * \return A list of all blocks that write to some output buffer * block */ -Array GetOutputBlocks(const ScheduleState& self, const GlobalVar& gv); +Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 5b0801f695b3..87ec6e550dcd 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -88,10 +88,9 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); } -Array GetOutputBlocks(const ScheduleState& self, const GlobalVar& gv) { - BaseFunc func = self->mod->Lookup(gv); - const auto* prim_func = TVM_TYPE_AS(func, PrimFuncNode); - return tir::GetOutputBlocks(self, prim_func); +Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { + const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + return tir::GetOutputBlocks(self, scope_block); } /******** InstructionKind Registration ********/ @@ -229,14 +228,17 @@ struct GetOutputBlocksTraits : public UnpackedInstTraits static constexpr bool kIsPure = true; private: - static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumInputs = 1; static constexpr size_t kNumAttrs = 0; static constexpr size_t kNumDecisions = 0; - static Array UnpackedApplyToSchedule(Schedule sch) { return sch->GetOutputBlocks(); } + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetOutputBlocks(block_rv); + } - static String UnpackedAsPython(Array outputs) { + static String UnpackedAsPython(Array outputs, String block_rv) { PythonAPICall py("get_output_blocks"); + py.Input("block", block_rv); py.OutputList(outputs); return py.Str(); } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index ac72f7aa7392..4d820078e527 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -174,12 +174,12 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } -Array TracedScheduleNode::GetOutputBlocks(const Optional& func_name) { - Array results = ConcreteScheduleNode::GetOutputBlocks(func_name); +Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { + Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // - /*inputs=*/{}, + /*inputs=*/{scope_block_rv}, /*attrs=*/{}, /*outputs=*/{results.begin(), results.end()})); return results; diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 70a63b22be7f..16ec86f22709 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -59,7 +59,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) final; Array GetProducers(const BlockRV& block_rv) final; Array GetConsumers(const BlockRV& block_rv) final; - Array GetOutputBlocks(const Optional& func_name) final; + Array GetOutputBlocks(const BlockRV& scope_block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const Array& loop_rvs) final; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index d087b90e6984..ba2c134def7c 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -360,7 +360,7 @@ def test_annotate_unannotate_block(): def test_get_output_blocks_single_output(): sch = tir.Schedule(mod=matmul_relu, debug_mask="all") - output_blocks = sch.get_output_blocks() + output_blocks = sch.get_output_blocks("root") assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected" block = sch.get(output_blocks[0]) assert block.name_hint == "relu" @@ -370,7 +370,7 @@ def test_get_output_blocks_single_output(): def test_get_output_blocks_multiple_outputs(): sch = tir.Schedule(mod=matmul, debug_mask="all") - output_blocks = sch.get_output_blocks() + output_blocks = sch.get_output_blocks("root") assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" block_1 = sch.get(output_blocks[0]) assert block_1.name_hint == "init" @@ -397,7 +397,7 @@ def blockized( B[vi, vj] = A[vi, vj] * 2.0 sch = tir.Schedule(mod=blockized, debug_mask="all") - output_blocks = sch.get_output_blocks() + output_blocks = sch.get_output_blocks("root") assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" block_1 = sch.get(output_blocks[0]) assert block_1.name_hint == "blockized_B" @@ -408,6 +408,14 @@ def blockized( b_block = sch.get_block("B") assert sch.get(b_block).same_as(block_2) + sch = tir.Schedule(mod=blockized, debug_mask="all") + output_blocks = sch.get_output_blocks("blockized_B") + assert len(output_blocks) == 1, "Unexpected number of blocks when 1 were expected" + block = sch.get(output_blocks[0]) + assert block.name_hint == "B" + b_block = sch.get_block("B") + assert sch.get(b_block).same_as(block) + if __name__ == "__main__": tvm.testing.main()