Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions include/tvm/arith/analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,9 @@ class ConstIntBoundAnalyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's range.
*/
TVM_DLL void Bind(const Var& var, const Range& range);
TVM_DLL void Bind(const Var& var, const Range& range, bool override = false);

private:
friend class Analyzer;
Expand Down Expand Up @@ -411,8 +412,9 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param expr The expression we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const PrimExpr& expr);
void Bind(const Var& var, const PrimExpr& expr, bool override = false);
/*!
* \brief Notify all the sub-analyzers that var
* is created and binded to a range.
Expand All @@ -421,14 +423,16 @@ class TVM_DLL Analyzer {
*
* \param var The variable.
* \param range The range we bind to.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Var& var, const Range& range);
void Bind(const Var& var, const Range& range, bool override = false);
/*!
* \brief Bind all the vars in the Map
*
* \param variables The {variable -> range} map.
* \param override Whether we allow overriding an existing var's expression.
*/
void Bind(const Map<Var, Range>& variables);
void Bind(const Map<Var, Range>& variables, bool override = false);
/*!
* \brief Whether can we prove expr >= val.

Expand All @@ -442,6 +446,19 @@ class TVM_DLL Analyzer {
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound);
/*!
* \brief Whether can we prove expr < val.

* Non-negative proof is very useful in integer analysis
* to lower divisions and mods given difference in trunc and ceil mode.
*
* \param expr The expression.
* \param upper_bound The upper bound.
* \return Whether we can prove it.
*
* \note Analyzer will call into sub-analyzers to get the result.
*/
bool CanProveLess(const PrimExpr& expr, int64_t upper_bound);
/*!
* \brief Whether can we prove condition.
*
Expand Down
12 changes: 8 additions & 4 deletions include/tvm/arith/int_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,13 @@ class IntSet : public ObjectRef {
//-----------------------------------------------
// Integer set legacy API.
//------------------------------------------------
/*!
* \brief Convert std::unordered_map<const VarNode*, IntSet> to Map<Var, IntSet>
*
* \param dom_map The domain map to convert.
* \return The converted map.
*/
Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>& dom_map);
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
Expand All @@ -160,8 +167,7 @@ class IntSet : public ObjectRef {
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet EvalSet(PrimExpr e,
const Map<IterVar, IntSet>& dom_map);
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map);
/*!
* \brief Same as EvalSet, but takes unordered_map
*
Expand All @@ -171,7 +177,6 @@ IntSet EvalSet(PrimExpr e,
*/
IntSet EvalSet(PrimExpr e,
const std::unordered_map<const tir::VarNode*, IntSet>& dom_map);

/*!
* \brief Find an symbolic integer set that contains is union over
* all the possible conditional values in dom_map.
Expand Down Expand Up @@ -202,7 +207,6 @@ IntSet EvalSet(IntSet s,
*/
IntSet EvalSet(Range r,
const std::unordered_map<const VarNode*, IntSet>& dom_map);

