-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[ARITH] Fix intersect of modular set #2726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,10 +32,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) | |
|
|
||
|
|
||
| // internal entry for const int bound | ||
| // This condition holds for all instances: coeff >= 0, base in [0, coeff] | ||
| struct ModularSetAnalyzer::Entry { | ||
| int64_t coeff{1}; | ||
| int64_t base{0}; | ||
|
|
||
| Entry() = default; | ||
|
|
||
| Entry(int64_t coeff, int64_t base) { | ||
| this->coeff = coeff; | ||
|
|
||
| if (coeff < 0) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure if we want to support such smart conversion here, coeff<0 could be more likely an error, and we can do CHECK_GE(coeff, 0) |
||
| coeff = -coeff; | ||
| } | ||
|
|
||
| if (coeff != 0) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because this constructor is "too smart", I would recommend us to create a static function instead. C++ constructor also have a restriction of not throw exception, having a factory function and put checks there avoids the problem
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
As far as I know, throwing exceptions from constructors is absolutely OK. (Unless you want to avoid exceptions completely and return an error code or an |
||
| base = base % coeff; | ||
| if (base < 0) base += coeff; | ||
| } | ||
| this->base = base; | ||
| } | ||
|
|
||
| bool is_const() const { | ||
| return coeff == 0; | ||
| } | ||
|
|
@@ -53,10 +70,7 @@ class ModularSetAnalyzer::Impl : | |
| if (!override) { | ||
| CHECK(!var_map_.count(var)); | ||
| } | ||
| Entry e; | ||
| e.coeff = info->coeff; | ||
| e.base = info->base; | ||
| var_map_[var] = e; | ||
| var_map_[var] = Entry(info->coeff, info->base); | ||
| } | ||
|
|
||
| // Detect useful constraints and use them in the analysis scope. | ||
|
|
@@ -65,10 +79,7 @@ class ModularSetAnalyzer::Impl : | |
| PVar<Integer> coeff, base; | ||
| // pattern match interesting constraints | ||
| if (((var % coeff) == base).Match(constraint)) { | ||
| Entry entry; | ||
| entry.coeff = coeff.Eval()->value; | ||
| entry.base = base.Eval()->value; | ||
| return UpdateByIntersect(var.Eval(), entry); | ||
| return UpdateByIntersect(var.Eval(), Entry(coeff.Eval()->value, base.Eval()->value)); | ||
| } | ||
| return nullptr; | ||
| } | ||
|
|
@@ -83,18 +94,12 @@ class ModularSetAnalyzer::Impl : | |
| } | ||
|
|
||
| Entry VisitExpr_(const IntImm* op) final { | ||
| Entry ret; | ||
| ret.base = op->value; | ||
| ret.coeff = 0; | ||
| return ret; | ||
| return Entry(0, op->value); | ||
| } | ||
|
|
||
| Entry VisitExpr_(const UIntImm* op) final { | ||
| if (op->value < std::numeric_limits<int64_t>::max()) { | ||
| Entry ret; | ||
| ret.base = static_cast<int>(op->value); | ||
| ret.coeff = 0; | ||
| return ret; | ||
| return Entry(0, static_cast<int>(op->value)); | ||
| } else { | ||
| return Everything(); | ||
| } | ||
|
|
@@ -103,19 +108,15 @@ class ModularSetAnalyzer::Impl : | |
| Entry VisitExpr_(const Add* op) final { | ||
| Entry a = VisitExpr(op->a); | ||
| Entry b = VisitExpr(op->b); | ||
| Entry ret; | ||
| ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); | ||
| ret.base = BaseSimplify(a.base + b.base, ret.coeff); | ||
| return ret; | ||
| int64_t coeff = GCD(a.coeff, b.coeff); | ||
| return Entry(coeff, a.base + b.base); | ||
| } | ||
|
|
||
| Entry VisitExpr_(const Sub* op) final { | ||
| Entry a = VisitExpr(op->a); | ||
| Entry b = VisitExpr(op->b); | ||
| Entry ret; | ||
| ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); | ||
| ret.base = BaseSimplify(a.base - b.base, ret.coeff); | ||
| return ret; | ||
| int64_t coeff = GCD(a.coeff, b.coeff); | ||
| return Entry(coeff, a.base - b.base); | ||
| } | ||
|
|
||
| Entry VisitExpr_(const Mul* op) final { | ||
|
|
@@ -128,10 +129,9 @@ class ModularSetAnalyzer::Impl : | |
| int64_t pq = a.coeff * b.coeff; | ||
| int64_t pm = a.coeff * b.base; | ||
| int64_t qn = a.base * b.coeff; | ||
| Entry ret; | ||
| ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); | ||
| ret.base = BaseSimplify(a.base * b.base, ret.coeff); | ||
| return ret; | ||
|
|
||
| int64_t coeff = GCD(pq, GCD(pm, qn)); | ||
| return Entry(coeff, a.base * b.base); | ||
| } | ||
|
|
||
| Entry DivByConst(const Expr& lhs, | ||
|
|
@@ -140,20 +140,15 @@ class ModularSetAnalyzer::Impl : | |
| Entry a = VisitExpr(lhs); | ||
| CHECK_NE(val, 0); | ||
| if (a.coeff % val == 0) { | ||
| Entry ret; | ||
| if (a.base == 0) { | ||
| // a c x / c -> a x | ||
| ret.coeff = std::abs(a.coeff / val); | ||
| ret.base = 0; | ||
| return ret; | ||
| return Entry(std::abs(a.coeff / val), 0); | ||
| } | ||
| // positive division have a clear rounding mode. | ||
| // Only handle case where we clearly know we need to round down. | ||
| if (a.base > 0 && val > 0 && | ||
| (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { | ||
| ret.coeff = a.coeff / val; | ||
| ret.base = a.base / val; | ||
| return ret; | ||
| return Entry(a.coeff / val, a.base / val); | ||
| } | ||
| } | ||
| return Everything(); | ||
|
|
@@ -244,58 +239,47 @@ class ModularSetAnalyzer::Impl : | |
| */ | ||
| static Entry Union(Entry a, Entry b) { | ||
| // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} | ||
| int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); | ||
| int64_t coeff = GCD(a.coeff, b.coeff); | ||
| if (coeff == 0) { | ||
| if (a.base == b.base) return a; | ||
| return Everything(); | ||
| } | ||
| int64_t base0 = a.base % coeff; | ||
| int64_t base1 = b.base % coeff; | ||
| Entry ret; | ||
| if (base0 == base1) { | ||
| ret.coeff = coeff; | ||
| ret.base = base0; | ||
| return ret; | ||
| return Entry(coeff, base0); | ||
| } else { | ||
| ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff); | ||
| ret.base = 0; | ||
| return ret; | ||
| return Entry(GCD(GCD(base0, base1), coeff), 0); | ||
| } | ||
| } | ||
| /*! | ||
| * \brief Create interect of two sets. | ||
| * \brief Create intersection of two sets. | ||
| * \param a The left operand. | ||
| * \param b the right operand. | ||
| */ | ||
| static Entry Intersect(Entry a, Entry b) { | ||
| // simple rule for now: pick higher constraints. | ||
| // TODO(team-team): Use extended euclidean algorithm. | ||
| if (a.coeff == 0) return a; | ||
| if (b.coeff == 0) return b; | ||
| if (a.coeff >= b.coeff) return a; | ||
| return b; | ||
| } | ||
| /*! | ||
| * \brief Simplify base so that it is in [0, coeff) when coeff != 0. | ||
| * \param base The base value. | ||
| * \param coeff The coeff value. | ||
| * \return The simplified base. | ||
| */ | ||
| static int64_t BaseSimplify(int64_t base, int64_t coeff) { | ||
| if (coeff == 0) return base; | ||
| base = base % coeff; | ||
| if (base < 0) base += coeff; | ||
| return base; | ||
| static Entry Intersect(Entry x, Entry y) { | ||
| int64_t n, m; | ||
| int64_t a = x.coeff, b = x.base, c = y.coeff, d = y.base; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A bit more comment on the derivation
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let z be the integer, we know z satisfies z = x * c1 + b1 ... derivation |
||
| int64_t gcd = ExtendedEuclidean(a, c, &n, &m); | ||
| int64_t v = d - b; | ||
| if (v % gcd == 0) { | ||
| n = v / gcd * n; | ||
| m = v / gcd * (-m); | ||
|
|
||
| int64_t coeff = a / gcd * c; | ||
| return Entry(coeff, n*a + b); | ||
| } else { | ||
| return Nothing(); | ||
| } | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Take GCD of a and b. | ||
| * \param a The first operand. | ||
| * \param b The second operand. | ||
| * \return The result. | ||
| */ | ||
| static int64_t ZeroAwareGCD(int64_t a, int64_t b) { | ||
| if (a < 0) a = -a; | ||
| if (b < 0) b = -b; | ||
| static int64_t GCD(int64_t a, int64_t b) { | ||
| if (a < b) std::swap(a, b); | ||
| if (b == 0) return a; | ||
| // perform GCD (greatest common divisor) | ||
|
|
@@ -306,14 +290,53 @@ class ModularSetAnalyzer::Impl : | |
| } | ||
| return b; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) | ||
| * \param a The first coefficient. (a >= 0) | ||
| * \param b The second coefficient. (b >= 0) | ||
| * \param x The solution of x. | ||
| * \param y The solution of y. | ||
| * \return The GCD of a and b. | ||
| */ | ||
| static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t *x, int64_t *y) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style int64_t* x |
||
| int64_t s = 0, old_s = 1; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add invariance check CHECK_GE(a, 0);
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making the algorithm work for negative arguments would be a better way in my opinion. GCD and Bezout coefficients are defined even for negative numbers, and requiring the arguments to be non-negative will make this function harder to use for some applications. |
||
| int64_t r = b, old_r = a; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think a bit more about how to make the code more readable: comment a bit on what is r and what is s, so it is easier for others to understand
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some useful comment that contains the derivation // Extended Euclidean algorithm |
||
|
|
||
| while (r != 0) { | ||
| int64_t q = old_r / r; | ||
| int64_t tmp = old_r; | ||
| old_r = r; | ||
| r = tmp - q * r; | ||
| tmp = old_s; | ||
| old_s = s; | ||
| s = tmp - q * s; | ||
| } | ||
|
|
||
| *x = old_s; | ||
| if (b != 0) { | ||
| *y = (old_r - old_s * a) / b; | ||
| } else { | ||
| *y = 1; | ||
| } | ||
|
|
||
| return old_r; | ||
| } | ||
|
|
||
| /*! | ||
| * \brief return everything dtype can represent. | ||
| * \return Bound that represent everything dtype can represent. | ||
| */ | ||
| static Entry Everything() { | ||
| Entry ret; | ||
| ret.coeff = 1; ret.base = 0; | ||
| return ret; | ||
| return Entry(1, 0); | ||
| } | ||
|
|
||
| /*! | ||
| * \brief return an empty set | ||
| * \return An empty modular set. | ||
| */ | ||
| static Entry Nothing() { | ||
| return Entry(0, 1); | ||
| } | ||
| }; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
base in [0, coeff) for coeff != 0, for coeff=0 base can be any number that indicates constant