Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -349,12 +349,13 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param check_level The iter mapping checking level.
* \param analyzer Analyzer used to get context information.
* \param simplify_trivial_iterators If true, iterators with unit extents are simplified
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, IterMapLevel check_level,
bool simplify_trivial_iterators = true);
arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import (
detect_iter_map,
iter_map_simplify,
normalize_iter_map_to_expr,
subspace_divide,
inverse_affine_iter_map,
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,49 @@ def detect_iter_map(
)


def iter_map_simplify(
indices,
input_iters,
predicate=True,
check_level=IterMapLevel.Surjective,
simplify_trivial_iterators=True,
):
"""Simplify the indices using iter map detection.

Parameters
----------
indices : List[PrimExpr]
The input indices

input_iters : Map[Var, Range]
The domain of each input iterators.

predicate : PrimExpr
The predicate constraints on the input iterators

check_level : Union[str, IterMapLevel]
Checking level of iteration mapping

simplify_trivial_iterators: bool
If true, iterators with extent of 1 will be replaced with a
constant value.

Returns
-------
results : IterMapResult
The iter map matching result.
The result's .indices is empty array if no match can be found.

"""
if isinstance(check_level, str):
check_level = IterMapLevel.from_str(check_level)
elif check_level is None:
check_level = IterMapLevel.NoCheck
return _ffi_api.IterMapSimplify(
indices, input_iters, predicate, check_level, simplify_trivial_iterators
)


def normalize_iter_map_to_expr(expr):
"""Given an IterMapExpr, transform it to normal PrimExpr

Expand Down
10 changes: 8 additions & 2 deletions src/arith/canonical_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include "const_fold.h"
#include "pattern_match.h"
#include "product_normal_form.h"
#include "rewrite_simplify.h"

namespace tvm {
Expand Down Expand Up @@ -808,12 +809,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const MulNode* op) {
}

// normal path.
// this only happens when b is symbolic
a = Normalize(a);
b = Normalize(b);
if (op->a.same_as(a) && op->b.same_as(b)) {

PrimExpr ret = MulAndNormalize(a, b);
const MulNode* mul = ret.as<MulNode>();

if (mul && mul->a.same_as(op->a) && mul->b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return Mul(a, b);
return ret;
}
}

Expand Down
14 changes: 10 additions & 4 deletions src/arith/ir_mutator_with_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ namespace arith {
using namespace tir;

Stmt IRMutatorWithAnalyzer::VisitStmt_(const ForNode* op) {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
// record the loop variable as iterators
Range dom = Range::FromMinExtent(op->min, op->extent);
analyzer_->Bind(op->loop_var, dom);
iter_vars_.Set(op->loop_var, dom);
return StmtExprMutator::VisitStmt_(op);
}

Stmt IRMutatorWithAnalyzer::VisitStmt_(const BlockNode* op) {
for (const auto& iter_var : op->iter_vars) {
analyzer_->Bind(iter_var->var, iter_var->dom);
iter_vars_.Set(iter_var->var, iter_var->dom);
}
return StmtExprMutator::VisitStmt_(op);
}
Expand Down Expand Up @@ -75,7 +79,7 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) {
Optional<Stmt> else_case;
{
With<ConstraintContext> ctx(analyzer_, real_condition);
then_case = this->VisitStmt(op->then_case);
WithRecordIterPredicate(real_condition, [&] { then_case = this->VisitStmt(op->then_case); });
}
if (op->else_case) {
With<ConstraintContext> ctx(analyzer_, analyzer_->rewrite_simplify(Not(real_condition)));
Expand All @@ -102,7 +106,9 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) {
IterVar iv = Downcast<IterVar>(op->node);
ICHECK_NE(iv->thread_tag.length(), 0U);
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
Range dom = Range::FromMinExtent(make_zero(op->value.dtype()), op->value);
analyzer_->Bind(iv->var, dom);
iter_vars_.Set(iv->var, dom);
Stmt stmt = StmtExprMutator::VisitStmt_(op);
return stmt;
} else {
Expand Down Expand Up @@ -135,7 +141,7 @@ PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) {
PrimExpr true_value, false_value;
{
With<ConstraintContext> constraint(analyzer_, cond);
true_value = this->VisitExpr(op->args[1]);
WithRecordIterPredicate(cond, [&] { true_value = this->VisitExpr(op->args[1]); });
}
{
With<ConstraintContext> constraint(analyzer_, analyzer_->rewrite_simplify(Not(cond)));
Expand Down
29 changes: 28 additions & 1 deletion src/arith/ir_mutator_with_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>

#include <utility>
Expand Down Expand Up @@ -63,8 +64,34 @@ class IRMutatorWithAnalyzer : public tir::StmtExprMutator {
protected:
/*! \brief internal analyzer field. */
Analyzer* analyzer_;
// the following two fields are useful in case we want
// note however that iter map analysis are usually more
// expensive and we only encourage doing them during
// necessary cases like layout remapping
/*! \brief Recorded loop iterators */
Map<Var, Range> iter_vars_;
/*! \brief iterator predicates */
Array<PrimExpr> iter_predicates_;
/*!
* \brief Run callback while trying to record iter predicate
* \param conditon Condition to be checked.
* \param callback The callback to be called.
*/
template <typename FLambda>
void WithRecordIterPredicate(PrimExpr condition, FLambda callback) {
auto f_use_itervar = [this](const tir::VarNode* v) {
return iter_vars_.count(GetRef<tir::Var>(v));
};
// simple heuristics for detecting predicate
if (tir::UsesVar(condition, f_use_itervar)) {
iter_predicates_.push_back(condition);
callback();
iter_predicates_.pop_back();
} else {
callback();
}
}
};

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_IR_MUTATOR_WITH_ANALYZER_H_
Loading