Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,21 @@ class ScheduleNode : public runtime::Object {
* \return The rfactor block
*/
virtual BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) = 0;
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
* memory.
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
* \param factor The factor multiple of alignment
* \param offset The required offset factor
*/
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
73 changes: 73 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,79 @@ def after_rfactor(a: ty.handle, b: ty.handle) -> None:
"""
return _ffi_api.ScheduleRFactor(self, loop, factor_axis) # type: ignore # pylint: disable=no-member

######## Schedule: Block annotatoin ########

def storage_align( # pylint: disable=too-many-arguments
self, block: BlockRV, buffer_index: int, axis: int, factor: int, offset: int
) -> None:
"""Set alignment requirement for specific dimension such that
stride[axis] == k * factor + offset for some k. This is useful to set memory layout for more
friendly memory access pattern. For example, we can set alignment to be factor=2, offset=1
to avoid bank conflict for thread access on higher dimension in GPU shared memory.

Parameters
----------
block : BlockRV
The producer block of the buffer.
buffer_index : int
The index of the buffer in block's write region.
axis : int
The dimension to be specified for alignment.
factor : int
The factor multiple of alignment.
offset : int
The required offset factor.

Examples
--------

Before storage_align, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0

Create the schedule and do storage_align:

.. code-block:: python

sch = tir.Schedule(before_storage_align)
sch.storage_align(sch.get_block("B"), buffer_index=0, axis=0, factor=128, offset=1)
print(tvm.script.asscript(sch.mod["main"]))

After applying rfactor, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_storage_align(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.alloc_buffer((128, 128))
C = tir.match_buffer(c, (128, 128))
with tir.block([128, 128], "B") as [vi, vj]:
tir.block_attr({"buffer_dim_align": [[[0, 128, 1]]]})
B[vi, vj] = A[vi, vj] * 2.0
with tir.block([128, 128], "C") as [vi, vj]:
C[vi, vj] = B[vi, vj] + 1.0

After lowering passes, buffer B will have strides as [129, 1].

Note
----
Storage_align requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
_ffi_api.ScheduleStorageAlign( # type: ignore # pylint: disable=no-member
self, block, buffer_index, axis, factor, offset
)

########## Schedule: Blockize & Tensorize ##########

########## Schedule: Annotation ##########
Expand Down
13 changes: 13 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,19 @@ BlockRealize CheckGetSingleChildBlockRealizeOnSRefTree(const ScheduleState& self
*/
BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sref);

/******** Block-buffer relation ********/

/*!
* \brief Get the BlockRealize of the single child block of the block or loop specified by
* `parent_sref` on SRef tree, or throw an exception if there is 0 or multiple child blocks
* \param self The schedule state
* \param block The queried block
* \param n The index of the queried buffer
* \return The buffer of the n-th write region of the block.
* \throw ScheduleError If the buffer index is out of bound.
*/
Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n);

/******** Commutative Reducer ********/

/*!
Expand Down
39 changes: 39 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,45 @@ BlockRealize GetBlockRealize(const ScheduleState& self, const StmtSRef& block_sr
}
}

/******** Block-buffer relation ********/

Buffer GetNthWriteBuffer(const ScheduleState& self, const Block& block, int n) {
class WriteBufferIndexOutOfRangeError : public ScheduleError {
public:
explicit WriteBufferIndexOutOfRangeError(IRModule mod, Block block, int buffer_index)
: mod_(std::move(mod)), block_(std::move(block)), buffer_index_(buffer_index) {}

String FastErrorString() const final {
return "ScheduleError: The input `buffer_index` is out of range. It is required to be in "
"range [0, num_write_regions) where `num_write_regions` is the number of buffer "
"regions written by the block.";
}

String DetailRenderTemplate() const final {
std::ostringstream os;
size_t num_writes = block_->writes.size();
os << "The block {0} has " << num_writes
<< " write regions, so `buffer_index` is required to be in [0, " << num_writes
<< "). However, the input `buffer_index` is " << buffer_index_
<< ", which is out of the expected range";
return os.str();
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Block block_;
int buffer_index_;
};

if (n < 0 || static_cast<size_t>(n) >= block->writes.size()) {
throw WriteBufferIndexOutOfRangeError(self->mod, block, n);
}
return block->writes[n]->buffer;
}

/******** Pattern Matcher ********/

/*!
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ void ConcreteScheduleNode::ReverseComputeInline(const BlockRV& block_rv) {
}

/******** Schedule: loop binding/annotation ********/
/******** Schedule: block annotation ********/

void ConcreteScheduleNode::StorageAlign(const BlockRV& block_rv, int buffer_index, int axis,
int factor, int offset) {
TVM_TIR_SCHEDULE_BEGIN();
tir::StorageAlign(state_, this->GetSRef(block_rv), buffer_index, axis, factor, offset);
TVM_TIR_SCHEDULE_END("storage-align", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: cache read/write ********/
/******** Schedule: reduction ********/

Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class ConcreteScheduleNode : public ScheduleNode {
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
BlockRV RFactor(const LoopRV& loop_rv, int factor_axis) override;
/******** Schedule: Block annotation ********/
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
20 changes: 20 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,26 @@ TVM_DLL void ReverseComputeInline(ScheduleState self, const StmtSRef& block_sref
* \return The sref of the rfactor block
*/
TVM_DLL StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis);
/******** Schedule: Block annotation ********/
/*!
* \brief Set alignment requirement for specific dimension such that
* stride[axis] == k * factor + offset for some k. This is useful to set memory layout for
* more friendly memory access pattern. For example, we can set alignment to be factor=2,
* offset=1 to avoid bank conflict for thread access on higher dimension in GPU shared
* memory.
* \param block_sref The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param axis The dimension to be specified for alignment
* \param factor The factor multiple of alignment
* \param offset The required offset factor
*/
TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
int axis, int factor, int offset);

/******** Annotation types for StorageAlign ********/
using StorageAlignTuple = Array<Integer>; // (buffer_idx, axis, factor, offset)
using StorageAlignAnnotation = Array<StorageAlignTuple>; // unordered array of StorageAlignTuple

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
/******** Schedule: Misc ********/
Expand Down
Loading