Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/HalideIR
4 changes: 2 additions & 2 deletions include/tvm/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ class ModularSetAnalyzer {
*/
ModularSet operator()(const Expr& expr);
/*!
* \brief Update constant int bound information of var.
* \brief Update modular set information of var.
*
* \param var The variable of interest.
* \param info The bound information.
* \param info The modular set information.
* \param override Whether do we allow override of existing information.
*/
void Update(const Var& var,
Expand Down
161 changes: 92 additions & 69 deletions src/arithmetic/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)


// internal entry for const int bound
// This condition holds for all instances: coeff >= 0, base in [0, coeff]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

base in [0, coeff) for coeff != 0, for coeff=0 base can be any number that indicates constant

struct ModularSetAnalyzer::Entry {
int64_t coeff{1};
int64_t base{0};

Entry() = default;

Entry(int64_t coeff, int64_t base) {
this->coeff = coeff;

if (coeff < 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we want to support such smart conversion here, coeff<0 could be more likely an error, and we can do CHECK_GE(coeff, 0)

coeff = -coeff;
}

if (coeff != 0) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because this constructor is "too smart", I would recommend us to create a static function instead. C++ constructor also have a restriction of not throw exception, having a factory function and put checks there avoids the problem

Entry::make(coeff, base);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ constructor also have a restriction of not throw exception

As far as I know, throwing exceptions from constructors is absolutely OK. (Unless you want to avoid exceptions completely and return an error code or an optional, in which case you have to use a factory function.)

base = base % coeff;
if (base < 0) base += coeff;
}
this->base = base;
}

bool is_const() const {
return coeff == 0;
}
Expand All @@ -53,10 +70,7 @@ class ModularSetAnalyzer::Impl :
if (!override) {
CHECK(!var_map_.count(var));
}
Entry e;
e.coeff = info->coeff;
e.base = info->base;
var_map_[var] = e;
var_map_[var] = Entry(info->coeff, info->base);
}

// Detect useful constraints and use them in the analysis scope.
Expand All @@ -65,10 +79,7 @@ class ModularSetAnalyzer::Impl :
PVar<Integer> coeff, base;
// pattern match interesting constraints
if (((var % coeff) == base).Match(constraint)) {
Entry entry;
entry.coeff = coeff.Eval()->value;
entry.base = base.Eval()->value;
return UpdateByIntersect(var.Eval(), entry);
return UpdateByIntersect(var.Eval(), Entry(coeff.Eval()->value, base.Eval()->value));
}
return nullptr;
}
Expand All @@ -83,18 +94,12 @@ class ModularSetAnalyzer::Impl :
}

Entry VisitExpr_(const IntImm* op) final {
Entry ret;
ret.base = op->value;
ret.coeff = 0;
return ret;
return Entry(0, op->value);
}

Entry VisitExpr_(const UIntImm* op) final {
if (op->value < std::numeric_limits<int64_t>::max()) {
Entry ret;
ret.base = static_cast<int>(op->value);
ret.coeff = 0;
return ret;
return Entry(0, static_cast<int>(op->value));
} else {
return Everything();
}
Expand All @@ -103,19 +108,15 @@ class ModularSetAnalyzer::Impl :
Entry VisitExpr_(const Add* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base + b.base, ret.coeff);
return ret;
int64_t coeff = GCD(a.coeff, b.coeff);
return Entry(coeff, a.base + b.base);
}

Entry VisitExpr_(const Sub* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Entry ret;
ret.coeff = ZeroAwareGCD(a.coeff, b.coeff);
ret.base = BaseSimplify(a.base - b.base, ret.coeff);
return ret;
int64_t coeff = GCD(a.coeff, b.coeff);
return Entry(coeff, a.base - b.base);
}

Entry VisitExpr_(const Mul* op) final {
Expand All @@ -128,10 +129,9 @@ class ModularSetAnalyzer::Impl :
int64_t pq = a.coeff * b.coeff;
int64_t pm = a.coeff * b.base;
int64_t qn = a.base * b.coeff;
Entry ret;
ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn));
ret.base = BaseSimplify(a.base * b.base, ret.coeff);
return ret;

int64_t coeff = GCD(pq, GCD(pm, qn));
return Entry(coeff, a.base * b.base);
}

Entry DivByConst(const Expr& lhs,
Expand All @@ -140,20 +140,15 @@ class ModularSetAnalyzer::Impl :
Entry a = VisitExpr(lhs);
CHECK_NE(val, 0);
if (a.coeff % val == 0) {
Entry ret;
if (a.base == 0) {
// a c x / c -> a x
ret.coeff = std::abs(a.coeff / val);
ret.base = 0;
return ret;
return Entry(std::abs(a.coeff / val), 0);
}
// positive division have a clear rounding mode.
// Only handle case where we clearly know we need to round down.
if (a.base > 0 && val > 0 &&
(round_down || parent_->CanProveGreaterEqual(lhs, 0))) {
ret.coeff = a.coeff / val;
ret.base = a.base / val;
return ret;
return Entry(a.coeff / val, a.base / val);
}
}
return Everything();
Expand Down Expand Up @@ -244,58 +239,47 @@ class ModularSetAnalyzer::Impl :
*/
static Entry Union(Entry a, Entry b) {
// {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}}
int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff);
int64_t coeff = GCD(a.coeff, b.coeff);
if (coeff == 0) {
if (a.base == b.base) return a;
return Everything();
}
int64_t base0 = a.base % coeff;
int64_t base1 = b.base % coeff;
Entry ret;
if (base0 == base1) {
ret.coeff = coeff;
ret.base = base0;
return ret;
return Entry(coeff, base0);
} else {
ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff);
ret.base = 0;
return ret;
return Entry(GCD(GCD(base0, base1), coeff), 0);
}
}
/*!
* \brief Create interect of two sets.
* \brief Create intersection of two sets.
* \param a The left operand.
* \param b the right operand.
*/
static Entry Intersect(Entry a, Entry b) {
// simple rule for now: pick higher constraints.
// TODO(team-team): Use extended euclidean algorithm.
if (a.coeff == 0) return a;
if (b.coeff == 0) return b;
if (a.coeff >= b.coeff) return a;
return b;
}
/*!
* \brief Simplify base so that it is in [0, coeff) when coeff != 0.
* \param base The base value.
* \param coeff The coeff value.
* \return The simplified base.
*/
static int64_t BaseSimplify(int64_t base, int64_t coeff) {
if (coeff == 0) return base;
base = base % coeff;
if (base < 0) base += coeff;
return base;
static Entry Intersect(Entry x, Entry y) {
int64_t n, m;
int64_t a = x.coeff, b = x.base, c = y.coeff, d = y.base;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit more comment on the derivation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let z be the integer, we know z satisfies

z = x * c1 + b1
z = y * c2 + b2

... derivation
z's general pattern

int64_t gcd = ExtendedEuclidean(a, c, &n, &m);
int64_t v = d - b;
if (v % gcd == 0) {
n = v / gcd * n;
m = v / gcd * (-m);

int64_t coeff = a / gcd * c;
return Entry(coeff, n*a + b);
} else {
return Nothing();
}
}

/*!
* \brief Take GCD of a and b.
* \param a The first operand.
* \param b The second operand.
* \return The result.
*/
static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
if (a < 0) a = -a;
if (b < 0) b = -b;
static int64_t GCD(int64_t a, int64_t b) {
if (a < b) std::swap(a, b);
if (b == 0) return a;
// perform GCD (greatest common divisor)
Expand All @@ -306,14 +290,53 @@ class ModularSetAnalyzer::Impl :
}
return b;
}

/*!
* \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
* \param a The first coefficient. (a >= 0)
* \param b The second coefficient. (b >= 0)
* \param x The solution of x.
* \param y The solution of y.
* \return The GCD of a and b.
*/
static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t *x, int64_t *y) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style int64_t* x

int64_t s = 0, old_s = 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add invariance check CHECK_GE(a, 0);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making the algorithm work for negative arguments would be a better way in my opinion. GCD and Bezout coefficients are defined even for negative numbers, and requiring the arguments to be non-negative will make this function harder to use for some applications.

int64_t r = b, old_r = a;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think a bit more about how to make the code more readable: comment a bit on what is r and what is s, so it is easier for others to understand

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some useful comment that contains the derivation

// Extended Euclidean algorithm
// initial condition:
// a * 0 + b * 1 = a
// a * 1 + b * 0 = b
//
// Iteration (r2 < r1):
// a * x1 + b * y1 = r1
// a * x2 + b * y2 = r2
// The above two eqs can derive the following eq (q = r2 / r1)
// a * (x1 - x2 * q) + b * (y1 - y2 * q) = r2 - r1 * q = r3
// Because r3 < r2, the iteration can eventually terminate


while (r != 0) {
int64_t q = old_r / r;
int64_t tmp = old_r;
old_r = r;
r = tmp - q * r;
tmp = old_s;
old_s = s;
s = tmp - q * s;
}

*x = old_s;
if (b != 0) {
*y = (old_r - old_s * a) / b;
} else {
*y = 1;
}

return old_r;
}

/*!
* \brief return everything dtype can represent.
* \return Bound that represent everything dtype can represent.
*/
static Entry Everything() {
Entry ret;
ret.coeff = 1; ret.base = 0;
return ret;
return Entry(1, 0);
}

/*!
* \brief return an empty set
* \return An empty modular set.
*/
static Entry Nothing() {
return Entry(0, 1);
}
};

Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_arith_modular_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,21 @@ def test_constraint_scope():
assert m.coeff == 1
assert m.base == 0

def test_intersect():
a = tvm.var("a")
analyzer = tvm.arith.Analyzer()
with analyzer.constraint_scope(a % 4 == 1):
with analyzer.constraint_scope(a % 3 == 1):
m = analyzer.modular_set(a)
assert m.coeff == 12
assert m.base == 1

with analyzer.constraint_scope(a % 3 == 2):
with analyzer.constraint_scope(a % 5 == 3):
with analyzer.constraint_scope(a % 7 == 2):
m = analyzer.modular_set(a)
assert m.coeff == 105
assert m.base == 23

if __name__ == "__main__":
test_cast()
Expand All @@ -126,3 +141,4 @@ def test_constraint_scope():
test_min_max_select()
test_mix_index()
test_constraint_scope()
test_intersect()