/*! \brief Map from Expr to IntSet */
using ExprIntSetMap = std::unordered_map<PrimExpr, IntSet, ObjectHash, ObjectEqual>;
/*!
Expand Down
29 changes: 19 additions & 10 deletions src/arith/analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,31 @@ Analyzer::Analyzer()
int_set(this) {
}

void Analyzer::Bind(const Var& var, const PrimExpr& expr) {
void Analyzer::Bind(const Var& var, const PrimExpr& expr, bool override) {
PrimExpr new_expr = expr;
new_expr = this->canonical_simplify(new_expr);
new_expr = this->rewrite_simplify(new_expr);

this->const_int_bound.Update(var, this->const_int_bound(new_expr));
this->modular_set.Update(var, this->modular_set(new_expr));
this->rewrite_simplify.Update(var, new_expr);
this->canonical_simplify.Update(var, new_expr);
this->const_int_bound.Update(var, this->const_int_bound(new_expr), override);
this->modular_set.Update(var, this->modular_set(new_expr), override);
this->rewrite_simplify.Update(var, new_expr, override);
this->canonical_simplify.Update(var, new_expr, override);
}

void Analyzer::Bind(const Var& var, const Range& range) {
void Analyzer::Bind(const Var& var, const Range& range, bool override) {
CHECK(range.defined());
if (tir::is_one(range->extent)) {
this->Bind(var, range->min);
this->Bind(var, range->min, override);
} else {
this->const_int_bound.Bind(var, range);
this->const_int_bound.Bind(var, range, override);
}
// skip modular_set
// skip rewrite simplify
}

void Analyzer::Bind(const Map<Var, Range>& variables) {
void Analyzer::Bind(const Map<Var, Range>& variables, bool override) {
for (const auto& iter : variables) {
this->Bind(iter.first, iter.second);
this->Bind(iter.first, iter.second, override);
}
}

Expand Down Expand Up @@ -92,6 +92,15 @@ bool Analyzer::CanProveGreaterEqual(const PrimExpr& expr, int64_t lower_bound) {
return false;
}

bool Analyzer::CanProveLess(const PrimExpr& expr, int64_t upper_bound) {
if (const auto* ptr = expr.as<tir::IntImmNode>()) {
return ptr->value < upper_bound;
}
auto bd = this->const_int_bound(this->rewrite_simplify(expr));
if (bd->max_value < upper_bound) return true;
return false;
}

bool Analyzer::CanProve(const PrimExpr& expr) {
if (const auto* ptr = expr.as<IntImmNode>()) {
return ptr->value != 0;
Expand Down
18 changes: 10 additions & 8 deletions src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ class ConstIntBoundAnalyzer::Impl :
}
};

void Bind(const Var& var, const Range& range) {
void Bind(const Var& var, const Range& range, bool override) {
Entry a = VisitExpr(range->min);
Entry b = VisitExpr(range->extent);
Entry ret;
ret.min_value = a.min_value;
ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1));
Update(var, ret, false);
Update(var, ret, override);
}

void Update(const Var& var,
Expand Down Expand Up @@ -150,10 +150,12 @@ class ConstIntBoundAnalyzer::Impl :
const PrimExprNode* op = expr.as<PrimExprNode>();
auto val = bound_->find(op);
if (val != bound_->end()) {
CHECK(val->second->min_value == res.min_value &&
val->second->max_value == res.max_value)
<< "Detected bound for " << expr
<< "conflicts with memorization";
auto everything = Everything(op->dtype);
CHECK(
(val->second->min_value == res.min_value && val->second->max_value == res.max_value) ||
(val->second->min_value == everything.min_value &&
val->second->max_value == everything.max_value))
<< "Detected bound for " << expr << "conflicts with memorization";
}
(*bound_)[op] = ConstIntBound(res.min_value, res.max_value);
}
Expand Down Expand Up @@ -574,8 +576,8 @@ void ConstIntBoundAnalyzer::Update(const Var& var,
impl_->Update(var, info, override);
}

void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) {
impl_->Bind(var, range);
void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range, bool override) {
impl_->Bind(var, range, override);
}

std::function<void()> ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& constraint) {
Expand Down
10 changes: 10 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,16 @@ inline IntervalSet Combine<tir::FloorModNode>(Analyzer* analyzer,
LOG(FATAL) << "Modular by zero in CombineInterval Mod";
}
if (analyzer->CanProveGreaterEqual(divisor, 0)) {
if (divisor.as<tir::IntImmNode>()) {
// a mod b = a - (a / b) * b if a_max / b == a_min / b
auto qmax = floordiv(a->max_value, divisor);
auto qmin = floordiv(a->min_value, divisor);
if (analyzer->CanProve(qmax == qmin)) {
auto tmax = a->max_value - divisor * qmin;
auto tmin = a->min_value - divisor * qmin;
return IntervalSet(tmin, tmax);
}
}
return IntervalSet(make_zero(divisor.dtype()), divisor - 1);
} else {
PrimExpr bound = abs(divisor) - 1;
Expand Down
14 changes: 8 additions & 6 deletions src/te/operation/compute_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,22 @@ void ComputeOpNode::PropBoundToInputs(
// undefined behaviour), so we can intersect the estimated set of the argument with the
// range expected by the tensor. However, intersection may result in overly complex
// expressions, so we perform a more relaxed form of intersection.
IntSet arg_intset = EvalSet(call->args[i], dom_map);
IntSet arg_intset = analyzer->int_set(call->args[i], ConvertDomMap(dom_map));
const arith::IntervalSetNode* arg_interval = arg_intset.as<arith::IntervalSetNode>();
if (arg_interval) {
PrimExpr shape_i_min_value = make_zero(t->shape[i].dtype());
PrimExpr shape_i_max_value = t->shape[i] - 1;
PrimExpr min_value = arg_interval->min_value;
PrimExpr max_value = arg_interval->max_value;
// Prefer the shape bounds only when we can prove they are tighter.
if (arith::is_neg_inf(min_value) ||
analyzer->CanProve(shape_i_min_value >= min_value)) {
// We must update bound's ends in pairs. Here is an counter example: shape_i is
// [0, 0] and arg_interval is [threadIdx.y, threadIdx.y], where threadIdx.y's range is
// [0, 7]. If we allowed updating one end, the bound would become [threadIdx.y, 0],
// awkward for further analysis.
if ((arith::is_pos_inf(max_value) && arith::is_neg_inf(min_value)) ||
(analyzer->CanProve(shape_i_min_value >= min_value) &&
analyzer->CanProve(shape_i_max_value <= max_value))) {
min_value = shape_i_min_value;
}
if (arith::is_pos_inf(max_value) ||
analyzer->CanProve(shape_i_max_value <= max_value)) {
max_value = shape_i_max_value;
}
dom.data[i].push_back(IntSet::interval(min_value, max_value));
Expand Down
15 changes: 10 additions & 5 deletions src/te/schedule/bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ void InferRootBound(const Stage& stage,
Array<IterVar> stage_attach = ctx.attach_path.at(stage->op);
// The parent set.
for (const Operation& op : consumers) {
std::unordered_map<const VarNode*, IntSet> relax_set;
Map<Var, IntSet> relax_set;
std::unordered_map<IterVar, IntSet> up_state;
bool found_attach = false;
CHECK(ctx.op2stage_.count(op.get()));
Expand Down Expand Up @@ -177,9 +177,9 @@ void InferRootBound(const Stage& stage,
<< "InferBound requires every leaf iter var's min equals 0, "
<< "call schedule.normalize to achieve this.";
if (NeedRelax(iv, found_attach, ctx.bind_map, scope)) {
relax_set[iv->var.get()] = IntSet::range(vrange);
relax_set.Set(iv->var, IntSet::range(vrange));
if (ctx.bind_map.count(iv)) {
relax_set[ctx.bind_map.at(iv)->var.get()] = IntSet::range(vrange);
relax_set.Set(ctx.bind_map.at(iv)->var, IntSet::range(vrange));
}
}
}
Expand All @@ -191,6 +191,9 @@ void InferRootBound(const Stage& stage,
// Relax if needed.
std::unordered_map<const VarNode*, IntSet> dom_map;
arith::Analyzer analyzer;
for (auto entry : *rmap) {
analyzer.Bind(entry.first->var, entry.second);
}
for (auto iv : op->root_iter_vars()) {
Range r;
if (up_state.count(iv)) {
Expand All @@ -199,11 +202,13 @@ void InferRootBound(const Stage& stage,
r = iv->dom;
}
if (relax_set.size() != 0) {
dom_map[iv->var.get()] = EvalSet(r, relax_set);
dom_map[iv->var.get()] = IntSet::interval(
analyzer.int_set(r->min, relax_set).min(),
analyzer.int_set(r->min + r->extent - 1, relax_set).max());
} else {
dom_map[iv->var.get()] = IntSet::range(r);
}
analyzer.Bind(iv->var, r);
analyzer.Bind(iv->var, r, true);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous Bind call binds root IterVars. We have to override them here. Is there a easy way to tell an IterVar root IterVar? If so, we can avoid such binding/overriding.

}
op->PropBoundToInputs(op, &analyzer, dom_map, &tmap);
}
Expand Down
12 changes: 8 additions & 4 deletions src/te/schedule/message_passing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -580,19 +580,23 @@ std::vector<PrimExpr> MakeBoundCheck(
PassUpBoundCheck(stage, dom_map, &bound_state, &analyzer);

std::vector<PrimExpr> preds;
std::unordered_map<const VarNode*, IntSet> iset_dmap;
Map<Var, IntSet> iset_dmap;

// setup domain map for set analysis
for (const auto& kv : dom_map) {
iset_dmap[kv.first->var.get()] = IntSet::range(kv.second);
iset_dmap.Set(kv.first->var, IntSet::range(kv.second));
}

for (auto entry : dom_map) {
analyzer.Bind(entry.first->var, entry.second);
}

for (const IterVar& iv : stage->all_iter_vars) {
if (skip_iter.count(iv) || iv->iter_type == kOpaque) continue;
if (bound_state.at(iv)) {
Range dom = dom_map.at(iv);
PrimExpr value = value_map.at(iv) - dom->min;
PrimExpr vmax = EvalSet(value, iset_dmap).max();
PrimExpr vmax = analyzer.int_set(value, iset_dmap).max();
if (vmax.dtype() != value.dtype() || !analyzer.CanProve(vmax < dom->extent)) {
preds.emplace_back(value < dom->extent);
}
Expand All @@ -604,7 +608,7 @@ std::vector<PrimExpr> MakeBoundCheck(
CHECK(iv->dom.defined());
if (!skip_ivar_domain && !IsRangeSame(iv->dom, dom)) {
PrimExpr value = value_map.at(iv) - iv->dom->min;
IntSet s = EvalSet(value, iset_dmap);
IntSet s = analyzer.int_set(value, iset_dmap);
PrimExpr vmin = s.min();
PrimExpr vmax = s.max();
// The range of `value` resides in [vmin, vmax]
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def test_mod():

flm = tvm.te.floormod
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(-10, 10)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 5)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(13, 15)}, (3, 5))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 15)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(3, 11)}, (0, 9))
ck.verify(flm(x, 10), {x : tvm.arith.IntervalSet(1, 21)}, (0, 9))

floordiv = tvm.te.floordiv
z = te.var("z")
ck.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 3))
ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)},
(0, 7))
ck1 = IntSetChecker()
ck1.analyzer.bind(x, tvm.ir.Range.make_by_min_extent(0, 2))
ck1.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (x*4, x*4+3))


def test_max_min():
Expand Down
Loading