From e199228470e67004cb7ae98a0e9b3d77e2c2999f Mon Sep 17 00:00:00 2001 From: Mercy Date: Mon, 4 Mar 2019 19:36:57 +0800 Subject: [PATCH 1/2] Fix intersect of modular set --- include/tvm/arithmetic.h | 4 +- src/arithmetic/modular_set.cc | 55 ++++++++++++++++--- .../python/unittest/test_arith_modular_set.py | 16 ++++++ 3 files changed, 65 insertions(+), 10 deletions(-) diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index d023f8f1cf7e..15e796d6d715 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -164,10 +164,10 @@ class ModularSetAnalyzer { */ ModularSet operator()(const Expr& expr); /*! - * \brief Update constant int bound information of var. + * \brief Update modular set information of var. * * \param var The variable of interest. - * \param info The bound information. + * \param info The modular set information. * \param override Whether do we allow override of existing information. */ void Update(const Var& var, diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 8112beef7551..0986845e10cd 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -263,17 +263,26 @@ class ModularSetAnalyzer::Impl : } } /*! - * \brief Create interect of two sets. + * \brief Create intersection of two sets. * \param a The left operand. * \param b the right operand. */ - static Entry Intersect(Entry a, Entry b) { - // simple rule for now: pick higher constraints. - // TODO(team-team): Use extended euclidean algorithm. - if (a.coeff == 0) return a; - if (b.coeff == 0) return b; - if (a.coeff >= b.coeff) return a; - return b; + static Entry Intersect(Entry x, Entry y) { + int64_t n, m; + int64_t a = x.coeff, b = x.base, c = y.coeff, d = y.base; + int64_t gcd = ExtendedEuclidean(a, c, &n, &m); + int64_t v = d - b; + if (v % gcd == 0) { + n = v / gcd * n; + m = v / gcd * (-m); + + Entry ret; + ret.coeff = a / gcd * c; + ret.base = BaseSimplify(n * a + b, ret.coeff); + return ret; + } else { + return Nothing(); + } } /*! * \brief Simplify base so that it is in [0, coeff) when coeff != 0. @@ -306,6 +315,26 @@ class ModularSetAnalyzer::Impl : } return b; } + + /*! + * \brief Use Extended Euclidean algorithm to solve ax + by = 1 + * \param a The first operand. + * \param b The second operand. + * \param x The solution of x. + * \param y The solution of y. + * \return The GCD of a and b. + */ + static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t *x, int64_t *y) { + if (b == 0) { + *x = 1; + *y = 0; + return a; + } + int64_t q = ExtendedEuclidean(b, a % b, y, x); + *y -= a / b * (*x); + return q; + } + /*! * \brief return everything dtype can represent. * \return Bound that represent everything dtype can represent. @@ -315,6 +344,16 @@ class ModularSetAnalyzer::Impl : ret.coeff = 1; ret.base = 0; return ret; } + + /*! + * \brief return an empty set + * \return An empty modular set. + */ + static Entry Nothing() { + Entry ret; + ret.coeff = 0; ret.base = 1; + return ret; + } }; ModularSet ModularSetAnalyzer::operator()(const Expr& expr) { diff --git a/tests/python/unittest/test_arith_modular_set.py b/tests/python/unittest/test_arith_modular_set.py index 06ae5197b974..64ec96d573b1 100644 --- a/tests/python/unittest/test_arith_modular_set.py +++ b/tests/python/unittest/test_arith_modular_set.py @@ -117,6 +117,21 @@ def test_constraint_scope(): assert m.coeff == 1 assert m.base == 0 +def test_intersect(): + a = tvm.var("a") + analyzer = tvm.arith.Analyzer() + with analyzer.constraint_scope(a % 4 == 1): + with analyzer.constraint_scope(a % 3 == 1): + m = analyzer.modular_set(a) + assert m.coeff == 12 + assert m.base == 1 + + with analyzer.constraint_scope(a % 3 == 2): + with analyzer.constraint_scope(a % 5 == 3): + with analyzer.constraint_scope(a % 7 == 2): + m = analyzer.modular_set(a) + assert m.coeff == 105 + assert m.base == 23 if __name__ == "__main__": test_cast() @@ -126,3 +141,4 @@ def test_constraint_scope(): test_min_max_select() test_mix_index() test_constraint_scope() + test_intersect() From 5663e9bf91d5bf984f21140588352cf0ceae10df Mon Sep 17 00:00:00 2001 From: Mercy Date: Sun, 10 Mar 2019 08:53:00 +0800 Subject: [PATCH 2/2] add constraint check to the constructor of modular set entry --- 3rdparty/HalideIR | 2 +- src/arithmetic/modular_set.cc | 140 +++++++++++++++------------------- 2 files changed, 63 insertions(+), 79 deletions(-) diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index 86351c40824d..b257a9221ee1 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit 86351c40824dfc4cbb7447d70e5e63d9bd76eb90 +Subproject commit b257a9221ee1e5180d994b3488ddcc259b0ac157 diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index 0986845e10cd..d9a139fa5e8d 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -32,10 +32,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) // internal entry for const int bound +// This condition holds for all instances: coeff >= 0, base in [0, coeff] struct ModularSetAnalyzer::Entry { int64_t coeff{1}; int64_t base{0}; + Entry() = default; + + Entry(int64_t coeff, int64_t base) { + this->coeff = coeff; + + if (coeff < 0) { + coeff = -coeff; + } + + if (coeff != 0) { + base = base % coeff; + if (base < 0) base += coeff; + } + this->base = base; + } + bool is_const() const { return coeff == 0; } @@ -53,10 +70,7 @@ class ModularSetAnalyzer::Impl : if (!override) { CHECK(!var_map_.count(var)); } - Entry e; - e.coeff = info->coeff; - e.base = info->base; - var_map_[var] = e; + var_map_[var] = Entry(info->coeff, info->base); } // Detect useful constraints and use them in the analysis scope. @@ -65,10 +79,7 @@ class ModularSetAnalyzer::Impl : PVar coeff, base; // pattern match interesting constraints if (((var % coeff) == base).Match(constraint)) { - Entry entry; - entry.coeff = coeff.Eval()->value; - entry.base = base.Eval()->value; - return UpdateByIntersect(var.Eval(), entry); + return UpdateByIntersect(var.Eval(), Entry(coeff.Eval()->value, base.Eval()->value)); } return nullptr; } @@ -83,18 +94,12 @@ class ModularSetAnalyzer::Impl : } Entry VisitExpr_(const IntImm* op) final { - Entry ret; - ret.base = op->value; - ret.coeff = 0; - return ret; + return Entry(0, op->value); } Entry VisitExpr_(const UIntImm* op) final { if (op->value < std::numeric_limits::max()) { - Entry ret; - ret.base = static_cast(op->value); - ret.coeff = 0; - return ret; + return Entry(0, static_cast(op->value)); } else { return Everything(); } @@ -103,19 +108,15 @@ class ModularSetAnalyzer::Impl : Entry VisitExpr_(const Add* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); - Entry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base + b.base, ret.coeff); - return ret; + int64_t coeff = GCD(a.coeff, b.coeff); + return Entry(coeff, a.base + b.base); } Entry VisitExpr_(const Sub* op) final { Entry a = VisitExpr(op->a); Entry b = VisitExpr(op->b); - Entry ret; - ret.coeff = ZeroAwareGCD(a.coeff, b.coeff); - ret.base = BaseSimplify(a.base - b.base, ret.coeff); - return ret; + int64_t coeff = GCD(a.coeff, b.coeff); + return Entry(coeff, a.base - b.base); } Entry VisitExpr_(const Mul* op) final { @@ -128,10 +129,9 @@ class ModularSetAnalyzer::Impl : int64_t pq = a.coeff * b.coeff; int64_t pm = a.coeff * b.base; int64_t qn = a.base * b.coeff; - Entry ret; - ret.coeff = ZeroAwareGCD(pq, ZeroAwareGCD(pm, qn)); - ret.base = BaseSimplify(a.base * b.base, ret.coeff); - return ret; + + int64_t coeff = GCD(pq, GCD(pm, qn)); + return Entry(coeff, a.base * b.base); } Entry DivByConst(const Expr& lhs, @@ -140,20 +140,15 @@ class ModularSetAnalyzer::Impl : Entry a = VisitExpr(lhs); CHECK_NE(val, 0); if (a.coeff % val == 0) { - Entry ret; if (a.base == 0) { // a c x / c -> a x - ret.coeff = std::abs(a.coeff / val); - ret.base = 0; - return ret; + return Entry(std::abs(a.coeff / val), 0); } // positive division have a clear rounding mode. // Only handle case where we clearly know we need to round down. if (a.base > 0 && val > 0 && (round_down || parent_->CanProveGreaterEqual(lhs, 0))) { - ret.coeff = a.coeff / val; - ret.base = a.base / val; - return ret; + return Entry(a.coeff / val, a.base / val); } } return Everything(); @@ -244,22 +239,17 @@ class ModularSetAnalyzer::Impl : */ static Entry Union(Entry a, Entry b) { // {ax + y} \cup {bz + h} => {gcd(a, b) x + {y or h}} - int64_t coeff = ZeroAwareGCD(a.coeff, b.coeff); + int64_t coeff = GCD(a.coeff, b.coeff); if (coeff == 0) { if (a.base == b.base) return a; return Everything(); } int64_t base0 = a.base % coeff; int64_t base1 = b.base % coeff; - Entry ret; if (base0 == base1) { - ret.coeff = coeff; - ret.base = base0; - return ret; + return Entry(coeff, base0); } else { - ret.coeff = ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff); - ret.base = 0; - return ret; + return Entry(GCD(GCD(base0, base1), coeff), 0); } } /*! @@ -276,35 +266,20 @@ class ModularSetAnalyzer::Impl : n = v / gcd * n; m = v / gcd * (-m); - Entry ret; - ret.coeff = a / gcd * c; - ret.base = BaseSimplify(n * a + b, ret.coeff); - return ret; + int64_t coeff = a / gcd * c; + return Entry(coeff, n*a + b); } else { return Nothing(); } } - /*! - * \brief Simplify base so that it is in [0, coeff) when coeff != 0. - * \param base The base value. - * \param coeff The coeff value. - * \return The simplified base. - */ - static int64_t BaseSimplify(int64_t base, int64_t coeff) { - if (coeff == 0) return base; - base = base % coeff; - if (base < 0) base += coeff; - return base; - } + /*! * \brief Take GCD of a and b. * \param a The first operand. * \param b The second operand. * \return The result. */ - static int64_t ZeroAwareGCD(int64_t a, int64_t b) { - if (a < 0) a = -a; - if (b < 0) b = -b; + static int64_t GCD(int64_t a, int64_t b) { if (a < b) std::swap(a, b); if (b == 0) return a; // perform GCD (greatest common divisor) @@ -317,22 +292,35 @@ class ModularSetAnalyzer::Impl : } /*! - * \brief Use Extended Euclidean algorithm to solve ax + by = 1 - * \param a The first operand. - * \param b The second operand. + * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b) + * \param a The first coefficient. (a >= 0) + * \param b The second coefficient. (b >= 0) * \param x The solution of x. * \param y The solution of y. * \return The GCD of a and b. */ static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t *x, int64_t *y) { - if (b == 0) { - *x = 1; - *y = 0; - return a; + int64_t s = 0, old_s = 1; + int64_t r = b, old_r = a; + + while (r != 0) { + int64_t q = old_r / r; + int64_t tmp = old_r; + old_r = r; + r = tmp - q * r; + tmp = old_s; + old_s = s; + s = tmp - q * s; } - int64_t q = ExtendedEuclidean(b, a % b, y, x); - *y -= a / b * (*x); - return q; + + *x = old_s; + if (b != 0) { + *y = (old_r - old_s * a) / b; + } else { + *y = 1; + } + + return old_r; } /*! @@ -340,9 +328,7 @@ class ModularSetAnalyzer::Impl : * \return Bound that represent everything dtype can represent. */ static Entry Everything() { - Entry ret; - ret.coeff = 1; ret.base = 0; - return ret; + return Entry(1, 0); } /*! @@ -350,9 +336,7 @@ class ModularSetAnalyzer::Impl : * \return An empty modular set. */ static Entry Nothing() { - Entry ret; - ret.coeff = 0; ret.base = 1; - return ret; + return Entry(0, 1); } };