Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,15 @@ class ScheduleNode : public runtime::Object {
* block
*/
virtual Array<BlockRV> GetConsumers(const BlockRV& block_rv) = 0;
/*!
* \brief Get the list of output blocks within the given scope
* An output block is a block which has atleast one buffer being written
* to, but is not allocated within the PrimFunc
* \param scope_block_rv The scope block from which output blocks are collected
* \return A list of all blocks that write to some output buffer
* block
*/
virtual Array<BlockRV> GetOutputBlocks(const BlockRV& scope_block_rv) = 0;
/******** Schedule: Transform loops ********/
/*!
* \brief Merge a list of loops into one. The loops under their LCA requires:
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,29 @@ def get_consumers(self, block: Union[BlockRV, str]) -> List[BlockRV]:
block = self._normalize_block_arg(block)
return list(_ffi_api.ScheduleGetConsumers(self, block)) # type: ignore # pylint: disable=no-member

@type_checked
def get_output_blocks(
self,
scope_block: Union[BlockRV, str],
) -> List[BlockRV]:
"""Get the list of output blocks within the given scope
An output block is a block which has atleast one buffer being written
to, but is not allocated within the PrimFunc

Parameters
----------
scope_block : Union[BlockRV, str],
The scope block from which output blocks are collected

Returns
-------
output_blocks : List[BlockRV]
A list of all blocks that write to some output buffer

"""
scope_block = self._normalize_block_arg(scope_block)
return list(_ffi_api.ScheduleGetOutputBlocks(self, scope_block)) # type: ignore # pylint: disable=no-member

########## Schedule: Transform loops ##########
@type_checked
def merge(
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,16 @@ Array<StmtSRef> GetProducers(const StmtSRef& block_sref, const BlockScope& scope
*/
Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope);

/*!
* \brief Get the list of output blocks within the given scope
* An output block is a block which has atleast one buffer being written
* to, but is not allocated within the PrimFunc
* \param scope_block_rv The scope block from which output blocks are collected
* \return A list of all blocks that write to some output buffer
* block
*/
Array<StmtSRef> GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block);

/*!
* \brief A solution to split a ordered list of subtrees into two parts,
* where producers are on the LHS and consumers are on the RHS.
Expand Down
27 changes: 27 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,33 @@ Array<StmtSRef> GetConsumers(const StmtSRef& block_sref, const BlockScope& scope
return results;
}

Array<StmtSRef> GetOutputBlocks(const ScheduleState& self, const BlockNode* scope_block) {
struct OutputBlockCollector : public StmtVisitor {
explicit OutputBlockCollector(const ScheduleState& self) : self_(self) {}

void VisitStmt_(const BlockNode* block) override {
auto it = self_->stmt2ref.find(block);
ICHECK(it != self_->stmt2ref.end());
auto block_sref = it->second;
if (block_sref->parent != nullptr) {
StmtSRef scope_root_sref =
GetScopeRoot(self_, block_sref, /*require_stage_pipeline=*/false);
if (IsOutputBlock(self_, block_sref, scope_root_sref)) {
results_.push_back(block_sref);
}
}
StmtVisitor::VisitStmt_(block);
}

const ScheduleState& self_;
Array<StmtSRef> results_;
};
OutputBlockCollector collector(self);
collector(scope_block->body);
auto results = collector.results_;
return results;
}

ProducerConsumerSplit ProducerConsumerSplit::Find(
const ScheduleState& self, const Array<Stmt>& subtrees,
const Array<StmtSRef>& producer_block_srefs, const Array<StmtSRef>& consumer_block_srefs,
Expand Down
7 changes: 7 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ Array<BlockRV> ConcreteScheduleNode::GetConsumers(const BlockRV& block_rv) {
throw;
}

Array<BlockRV> ConcreteScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV<BlockRV>(tir::GetOutputBlocks(state_, this->GetSRef(scope_block_rv)));
TVM_TIR_SCHEDULE_END("get-output-blocks", this->error_render_level_);
throw;
}

/******** Schedule: Transform loops ********/

LoopRV ConcreteScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ConcreteScheduleNode : public ScheduleNode {
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
Array<BlockRV> GetProducers(const BlockRV& block_rv) override;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) override;
Array<BlockRV> GetOutputBlocks(const BlockRV& scope_block_rv) override;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) override;
LoopRV Merge(const Array<LoopRV>& loop_rvs) override;
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ Array<StmtSRef> GetProducers(const ScheduleState& self, const StmtSRef& block_sr
* \return A list of blocks, the consumers of the given block
*/
Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sref);
/*!
* \brief Get the list of output blocks within the given scope
* An output block is a block which has atleast one buffer being written
* to, but is not allocated within the PrimFunc
* \param scope_block_rv The scope block from which output blocks are collected
* \return A list of all blocks that write to some output buffer
* block
*/
Array<StmtSRef> GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref);
/******** Schedule: Transform loops ********/
/*!
* Split a loop into a list of consecutive loops. It requires:
Expand Down
30 changes: 30 additions & 0 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ Array<StmtSRef> GetConsumers(const ScheduleState& self, const StmtSRef& block_sr
return tir::GetConsumers(block_sref, self->GetBlockScope(scope_root));
}

Array<StmtSRef> GetOutputBlocks(const ScheduleState& self, const StmtSRef& scope_sref) {
const auto* scope_block = TVM_SREF_TO_BLOCK(scope_sref);
return tir::GetOutputBlocks(self, scope_block);
}

/******** InstructionKind Registration ********/

