Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,23 @@ class ScheduleNode : public runtime::Object {
*/
virtual void PadEinsum(const BlockRV& block_rv, const Array<Integer>& padding) = 0;

/******** Schedule: Buffer transformation ********/
/*!
* \brief Compute the target buffer via rolling buffering.
* \details This primitive selects the outermost rollable axis with a positive bound overlap that
* appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
* the rolling dimension, append block predicate to avoid recomputing overlapping elements.
* It requires:
* 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
* 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
* the producer and consumer of the buffer are cascaded through compute_at.
* 3) The access region of the buffer has at least one dimension that contains
* a positive bound overlap.
* \param block_rv The producer block of the buffer.
* \param write_buffer_index The index of the buffer in block's write region.
*/
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
108 changes: 108 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3042,6 +3042,114 @@ def after_pad_einsum(
self, block, padding
)

######## Schedule: Buffer transformation ########

@type_checked
def rolling_buffer(
self,
block: Union[BlockRV, str],
write_buffer_index: int,
) -> None:
"""Compute the target buffer via rolling buffering, select the outermost rollable
axis with a positive bound overlap that appears in the block's ancestor loops
as `rolling axis`, fold and circularize the buffer along the rolling dimension,
append block predicate to avoid recomputing overlapping elements. It requires:

1) The block is not an output block and has only RAW dependencies.

2) The buffer to be an intermediate buffer defined via `alloc_buffer`.

3) The LCA of the producer and consumer of the buffer is a for loop, typically,
the producer and consumer of the buffer are cascaded through compute_at.

4) The access region of the buffer has at least one dimension that contains
a positive bound overlap.

Parameters
----------
block : Union[BlockRV, str]
The producer block of the buffer.
write_buffer_index : int
The index of the buffer in block's write region.

Examples
--------

Before rolling_buffer, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_rolling_buffer(
A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"]
) -> None:
# body
# with T.block("root")
B = T.alloc_buffer([10, 10], dtype="int8")
for i0, i1 in T.grid(2, 2):
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
with T.block("B"):
ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
B[ax0_1, ax1_1] = T.max(
B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
)
for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
with T.block("C"):
ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
C[ax0_1, ax1_1] = T.max(
C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1]
)

Create the schedule and do rolling_buffer:

.. code-block:: python

sch = tir.Schedule(before_rolling_buffer)
sch.rolling_buffer(sch.get_block("B"), write_buffer_index=0)
print(sch.mod["main"].script())

After applying rolling_buffer, the IR becomes:

.. code-block:: python

@T.prim_func
def after_rolling_buffer(
A: T.Buffer[(12, 12), "int8"],
C: T.Buffer[(8, 8), "int8"]
) -> None:
# body
# with T.block("root")
B = T.alloc_buffer([6, 10], dtype="int8")
for i0, i1 in T.grid(2, 2):
for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
with T.block("B"):
T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1))
ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
B[ax0_1 % 6, ax1_1] = T.max(
B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1]
)
for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
with T.block("C"):
ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
C[ax0_1, ax1_1] = T.max(
C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1]
)

Note
----
The region_cover property of the consumer block of the target buffer will become false.
"""
block = self._normalize_block_arg(block)
return _ffi_api.ScheduleRollingBuffer(self, block, write_buffer_index) # type: ignore # pylint: disable=no-member

########## Schedule: Misc ##########

@type_checked
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,8 @@ void ConcreteScheduleNode::SetAxisSeparator(const BlockRV& block_rv, int buffer_
this->state_->DebugVerify();
}

/******** Schedule: Padding ********/

BlockRV ConcreteScheduleNode::DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
Expand All @@ -816,6 +818,16 @@ void ConcreteScheduleNode::PadEinsum(const BlockRV& block_rv, const Array<Intege
TVM_TIR_SCHEDULE_END("pad-einsum", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Buffer Transformation ********/

void ConcreteScheduleNode::RollingBuffer(const BlockRV& block_rv, int write_buffer_index) {
TVM_TIR_SCHEDULE_BEGIN();
tir::RollingBuffer(state_, this->GetSRef(block_rv), write_buffer_index);
TVM_TIR_SCHEDULE_END("rolling-buffer", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Misc ********/

} // namespace tir
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const Array<IntImm>& axis_separators) override;
/******** Schedule: Padding decomposition ********/
BlockRV DecomposePadding(const BlockRV& block_rv, const LoopRV& loop_rv) override;
/******** Schedule: Buffer transformation ********/
void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) override;
/******** Schedule: Misc ********/
void EnterPostproc() override {}

Expand Down
16 changes: 16 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,22 @@ TVM_DLL StmtSRef DecomposePadding(ScheduleState self, const StmtSRef& block_sref
TVM_DLL void PadEinsum(ScheduleState self, const StmtSRef& block_sref,
const Array<Integer>& padding);

/******** Schedule: Buffer transformation ********/
/*!
* \brief Compute the target buffer via rolling buffering.
* \details This primitive selects the outermost rollable axis with a positive bound overlap that
* appears in the block's ancestor loops as `rolling axis`, fold and circularize the buffer along
* the rolling dimension, append block predicate to avoid recomputing overlapping elements.
* It requires:
* 1) The buffer to be an intermediate buffer defined via `alloc_buffer`.
* 2) The LCA of the producer and consumer of the buffer is a for loop, typically,
* the producer and consumer of the buffer are cascaded through compute_at.
* 3) The access region of the buffer has at least one dimension that contains
* a positive bound overlap.
* \param block_rv The producer block of the buffer.
* \param write_buffer_index The index of the buffer in block's write region.
*/
TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index);
/******** Schedule: Misc ********/

} // namespace tir
Expand Down
Loading