Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,7 @@ BufferRegion BufferRegion::FromPoint(Buffer buffer, Array<PrimExpr> indices) {
region.push_back(
Range::FromMinExtent(ramp_index->base, ramp_index->stride * ramp_index->lanes));
} else {
region.push_back(Range::FromMinExtent(index, 1));
region.push_back(Range::FromMinExtent(index, make_const(index.dtype(), 1)));
}
}
return BufferRegion(buffer, region);
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,18 @@ Optional<Array<Var>> CheckTrivialBufferIndices(const T& buffer_access) {
return indices;
}

/*!
* \brief Simplify non-trivial expressions
* \param expr The expression to be simplified
* \param analyzer The analyzer
* \return The simplified expression
*
* During scheduling, we often need preserve block iters in trivial expressions that can be
* simplified to constant values for further scheduling and analysis because simplifing away the
* block iters may result in loss of information for further analysis.
*/
PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer);

/*! \brief Necessary information used for tensorization */
class TensorizeInfoNode : public Object {
public:
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,15 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, //
}
}

PrimExpr SimplifyNonTrivialExpr(const PrimExpr& expr, arith::Analyzer* analyzer) {
auto simplified = analyzer->Simplify(expr);
if (simplified->IsInstance<IntImmNode>()) {
return expr;
} else {
return simplified;
}
}

TVM_REGISTER_NODE_TYPE(TensorizeInfoNode);