struct GetBlockTraits : public UnpackedInstTraits<GetBlockTraits> {
Expand Down Expand Up @@ -218,11 +223,36 @@ struct GetConsumersTraits : public UnpackedInstTraits<GetConsumersTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct GetOutputBlocksTraits : public UnpackedInstTraits<GetOutputBlocksTraits> {
static constexpr const char* kName = "GetOutputBlocks";
static constexpr bool kIsPure = true;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv) {
return sch->GetOutputBlocks(block_rv);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv) {
PythonAPICall py("get_output_blocks");
py.Input("block", block_rv);
py.OutputList(outputs);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(GetBlockTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetLoopsTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetChildBlocksTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetProducersTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetConsumersTraits);
TVM_REGISTER_INST_KIND_TRAITS(GetOutputBlocksTraits);

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetProducers")
.set_body_method<Schedule>(&ScheduleNode::GetProducers);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetConsumers")
.set_body_method<Schedule>(&ScheduleNode::GetConsumers);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetOutputBlocks")
.set_body_method<Schedule>(&ScheduleNode::GetOutputBlocks);
/******** (FFI) Transform loops ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleMerge").set_body_method<Schedule>(&ScheduleNode::Merge);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleFuse").set_body_method<Schedule>(&ScheduleNode::Fuse);
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,17 @@ Array<BlockRV> TracedScheduleNode::GetConsumers(const BlockRV& block_rv) {
return results;
}

Array<BlockRV> TracedScheduleNode::GetOutputBlocks(const BlockRV& scope_block_rv) {
Array<BlockRV> results = ConcreteScheduleNode::GetOutputBlocks(scope_block_rv);

static const InstructionKind& kind = InstructionKind::Get("GetOutputBlocks");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind, //
/*inputs=*/{scope_block_rv},
/*attrs=*/{},
/*outputs=*/{results.begin(), results.end()}));
return results;
}

/******** Schedule: Transform loops ********/

LoopRV TracedScheduleNode::Merge(const Array<LoopRV>& loop_rvs) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) final;
Array<BlockRV> GetProducers(const BlockRV& block_rv) final;
Array<BlockRV> GetConsumers(const BlockRV& block_rv) final;
Array<BlockRV> GetOutputBlocks(const BlockRV& scope_block_rv) final;
/******** Schedule: Transform loops ********/
LoopRV Fuse(const Array<LoopRV>& loop_rvs, bool preserve_unit_iters) final;
LoopRV Merge(const Array<LoopRV>& loop_rvs) final;
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,5 +358,64 @@ def test_annotate_unannotate_block():
verify_trace_roundtrip(sch=sch, mod=matmul_relu)


def test_get_output_blocks_single_output():
sch = tir.Schedule(mod=matmul_relu, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 1, "Unexpected number of blocks when 1 was expected"
block = sch.get(output_blocks[0])
assert block.name_hint == "relu"
relu_block = sch.get_block("relu")
assert sch.get(relu_block).same_as(block)


def test_get_output_blocks_multiple_outputs():
sch = tir.Schedule(mod=matmul, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected"
block_1 = sch.get(output_blocks[0])
assert block_1.name_hint == "init"
block_2 = sch.get(output_blocks[1])
assert block_2.name_hint == "update"
init_block = sch.get_block("init")
assert sch.get(init_block).same_as(block_1)
update_block = sch.get_block("update")
assert sch.get(update_block).same_as(block_2)


def test_get_output_blocks_nested():
@T.prim_func
def blockized(
A: T.Buffer((128, 128), "float32"),
B: T.Buffer((128, 128), "float32"),
) -> None:
with T.block("blockized_B"):
vio = T.axis.spatial(1, 0)
vjo = T.axis.spatial(1, 0)
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0

sch = tir.Schedule(mod=blockized, debug_mask="all")
output_blocks = sch.get_output_blocks("root")
assert len(output_blocks) == 2, "Unexpected number of blocks when 2 were expected"
block_1 = sch.get(output_blocks[0])
assert block_1.name_hint == "blockized_B"
block_2 = sch.get(output_blocks[1])
assert block_2.name_hint == "B"
blockized_block = sch.get_block("blockized_B")
assert sch.get(blockized_block).same_as(block_1)
b_block = sch.get_block("B")
assert sch.get(b_block).same_as(block_2)

sch = tir.Schedule(mod=blockized, debug_mask="all")
output_blocks = sch.get_output_blocks("blockized_B")
assert len(output_blocks) == 1, "Unexpected number of blocks when 1 were expected"
block = sch.get(output_blocks[0])
assert block.name_hint == "B"
b_block = sch.get_block("B")
assert sch.get(b_block).same_as(block)


if __name__ == "__main__":
tvm.testing.main()