diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 257c9c5d04eb..c294d0ae8762 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -291,6 +291,15 @@ class ScheduleNode : public runtime::Object { * block */ virtual Array GetConsumers(const BlockRV& block_rv) = 0; + /*! + * \brief Get the list of output blocks within the given scope + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected + * \return A list of all blocks that write to some output buffer + * block + */ + virtual Array GetOutputBlocks(const BlockRV& scope_block_rv) = 0; /******** Schedule: Transform loops ********/ /*! * \brief Merge a list of loops into one. The loops under their LCA requires: diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 8113097003fa..b19e30848f69 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -540,6 +540,29 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]: block = self._normalize_block_arg(block) return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member + @type_checked + def get_output_blocks( + self, + scope_block: Union[BlockRV, str], + ) -> List[BlockRV]: + """Get the list of output blocks within the given scope + An output block is a block which has atleast one buffer being written + to, but is not allocated within the PrimFunc + + Parameters + ---------- + scope_block : Union[BlockRV, str], + The scope block from which output blocks are collected + + Returns + ------- + output_blocks : List[BlockRV] + A list of all blocks that write to some output buffer + + """ + scope_block = self._normalize_block_arg(scope_block) + return list(_ffi_api.ScheduleGetOutputBlocks(self, scope_block)) # type: ignore # pylint: disable=no-member + ########## Schedule: Transform loops ########## @type_checked def merge( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index bc505a0104be..7b9c2a9a2448 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -385,6 +385,16 @@ Array GetProducers(const StmtSRef& block_sref, const BlockScope& scope */ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope); +/*! + * \brief Get the list of output blocks within the given scope + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected + * \return A list of all blocks that write to some output buffer + * block + */ +Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block); + /*! * \brief A solution to split a ordered list of subtrees into two parts, * where producers are on the LHS and consumers are on the RHS. diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 674abe28a3e0..2c4da4aaf731 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1043,6 +1043,33 @@ Array GetConsumers(const StmtSRef& block_sref, const BlockScope& scope return results; } +Array GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) { + struct OutputBlockCollector : public StmtVisitor { + explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {} + + void VisitStmt_(const BlockNode* block) override { + auto it = self_->stmt2ref.find(block); + ICHECK(it != self_->stmt2ref.end()); + auto block_sref = it->second; + if (block_sref->parent != nullptr) { + StmtSRef scope_root_sref = + GetScopeRoot(self_, block_sref, /*require_stage_pipeline=*/false); + if (IsOutputBlock(self_, block_sref, scope_root_sref)) { + results_.push_back(block_sref); + } + } + StmtVisitor::VisitStmt_(block); + } + + const ScheduleState& self_; + Array results_; + }; + OutputBlockCollector collector(self); + collector(scope_block->body); + auto results = collector.results_; + return results; +} + ProducerConsumerSplit ProducerConsumerSplit::Find( const ScheduleState& self, const Array& subtrees, const Array& producer_block_srefs, const Array& consumer_block_srefs, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 8e5eefdb6ac4..7192a4809994 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -354,6 +354,13 @@ Array ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) { throw; } +Array ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv))); + TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_); + throw; +} + /******** Schedule: Transform loops ********/ LoopRV ConcreteScheduleNode::Merge(const Array& loop_rvs) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 903658e33949..eb7c38753c94 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -99,6 +99,7 @@ class ConcreteScheduleNode : public ScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) override; Array GetProducers(const BlockRV& block_rv) override; Array GetConsumers(const BlockRV& block_rv) override; + Array GetOutputBlocks(const BlockRV& scope_block_rv) override; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) override; LoopRV Merge(const Array& loop_rvs) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0050789243f8..5f5591ac45a0 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -148,6 +148,15 @@ Array GetProducers(const ScheduleState& self, const StmtSRef& block_sr * \return A list of blocks, the consumers of the given block */ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sref); +/*! + * \brief Get the list of output blocks within the given scope + * An output block is a block which has atleast one buffer being written + * to, but is not allocated within the PrimFunc + * \param scope_block_rv The scope block from which output blocks are collected + * \return A list of all blocks that write to some output buffer + * block + */ +Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref); /******** Schedule: Transform loops ********/ /*! * Split a loop into a list of consecutive loops. It requires: diff --git a/src/tir/schedule/primitive/get_block_loop.cc b/src/tir/schedule/primitive/get_block_loop.cc index 72f43a8d4929..87ec6e550dcd 100644 --- a/src/tir/schedule/primitive/get_block_loop.cc +++ b/src/tir/schedule/primitive/get_block_loop.cc @@ -88,6 +88,11 @@ Array GetConsumers(const ScheduleState& self, const StmtSRef& block_sr return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root)); } +Array GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) { + const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref); + return tir::GetOutputBlocks(self, scope_block); +} + /******** InstructionKind Registration ********/ struct GetBlockTraits : public UnpackedInstTraits { @@ -218,11 +223,36 @@ struct GetConsumersTraits : public UnpackedInstTraits { friend struct ::tvm::tir::UnpackedInstTraits; }; +struct GetOutputBlocksTraits : public UnpackedInstTraits { + static constexpr const char* kName = "GetOutputBlocks"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 1; + static constexpr size_t kNumAttrs = 0; + static constexpr size_t kNumDecisions = 0; + + static Array UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) { + return sch->GetOutputBlocks(block_rv); + } + + static String UnpackedAsPython(Array outputs, String block_rv) { + PythonAPICall py("get_output_blocks"); + py.Input("block", block_rv); + py.OutputList(outputs); + return py.Str(); + } + + template + friend struct ::tvm::tir::UnpackedInstTraits; +}; + TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits); TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits); TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits); TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits); TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits); +TVM_REGISTER_INST_KIND_TRAITS(GetOutputBlocksTraits); } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index aafb29800bc1..fdf473ff7972 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -152,6 +152,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers") .set_body_method(&ScheduleNode::GetProducers); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers") .set_body_method(&ScheduleNode::GetConsumers); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks") + .set_body_method(&ScheduleNode::GetOutputBlocks); /******** (FFI) Transform loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method(&ScheduleNode::Merge); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method(&ScheduleNode::Fuse); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index fe48c52e3103..4d820078e527 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -174,6 +174,17 @@ Array TracedScheduleNode::GetConsumers(const BlockRV& block_rv) { return results; } +Array TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) { + Array results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv); + + static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{scope_block_rv}, + /*attrs=*/{}, + /*outputs=*/{results.begin(), results.end()})); + return results; +} + /******** Schedule: Transform loops ********/ LoopRV TracedScheduleNode::Merge(const Array& loop_rvs) { diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 9630d1513e8d..16ec86f22709 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -59,6 +59,7 @@ class TracedScheduleNode : public ConcreteScheduleNode { Array GetChildBlocks(const LoopRV& loop_rv) final; Array GetProducers(const BlockRV& block_rv) final; Array GetConsumers(const BlockRV& block_rv) final; + Array GetOutputBlocks(const BlockRV& scope_block_rv) final; /******** Schedule: Transform loops ********/ LoopRV Fuse(const Array& loop_rvs, bool preserve_unit_iters) final; LoopRV Merge(const Array& loop_rvs) final; diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 53ee6a58cd9a..ba2c134def7c 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -358,5 +358,64 @@ def test_annotate_unannotate_block(): verify_trace_roundtrip(sch=sch, mod=matmul_relu) +def test_get_output_blocks_single_output(): + sch = tir.Schedule(mod=matmul_relu, debug_mask="all") + output_blocks = sch.get_output_blocks("root") + assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected" + block = sch.get(output_blocks[0]) + assert block.name_hint == "relu" + relu_block = sch.get_block("relu") + assert sch.get(relu_block).same_as(block) + + +def test_get_output_blocks_multiple_outputs(): + sch = tir.Schedule(mod=matmul, debug_mask="all") + output_blocks = sch.get_output_blocks("root") + assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" + block_1 = sch.get(output_blocks[0]) + assert block_1.name_hint == "init" + block_2 = sch.get(output_blocks[1]) + assert block_2.name_hint == "update" + init_block = sch.get_block("init") + assert sch.get(init_block).same_as(block_1) + update_block = sch.get_block("update") + assert sch.get(update_block).same_as(block_2) + + +def test_get_output_blocks_nested(): + @T.prim_func + def blockized( + A: T.Buffer((128, 128), "float32"), + B: T.Buffer((128, 128), "float32"), + ) -> None: + with T.block("blockized_B"): + vio = T.axis.spatial(1, 0) + vjo = T.axis.spatial(1, 0) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + + sch = tir.Schedule(mod=blockized, debug_mask="all") + output_blocks = sch.get_output_blocks("root") + assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected" + block_1 = sch.get(output_blocks[0]) + assert block_1.name_hint == "blockized_B" + block_2 = sch.get(output_blocks[1]) + assert block_2.name_hint == "B" + blockized_block = sch.get_block("blockized_B") + assert sch.get(blockized_block).same_as(block_1) + b_block = sch.get_block("B") + assert sch.get(b_block).same_as(block_2) + + sch = tir.Schedule(mod=blockized, debug_mask="all") + output_blocks = sch.get_output_blocks("blockized_B") + assert len(output_blocks) == 1, "Unexpected number of blocks when 1 were expected" + block = sch.get(output_blocks[0]) + assert block.name_hint == "B" + b_block = sch.get_block("B") + assert sch.get(b_block).same_as(block) + + if __name__ == "__main__": tvm.testing.main()