Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,14 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) = 0;
/*!
* \brief Create a block to cache precomputed index for later use.
* if there is no index computation, keep unchanged.
* \param block_rv The target block
* \param buffer_index The index of the target buffer in block's read region
* \return The cache stage blocks.
*/
virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, int buffer_index) = 0;
/*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
Expand Down
86 changes: 86 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1278,6 +1278,92 @@ def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None:
self, block, read_buffer_index, storage_scope
)

@type_checked
def cache_index(
self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer]
) -> List[BlockRV]:
"""Create a block to cache precomputed index for later use.
if there is no index computation, keep unchanged.

Parameters
----------
block : Union[BlockRV, str]
The target block operates on the target buffer.

buffer_index: int
The index of the target buffer in block's read region


Returns
-------
cached_blocks : List[BlockRV]
The blocks of the stage writing the cache buffers

Examples
--------
Before cache_inplace, in TensorIR, the IR is:

.. code-block:: python

@T.prim_func
def resize(a: T.handle, b: T.handle) -> None:
A = T.match_buffer(a, (1, 3, 40, 40))
B = T.match_buffer(b, (1, 3, 80, 80))
for i0, i1, i2, i3 in T.grid(1, 3, 80, 80):
with T.block("A"):
n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3])
B[n, c, vi, vj] = A[n, c, vi//4 + vj//4, vj//2]

Create the schedule and cache_index:

.. code-block:: python

sch = tir.Schedule(resize)
block_a = sch.get_block("A")
sch.cache_index(block_a, 0)
print(sch.mod["main"].script())

After applying cache_index, the IR becomes:

.. code-block:: python

@T.prim_func
def resize_cache_index(
A: T.Buffer[(1, 3, 40, 40), "float32"], B: T.Buffer[(1, 3, 80, 80), "float32"]
) -> None:
index_var_0 = T.alloc_buffer([80, 80], dtype="int32", strides=[1])
index_var_1 = T.alloc_buffer([80], dtype="int32", strides=[1])
for ax0, ax1 in T.grid(80, 80):
with T.block("index_0"):
v0 = T.axis.spatial(80, ax0)
v1 = T.axis.spatial(80, ax1)
T.reads()
T.writes(index_var_0[v0, v1])
index_var_0[v0, v1] = v0 // 4 + v1 // 4
for ax0 in T.serial(80):
with T.block("index_1"):
v0 = T.axis.spatial(80, ax0)
T.reads()
T.writes(index_var_1[v0])
index_var_1[v0] = v0 // 2
for i0, i1, i2, i3 in T.grid(1, 3, 80, 80):
with T.block("A"):
n, c, vi, vj = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[n, c, vi // 4 + vj // 4, vj // 2])
T.writes(B[n, c, vi, vj])
B[n, c, vi, vj] = A[n, c, index_var_0[vi, vj], index_var_1[vj]]

"""
block = self._normalize_block_arg(block)

if not isinstance(buffer_index, int):
_, buffer_index, _ = self._normalize_buffer_arg(
block, buffer_index, required_buffer_type="read"
)
return _ffi_api.ScheduleCacheIndex( # type: ignore # pylint: disable=no-member
self, block, buffer_index
)

@type_checked
def reindex(
self,
Expand Down
13 changes: 13 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,19 @@ Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int w
return return_blocks;
}

Array<BlockRV> ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) {
Array<StmtSRef> result;
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheIndex(state_, this->GetSRef(block_rv), buffer_index);
TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_);
this->state_->DebugVerify();
Array<BlockRV> return_blocks;
for (const StmtSRef& blockrv : result) {
return_blocks.push_back(CreateRV<BlockRV>(blockrv));
}
return return_blocks;
}

BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) {
StmtSRef result{nullptr};
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ConcreteScheduleNode : public ScheduleNode {
const String& storage_scope) override;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
Array<BlockRV> CacheIndex(const BlockRV& block_rv, int write_buffer_index) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ TVM_DLL StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int
*/
TVM_DLL Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref,
int read_buffer_index, const String& storage_scope);
/*!
* \brief Create a block to cache precomputed index for later use.
* if there is no index computation, keep unchanged.
* \param block_sref The target block
* \param buffer_index The index of the target buffer in block's read region,
* \return The cache stage block.
*/
TVM_DLL Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
int buffer_index);
/*!
*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
Expand Down
Loading