From b44f4f7a0c587bf0418afaac6089c17ceaf3225e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 21 Sep 2025 22:24:19 +0800 Subject: [PATCH 1/8] Enhance ConstIntBoundAnalyzer and IntervalSet with modular set analysis - Added modular set analysis to ConstIntBoundAnalyzer for tighter bounds when min_value equals max_value. - Introduced ComputeGCD function to calculate the GCD of two integers. - Updated Combine functions in IntervalSet to accept operation nodes for better type handling. - Enhanced tests for modular set bounds in both const integer bounds and interval sets. --- src/arith/const_int_bound.cc | 61 +++++++++++- src/arith/int_set.cc | 92 +++++++++++++------ .../arith/test_arith_const_int_bound.py | 12 +++ tests/python/arith/test_arith_intset.py | 10 ++ 4 files changed, 146 insertions(+), 29 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index b8e5db483f4f..7fe6cdcd560c 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -102,6 +102,7 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: + explicit Impl(Analyzer* parent) : parent_(parent) {} /*! \brief additional bound info about expr in bound */ struct BoundInfo { /*! \brief The expr */ @@ -129,8 +130,7 @@ class ConstIntBoundAnalyzer::Impl auto it = var_map_.find(var); if (it != var_map_.end()) { ICHECK(it->second == info) - << "Trying to update var \'" << var << "\'" - << " with a different const bound: " + << "Trying to update var \'" << var << "\'" << " with a different const bound: " << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) << ", new=" << ConstIntBound(info.min_value, info.max_value); } @@ -278,6 +278,25 @@ class ConstIntBoundAnalyzer::Impl if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -324,6 +343,24 @@ class ConstIntBoundAnalyzer::Impl if (b.min_value > 0) { int64_t b_max_cap = InfAwareAdd(b.max_value, -1); + // Try to get tighter bounds using modular set information + if (parent_ && b.min_value == b.max_value) { + ModularSet mod_a = parent_->modular_set(op->a); + int64_t modulus = b.min_value; + int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus); + + // If gcd_coeff_mod > 1, we can get tighter bounds + // The result will be of the form gcd_coeff_mod * k + (base % modulus) + // where k ranges to cover [0, modulus - gcd_coeff_mod] + if (gcd_coeff_mod > 1) { + int64_t base_mod = mod_a->base % modulus; + if (base_mod < 0) base_mod += modulus; + int64_t tight_max = modulus - gcd_coeff_mod + base_mod; + if (tight_max >= modulus) tight_max -= modulus; + return MakeBound(base_mod, tight_max); + } + } + if (a.min_value >= 0) { // 0 <= [a_min, a_max] < b_min if (a.max_value < b.min_value) return a; @@ -458,6 +495,8 @@ class ConstIntBoundAnalyzer::Impl private: friend class ConstIntBoundAnalyzer; + // parent analyzer + Analyzer* parent_; // internal variable map std::unordered_map var_map_; // additional bound info @@ -525,6 +564,22 @@ class ConstIntBoundAnalyzer::Impl // If the range of b does not have 0, use BinaryOpBoundary. return BinaryOpBoundary(a, b, op); } + /*! + * \brief Compute GCD of two integers. + * \param a The first integer. + * \param b The second integer. + * \return the result. + */ + static int64_t ComputeGCD(int64_t a, int64_t b) { + a = std::abs(a); + b = std::abs(b); + while (b != 0) { + int64_t temp = b; + b = a % b; + a = temp; + } + return a; + } /*! * \brief Compute x + y, aware of inf. * \param x The left operand. @@ -805,7 +860,7 @@ std::function ConstIntBoundAnalyzer::EnterConstraint(const PrimExpr& con return impl_->EnterConstraint(constraint); } -ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl()) {} +ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(Analyzer* parent) : impl_(new Impl(parent)) {} ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index aa15284b3e03..6f6730352a22 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -111,8 +112,9 @@ TVM_DECLARE_LOGICAL_OP(Not); * \brief Combine two interval set under arithmetic operations. * \note this can possibly relax the set. */ -template -inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, DataType dtype) { +template +inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, const OpNode* op) { + DataType dtype = op->dtype; if (a->IsSinglePoint() && b->IsSinglePoint()) { PrimExpr expr; if (auto res = TryConstFold(a->min_value, b->min_value)) { @@ -134,7 +136,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, Dat template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::AddNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value + b->min_value); } @@ -149,7 +151,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::SubNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value - b->min_value); } @@ -164,7 +166,7 @@ inline IntervalSet Combine(Analyzer* analyer, IntervalSet a, IntervalS template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MulNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value * b->min_value); } @@ -198,7 +200,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::DivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(a->min_value / b->min_value); } @@ -232,7 +234,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::ModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(truncmod(a->min_value, b->min_value)); } @@ -261,7 +263,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorDivNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floordiv(a->min_value, b->min_value)); } @@ -295,7 +297,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::FloorModNode* op) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(floormod(a->min_value, b->min_value)); } @@ -321,6 +323,39 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int return IntervalSet(tmin, tmax); } } + // Enhanced: Use ModularSet analysis for better bounds + if (auto* div_imm = divisor.as()) { + int64_t div_val = div_imm->value; + + // Analyze the modular properties of the dividend + ModularSet dividend_mod = analyzer->modular_set(op->a); + + if (dividend_mod.defined() && dividend_mod->coeff > 0) { + // Calculate GCD of dividend coefficient and divisor + int64_t gcd = 1; + if (dividend_mod->coeff != 0 && div_val != 0) { + int64_t a_coeff = std::abs(dividend_mod->coeff); + int64_t b_val = std::abs(div_val); + while (b_val != 0) { + int64_t temp = b_val; + b_val = a_coeff % b_val; + a_coeff = temp; + } + gcd = a_coeff; + } + + if (gcd > 1 && div_val % gcd == 0) { + // The dividend is a multiple of gcd, and divisor is also a multiple of gcd + // So the result is also a multiple of gcd, with max value = (div_val/gcd - 1) * gcd + int64_t max_quotient = (div_val / gcd) - 1; + int64_t max_mod_result = max_quotient * gcd + (dividend_mod->base % gcd); + + if (max_mod_result >= 0 && max_mod_result < div_val) { + return IntervalSet(make_zero(op->dtype), make_const(op->dtype, max_mod_result)); + } + } + } + } return IntervalSet(make_zero(divisor.dtype()), divisor - 1); } else { PrimExpr bound = abs(divisor) - 1; @@ -333,7 +368,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MaxNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(max(a->min_value, b->min_value)); } @@ -344,7 +379,7 @@ inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, Interval template <> inline IntervalSet Combine(Analyzer* analzyer, IntervalSet a, IntervalSet b, - DataType /* dtype */) { + const tir::MinNode* /* op */) { if (a->IsSinglePoint() && b->IsSinglePoint()) { return IntervalSet::SinglePoint(min(a->min_value, b->min_value)); } @@ -475,19 +510,25 @@ class IntervalSetEvaluator : public ExprFunctor { if (op->lanes->IsInstance()) { int lanes = static_cast(Downcast(op->lanes)->value); if (vstride > 0) { - return Combine(analyzer_, base, - IntervalSet(make_zero(t), make_const(t, vstride * (lanes - 1))), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), stride_expr), add_node); } else { - return Combine(analyzer_, base, - IntervalSet(make_const(t, vstride * (lanes - 1)), make_zero(t)), - op->dtype); + PrimExpr stride_expr = make_const(t, vstride * (lanes - 1)); + auto add_op = tir::Add(op->base, stride_expr); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(stride_expr, make_zero(t)), add_node); } } else { /* Scalable vector */ if (vstride > 0) { - return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(make_zero(t), pos_inf()), add_node); } else { - return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), op->dtype); + auto add_op = tir::Add(op->base, make_zero(t)); + auto add_node = add_op.as(); + return Combine(analyzer_, base, IntervalSet(neg_inf(), make_zero(t)), add_node); } } } @@ -563,7 +604,7 @@ class IntervalSetEvaluator : public ExprFunctor { if (MatchPoint(a, op->a) && MatchPoint(b, op->b)) { return IntervalSet::SinglePoint(ffi::GetRef(op)); } - return Combine(analyzer_, a, b, op->dtype); + return Combine(analyzer_, a, b, op); } // recursive depth @@ -640,13 +681,13 @@ void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_o ICHECK(ExprDeepEqual()(old_info.min(), info.min())) << "Trying to update var \'" << var << "\'" - << " with a different minimum value: " - << "original=" << old_info.min() << ", new=" << info.min(); + << " with a different minimum value: " << "original=" << old_info.min() + << ", new=" << info.min(); ICHECK(ExprDeepEqual()(old_info.max(), info.max())) << "Trying to update var \'" << var << "\'" - << " with a different maximum value: " - << "original=" << old_info.max() << ", new=" << info.max(); + << " with a different maximum value: " << "original=" << old_info.max() + << ", new=" << info.max(); } } dom_map_.Set(var, info); @@ -1194,8 +1235,7 @@ ffi::Array EstimateRegionUpperBound(const ffi::Array& region, TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "IntervalSet" - << "[" << op->min_value << ", " << op->max_value << ']'; + p->stream << "IntervalSet" << "[" << op->min_value << ", " << op->max_value << ']'; }); TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/tests/python/arith/test_arith_const_int_bound.py b/tests/python/arith/test_arith_const_int_bound.py index 14bfec2328f2..8728df7e3f3a 100644 --- a/tests/python/arith/test_arith_const_int_bound.py +++ b/tests/python/arith/test_arith_const_int_bound.py @@ -298,5 +298,17 @@ class TestRampBound(BaseCompare): ) +class TestModularSetBound(BaseCompare): + analyzer = tvm.arith.Analyzer() + tx = tvm.te.var("tx", dtype="int32") + bx = tvm.te.var("bx", dtype="int32") + + expr = (bx * 2048 + tx * 16) % 7168 + + test_case = tvm.testing.parameter( + TestCase(expr, (0, 7152), {bx: (0, 3584), tx: (0, 128)}), + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/arith/test_arith_intset.py b/tests/python/arith/test_arith_intset.py index 18865a73df45..04014ca30095 100644 --- a/tests/python/arith/test_arith_intset.py +++ b/tests/python/arith/test_arith_intset.py @@ -387,5 +387,15 @@ def test_union_lower_bound(): assert result.max_value.same_as(pos_inf) +def test_modular_set(): + ck = IntSetChecker() + x = tvm.te.var("x", dtype="int32") + y = tvm.te.var("y", dtype="int32") + expr = (x * 2048 + y * 16) % 7168 + ck.verify( + expr, {x: tvm.arith.IntervalSet(0, 128), y: tvm.arith.IntervalSet(0, 3584)}, (0, 7152) + ) + + if __name__ == "__main__": tvm.testing.main() From 2f7584f7927e9b886fe31296dd46148feb11286a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 01:01:29 +0800 Subject: [PATCH 2/8] replace gcd compute with ZeroAwareGCD --- src/arith/int_set.cc | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 6f6730352a22..0393e57bf5aa 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -34,6 +34,7 @@ #include #include "constraint_extract.h" +#include "int_operator.h" #include "interval_set.h" #include "pattern_match.h" @@ -332,17 +333,7 @@ inline IntervalSet Combine(Analyzer* analyzer, IntervalSet a, Int if (dividend_mod.defined() && dividend_mod->coeff > 0) { // Calculate GCD of dividend coefficient and divisor - int64_t gcd = 1; - if (dividend_mod->coeff != 0 && div_val != 0) { - int64_t a_coeff = std::abs(dividend_mod->coeff); - int64_t b_val = std::abs(div_val); - while (b_val != 0) { - int64_t temp = b_val; - b_val = a_coeff % b_val; - a_coeff = temp; - } - gcd = a_coeff; - } + int64_t gcd = ZeroAwareGCD(dividend_mod->coeff, div_val); if (gcd > 1 && div_val % gcd == 0) { // The dividend is a multiple of gcd, and divisor is also a multiple of gcd From 256a40c1fd9597f37b5fe2a459549a27002987d6 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 01:25:03 +0800 Subject: [PATCH 3/8] doc op node --- src/arith/int_set.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 0393e57bf5aa..5cb47109d474 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -111,6 +111,10 @@ TVM_DECLARE_LOGICAL_OP(Not); /*! * \brief Combine two interval set under arithmetic operations. + * \param analyzer The analyzer for simplification and proving + * \param a The first interval set + * \param b The second interval set + * \param op The operation node, used to extract dtype and other properties * \note this can possibly relax the set. */ template From b6d0935be1464f3f672746fb892c7208a4d32d91 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 01:26:01 +0800 Subject: [PATCH 4/8] replace Compute GCD with ZeroAwareGCD --- src/arith/const_int_bound.cc | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 7fe6cdcd560c..0d40acc8e68d 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -283,7 +283,7 @@ class ConstIntBoundAnalyzer::Impl if (parent_ && b.min_value == b.max_value) { ModularSet mod_a = parent_->modular_set(op->a); int64_t modulus = b.min_value; - int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus); + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); // If gcd_coeff_mod > 1, we can get tighter bounds // The result will be of the form gcd_coeff_mod * k + (base % modulus) @@ -347,7 +347,7 @@ class ConstIntBoundAnalyzer::Impl if (parent_ && b.min_value == b.max_value) { ModularSet mod_a = parent_->modular_set(op->a); int64_t modulus = b.min_value; - int64_t gcd_coeff_mod = ComputeGCD(mod_a->coeff, modulus); + int64_t gcd_coeff_mod = ZeroAwareGCD(mod_a->coeff, modulus); // If gcd_coeff_mod > 1, we can get tighter bounds // The result will be of the form gcd_coeff_mod * k + (base % modulus) @@ -564,22 +564,7 @@ class ConstIntBoundAnalyzer::Impl // If the range of b does not have 0, use BinaryOpBoundary. return BinaryOpBoundary(a, b, op); } - /*! - * \brief Compute GCD of two integers. - * \param a The first integer. - * \param b The second integer. - * \return the result. - */ - static int64_t ComputeGCD(int64_t a, int64_t b) { - a = std::abs(a); - b = std::abs(b); - while (b != 0) { - int64_t temp = b; - b = a % b; - a = temp; - } - return a; - } + /*! * \brief Compute x + y, aware of inf. * \param x The left operand. From 21eda0b1d09ea519b6be1b9769496123798a828e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sun, 5 Oct 2025 01:43:00 +0800 Subject: [PATCH 5/8] add example --- src/arith/const_int_bound.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 0d40acc8e68d..e6b5a8e16be6 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -288,6 +288,14 @@ class ConstIntBoundAnalyzer::Impl // If gcd_coeff_mod > 1, we can get tighter bounds // The result will be of the form gcd_coeff_mod * k + (base % modulus) // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // ModularSet(expr) = 16*k (coeff=16, base=0) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] if (gcd_coeff_mod > 1) { int64_t base_mod = mod_a->base % modulus; if (base_mod < 0) base_mod += modulus; @@ -352,6 +360,14 @@ class ConstIntBoundAnalyzer::Impl // If gcd_coeff_mod > 1, we can get tighter bounds // The result will be of the form gcd_coeff_mod * k + (base % modulus) // where k ranges to cover [0, modulus - gcd_coeff_mod] + // + // Example: expr = (bx * 2048 + tx * 16) % 7168 + // where bx in [0, 3584), tx in [0, 128) + // ModularSet(expr) = 16*k (coeff=16, base=0) + // GCD(16, 7168) = 16 + // Result can only be {0, 16, 32, ..., 7152} + // Without this optimization: bound = [0, 7167] + // With this optimization: bound = [0, 7152] if (gcd_coeff_mod > 1) { int64_t base_mod = mod_a->base % modulus; if (base_mod < 0) base_mod += modulus; From 8d7c919c012e3cad32f644c3084876d0d8d4fc8f Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 17 Oct 2025 13:37:47 +0800 Subject: [PATCH 6/8] test fix --- tests/python/te/test_te_create_primfunc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index c8a095280230..426272584bb5 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -852,7 +852,7 @@ def tir_workload( v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)]) T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) - for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30): + for rv0, rv1 in T.grid((v_ax2 % 3 * 4 + 16) // 12 + 1, (v_ax3 % 3 * 10 + 40) // 30 + 1): with T.block("adaptive_pool_sum"): v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0) v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1) @@ -870,7 +870,7 @@ def tir_workload( T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3]) T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3]) T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"}) - adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30)) + adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", (v_ax2 % 3 * 4 + 16) // 12 + 1) * T.Cast("float32", (v_ax3 % 3 * 10 + 40) // 30 + 1)) # fmt: on def te_workload(): From 6d22dc236cb184f0c7320622d6061fed9741f868 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 18 Oct 2025 00:35:54 +0800 Subject: [PATCH 7/8] test fix --- ...ule_feature_extractor_per_store_feature.py | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py index 057cd0e9f7ae..b901c3ce1372 100644 --- a/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py +++ b/tests/python/meta_schedule/test_meta_schedule_feature_extractor_per_store_feature.py @@ -846,21 +846,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 25.000000042995662, - 20.000001375860553, - 23.00000017198264, - 14.000088052430122, + 25.00000004, + 19.99718086, + 23.00000017, + 13.99726771, 1.0, 0.0, 0.0, - 18.00000550343433, - 20.00562591970089, - 2.321928094887362, - 23.00000017198264, - 18.00000550343433, - 21.000000687930438, - 12.0003521774803, - 12.0003521774803, + 18.0000055, + 20.00000138, + 2.32192809, + 23.00000017, + 17.997185, + 21.00000069, + 11.99753235, + 12.00035218, ], rtol=1e-5, atol=1e-5, @@ -872,21 +872,21 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 25.000000042995662, - 12.0003521774803, - 23.00000017198264, - 9.002815015607053, + 25.00000004, + 11.00070427, + 23.00000017, + 5.04439412, 1.0, 0.0, 0.0, - 6.022367813028454, - 11.98049663618346, - 8.005624549193879, - 17.000011006847668, - 4.087462841250339, - 15.000044026886828, - 1.584962500721156, - 4.087462841250339, + 6.02236781, + 11.98049664, + 8.00562455, + 17.00001101, + 3.169925, + 15.00004403, + 0.169925, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1052,21 +1052,21 @@ def _create_schedule(): 1.0, 0.0, 0.0, - 22.00000034396526, - 20.000001375860553, - 20.000001375860553, - 14.000088052430122, + 22.00000034, + 19.85798251, + 20.00000138, + 13.85807816, 1.0, 0.0, 0.0, - 15.000044026886828, - 20.17555076886471, - 2.321928094887362, - 20.000001375860553, - 18.00000550343433, - 18.00000550343433, - 12.0003521774803, - 4.087462841250339, + 15.00004403, + 20.04456622, + 2.32192809, + 20.00000138, + 17.85798707, + 18.0000055, + 11.8583696, + 4.08746284, ], rtol=1e-5, atol=1e-5, @@ -1078,20 +1078,20 @@ def _create_schedule(): 0.0, 1.0, 0.0, - 22.00000034396526, - 9.002815015607053, - 20.000001375860553, - 3.169925001442312, + 22.00000034, + 7.01122726, + 20.00000138, + 4.08746284, 1.0, 0.0, 0.0, 3.169925001442312, - 9.61654884377899, + 4.08746284, 8.005624549193879, 14.000088052430122, - 1.584962500721156, - 12.0003521774803, - 0.044394119358453436, + 0.5849625, + 12.00035218, + 0.08746284, 4.087462841250339, ], rtol=1e-5, From c404a13ee0b687bc360342d28b7fb4eab8def598 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 18 Oct 2025 02:51:22 +0800 Subject: [PATCH 8/8] lint fix --- src/arith/const_int_bound.cc | 3 ++- src/arith/int_set.cc | 11 ++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index e6b5a8e16be6..7e1d8fb3fb89 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -130,7 +130,8 @@ class ConstIntBoundAnalyzer::Impl auto it = var_map_.find(var); if (it != var_map_.end()) { ICHECK(it->second == info) - << "Trying to update var \'" << var << "\'" << " with a different const bound: " + << "Trying to update var \'" << var << "\'" + << " with a different const bound: " << "original=" << ConstIntBound(it->second.min_value, it->second.max_value) << ", new=" << ConstIntBound(info.min_value, info.max_value); } diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 5cb47109d474..1433ceb70fc0 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -676,13 +676,13 @@ void IntSetAnalyzer::Impl::Update(const Var& var, const IntSet& info, bool can_o ICHECK(ExprDeepEqual()(old_info.min(), info.min())) << "Trying to update var \'" << var << "\'" - << " with a different minimum value: " << "original=" << old_info.min() - << ", new=" << info.min(); + << " with a different minimum value: " + << "original=" << old_info.min() << ", new=" << info.min(); ICHECK(ExprDeepEqual()(old_info.max(), info.max())) << "Trying to update var \'" << var << "\'" - << " with a different maximum value: " << "original=" << old_info.max() - << ", new=" << info.max(); + << " with a different maximum value: " + << "original=" << old_info.max() << ", new=" << info.max(); } } dom_map_.Set(var, info); @@ -1230,7 +1230,8 @@ ffi::Array EstimateRegionUpperBound(const ffi::Array& region, TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "IntervalSet" << "[" << op->min_value << ", " << op->max_value << ']'; + p->stream << "IntervalSet" + << "[" << op->min_value << ", " << op->max_value << ']'; }); TVM_FFI_STATIC_INIT_BLOCK() {