Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,12 @@ class ScheduleNode : public runtime::Object {
* \brief Create a block to cache precomputed index for later use.
* if there is no index computation, keep unchanged.
* \param block_rv The target block
* \param buffer_index The index of the target buffer in block's read region
* \param storage_scope The storage scope of cached block
* \param cse_thresh The repeat threshold that determines a common sub expr
* \return The cache stage blocks.
*/
virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, int buffer_index) = 0;
virtual Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
int cse_thresh) = 0;
/*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
* The layout of the cache will be the same as by the iterators of the block that reads/writes the
Expand Down
21 changes: 12 additions & 9 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,10 @@ def cache_inplace(data_io: T.Buffer[64, "int32"]) -> None:

@type_checked
def cache_index(
self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer]
self,
block: Union[BlockRV, str],
storage_scope: str,
cse_thresh: int = 0,
) -> List[BlockRV]:
"""Create a block to cache precomputed index for later use.
if there is no index computation, keep unchanged.
Expand All @@ -1304,8 +1307,12 @@ def cache_index(
block : Union[BlockRV, str]
The target block operates on the target buffer.

buffer_index: int
The index of the target buffer in block's read region
storage_scope: str
The storage scope of cached block.

cse_thresh: int
The repeat threshold that determines a common sub expr,
default 0 means cache all index computation.


Returns
Expand Down Expand Up @@ -1334,7 +1341,7 @@ def resize(a: T.handle, b: T.handle) -> None:

sch = tir.Schedule(resize)
block_a = sch.get_block("A")
sch.cache_index(block_a, 0)
sch.cache_index(block_a, "global", 1)
print(sch.mod["main"].script())

After applying cache_index, the IR becomes:
Expand Down Expand Up @@ -1370,12 +1377,8 @@ def resize_cache_index(
"""
block = self._normalize_block_arg(block)

if not isinstance(buffer_index, int):
_, buffer_index, _ = self._normalize_buffer_arg(
block, buffer_index, required_buffer_type="read"
)
return _ffi_api.ScheduleCacheIndex( # type: ignore # pylint: disable=no-member
self, block, buffer_index
self, block, storage_scope, cse_thresh
)

@type_checked
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -581,10 +581,11 @@ Array<BlockRV> ConcreteScheduleNode::CacheInplace(const BlockRV& block_rv, int w
return return_blocks;
}

Array<BlockRV> ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv, int buffer_index) {
Array<BlockRV> ConcreteScheduleNode::CacheIndex(const BlockRV& block_rv,
const String& storage_scope, int cse_thresh) {
Array<StmtSRef> result;
TVM_TIR_SCHEDULE_BEGIN();
result = tir::CacheIndex(state_, this->GetSRef(block_rv), buffer_index);
result = tir::CacheIndex(state_, this->GetSRef(block_rv), storage_scope, cse_thresh);
TVM_TIR_SCHEDULE_END("cache-index", this->error_render_level_);
this->state_->DebugVerify();
Array<BlockRV> return_blocks;
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ class ConcreteScheduleNode : public ScheduleNode {
const Array<BlockRV> consumer_blocks = {}) override;
Array<BlockRV> CacheInplace(const BlockRV& block_rv, int read_buffer_index,
const String& storage_scope) override;
Array<BlockRV> CacheIndex(const BlockRV& block_rv, int write_buffer_index) override;
Array<BlockRV> CacheIndex(const BlockRV& block_rv, const String& storage_scope,
int cse_thresh) override;
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
Expand Down
5 changes: 3 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,12 @@ TVM_DLL Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_s
* \brief Create a block to cache precomputed index for later use.
* if there is no index computation, keep unchanged.
* \param block_sref The target block
* \param buffer_index The index of the target buffer in block's read region,
* \param storage_scope The storage scope of cached block
* \param cse_thresh The repeat threshold that determines a common sub expr
* \return The cache stage block.
*/
TVM_DLL Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
int buffer_index);
const String& storage_scope, int cse_thresh);
/*!
*!
* \brief Create a block that read/write a buffer region into a read/write cache with reindexing.
Expand Down
147 changes: 98 additions & 49 deletions src/tir/schedule/primitive/cache_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
#include <tvm/arith/int_set.h>

#include "../../transforms/common_subexpr_elim_tools.h"
#include "../../transforms/replace_selected_expr.h"
#include "../utils.h"

namespace tvm {
Expand All @@ -27,8 +29,10 @@ namespace tir {

/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */
struct IndexInfo {
/*! \brief The target buffer to cache the index. */
Buffer target_buffer;
/*! \brief The target block to perform cache_index */
StmtSRef target_block;
/*! \brief Record the common subexpr extract threshold */
size_t cse_thresh;
/*! \brief The cache buffer to store the precomputed index */
std::vector<Buffer> cache_buffer;
/*! \brief The expr to be precomputed */
Expand Down Expand Up @@ -74,9 +78,8 @@ class IndexInfoCollector : public StmtExprVisitor {
*/
static void Collect(const ScheduleState& self, const StmtSRef& block_sref,
const StmtSRef& scope_sref, IndexInfo* info) {
IndexInfoCollector collector(self, block_sref, scope_sref, info->target_buffer);
IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh);
collector(GetRef<Stmt>(scope_sref->stmt));
// info->loc_sref = collector.loc_sref_;
info->loc_pos = collector.loc_pos_;
info->index_exprs = collector.exprs_;
info->range_map = collector.range_map_;
Expand All @@ -88,11 +91,11 @@ class IndexInfoCollector : public StmtExprVisitor {
* \param self The state of the schedule
* \param block_sref The sref of the target block of the buffer being applied cache_index
* \param scope_sref The sref of the scope block of the target block
* \param buffer The target buffer to cache the indexs
* \param cse_thresh The repeat threshold that determines a common subexpr
*/
IndexInfoCollector(const ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& scope_sref, const Buffer& buffer)
: self_(self), block_sref_(block_sref), scope_sref_(scope_sref), buffer_(buffer) {}
const StmtSRef& scope_sref, int cse_thresh)
: self_(self), block_sref_(block_sref), scope_sref_(scope_sref), cse_thresh_(cse_thresh) {}

void VisitStmt_(const SeqStmtNode* seq_stmt) final {
for (size_t i = 0; i < seq_stmt->size(); ++i) {
Expand All @@ -110,8 +113,9 @@ class IndexInfoCollector : public StmtExprVisitor {
}

void VisitStmt_(const BlockNode* block) final {
// Only visit the target's parent block
visiting_target_block = static_cast<bool>(block_sref_->stmt == block);
StmtVisitor::VisitStmt_(block);
visiting_target_block = false;
if (block == scope_sref_->stmt) {
// The block vistied is the current parent scope
// Handling cases when no SeqStmt in the scope
Expand All @@ -136,15 +140,56 @@ class IndexInfoCollector : public StmtExprVisitor {
}
}

void VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.same_as(buffer_)) {
for (const PrimExpr& it : load->indices) {
if (!it->IsInstance<VarNode>()) {
exprs_.push_back(it);
void VisitStmt_(const BufferStoreNode* store) final {
// Only analyze the cache candidate for stores in target block
if (visiting_target_block) {
auto IsEligibleComputation = [](const PrimExpr& expr) {
return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
(expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
};

// Analyze sub expr candidates
ComputationTable table_syntactic_comp_done_by_stmt =
ComputationsDoneBy::GetComputationsDoneBy(GetRef<Stmt>(store), IsEligibleComputation,
[](const PrimExpr& expr) { return true; });
std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true);

// Analyze the sub expr of a candidate whose repeat time is under cse_thresh_
for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];
if (computation_and_nb.second < cse_thresh_) {
std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
computation_and_nb.first, IsEligibleComputation,
[](const PrimExpr& expr) { return true; });
InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
true, computation_and_nb.second);
}
}

// Record the final sub expr with repeat time greater than cse_thresh_
// In order to make the result stable, sort it by post order and then by complexity
PostOrderVisit(store->value, [&semantic_comp_done_by_stmt, this](const ObjectRef& node) {
if (node->IsInstance<PrimExprNode>()) {
PrimExpr this_expr = Downcast<PrimExpr>(node);
for (auto& it : semantic_comp_done_by_stmt) {
if (it.second >= this->cse_thresh_ && EquivalentTerms(this_expr, it.first, true)) {
auto find_result =
std::find_if(this->exprs_.begin(), this->exprs_.end(),
[&](PrimExpr expr) { return expr.get() == it.first.get(); });
if (find_result == this->exprs_.end()) {
this->exprs_.push_back(it.first);
}
}
}
}
});
auto cmp = [&](const PrimExpr& lhs, const PrimExpr& rhs) -> bool {
return CalculateExprComplexity(lhs) > CalculateExprComplexity(rhs);
};
std::stable_sort(exprs_.begin(), exprs_.end(), cmp);
}
ExprVisitor::VisitExpr_(load);
StmtVisitor::VisitStmt_(store);
}

/*! \brief The schedule class */
Expand All @@ -153,12 +198,14 @@ class IndexInfoCollector : public StmtExprVisitor {
const StmtSRef& block_sref_;
/*! \brief The parent scope of the target block */
const StmtSRef& scope_sref_;
/*! \brief The target buffer to cache the index */
const Buffer& buffer_;
/*! \brief Record the common subexpr extract threshold */
size_t cse_thresh_;
/*! \brief The calculation expr to be precomputed */
std::vector<PrimExpr> exprs_;
/*! \brief The flag whether we have visited the target block */
bool visited_block_{false};
/*! \brief The flag indicating currently visiting target block */
bool visiting_target_block{false};
/*! \brief The index to insert the cache_index stage */
int loc_pos_{-1};
/*! \brief The flag indicating the right scope to update seq pos */
Expand All @@ -169,17 +216,15 @@ class IndexInfoCollector : public StmtExprVisitor {

/*!
* \brief Create a loop nest that writes precomputed index into index buffer.
* \param cache_region The cached copy region.
* \param info The cache stage information, which will be updated in the function.
* \param storage_scope The storage scope of the cached buffer (only used in naming here)
* \returns A block indicating the body of the loop nesting.
*/
Array<Block> MakeIndexCacheStage(IndexInfo* info) {
Array<Block> MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) {
Array<Block> blocks;
Array<Stmt> bodies;
bodies.reserve(info->index_exprs.size());
info->cache_buffer.reserve(info->index_exprs.size());
const String& storage_scope = info->target_buffer.scope();

// For each index calculation, create a block to pre-compute.
for (size_t expr_index = 0; expr_index < info->index_exprs.size(); expr_index++) {
Expand Down Expand Up @@ -214,10 +259,7 @@ Array<Block> MakeIndexCacheStage(IndexInfo* info) {
});
}

// Inference the shape and create cache buffer
arith::IntSet val_range =
arith::EvalSet(Substitute(index_expr, info->var_binding), arith::AsIntSet(info->range_map));
DataType data_type = DetermineDatatype(val_range);
DataType data_type = index_expr.dtype();
Var index_buffer_var("index_var_" + std::to_string(expr_index),
PointerType(PrimType(data_type), storage_scope));
Array<PrimExpr> buffer_shape;
Expand Down Expand Up @@ -346,7 +388,9 @@ class CacheIndexRewriter : public StmtExprMutator {
Stmt VisitStmt_(const BlockNode* block) final {
Block old_stmt = GetRef<Block>(block);
// Mutate the body
visiting_target_block = static_cast<bool>(block == info_->target_block->stmt);
Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
visiting_target_block = false;

// Check if it is the block corresponding to the parent scope
if (block == scope_sref_->stmt) {
Expand All @@ -362,24 +406,23 @@ class CacheIndexRewriter : public StmtExprMutator {
return std::move(stmt);
}

PrimExpr VisitExpr_(const BufferLoadNode* load) final {
if (load->buffer.same_as(info_->target_buffer)) {
// Rewrite the target buffer load
Array<PrimExpr> new_indices;
for (const PrimExpr& index : load->indices) {
auto it = std::find_if(info_->index_exprs.begin(), info_->index_exprs.end(),
[&](PrimExpr& e) { return e.get() == index.get(); });
if (it == info_->index_exprs.end()) {
new_indices.push_back(index);
} else {
// Replace load index with cached index
auto offset = std::distance(info_->index_exprs.begin(), it);
new_indices.push_back(BufferLoad(info_->cache_buffer[offset], cache_indices_[offset]));
}
Stmt VisitStmt_(const BufferStoreNode* store) final {
Stmt ret_stmt = StmtMutator::VisitStmt_(store);
// Replace common sub expr for target block, with cached buffer load
if (visiting_target_block) {
for (size_t i = 0; i < info_->index_exprs.size(); i++) {
PrimExpr& computation = info_->index_exprs[i];
std::function<bool(const PrimExpr&)> predicate_selector =
[computation](const PrimExpr& current_expr) {
return (EquivalentTerms(current_expr, computation, true));
};
BufferLoad load = BufferLoad(info_->cache_buffer[i], cache_indices_[i]);
ret_stmt = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(
ret_stmt, predicate_selector, std::move(load),
[](const PrimExpr& expr) { return true; });
}
return BufferLoad(load->buffer, new_indices);
}
return ExprMutator::VisitExpr_(load);
return ret_stmt;
}

PrimExpr VisitExpr_(const LoadNode* op) final {
Expand All @@ -393,9 +436,12 @@ class CacheIndexRewriter : public StmtExprMutator {
IndexInfo* info_;
/*! \brief The indices for the cache buffer */
std::vector<Array<PrimExpr>> cache_indices_;
/*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/
bool visiting_target_block{false};
};

Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_index) {
Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
const String& storage_scope, int cse_thresh) {
/*!
* Check:
* - The index is in the array of block reading region
Expand All @@ -407,9 +453,9 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref, int b

// Step 0. Checking index, getting the target buffer and the parent scope
IndexInfo info;
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
info.target_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kRead);
info.target_block = block_sref;
CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number";
info.cse_thresh = cse_thresh;
StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);

// Step 1. Collect the indexing info of target buffer.
Expand All @@ -418,7 +464,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref, int b
// Step 2. Create cache stages and rewrite the stmt.
BlockRealize realize = GetBlockRealize(self, block_sref);
info.var_binding = GetBindings(realize);
Array<Block> cache_stages = MakeIndexCacheStage(&info);
Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope);
Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);

bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline;
Expand Down Expand Up @@ -458,17 +504,20 @@ struct CacheIndexTraits : public UnpackedInstTraits<CacheIndexTraits> {

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index) {
return sch->CacheIndex(block, buffer_index->value);
static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block, String storage_scope,
Integer cse_thresh) {
return sch->CacheIndex(block, storage_scope, cse_thresh->value);
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer buffer_index) {
static String UnpackedAsPython(Array<String> outputs, String block, String storage_scope,
Integer cse_thresh) {
PythonAPICall py("cache_index");
py.Input("block", block);
py.Input("buffer_index", buffer_index->value);
py.Input("storage_scope", storage_scope);
py.Input("cse_thresh", cse_thresh->value);
py.OutputList(outputs);
return py.Str();
}
Expand Down
Loading