/*! \brief Auxiliary data structure of information extracted from tensor intrin description */
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ bool AutoTensorizeComparator::CompareBufferAccess(const T* lhs, const T* rhs) {
return false;
}
std::vector<PrimExpr> lhs_indices;
for (const auto& index : lhs->indices) {
lhs_indices.push_back(analyzer_.Simplify(index));
for (const PrimExpr& index : lhs->indices) {
lhs_indices.push_back(SimplifyNonTrivialExpr(index, &analyzer_));
}

auto is_scalar_access = [](const Array<PrimExpr>& indices, PrimExpr index) {
Expand Down
23 changes: 5 additions & 18 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,6 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
Array<IterVar> new_block_iters;
// the substition map from the original block iter to the iters of the reindex block
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectEqual> block_var_replace_map;
// block access region of reindexed buffer and target buffer
Region reindex_region, target_region;
// indices to access the reindex buffer and the target buffer
Array<PrimExpr> reindex_indices, target_indices;

Expand All @@ -201,24 +199,19 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
Var var("v" + std::to_string(new_block_iters.size()), iter->var->dtype);
bool used = covered.count(iter->var);
if (used) {
new_block_iters.push_back(IterVar(/*dom=*/used ? iter->dom : Range::FromMinExtent(0, 1),
new_block_iters.push_back(IterVar(/*dom=*/iter->dom,
/*var=*/var,
/*IterVarType=*/kDataPar));
} else {
skipped_block_iters.insert(i);
}
if (used) {
reindex_indices.push_back(var);
reindex_region.push_back(Range::FromMinExtent(var, IntImm(var->dtype, 1)));
}
block_var_replace_map[iter->var] = var;
}

// Step 2: Replace the original block iters with the new block iters
BufferRegion buffer_region = buffer_index_type == BufferIndexType::kWrite
? block->writes[buffer_index]
: block->reads[buffer_index];
target_region = Substitute(buffer_region->region, block_var_replace_map);
for (const PrimExpr& index : original_indices) {
target_indices.push_back(Substitute(index, block_var_replace_map));
}
Expand All @@ -232,25 +225,19 @@ Block MakeReIndexStage(const Block& block, CacheStageInfo* info,
Array<PrimExpr> dst_indices{nullptr};

if (buffer_index_type == BufferIndexType::kWrite) {
src_region = reindex_region;
dst_region = target_region;
src_indices = reindex_indices;
dst_indices = target_indices;
} else {
src_region = target_region;
dst_region = reindex_region;
src_indices = target_indices;
dst_indices = reindex_indices;
}

// Create the body block
Block new_block(
/*iter_vars=*/new_block_iters,
/*reads=*/
{BufferRegion(info->read_buffer, src_region)},
/*writes=*/
{BufferRegion(info->write_buffer, dst_region)},
/*name_hint=*/buffer_region->buffer->name + "_reindex",
/*reads=*/{BufferRegion::FromPoint(info->read_buffer, src_indices)},
/*writes=*/{BufferRegion::FromPoint(info->write_buffer, dst_indices)},
/*name_hint=*/info->write_buffer->name + "_reindex",
/*body=*/
BufferStore(info->write_buffer, BufferLoad(info->read_buffer, src_indices), dst_indices));

Expand Down Expand Up @@ -1169,7 +1156,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde
analyzer.Bind(iter->var, iter->dom);
}
original_indices.MutateByApply(
[&analyzer](const PrimExpr& expr) { return analyzer.Simplify(expr); });
[&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); });

// Collect block iters appearing in the original_indices
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> covered;
Expand Down
48 changes: 26 additions & 22 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,9 @@ class TransformLayoutRewriter : private arith::IRMutatorWithAnalyzer {

void RewriteBufferAccess(Buffer* buffer, Array<PrimExpr>* indices) {
*buffer = new_buffer_;
*indices = index_map_->MapIndices(*indices, analyzer_);
*indices = index_map_->MapIndices(*indices);
(*indices).MutateByApply(
[&](const PrimExpr& e) { return SimplifyNonTrivialExpr(e, analyzer_); });
}

using Parent = arith::IRMutatorWithAnalyzer;
Expand Down Expand Up @@ -1113,7 +1115,7 @@ class IndexMapNotApplicableToBlockIterError : public ScheduleError {

IRModule mod() const final { return mod_; }

Array<ObjectRef> LocationsOfInterest() const final { return {}; }
Array<ObjectRef> LocationsOfInterest() const final { return {block_}; }

private:
IRModule mod_;
Expand Down Expand Up @@ -1194,22 +1196,14 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
Array<PrimExpr> transformed_block_iters = index_map->MapIndices(block_vars);
Array<PrimExpr> new_block_iter_range = index_map->MapShape(block_iter_range_array);

auto iter_map = arith::DetectIterMap(
/*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true),
/*check_level=*/arith::IterMapLevel::Bijective, &analyzer,
/*simplify_trivial_iterators=*/true);
if (iter_map->indices.empty()) {
throw NotBijectiveAffineIndexMapError(self->mod, index_map);
}

// Step 5: Create the new block after transformation.

// Step 5.1: Create new block iters. After applying the IndexMap f to block iters ax_0, ..., ax_n,
// create block iter each expression in f(ax_0, ..., ax_n).
Array<IterVar> new_block_iters; // new block iters
Array<PrimExpr> new_block_vars; // iter_var->var of new block iters
for (size_t i = 0; i < index_map->final_indices.size(); ++i) {
Var new_block_var{"v" + std::to_string(i), DataType::Int(32)};
for (size_t i = 0; i < transformed_block_iters.size(); ++i) {
Var new_block_var{"v" + std::to_string(i), transformed_block_iters[i]->dtype};
new_block_vars.push_back(new_block_var);
IterVarType iter_type = DetectNewBlockIterType(transformed_block_iters[i], block_iter_type);
if (iter_type == kOpaque) {
Expand All @@ -1221,18 +1215,28 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,

// Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters
// in the body.
Map<Var, PrimExpr> inverse_subst_map;
// Construct the inverse map
{
Array<Range> initial_ranges;
for (const PrimExpr& extent : block_iter_range_array) {
initial_ranges.push_back(Range::FromMinExtent(make_const(extent.dtype(), 0), extent));
}
IndexMap inverse_index_map{nullptr};
try {
inverse_index_map = index_map.Inverse(initial_ranges);
} catch (...) {
throw NotBijectiveAffineIndexMapError(self->mod, index_map);
}

auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars);
// Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant
// zero.
for (const auto& iter_var : block_ptr->iter_vars) {
if (inverse_map.find(iter_var->var) == inverse_map.end()) {
ICHECK(is_one(iter_var->dom->extent));
inverse_map.Set(iter_var->var, 0);
Array<PrimExpr> inversed_new_block_vars = inverse_index_map->MapIndices(
new_block_vars); // old block vars written in terms of new block vars

for (int i = 0, n = block_vars.size(); i < n; ++i) {
inverse_subst_map.Set(Downcast<Var>(block_vars[i]), inversed_new_block_vars[i]);
}
}

Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr), inverse_map));
Block new_block = Downcast<Block>(Substitute(GetRef<Block>(block_ptr), inverse_subst_map));
new_block.CopyOnWrite()->iter_vars = new_block_iters;
new_block = Downcast<Block>(BlockBufferAccessSimplifier::Simplify(new_block, &analyzer));

Expand All @@ -1241,7 +1245,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
// Make new loop vars
Array<PrimExpr> new_loop_vars;
for (int i = 0; i < static_cast<int>(new_block_iters.size()); ++i) {
new_loop_vars.push_back(Var("ax" + std::to_string(i), DataType::Int(32)));
new_loop_vars.push_back(Var("ax" + std::to_string(i), new_block_iters[i]->var.dtype()));
}

// Make new block realize
Expand Down
25 changes: 19 additions & 6 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,25 @@ void BlockBufferAccessSimplifier::SimplifyAccessRegion(Array<BufferRegion>* old_
auto fmutate = [this](const BufferRegion& buffer_region) {
std::vector<Range> new_buffer_region;
for (const auto& range : buffer_region->region) {
new_buffer_region.push_back(Range::FromMinExtent(analyzer_->Simplify(range->min),
analyzer_->Simplify(range->extent)));
if (is_one(range->extent) && range->min->IsInstance<VarNode>()) {
new_buffer_region.push_back(Range::FromMinExtent(
SimplifyNonTrivialExpr(range->min, analyzer_), make_const(range->min.dtype(), 1)));
} else {
new_buffer_region.push_back(
Range::FromMinExtent(SimplifyNonTrivialExpr(range->min, analyzer_),
SimplifyNonTrivialExpr(range->extent, analyzer_)));
}
}
return BufferRegion(buffer_region->buffer, new_buffer_region);
};
(*old_access_regions).MutateByApply(fmutate);
}

void BlockBufferAccessSimplifier::SimplifyBufferIndices(Array<PrimExpr>* indices) {
(*indices).MutateByApply(
[this](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, analyzer_); });
}

Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
Block block = Downcast<Block>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
auto* n = block.CopyOnWrite();
Expand All @@ -376,13 +387,15 @@ Stmt BlockBufferAccessSimplifier::VisitStmt_(const BlockNode* op) {
}

Stmt BlockBufferAccessSimplifier::VisitStmt_(const BufferStoreNode* op) {
auto node = Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
return VisitBufferAccess(std::move(node));
BufferStore node = Downcast<BufferStore>(arith::IRMutatorWithAnalyzer::VisitStmt_(op));
SimplifyBufferIndices(&node.CopyOnWrite()->indices);
return std::move(node);
}

PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) {
auto node = Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
BufferLoad node = Downcast<BufferLoad>(arith::IRMutatorWithAnalyzer::VisitExpr_(op));
SimplifyBufferIndices(&node.CopyOnWrite()->indices);
return std::move(node);
}

} // namespace tir
Expand Down
9 changes: 2 additions & 7 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,11 @@ class BlockBufferAccessSimplifier : public arith::IRMutatorWithAnalyzer {
using IRMutatorWithAnalyzer::VisitStmt_;

void SimplifyAccessRegion(Array<BufferRegion>* old_access_regions);
void SimplifyBufferIndices(Array<PrimExpr>* indices);

Stmt VisitStmt_(const BlockNode* op) final;
Stmt VisitStmt_(const BufferStoreNode* op) final;
PrimExpr VisitExpr_(const BufferLoadNode* op) final;

template <typename Node>
Node VisitBufferAccess(Node node) {
node.CopyOnWrite()->indices.MutateByApply(
[this](const PrimExpr& expr) { return analyzer_->Simplify(expr); });
return node;
}
};

} // namespace tir
Expand Down
Loading