Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,16 @@ class ScheduleNode : public runtime::Object {
*/
virtual BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
* It requires the the target block both read & write the target buffer.
* \param block_rv The target block operates on the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope
* \return The cache stage blocks, cache read block together with cache write block.
*/
virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
Expand Down
89 changes: 89 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,95 @@ def after_cache_write(a: T.handle, b: T.handle) -> None:
self, block, write_buffer_index, storage_scope
)

@type_checked
def cache_inplace(
self,
block: Union[BlockRV, str],
read_buffer_index: Union[int, str, Buffer],
storage_scope: str,
) -> List[BlockRV]:
"""Create blocks that reads & write a buffer region into a cache block.
It requires the the target block both read & write the target buffer.
Mainly for inplace operation.

Parameters
----------
block : Union[BlockRV, str]
The target block operates on the target buffer.

read_buffer_index: int
The index of the buffer in block's read region, the unique
name of a read buffer in the block, or a Buffer object
that is within the blocks read region.

storage_scope: str
The target storage scope.


Returns
-------
cached_blocks : List[BlockRV]
The blocks of the cache stage, read cache first, write cache second

Examples
--------
Before cache_inplace, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def before_cache_inplace(data_io: T.Buffer[(64), "int32"]):
for i0 in T.serial(1):
with T.block("A"):
T.reads(data_io[:64])
T.writes(data_io[:64])
T.evaluate(T.call_extern("call_impl", data_io.data, dtype=""))

Create the schedule and cache_inplace:

.. code-block:: python

sch = tir.Schedule(before_cache_inplace)
block_a = sch.get_block("A")
sch.cache_inplace(block_a, 0, "local")
print(sch.mod["main"].script())

After applying cache_inplace, the IR becomes:

.. code-block:: python

@T.prim_func
def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None:
data_io_local = T.alloc_buffer([64], dtype="int32", scope="local")
for i0 in T.serial(1):
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io[v0])
T.writes(data_io_local[v0])
data_io_local[v0] = data_io[v0]
with T.block("A"):
T.reads(data_io_local[0 : 64])
T.writes(data_io_local[0 : 64])
T.evaluate(T.call_extern("call_impl", data_io_local.data, dtype=""))
for ax0 in T.serial(64):
with T.block("data_io_local"):
v0 = T.axis.spatial(64, ax0)
T.reads(data_io_local[v0])
T.writes(data_io[v0])
data_io[v0] = data_io_local[v0]

"""
block = self._normalize_block_arg(block)

if not isinstance(read_buffer_index, int):
_, read_buffer_index, _ = self._normalize_buffer_arg(
block, read_buffer_index, required_buffer_type="read"
)
return _ffi_api.ScheduleCacheInplace( # type: ignore # pylint: disable=no-member
self, block, read_buffer_index, storage_scope
)

@type_checked
def reindex(
self,
Expand Down
13 changes: 13 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,19 @@ BlockRV ConcreteScheduleNode::CacheWrite(const BlockRV& block_rv, int write_buff
return CreateRV<BlockRV>(result);
}

Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) {
Array<StmtSRef> results;
TVM_TIR_SCHEDULE_BEGIN();
results = tir::CacheInplace(state_, this->GetSRef(block_rv), write_buffer_index, storage_scope);
TVM_TIR_SCHEDULE_END("cache-buffer", this->error_render_level_);
this->state_->DebugVerify();
Array<BlockRV> return_blocks;
return_blocks.push_back(CreateRV<BlockRV>(results[0]));
return_blocks.push_back(CreateRV<BlockRV>(results[1]));
return return_blocks;
}

BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) {
StmtSRef result{nullptr};
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const Array<BlockRV> consumer_blocks = {}) override;
BlockRV CacheWrite(const BlockRV& block_rv, int write_buffer_index,
const String& storage_scope) override;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
Expand Down
14 changes: 13 additions & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,18 @@ TVM_DLL StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int r
*/
TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope);
/*!
*!
* \brief Create 2 blocks that read&write a buffer region into a read/write cache.
* It requires the the target block both read & write the target buffer.
* \param self The state of the schedule
* \param block_sref The target block operates on the target buffer.
* \param read_buffer_index The index of the buffer in block's read region.
* \param storage_scope The target storage scope
* \return The cache stage blocks, cache read block together with cache write block.
*/
TVM_DLL Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope);
/*!
*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
Expand All @@ -275,7 +287,7 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int
* 1) There is only one block who reads/writes the target buffer
* 2) There is only one buffer load/store of this buffer in the block
* \param self The state of the schedule
* \param block_rv The block operates on the target buffer.
* \param block_sref The block operates on the target buffer.
* \param buffer_index The index of the buffer in block's read or write region.
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \return The reindex stage block.
Expand Down
Loading