diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3a71e5eb5fbf..6ca3ba9cfd55 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -112,7 +112,7 @@ class ConstIntBoundAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr); /*! * \brief analyze the expr with the intermediate memorized to avoid redundant computation @@ -120,8 +120,8 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, + std::unordered_map* bound); /*! * \brief Update constant int bound information of var. @@ -130,22 +130,22 @@ class ConstIntBoundAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ConstIntBound& info, - bool override = false); + TVM_DLL void Update(const Var& var, + const ConstIntBound& info, + bool override = false); /*! * \brief Bind variable to a range. * * \param var The variable. * \param range The range we bind to. */ - void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range); private: friend class Analyzer; friend class ConstraintContext; explicit ConstIntBoundAnalyzer(Analyzer* parent); - ~ConstIntBoundAnalyzer(); + TVM_DLL ~ConstIntBoundAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -212,7 +212,7 @@ class ModularSetAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ModularSet operator()(const PrimExpr& expr); + TVM_DLL ModularSet operator()(const PrimExpr& expr); /*! * \brief Update constant int bound information of var. * @@ -220,15 +220,15 @@ class ModularSetAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ModularSet& info, - bool override = false); + TVM_DLL void Update(const Var& var, + const ModularSet& info, + bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit ModularSetAnalyzer(Analyzer* parent); - ~ModularSetAnalyzer(); + TVM_DLL ~ModularSetAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -252,7 +252,7 @@ class RewriteSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -261,9 +261,9 @@ class RewriteSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, + const PrimExpr& new_expr, + bool override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -272,7 +272,7 @@ class RewriteSimplifier { friend class ConstraintContext; friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); - ~RewriteSimplifier(); + TVM_DLL ~RewriteSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -288,7 +288,7 @@ class CanonicalSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -297,15 +297,15 @@ class CanonicalSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, + const PrimExpr& new_expr, + bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit CanonicalSimplifier(Analyzer* parent); - ~CanonicalSimplifier(); + TVM_DLL ~CanonicalSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -363,12 +363,12 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - IntSet operator()(const PrimExpr& expr, const Map& dom_map); + TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); private: friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); - ~IntSetAnalyzer(); + TVM_DLL ~IntSetAnalyzer(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -384,7 +384,7 @@ class IntSetAnalyzer { * If the analyzer uses memoization, we need to clear the internal * cache when information about a Var has been overridden. */ -class Analyzer { +class TVM_DLL Analyzer { public: /* * Disable copy constructor. diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 592a79fb86d1..4980c41016fd 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -41,39 +41,6 @@ namespace tvm { namespace tir { -/*! - * \brief Simplify the expression. - * \param expr The expression to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -TVM_DLL PrimExpr Simplify(PrimExpr expr, Map vrange = Map()); - -/*! - * \brief Simplify the statement. - * \param stmt The statement to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt Simplify(Stmt stmt, Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param stmt The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt CanonicalSimplify(Stmt stmt, - Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param expr The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized expression. - */ -TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, - Map vrange = Map()); /*! * \brief verifies whether the IR stmt or Expr is in SSA form. diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 01d50e86a88a..db1662c7b52a 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -23,8 +23,8 @@ from random import randrange import numpy as np - -from tvm.tir import expr, ir_pass +import tvm.arith +from tvm.tir import expr logger = logging.getLogger('autotvm') @@ -156,7 +156,8 @@ def get_const_int(exp): if isinstance(exp, int): return exp if not isinstance(exp, (expr.IntImm,)): - exp = ir_pass.Simplify(exp) + ana = tvm.arith.Analyzer() + exp = ana.simplify(exp) if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -180,7 +181,8 @@ def get_const_tuple(in_tuple): if isinstance(elem, expr.Var): ret.append(elem) elif not isinstance(elem, (expr.IntImm, int)): - elem = ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() + elem = ana.simplify(elem) if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5c92965a0130..35700badb04b 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -287,6 +287,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" in f.attrs and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), + tvm.tir.transform.Simplify(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) mod_dev = opt_device(mod_mixed) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 107f51b8bbcc..765efa0b976c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -29,10 +29,10 @@ import tvm.tir import tvm.te import tvm.te._ffi_api +import tvm.arith from tvm.tir import expr as _expr from tvm.tir import stmt as _stmt -from tvm.tir import ir_pass as _ir_pass from tvm.te.tensor import Tensor, Operation from tvm.tir import all as _all from tvm.tir import any as _any @@ -160,6 +160,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None): self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body + self.analyzer = tvm.arith.Analyzer() self.returned = False # If this function has a valid return @@ -326,7 +327,7 @@ def visit_Assign(self, node): _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] if isinstance(rhs, _expr.PrimExpr): - rhs = _ir_pass.Simplify(rhs) + rhs = self.analyzer.simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs @@ -410,7 +411,7 @@ def visit_With(self, node): def visit_If(self, node): - cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) + cond = self.analyzer.simplify(self.visit(node.test)) # Return no IfThenElse if proven if isinstance(cond, _expr.IntImm): @@ -501,8 +502,8 @@ def visit_For(self, node): _name = node.target.id if isinstance(for_type, tuple): - low = _ir_pass.CanonicalSimplify(low) - ext = _ir_pass.CanonicalSimplify(ext) + low = self.analyzer.simplify(low) + ext = self.analyzer.simplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const " + \ diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 0f50636d68d8..5a3d394c098f 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -20,6 +20,8 @@ import logging import numpy as np import tvm +import tvm.arith +import tvm.tir import tvm._ffi @@ -168,4 +170,23 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) +def assert_prim_expr_equal(lhs, rhs): + """Assert lhs and rhs equals to each iother. + + Parameters + ---------- + lhs : tvm.tir.PrimExpr + The left operand. + + rhs : tvm.tir.PrimExpr + The left operand. + """ + ana = tvm.arith.Analyzer() + res = ana.simplify(lhs - rhs) + equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 + if not equal: + raise ValueError("{} and {} are not equal".format(lhs, rhs)) + + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 0c4c36888eb5..4dd541e9bdbb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -21,7 +21,6 @@ from . import stmt as _stmt from . import expr as _expr -from . import ir_pass as _pass class WithScope(object): @@ -212,7 +211,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): self.nidx += 1 self._seq_stack.append([]) loop_var = _expr.Var(name, dtype=dtype) - extent = end if begin == 0 else _pass.Simplify(end - begin) + extent = end if begin == 0 else (end - begin) def _exit_cb(): if for_type == "serial": for_type_id = 0 diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index cc9c745a24b8..58723170b3ca 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -207,8 +207,9 @@ bool DetectClipBound( return false; } LinearEqEntry ret; + Analyzer analyzer; if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; - ret.coeff = Simplify(ret.coeff); + ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift @@ -254,14 +255,15 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { for (PrimExpr cond : splits) { if (!DetectClipBound(cond, &rmap)) return Array(); } + Analyzer analyzer; Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { - e.min_value = Simplify(e.min_value); + e.min_value = analyzer.Simplify(e.min_value); } if (e.max_value.defined()) { - e.max_value = Simplify(e.max_value); + e.max_value = analyzer.Simplify(e.max_value); } ret.push_back(e.min_value); ret.push_back(e.max_value); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8c5afb1be8b5..027259a4d225 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -570,11 +570,12 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, // TODO(tqchen): revisit IntSet interface as well. Range IntSet::cover_range(Range max_range) const { IntSet temp; + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); CHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { return Range::make_by_min_extent( - s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -607,26 +608,30 @@ bool IntSet::is_single_point() const { } bool IntSet::can_prove_positive() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_positive_const(tir::Simplify(s_int->min_value))); + return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); } bool IntSet::can_prove_negative() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_negative_const(tir::Simplify(s_int->max_value))); + return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); } bool IntSet::can_prove_non_positive() const { + Analyzer analyzer; if (const auto* s_int = (*this).as()) { - auto max = tir::Simplify(s_int->max_value); + auto max = analyzer.Simplify(s_int->max_value); return is_zero(max) || is_negative_const(max); } return false; } bool IntSet::can_prove_non_negative() const { + Analyzer analyzer; if (const IntervalSetNode* s_int = (*this).as()) { - auto min = tir::Simplify(s_int->min_value); + auto min = analyzer.Simplify(s_int->min_value); return is_zero(min) || is_positive_const(min); } return false; @@ -669,8 +674,8 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) { } // Range related code -inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); +inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer->Simplify(lhs - rhs)); } IntSet IntSet::range(Range r) { @@ -685,8 +690,9 @@ bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; const IntervalSetNode* a_int = a.as(); if (!a_int) return false; - return ProveEqual(a_int->min_value, b->min) && - ProveEqual(a_int->max_value, b->extent + b->min - 1); + Analyzer ana; + return ProveEqual(&ana, a_int->min_value, b->min) && + ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } IntSet Union(const Array& sets) { @@ -697,8 +703,8 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), + ana.Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -709,8 +715,8 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), + ana.Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { @@ -758,7 +764,7 @@ IntSet EvalSet(Range r, IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); return std::move(res); } diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c3802b1a63f4..7e2ef701265c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -236,6 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed, }), BindTarget(target), tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), tir::transform::LowerIntrin(), tir::transform::LowerDeviceStorageAccessInfo(), }; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index f9653e24b1b9..e2e7f4994349 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,9 +22,10 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include +#include #include #include -#include #include #include "./type_relations.h" @@ -48,7 +49,8 @@ bool EqualCheck(const IndexExpr& lhs, return pdiff[0] == 0; } // symbolic - diff = tvm::tir::CanonicalSimplify(diff); + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1d8004e9938f..d4631aaa8023 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -414,7 +414,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( + PrimExpr vec_index = analyzer_->Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); @@ -492,7 +492,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( + PrimExpr vec_index = analyzer_->Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 1a324588537f..9ebb89ee95a6 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -24,9 +24,11 @@ * The result Jacobian shape will be (Y.shape, X.shape) */ #include +#include #include #include -#include +#include + #include #include "ad_util.h" @@ -264,7 +266,7 @@ class JacobianMutator : public ExprMutator { CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return Simplify( + return analyzer_.Simplify( ReduceNode::make(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index)); } @@ -302,6 +304,7 @@ class JacobianMutator : public ExprMutator { Tensor input_; Array indices_; Var input_var_; + arith::Analyzer analyzer_; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { @@ -341,11 +344,11 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { // Differentiate wrt input[input_indices] input_indices.push_back(new_v); } - + arith::Analyzer analzyer; // Compute Jacobian PrimExpr new_body = Jacobian( Substitute(op->body[output->value_index], vmap), input, input_indices); - new_body = Simplify(new_body); + new_body = analzyer.Simplify(new_body); int value_index = 0; Array new_bodies; diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 2ee5b273d4f6..1916b4a4823e 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -39,10 +39,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); -} - int ScanOpNode::num_outputs() const { return static_cast(update.size()); } @@ -77,6 +73,10 @@ Operation ScanOpNode::make(std::string name, auto n = make_object(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); + arith::Analyzer analyzer; + auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer.Simplify(lhs - rhs)); + }; for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); @@ -232,10 +232,11 @@ void ScanOpNode::GatherBound( time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); } CHECK(!out_dom_map->count(this->scan_axis)); + arith::Analyzer analyzer; Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, tir::Simplify(r->extent + r->min - sdom->min)); + sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -260,10 +261,11 @@ Stmt ScanOpNode::BuildRealize( const Stage& stage, const std::unordered_map& dom_map, const Stmt& body) const { + arith::Analyzer analyzer; CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); Range tdom = Range::make_by_min_extent( - 0, tir::Simplify(sdom->extent + sdom->min)); + 0, analyzer.Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 6064f5c4e008..b66406969c76 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -222,6 +222,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { compute_intrin_iter_space->Set(iv->var, vrange); } } + analyzer_.Bind(*compute_intrin_iter_space); // input remap. Array inputs = self->InputTensors(); @@ -234,7 +235,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { // Enable fuzzy matching, to match [1, n, m] to [n, m] e.start = e.region.size() - e.tensor.ndim(); for (size_t j = 0; j < e.start; ++j) { - auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space); + auto canonical_extent = analyzer_.Simplify(e.region[j]->extent); CHECK(is_one(canonical_extent)) << "Tensorize " << intrin->name << ":" << " Input dimension mismatch with tensor intrin " @@ -304,6 +305,8 @@ class TensorIntrinMatcher final : public StmtExprMutator { std::unordered_map var_remap_; // IterVar remap. std::unordered_map axis_remap_; + // arith analyzer + arith::Analyzer analyzer_; }; // Try to match tensor dataflow of the stage with the intrinsic @@ -339,11 +342,12 @@ void VerifyTensorizeBody( CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; + arith::Analyzer ana; + ana.Bind(compute_intrin_iter_space); + for (size_t i = 0; i < body.size(); ++i) { - PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space); - lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); - PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); - rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); + PrimExpr lhs = ana.Simplify(body[i]); + PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4ff8586bccdf..6ed9438ec90f 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -324,6 +324,7 @@ void PassUpDomain(const FuseNode* s, CHECK(dom_map.count(s->outer)); CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); + arith::Analyzer ana; if (fused.match_range(dom_map.at(s->fused))) { *outer = IntSet::range(dom_map.at(s->outer)); @@ -348,15 +349,15 @@ void PassUpDomain(const FuseNode* s, *outer = IntSet::interval( outer_min + indexdiv(fused.min(), inner_extent), outer_min + indexdiv(fused.max(), inner_extent)); - if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) && - is_zero(Simplify(indexmod(fused.min(), fused_extent)))) { + if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && + is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { // fused never spans multiple rows, make a tight bounding box // there may be other cases when bounding box could be tightened *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent), inner_min + indexmod(fused.max(), inner_extent)); } else { // fused may span multiple rows, use full row widths - if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) || - !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) { + if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || + !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { LOG(WARNING) << "fused and original axes are not aligned, this may cause redundant computations"; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 57b637df0570..c818218fa65e 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -181,7 +181,7 @@ class SchedulePostProc : public StmtExprMutator { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { - CHECK(is_zero(tir::Simplify(it->second - op->value))); + CHECK(is_zero(analyzer_.Simplify(it->second - op->value))); return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; @@ -335,6 +335,8 @@ class SchedulePostProc : public StmtExprMutator { std::unordered_map replace_realize_; // replace producer consumer. std::unordered_map replace_op_; + // integer analyzer + arith::Analyzer analyzer_; }; Stmt ScheduleOps( diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 56235591f722..797a8b2b7b88 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -508,7 +508,7 @@ class BufferAnalyser : public StmtExprVisitor { return; } auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); + auto simplified_index = analyzer_.Simplify(index); index_visitor(simplified_index); } @@ -611,7 +611,7 @@ class BufferAnalyser : public StmtExprVisitor { index_visitor.scaling_factor_ = shape->value; } auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); + auto simplified_index = analyzer_.Simplify(index); index_visitor(simplified_index); } } @@ -645,7 +645,7 @@ class BufferAnalyser : public StmtExprVisitor { PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + \ indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = analyzer_.Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -773,6 +773,7 @@ class BufferAnalyser : public StmtExprVisitor { IndexVisitor index_visitor; Tile warp_tile_; Tile thread_tile_; + arith::Analyzer analyzer_; int warp_threads_y_{-1}; bool invalid_{false}; }; @@ -1148,7 +1149,7 @@ class TensorCoreIRMutator : public StmtExprMutator { buffer_node->strides = strides; buffer_node->shape = shape; buffer_node->data_alignment = 1; - buffer_node->elem_offset = Simplify(elem_offset); + buffer_node->elem_offset = analyzer_.Simplify(elem_offset); buffer_node->offset_factor = 1; Buffer buffer(buffer_node); @@ -1184,6 +1185,7 @@ class TensorCoreIRMutator : public StmtExprMutator { std::unordered_map frag_load_; std::unordered_map frag_store_; std::unordered_map bounds_; + arith::Analyzer analyzer_; Tile warp_tile_; int warp_threads_y_{-1}; }; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 6bbf6451b7ac..a7bc822fdd30 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -37,9 +38,9 @@ namespace tir { using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; -Array SimplifyArray(Array array) { +Array SimplifyArray(arith::Analyzer* ana, Array array) { for (size_t i = 0; i < array.size(); ++i) { - array.Set(i, tir::Simplify(array[i])); + array.Set(i, ana->Simplify(array[i])); } return array; } @@ -185,14 +186,14 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(const PrimExpr &base) { +inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - PrimExpr simplified_base = Simplify(base); + PrimExpr simplified_base = analyzer->Simplify(base); std::vector eles = ExprSplitAddition(simplified_base); std::list mult_exprs; std::list > mod_exprs; @@ -254,6 +255,7 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { // We also perform optimization to simplify the indexing expression. inline PrimExpr ElemOffset(const BufferNode* n, Array index) { PrimExpr base = n->elem_offset; + arith::Analyzer ana; if (n->strides.size() == 0) { // Scalar case if (n->shape.size() == 0 && index.size() == 1) { @@ -265,7 +267,7 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(offset * n->shape[i] + index[i]); + offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]); } base = base + offset; } @@ -273,12 +275,12 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { } else { CHECK_EQ(n->strides.size(), index.size()); if (is_zero(base)) { - base = MergeMulMod(index[0] * n->strides[0]); + base = MergeMulMod(&ana, index[0] * n->strides[0]); } else { - base = MergeMulMod(base + index[0] * n->strides[0]); + base = MergeMulMod(&ana, base + index[0] * n->strides[0]); } for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(base + index[i] * n->strides[i]); + base = MergeMulMod(&ana, base + index[i] * n->strides[i]); } } return base; @@ -353,8 +355,9 @@ Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeSlice(Array begins, Array extents) const { const BufferNode* n = operator->(); - begins = SimplifyArray(begins); - PrimExpr elem_offset = tir::Simplify(ElemOffset(n, begins)); + arith::Analyzer ana; + begins = SimplifyArray(&ana, begins); + PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins)); Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -363,7 +366,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { if (!is_zero(begins[i]) || - !is_zero(tir::Simplify(extents[i] - n->shape[i]))) { + !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 85842a0b9dcf..fb63fed623cf 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -24,6 +24,8 @@ #include #include #include +#include + #include namespace tvm { @@ -253,15 +255,16 @@ inline bool GetStoreRule(Array* rule, } inline Array TransformIndex(const Array& src_index, - const Array& src_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& transform_rule) { + arith::Analyzer ana; Array result; std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } return result; } @@ -284,9 +287,10 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) } inline Array TransformShape(const Array& src_shape, - const Array& src_axis, - const Array& target_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& target_axis, + const Array& transform_rule) { + arith::Analyzer ana; CHECK_EQ(src_shape.size(), src_axis.size()); // bind variables for original axes // for major-axis, bind the corresponding size @@ -329,7 +333,7 @@ inline Array TransformShape(const Array& src_shape, if (symbolic_var_set.count(i)) { result.push_back(tir::AnyNode::make()); } else { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } } } diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc index c684b9e68038..51a6d8bf5fed 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -31,10 +31,11 @@ namespace tvm { namespace tir { -void BinderAddAssert(PrimExpr cond, +void BinderAddAssert(arith::Analyzer* ana, + PrimExpr cond, const std::string& arg_name, std::vector* asserts) { - PrimExpr scond = Simplify(cond); + PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " << " on argument " << arg_name; @@ -65,10 +66,10 @@ bool ArgBinder::Bind_(const PrimExpr& arg, } return true; } else { - BinderAddAssert(it->second == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); } } else { - BinderAddAssert(arg == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); } return false; } @@ -121,7 +122,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, + BinderAddAssert(&analyzer_, + truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } } @@ -130,7 +132,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { - CHECK(is_one(Simplify(value->shape[i]))) + CHECK(is_one(analyzer_.Simplify(value->shape[i]))) << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } @@ -269,7 +271,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); - stride = Simplify(stride * buffer->shape[k]); + stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { std::ostringstream stride_null_err_msg; @@ -304,7 +306,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, + truncmod(offset, factor) == zero, + arg_name + ".elem_offset", &asserts_); } } } diff --git a/src/tir/pass/arg_binder.h b/src/tir/pass/arg_binder.h index dfeb82853529..0ff51e8c98f1 100644 --- a/src/tir/pass/arg_binder.h +++ b/src/tir/pass/arg_binder.h @@ -26,6 +26,8 @@ #include #include +#include + #include #include #include @@ -153,6 +155,8 @@ class ArgBinder { Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; + /*! \brief internal analyzer. */ + arith::Analyzer analyzer_; }; } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 60b5bd9a7f9c..ea762453cc9c 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -32,40 +32,6 @@ namespace tvm { namespace tir { -TVM_REGISTER_GLOBAL("ir_pass.Simplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = Simplify(args[0].operator Stmt(), args[1]); - } else { - *ret = Simplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = Simplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = Simplify(args[0].operator PrimExpr()); - } - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator PrimExpr()); - } - } - }); - TVM_REGISTER_GLOBAL("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 5e40eb2d9025..d409ffc4a15d 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include "../../arith/pattern_match.h" @@ -125,7 +126,7 @@ class CopyIntrinInjector : public StmtMutator { DataType t = loop_vars[i].dtype(); PrimExpr svalue = src_shape[i]; if (min_value.defined()) { - PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); + PrimExpr pbefore = analyzer_.Simplify(MaxNode::make(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); @@ -133,16 +134,16 @@ class CopyIntrinInjector : public StmtMutator { pad_before.push_back(make_zero(t)); } if (max_value.defined()) { - PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), - make_zero(t))); + PrimExpr pafter = analyzer_.Simplify( + max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); svalue = svalue - pafter; pad_after.push_back(pafter); } else { pad_after.push_back(make_zero(t)); } - src_shape.Set(i, Simplify(svalue)); + src_shape.Set(i, analyzer_.Simplify(svalue)); } - src_elem_offset = Simplify(src_elem_offset); + src_elem_offset = analyzer_.Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); CHECK_EQ(load_strides.size(), loop_var_size + 1); @@ -189,6 +190,8 @@ class CopyIntrinInjector : public StmtMutator { const PackedFunc& flower_copy_fromto_; // Storage scope std::unordered_map storage_scope_; + // arith analyzer + arith::Analyzer analyzer_; }; Stmt InjectCopyIntrin(Stmt stmt, diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 9fa72303e2d8..a77d529e7764 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -24,11 +24,9 @@ #include #include #include +#include #include #include - -#include - #include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" @@ -123,8 +121,8 @@ class StorageAccessInfoLower : public StmtExprMutator { int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); return cast(ptr_type, - tir::Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); + analyzer_.Simplify(offset / make_const( + offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { @@ -137,6 +135,8 @@ class StorageAccessInfoLower : public StmtExprMutator { }; // The storage scope of each buffer std::unordered_map storage_info_; + // analyzer + arith::Analyzer analyzer_; }; Stmt LowerStorageAccessInfo(Stmt stmt) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 655a0074c7fd..85744d11a9e5 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -313,6 +313,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } return ret; } + // The local buffer index. + PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { + if (!is_zero(group_index)) { + return analyzer_.Simplify(group_index * reduce_extent + reduce_index); + } else { + return reduce_index; + } + } // sync thread op. static Stmt SyncThread(const std::string& sync) { return EvaluateNode::make( @@ -320,14 +328,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {StringImmNode::make(sync)}, CallNode::Intrinsic)); } - // The local buffer index. - static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { - if (!is_zero(group_index)) { - return tir::Simplify(group_index * reduce_extent + reduce_index); - } else { - return reduce_index; - } - } // The warp size of the device. int warp_size_{1}; @@ -338,6 +338,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_map load_remap_; // Allocate remap std::unordered_map alloc_remap_; + // Internal analyzer + arith::Analyzer analyzer_; }; namespace transform { diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0aee3c284422..ac08e6fd07a4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -371,7 +371,6 @@ class WarpMemoryRewriter : private StmtMutator { BindVarBoundInfo binder(&analyzer_); binder(stmt); stmt = operator()(std::move(stmt)); - stmt = CanonicalSimplify(stmt); return stmt; } diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index ecfa25e28975..1e4fd73f6d0c 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -98,36 +98,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { - -Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt)); -} - -PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return analyzer.canonical_simplify(expr); -} - -PrimExpr Simplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - expr = analyzer.Simplify(expr); - return expr; -} - -Stmt Simplify(Stmt stmt, Map vrange) { - return CanonicalSimplify(std::move(stmt), vrange); -} - namespace transform { Pass Simplify() { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c13879c31c64..f960306f2ee8 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -625,7 +625,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } - combo_size = tir::Simplify(combo_size); + combo_size = analyzer_.Simplify(combo_size); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {combo_size}, const_true(), EvaluateNode::make(0)); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 27c39d4c18aa..9ff5429329c5 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -25,9 +25,10 @@ #include #include #include -#include #include #include +#include +#include #include #include #include @@ -160,7 +161,7 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - PrimExpr extent = tir::Simplify(op->extent); + PrimExpr extent = analyzer_.Simplify(op->extent); const IntImmNode *v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, @@ -184,6 +185,8 @@ class LoopUnroller : public StmtExprMutator { int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // analyzer + arith::Analyzer analyzer_; }; diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/arith_simplify_test.cc similarity index 79% rename from tests/cpp/ir_simplify_test.cc rename to tests/cpp/arith_simplify_test.cc index 69cf1298a320..f4c259fca342 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -19,35 +19,38 @@ #include #include -#include +#include #include -TEST(IRSIMPLIFY, MinMax) { +TEST(Simplify, MinMax) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; - auto e1s = tvm::tir::CanonicalSimplify(e1); + auto e1s = ana.canonical_simplify(e1); CHECK(tvm::tir::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); - auto e2s = tvm::tir::CanonicalSimplify(e2); + auto e2s = ana.canonical_simplify(e2); CHECK(tvm::tir::is_zero(e2s)); } -TEST(IRSIMPLIFY, Mul) { +TEST(Simplify, Mul) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e = (x * x) - (x * x) ; - auto es = tvm::tir::CanonicalSimplify(e); + auto es = ana.canonical_simplify(e); CHECK(tvm::tir::is_zero(es)); } -TEST(IRSIMPLIFY, Mod) { +TEST(Simplify, Mod) { + tvm::arith::Analyzer ana; auto x = tvm::Integer(10); auto y = tvm::Integer(12); // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = tvm::tir::CanonicalSimplify(tvm::tir::ModNode::make(x, y)); - auto es = tvm::tir::CanonicalSimplify(mod - x); + auto mod = ana.canonical_simplify(tvm::tir::ModNode::make(x, y)); + auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } int main(int argc, char ** argv) { diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 5baabd16c615..6efb67b19bad 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -18,13 +18,6 @@ from tvm import te -def assert_expr_equal(a, b): - res = tvm.tir.ir_pass.Simplify(a - b) - equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 - if not equal: - raise ValueError("{} and {} are not equal".format(a, b)) - - def test_deduce(): a = te.var('a') b = te.var('b') @@ -41,32 +34,32 @@ def test_deduce(): e0 = (-b)*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d - c, b*-1) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e0 = d*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d-c, d) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e1 = (a*4+b < c) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = fdiv(c-1-b, 4) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) e2 = (tvm.te.max(5, a * 4) < 0) @@ -83,15 +76,15 @@ def test_deduce(): e3 = (-b)+a*c-d res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = fdiv(2,c)+1 - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) # tests for `EQ` op res4 = tvm.arith.deduce_bound(a, a == b, {}, {}) - assert_expr_equal(res4.max_value, b) - assert_expr_equal(res4.min_value, b) + tvm.testing.assert_prim_expr_equal(res4.max_value, b) + tvm.testing.assert_prim_expr_equal(res4.min_value, b) # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) @@ -100,20 +93,20 @@ def test_deduce(): # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) - assert_expr_equal(res6.max_value, 10) - assert_expr_equal(res6.min_value, 10) + tvm.testing.assert_prim_expr_equal(res6.max_value, 10) + tvm.testing.assert_prim_expr_equal(res6.min_value, 10) # Add, Sub in `EQ` e4 = ((a - c) == (b + d)) ans4 = (b + d + c) res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res7.max_value, ans4) - assert_expr_equal(res7.min_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.max_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.min_value, ans4) # Satisfiable Mul in `EQ` with negative sign res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {}) - assert_expr_equal(res8.max_value, -2) - assert_expr_equal(res8.min_value, -2) + tvm.testing.assert_prim_expr_equal(res8.max_value, -2) + tvm.testing.assert_prim_expr_equal(res8.min_value, -2) # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) @@ -158,21 +151,22 @@ def test_basic(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 + + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True) res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True) test_basic(0, 4, 4) test_basic(1, 5, 4) @@ -190,21 +184,21 @@ def test_complex(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) < 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) <= 63, True) res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) > 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) >= 63, True) test_complex(0, 4, 4) test_complex(0, 4, -4) diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index d6953713f14b..129237a8c58b 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -23,15 +23,15 @@ def test_basic(): c = te.var("c") m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a]) - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1) assert m[0].value == 2 m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a, b]) assert len(m) == 0 m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) - assert tvm.tir.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 - assert tvm.tir.ir_pass.Simplify(m[2] - 2).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c) + tvm.testing.assert_prim_expr_equal(m[2], 2) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 278581d0cacd..82153ab5207e 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -22,14 +22,14 @@ def test_basic(): b = te.var("b") m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a]) assert m[0].value == 4 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7) m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a]) assert len(m) == 0 m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a]) assert m[0].value == 5 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7 + 1) m = tvm.arith.detect_linear_equation(a * b + 7, [a]) assert m[0] == b @@ -39,13 +39,15 @@ def test_basic(): m = tvm.arith.detect_linear_equation(b * 7, []) assert len(m) == 1 - assert tvm.tir.ir_pass.Simplify(m[0] - b * 7).value == 0 + tvm.testing.assert_prim_expr_equal(m[0], b * 7) def test_multivariate(): v = [te.var("v%d" % i) for i in range(4)] b = te.var("b") m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) - assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5)) + + tvm.testing.assert_prim_expr_equal(m[0], b + 5) + assert(m[1].value == 8) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) @@ -61,11 +63,12 @@ def test_multivariate(): m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]]) assert(m[0].value == 0) - assert(tvm.tir.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + + tvm.testing.assert_prim_expr_equal(m[1], v[0] - v[1]) m = tvm.arith.detect_linear_equation((v[0] - v[1]), []) assert(len(m) == 1) - assert(tvm.tir.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) + tvm.testing.assert_prim_expr_equal(m[0], v[0] - v[1]) if __name__ == "__main__": test_basic() diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 45f8fc10aaf0..645b1a2b537c 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -55,12 +55,13 @@ def check_bruteforce(bool_expr, vranges, cond=None): counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) raise AssertionError("Expression {}\nis not true on {}\n" "Counterexample: {}" - .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + .format(tir.arith.Analyzer().simplify(bool_expr), vranges, counterex)) def check_solution(solution, vranges={}): """Check that solution is a bijective transformation""" def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() all_vranges = vranges.copy() all_vranges.update({v: r for v, r in constraints1.ranges.items()}) @@ -68,7 +69,7 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): cond_on_vars = tir.const(1, 'bool') for v in constraints1.variables: # variable mapping is consistent - v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true cond_subst = tir.ir_pass.Substitute( @@ -80,7 +81,7 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): range_cond = te.all(v >= r.min, v < r.min + r.extent) range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) - cond_subst = tir.ir_pass.Simplify(cond_subst) + cond_subst = ana.simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 5b4a1c92a7e4..8afd65330f98 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -25,7 +25,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.tir.ir_pass.Substitute(val, var_dict) - val = tvm.tir.ir_pass.Simplify(val) + val = tvm.arith.Analyzer().simplify(val) assert isinstance(val, (tvm.tir.IntImm,)) return val.value diff --git a/tests/python/unittest/test_te_schedule_bound_inference.py b/tests/python/unittest/test_te_schedule_bound_inference.py index edae527c0183..6b6c519c8fa3 100644 --- a/tests/python/unittest/test_te_schedule_bound_inference.py +++ b/tests/python/unittest/test_te_schedule_bound_inference.py @@ -139,19 +139,20 @@ def test_bound_fusesplit1(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.tir.indexdiv - assert(tvm.tir.ir_pass.Simplify( - bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) + tvm.testing.assert_prim_expr_equal( + bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l)) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")}) - comp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value - exp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(expected_extent, vars)).value - assert(comp_ext == exp_ext) + tvm.testing.assert_prim_expr_equal( + tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), + tvm.tir.ir_pass.Substitute(expected_extent, vars) + ) - assert(tvm.tir.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l) def test_bound_fusesplit2(): m = te.var("m") @@ -169,10 +170,10 @@ def test_bound_fusesplit2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")}) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3) def test_bound_warp(): diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index dafffed9bd44..ef5b3fd3a44e 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -105,9 +105,10 @@ def check(factor): assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[z], out_dom, in_dom, vadd) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0])) + ana.simplify(body[0]), + ana.simplify(vadd.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -139,9 +140,11 @@ def check(factor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() + assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -164,9 +167,10 @@ def check_rfactor(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -188,9 +192,10 @@ def check_rfactor_no_reset(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -213,9 +218,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index fe23955017a0..7ee1e539204b 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -48,17 +48,14 @@ def test_buffer_access_ptr_offset(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 100) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = te.size_var('int32') aptr = Ab.access_ptr("rw", offset=100 + 100 + v) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 200 + v) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v)) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v)) + tvm.testing.assert_prim_expr_equal(aptr.args[2], tvm.tir.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE @@ -80,8 +77,7 @@ def test_buffer_vload(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - offset = tvm.tir.ir_pass.Simplify(load.index) - assert tvm.ir.structural_equal(offset, n * 2 + 103) + tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) def test_buffer_index_merge_mult_mod(): diff --git a/tests/python/unittest/test_tir_pass_basic.py b/tests/python/unittest/test_tir_pass_basic.py index 228e0c52c435..23982873075f 100644 --- a/tests/python/unittest/test_tir_pass_basic.py +++ b/tests/python/unittest/test_tir_pass_basic.py @@ -17,16 +17,6 @@ import tvm from tvm import te -def test_simplify(): - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod - x = te.var('x') - e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1) - assert(tvm.ir.structural_equal(e1, x + 3)) - e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.ir.structural_equal(e2, x * 8)) - e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3) - assert(tvm.ir.structural_equal(e3, tmod(x, 3))) def test_verify_ssa(): diff --git a/tests/python/unittest/test_tir_pass_decorate_device_scope.py b/tests/python/unittest/test_tir_pass_decorate_device_scope.py index 327cfd9ed548..9c58431158b9 100644 --- a/tests/python/unittest/test_tir_pass_decorate_device_scope.py +++ b/tests/python/unittest/test_tir_pass_decorate_device_scope.py @@ -31,8 +31,7 @@ def test_decorate_device(): s[A1].set_scope("shared") bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt1 = tvm.tir.ir_pass.Simplify(stmt) + stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) stmt2 = tvm.tir.ir_pass.DecorateDeviceScope(stmt1) assert isinstance(stmt2, tvm.tir.AttrStmt) assert stmt2.attr_key == "device_scope" diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 9d1641366d7d..887b8b0c2b75 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -57,7 +57,7 @@ def test_copy_pad(): mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 @@ -82,18 +82,15 @@ def test_single_point_test(): mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1 - assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(src.strides[0], 1) + tvm.testing.assert_prim_expr_equal(dst.strides[0], 1) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body -def assert_expr_equal(a, b): - assert tvm.tir.ir_pass.Simplify(a - b).value == 0 - def test_copy_pad_split(): m = 4 * 3 A = te.placeholder((m, ), name="A") @@ -115,13 +112,13 @@ def test_copy_pad_split(): def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) - assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) + tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) rpad_before = tvm.te.max(1 - xo * 4, 0) rpad_after = tvm.te.max(xo * 4 - 7, 0) - assert_expr_equal(pad_before[0], rpad_before) - assert_expr_equal(pad_after[0], rpad_after) - assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) + tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before) + tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after) + tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_lower_intrin.py b/tests/python/unittest/test_tir_transform_lower_intrin.py index b2e984aae6fd..98c79e414226 100644 --- a/tests/python/unittest/test_tir_transform_lower_intrin.py +++ b/tests/python/unittest/test_tir_transform_lower_intrin.py @@ -22,10 +22,13 @@ def lower_intrin(params, stmt): """wrapper to call transformation in stmt""" lower_expr = isinstance(stmt, tvm.tir.PrimExpr) stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - func = tvm.tir.PrimFunc(params, stmt).with_attr( - "target", tvm.target.create("llvm")) - func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"] + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(params, stmt).with_attr( + "target", tvm.target.create("llvm"))) + mod = tvm.transform.Sequential([ + tvm.tir.transform.Simplify(), + tvm.tir.transform.LowerIntrin() + ])(mod) + func = mod["main"] stmt = func.body return stmt.value if lower_expr else stmt.body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index e2bfeb009a11..57eb349f18df 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -51,9 +51,10 @@ def test_flatten_prefetch(): [_A], stmt, {A: _A}) mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) stmt = mod["main"].body - stmt = tvm.tir.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -74,9 +75,11 @@ def test_flatten_storage_align(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) + stmt = mod["main"].body - stmt = tvm.tir.ir_pass.Simplify(stmt) assert(stmt.body.extents[0].value == 17 * 8) @@ -103,11 +106,12 @@ def test_flatten_double_buffer(): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([A, C], stmt)) - mod = tvm.tir.transform.StorageFlatten(64)(mod) - stmt = mod["main"].body + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.InjectDoubleBuffer(2), + tvm.tir.transform.Simplify()])(mod) - stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) - stmt = tvm.tir.ir_pass.Simplify(stmt) + stmt = mod["main"].body assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 74be9453ae61..afa8833b3490 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -25,8 +25,9 @@ #define TOPI_DETAIL_CONSTANT_UTILS_H_ #include -#include +#include #include +#include #include #include @@ -119,7 +120,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); + result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero); } return result; } diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index a1ee8c1a5901..7569bb0a7ba2 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -26,8 +26,8 @@ #include #include +#include #include -#include #include #include @@ -184,6 +184,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, pad_after.push_back(pad_before[i]); } } + arith::Analyzer analyzer; CHECK_GE(pad_before.size(), 1); CHECK_EQ(pad_before.size(), pad_after.size()); tvm::Array output_shape; @@ -200,13 +201,14 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, output_shape.push_back(t->shape[i]); } else { output_shape.push_back( - tvm::tir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); } } if (!pad_value.defined()) { pad_value = tvm::tir::make_const(t->dtype, 0); } + auto l = [&](tvm::Array ovars) { tvm::Array indices; tvm::Array sel; @@ -223,7 +225,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, indices.push_back(ovars[i]); } if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { - sel.push_back(tvm::tir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); + sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { pad_idx.push_back(tvm::if_then_else( diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index 6bda65317706..c69fc5406e33 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -25,7 +25,7 @@ #define TOPI_NN_BNN_H_ #include -#include +#include #include #include @@ -55,11 +55,12 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) << "binarize_pack: axis size must be a multiple of 32"; + arith::Analyzer analyzer; auto n = ishape.size(); Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? - tvm::tir::Simplify(indexdiv(ishape[i], 32)) : + analyzer.Simplify(indexdiv(ishape[i], 32)) : ishape[i]); } diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index a67bf3a300b2..32ee1392e846 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -25,7 +25,7 @@ #define TOPI_NN_DILATE_H_ #include -#include +#include #include #include @@ -75,8 +75,9 @@ inline Tensor dilate(const Tensor& x, << ") must match dimension of x (" << n << ")"; Array out_shape; + arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(tvm::tir::Simplify( + out_shape.push_back(analyzer.Simplify( (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 20b7b246317b..324ecadfe95a 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -102,10 +102,10 @@ inline Tensor pool_impl(const Tensor& x, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - - auto out_height = tvm::tir::Simplify( + arith::Analyzer analyzer; + auto out_height = analyzer.Simplify( indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = tvm::tir::Simplify( + auto out_width = analyzer.Simplify( indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); @@ -212,11 +212,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - + arith::Analyzer analyzer; auto out_height = - tvm::tir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); + analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = - tvm::tir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); + analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -711,7 +711,8 @@ inline Tensor pool_impl_nd(const Tensor& x, pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); - auto out_dim = tvm::tir::Simplify( + arith::Analyzer analyzer; + auto out_dim = analyzer.Simplify( indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 431ace5bc11e..0609020b5c81 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -375,12 +375,12 @@ inline Tensor concatenate(const Array& inputs, for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } - + arith::Analyzer analyzer; PrimExpr join_size = axis_sizes[0]; for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } - join_size = tvm::tir::Simplify(join_size); + join_size = analyzer.Simplify(join_size); Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index db9da844e3af..b7cb32d58d01 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -167,7 +167,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py index a54941315a1a..650809985279 100644 --- a/topi/python/topi/intel_graphics/depthwise_conv2d.py +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -168,7 +168,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index f628fadee96e..ebcf478033fb 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -45,9 +45,9 @@ def dilate(data, strides, name="DilatedInput"): if len(strides) != n: raise ValueError("data dimension and strides size dismatch : %d vs %d" % ( n, len(strides))) - + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) + ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) def _dilate(*indices): not_zero = [] diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index 8fe53374f2b5..b298a0a2bb95 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -55,9 +55,9 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify( - (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = (pad_value if isinstance(pad_value, tvm.tir.PrimExpr) else tvm.tir.const(pad_value, data.dtype)) def _pad(*indices): @@ -115,8 +115,9 @@ def mirror_pad(data, if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i])) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) assert mode in ('SYMMETRIC', 'REFLECT') mode = int(mode == 'SYMMETRIC') diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 50a6a36edc46..cc437325e0d6 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -101,7 +101,8 @@ def get_const_int(expr): if isinstance(expr, Integral): return expr if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -123,7 +124,8 @@ def get_const_float(expr): if isinstance(expr, float): return float(expr) if not isinstance(expr, tvm.tir.FloatImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.FloatImm): raise ValueError("Expect value to be constant float") return float(expr.value) @@ -145,7 +147,8 @@ def equal_const_int(expr, value): if isinstance(expr, Integral): return expr == value if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): return False return expr.value == value @@ -165,11 +168,13 @@ def get_const_tuple(in_tuple): The output. """ ret = [] + ana = None for elem in in_tuple: if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): ret.append(elem) elif not isinstance(elem, (tvm.tir.IntImm, int)): - elem = tvm.tir.ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() if ana is None else ana + elem = ana.simplify(elem) if not isinstance(elem, tvm.tir.IntImm): ret.append(elem) else: @@ -208,7 +213,7 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.tir.ir_pass.Simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr + return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr def ravel_index(indices, shape): diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index c2684d159b98..9836d133ceb7 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -364,6 +364,7 @@ def _fold_buffer_dim(buf, scope, elem_block): shape.append(1) strides.append(elem_block) + analyzer = tvm.arith.Analyzer() while base < ndim + 1: x_size = 1 x_stride = buf.strides[ndim - base] @@ -378,7 +379,7 @@ def _fold_buffer_dim(buf, scope, elem_block): break x_size = x_size * buf.shape[k] next_base = i + 1 - shape.append(tvm.tir.ir_pass.Simplify(x_size)) + shape.append(analyzer.simplify(x_size)) strides.append(x_stride) assert next_base != base base = next_base @@ -769,10 +770,11 @@ def inject_alu_intrin(stmt_in): """ env = get_env() idxm = tvm.tir.indexmod + analyzer = tvm.arith.Analyzer() def _do_fold(stmt): def _equal(x, y): - return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0) + return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff = list(src_coeff) @@ -791,7 +793,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): next_ext = extents.pop() if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): - vext = tvm.tir.ir_pass.Simplify(vext * next_ext) + vext = analyzer.simplify(vext * next_ext) else: rev_src_coeff.append(vsrc) rev_dst_coeff.append(vdst) @@ -851,7 +853,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): if loop_body.value.name == 'shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] - rhs = tvm.tir.ir_pass.Simplify(-loop_body.value.args[1]) + rhs = analyzer.simplify(-loop_body.value.args[1]) elif loop_body.value.name == 'shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] @@ -914,10 +916,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( + analyzer.simplify( idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( + analyzer.simplify( idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir.structural_equal(src_coeff[-2], 1) assert tvm.ir.structural_equal(dst_coeff[-2], 1) @@ -942,9 +944,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff.append(src_offset) dst_coeff.append(dst_offset) src_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] dst_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] # Flatten the outer loops if extents: