From b49a433304c32a71c9b3c5e2d1b0fc622371ddda Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 20 Apr 2020 10:27:30 -0700 Subject: [PATCH] [ARITH] Remove the legacy Simplify, migrate to Analyzer. The legacy Simplify/CanonicalSimplify are now a thin wrapper around the Analyzer. This PR removes these functions and migrated every place that requires simplification to enforce Analyzer creation. The new API would encourage more Analyzer sharing and potentially enable context-aware analyzer-based simplification. --- include/tvm/arith/analyzer.h | 52 ++++++++--------- include/tvm/tir/ir_pass.h | 33 ----------- python/tvm/autotvm/util.py | 10 ++-- python/tvm/driver/build_module.py | 1 + python/tvm/te/hybrid/parser.py | 11 ++-- python/tvm/testing.py | 21 +++++++ python/tvm/tir/ir_builder.py | 3 +- src/arith/detect_linear_equation.cc | 8 ++- src/arith/int_set.cc | 34 ++++++----- src/driver/driver_api.cc | 1 + src/relay/op/type_relations.cc | 6 +- src/target/spirv/codegen_spirv.cc | 4 +- src/te/autodiff/jacobian.cc | 11 ++-- src/te/operation/scan_op.cc | 14 +++-- src/te/operation/tensorize.cc | 14 +++-- src/te/schedule/message_passing.cc | 9 +-- src/te/schedule/schedule_ops.cc | 4 +- ...hedule_postproc_rewrite_for_tensor_core.cc | 10 ++-- src/tir/ir/buffer.cc | 25 +++++---- src/tir/ir/data_layout.cc | 18 +++--- src/tir/pass/arg_binder.cc | 20 ++++--- src/tir/pass/arg_binder.h | 4 ++ src/tir/pass/ffi_api.cc | 34 ----------- src/tir/transforms/inject_copy_intrin.cc | 13 +++-- .../lower_device_storage_access_info.cc | 10 ++-- src/tir/transforms/lower_thread_allreduce.cc | 20 ++++--- src/tir/transforms/lower_warp_memory.cc | 1 - src/tir/transforms/simplify.cc | 30 ---------- src/tir/transforms/storage_rewrite.cc | 2 +- src/tir/transforms/unroll_loop.cc | 7 ++- ...implify_test.cc => arith_simplify_test.cc} | 21 ++++--- .../unittest/test_arith_deduce_bound.py | 56 +++++++++---------- .../unittest/test_arith_detect_clip_bound.py | 6 +- .../test_arith_detect_linear_equation.py | 15 +++-- .../test_arith_solve_linear_system.py | 7 ++- .../python/unittest/test_te_hybrid_script.py | 2 +- .../test_te_schedule_bound_inference.py | 21 +++---- .../unittest/test_te_schedule_tensorize.py | 26 +++++---- tests/python/unittest/test_tir_buffer.py | 12 ++-- tests/python/unittest/test_tir_pass_basic.py | 10 ---- .../test_tir_pass_decorate_device_scope.py | 3 +- .../test_tir_transform_inject_copy_intrin.py | 21 +++---- .../test_tir_transform_lower_intrin.py | 11 ++-- .../test_tir_transform_storage_flatten.py | 20 ++++--- topi/include/topi/detail/constant_utils.h | 5 +- topi/include/topi/nn.h | 8 ++- topi/include/topi/nn/bnn.h | 5 +- topi/include/topi/nn/dilate.h | 5 +- topi/include/topi/nn/pooling.h | 17 +++--- topi/include/topi/transform.h | 4 +- topi/python/topi/cuda/depthwise_conv2d.py | 2 +- .../topi/intel_graphics/depthwise_conv2d.py | 2 +- topi/python/topi/nn/dilate.py | 4 +- topi/python/topi/nn/pad.py | 7 ++- topi/python/topi/util.py | 15 +++-- vta/python/vta/ir_pass.py | 18 +++--- 56 files changed, 369 insertions(+), 384 deletions(-) rename tests/cpp/{ir_simplify_test.cc => arith_simplify_test.cc} (79%) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 3a71e5eb5fbf..6ca3ba9cfd55 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -112,7 +112,7 @@ class ConstIntBoundAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr); /*! * \brief analyze the expr with the intermediate memorized to avoid redundant computation @@ -120,8 +120,8 @@ class ConstIntBoundAnalyzer { * \param bound The lookup table to store the intermediate results * \return the result of the analysis. */ - ConstIntBound operator()(const PrimExpr& expr, - std::unordered_map* bound); + TVM_DLL ConstIntBound operator()(const PrimExpr& expr, + std::unordered_map* bound); /*! * \brief Update constant int bound information of var. @@ -130,22 +130,22 @@ class ConstIntBoundAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ConstIntBound& info, - bool override = false); + TVM_DLL void Update(const Var& var, + const ConstIntBound& info, + bool override = false); /*! * \brief Bind variable to a range. * * \param var The variable. * \param range The range we bind to. */ - void Bind(const Var& var, const Range& range); + TVM_DLL void Bind(const Var& var, const Range& range); private: friend class Analyzer; friend class ConstraintContext; explicit ConstIntBoundAnalyzer(Analyzer* parent); - ~ConstIntBoundAnalyzer(); + TVM_DLL ~ConstIntBoundAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -212,7 +212,7 @@ class ModularSetAnalyzer { * \param expr The expression of interest. * \return the result of the analysis. */ - ModularSet operator()(const PrimExpr& expr); + TVM_DLL ModularSet operator()(const PrimExpr& expr); /*! * \brief Update constant int bound information of var. * @@ -220,15 +220,15 @@ class ModularSetAnalyzer { * \param info The bound information. * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const ModularSet& info, - bool override = false); + TVM_DLL void Update(const Var& var, + const ModularSet& info, + bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit ModularSetAnalyzer(Analyzer* parent); - ~ModularSetAnalyzer(); + TVM_DLL ~ModularSetAnalyzer(); /*! * \brief Update the internal state to enter constraint. * \param constraint A constraint expression. @@ -252,7 +252,7 @@ class RewriteSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -261,9 +261,9 @@ class RewriteSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, + const PrimExpr& new_expr, + bool override = false); std::function EnterConstraint(const PrimExpr& constraint); @@ -272,7 +272,7 @@ class RewriteSimplifier { friend class ConstraintContext; friend class CanonicalSimplifier; explicit RewriteSimplifier(Analyzer* parent); - ~RewriteSimplifier(); + TVM_DLL ~RewriteSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -288,7 +288,7 @@ class CanonicalSimplifier { * \param expr The expression of interest. * \return the result of the analysis. */ - PrimExpr operator()(const PrimExpr& expr); + TVM_DLL PrimExpr operator()(const PrimExpr& expr); /*! * \brief Update binding of var to a new expression. @@ -297,15 +297,15 @@ class CanonicalSimplifier { * \param new_expr * \param override Whether do we allow override of existing information. */ - void Update(const Var& var, - const PrimExpr& new_expr, - bool override = false); + TVM_DLL void Update(const Var& var, + const PrimExpr& new_expr, + bool override = false); private: friend class Analyzer; friend class ConstraintContext; explicit CanonicalSimplifier(Analyzer* parent); - ~CanonicalSimplifier(); + TVM_DLL ~CanonicalSimplifier(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -363,12 +363,12 @@ class IntSetAnalyzer { * \param dom_map The domain map to indicate which variable to relax. * \return the result of the analysis. */ - IntSet operator()(const PrimExpr& expr, const Map& dom_map); + TVM_DLL IntSet operator()(const PrimExpr& expr, const Map& dom_map); private: friend class Analyzer; explicit IntSetAnalyzer(Analyzer* parent); - ~IntSetAnalyzer(); + TVM_DLL ~IntSetAnalyzer(); class Impl; /*! \brief Internal impl */ Impl* impl_; @@ -384,7 +384,7 @@ class IntSetAnalyzer { * If the analyzer uses memoization, we need to clear the internal * cache when information about a Var has been overridden. */ -class Analyzer { +class TVM_DLL Analyzer { public: /* * Disable copy constructor. diff --git a/include/tvm/tir/ir_pass.h b/include/tvm/tir/ir_pass.h index 592a79fb86d1..4980c41016fd 100644 --- a/include/tvm/tir/ir_pass.h +++ b/include/tvm/tir/ir_pass.h @@ -41,39 +41,6 @@ namespace tvm { namespace tir { -/*! - * \brief Simplify the expression. - * \param expr The expression to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -TVM_DLL PrimExpr Simplify(PrimExpr expr, Map vrange = Map()); - -/*! - * \brief Simplify the statement. - * \param stmt The statement to be simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt Simplify(Stmt stmt, Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param stmt The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized statement. - */ -Stmt CanonicalSimplify(Stmt stmt, - Map vrange = Map()); - -/*! - * \brief Simplify by applying canonical form. - * \param expr The statement to be canonically simplifed. - * \param vrange The range information about the variable. - * \return Canonicalized expression. - */ -TVM_DLL PrimExpr CanonicalSimplify(PrimExpr expr, - Map vrange = Map()); /*! * \brief verifies whether the IR stmt or Expr is in SSA form. diff --git a/python/tvm/autotvm/util.py b/python/tvm/autotvm/util.py index 01d50e86a88a..db1662c7b52a 100644 --- a/python/tvm/autotvm/util.py +++ b/python/tvm/autotvm/util.py @@ -23,8 +23,8 @@ from random import randrange import numpy as np - -from tvm.tir import expr, ir_pass +import tvm.arith +from tvm.tir import expr logger = logging.getLogger('autotvm') @@ -156,7 +156,8 @@ def get_const_int(exp): if isinstance(exp, int): return exp if not isinstance(exp, (expr.IntImm,)): - exp = ir_pass.Simplify(exp) + ana = tvm.arith.Analyzer() + exp = ana.simplify(exp) if not isinstance(exp, (expr.IntImm,)): raise ValueError("Expect value to be constant int") return exp.value @@ -180,7 +181,8 @@ def get_const_tuple(in_tuple): if isinstance(elem, expr.Var): ret.append(elem) elif not isinstance(elem, (expr.IntImm, int)): - elem = ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() + elem = ana.simplify(elem) if not isinstance(elem, (expr.IntImm)): ret.append(elem) else: diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5c92965a0130..35700badb04b 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -287,6 +287,7 @@ def _build_for_device(input_mod, target, target_host): lambda f: "calling_conv" in f.attrs and f.attrs["calling_conv"].value == CallingConv.DEVICE_KERNEL_LAUNCH), tvm.tir.transform.LowerWarpMemory(), + tvm.tir.transform.Simplify(), tvm.tir.transform.LowerDeviceStorageAccessInfo(), tvm.tir.transform.LowerIntrin()]) mod_dev = opt_device(mod_mixed) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 107f51b8bbcc..765efa0b976c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -29,10 +29,10 @@ import tvm.tir import tvm.te import tvm.te._ffi_api +import tvm.arith from tvm.tir import expr as _expr from tvm.tir import stmt as _stmt -from tvm.tir import ir_pass as _ir_pass from tvm.te.tensor import Tensor, Operation from tvm.tir import all as _all from tvm.tir import any as _any @@ -160,6 +160,7 @@ def __init__(self, args, usage, symbols, closure_vars, func_name=None): self.outputs = [] # Output tensors' name self.side_effect = set() # Tensors with side effects self.parsed_body = None # The parsed HalideIR body + self.analyzer = tvm.arith.Analyzer() self.returned = False # If this function has a valid return @@ -326,7 +327,7 @@ def visit_Assign(self, node): _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") lhs = node.targets[0] if isinstance(rhs, _expr.PrimExpr): - rhs = _ir_pass.Simplify(rhs) + rhs = self.analyzer.simplify(rhs) if isinstance(lhs, ast.Name): #TODO: support defined intermediate buffer later lhs_ = lhs @@ -410,7 +411,7 @@ def visit_With(self, node): def visit_If(self, node): - cond = _ir_pass.CanonicalSimplify(self.visit(node.test)) + cond = self.analyzer.simplify(self.visit(node.test)) # Return no IfThenElse if proven if isinstance(cond, _expr.IntImm): @@ -501,8 +502,8 @@ def visit_For(self, node): _name = node.target.id if isinstance(for_type, tuple): - low = _ir_pass.CanonicalSimplify(low) - ext = _ir_pass.CanonicalSimplify(ext) + low = self.analyzer.simplify(low) + ext = self.analyzer.simplify(ext) _internal_assert(isinstance(low, _expr.ConstExpr) and isinstance(ext, _expr.ConstExpr), \ "Const range should start from a const " + \ diff --git a/python/tvm/testing.py b/python/tvm/testing.py index 0f50636d68d8..5a3d394c098f 100644 --- a/python/tvm/testing.py +++ b/python/tvm/testing.py @@ -20,6 +20,8 @@ import logging import numpy as np import tvm +import tvm.arith +import tvm.tir import tvm._ffi @@ -168,4 +170,23 @@ def compare_derivative(j, n_der, grad): x_name, grad.shape, dist, max_diff, avg_diff) +def assert_prim_expr_equal(lhs, rhs): + """Assert lhs and rhs equals to each iother. + + Parameters + ---------- + lhs : tvm.tir.PrimExpr + The left operand. + + rhs : tvm.tir.PrimExpr + The left operand. + """ + ana = tvm.arith.Analyzer() + res = ana.simplify(lhs - rhs) + equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 + if not equal: + raise ValueError("{} and {} are not equal".format(lhs, rhs)) + + + tvm._ffi._init_api("testing", __name__) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 0c4c36888eb5..4dd541e9bdbb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -21,7 +21,6 @@ from . import stmt as _stmt from . import expr as _expr -from . import ir_pass as _pass class WithScope(object): @@ -212,7 +211,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): self.nidx += 1 self._seq_stack.append([]) loop_var = _expr.Var(name, dtype=dtype) - extent = end if begin == 0 else _pass.Simplify(end - begin) + extent = end if begin == 0 else (end - begin) def _exit_cb(): if for_type == "serial": for_type_id = 0 diff --git a/src/arith/detect_linear_equation.cc b/src/arith/detect_linear_equation.cc index cc9c745a24b8..58723170b3ca 100644 --- a/src/arith/detect_linear_equation.cc +++ b/src/arith/detect_linear_equation.cc @@ -207,8 +207,9 @@ bool DetectClipBound( return false; } LinearEqEntry ret; + Analyzer analyzer; if (!LinearEqDetector(var).Detect(canonical, &ret)) return false; - ret.coeff = Simplify(ret.coeff); + ret.coeff = analyzer.Simplify(ret.coeff); IntervalEntry& p = (*bmap)[var.get()]; if (is_const_int(ret.coeff, 1)) { // var + shift >=0 -> var >= -shift @@ -254,14 +255,15 @@ Array DetectClipBound(const PrimExpr& e, const Array& vars) { for (PrimExpr cond : splits) { if (!DetectClipBound(cond, &rmap)) return Array(); } + Analyzer analyzer; Array ret; for (Var v : vars) { IntervalEntry e = rmap[v.get()]; if (e.min_value.defined()) { - e.min_value = Simplify(e.min_value); + e.min_value = analyzer.Simplify(e.min_value); } if (e.max_value.defined()) { - e.max_value = Simplify(e.max_value); + e.max_value = analyzer.Simplify(e.max_value); } ret.push_back(e.min_value); ret.push_back(e.max_value); diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index 8c5afb1be8b5..027259a4d225 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -570,11 +570,12 @@ IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, // TODO(tqchen): revisit IntSet interface as well. Range IntSet::cover_range(Range max_range) const { IntSet temp; + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); CHECK(s_int != nullptr); if (s_int->HasUpperBound() && s_int->HasLowerBound()) { return Range::make_by_min_extent( - s_int->min_value, Simplify(s_int->max_value + 1 - s_int->min_value)); + s_int->min_value, analyzer.Simplify(s_int->max_value + 1 - s_int->min_value)); } return max_range; } @@ -607,26 +608,30 @@ bool IntSet::is_single_point() const { } bool IntSet::can_prove_positive() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_positive_const(tir::Simplify(s_int->min_value))); + return (s_int && is_positive_const(analyzer.Simplify(s_int->min_value))); } bool IntSet::can_prove_negative() const { + Analyzer analyzer; const IntervalSetNode* s_int = (*this).as(); - return (s_int && is_negative_const(tir::Simplify(s_int->max_value))); + return (s_int && is_negative_const(analyzer.Simplify(s_int->max_value))); } bool IntSet::can_prove_non_positive() const { + Analyzer analyzer; if (const auto* s_int = (*this).as()) { - auto max = tir::Simplify(s_int->max_value); + auto max = analyzer.Simplify(s_int->max_value); return is_zero(max) || is_negative_const(max); } return false; } bool IntSet::can_prove_non_negative() const { + Analyzer analyzer; if (const IntervalSetNode* s_int = (*this).as()) { - auto min = tir::Simplify(s_int->min_value); + auto min = analyzer.Simplify(s_int->min_value); return is_zero(min) || is_positive_const(min); } return false; @@ -669,8 +674,8 @@ IntSet IntSet::interval(PrimExpr min, PrimExpr max) { } // Range related code -inline bool ProveEqual(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); +inline bool ProveEqual(Analyzer* analyzer, PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer->Simplify(lhs - rhs)); } IntSet IntSet::range(Range r) { @@ -685,8 +690,9 @@ bool IntSet::match_range(const Range& b) const { const IntSet& a = *this; const IntervalSetNode* a_int = a.as(); if (!a_int) return false; - return ProveEqual(a_int->min_value, b->min) && - ProveEqual(a_int->max_value, b->extent + b->min - 1); + Analyzer ana; + return ProveEqual(&ana, a_int->min_value, b->min) && + ProveEqual(&ana, a_int->max_value, b->extent + b->min - 1); } IntSet Union(const Array& sets) { @@ -697,8 +703,8 @@ IntSet Union(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Union(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), + ana.Simplify(x->max_value)); } IntSet Intersect(const Array& sets) { @@ -709,8 +715,8 @@ IntSet Intersect(const Array& sets) { for (size_t i = 1; i < sets.size(); ++i) { x = Intersect(&ana, x, ToIntervalSet(sets[i])); } - return IntervalSet(tir::Simplify(x->min_value), - tir::Simplify(x->max_value)); + return IntervalSet(ana.Simplify(x->min_value), + ana.Simplify(x->max_value)); } Map ConvertDomMap(const Map& dom_map) { @@ -758,7 +764,7 @@ IntSet EvalSet(Range r, IntervalSetEvaluator m(&ana, dom_map); // Simplifying first can give tighter bounds if r->min and r->extent share variables PrimExpr sum = r->min + r->extent - 1; - auto res = m.Eval(IntervalSet(r->min, Simplify(sum))); + auto res = m.Eval(IntervalSet(r->min, ana.Simplify(sum))); return std::move(res); } diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index c3802b1a63f4..7e2ef701265c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -236,6 +236,7 @@ split_dev_host_funcs(IRModule mod_mixed, }), BindTarget(target), tir::transform::LowerWarpMemory(), + tir::transform::Simplify(), tir::transform::LowerIntrin(), tir::transform::LowerDeviceStorageAccessInfo(), }; diff --git a/src/relay/op/type_relations.cc b/src/relay/op/type_relations.cc index f9653e24b1b9..e2e7f4994349 100644 --- a/src/relay/op/type_relations.cc +++ b/src/relay/op/type_relations.cc @@ -22,9 +22,10 @@ * \brief A set of utilities and common functionality * for type relations. */ +#include +#include #include #include -#include #include #include "./type_relations.h" @@ -48,7 +49,8 @@ bool EqualCheck(const IndexExpr& lhs, return pdiff[0] == 0; } // symbolic - diff = tvm::tir::CanonicalSimplify(diff); + tvm::arith::Analyzer ana; + diff = ana.Simplify(diff); if (const int64_t* pdiff = tir::as_const_int(diff)) { return pdiff[0] == 0; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 1d8004e9938f..d4631aaa8023 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -414,7 +414,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( + PrimExpr vec_index = analyzer_->Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); @@ -492,7 +492,7 @@ void CodeGenSPIRV::VisitStmt_(const StoreNode* op) { CHECK((me->coeff % ramp->lanes) == 0 && (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; - PrimExpr vec_index = tir::Simplify( + PrimExpr vec_index = analyzer_->Simplify( ramp->base / make_const(ramp->base.dtype(), ramp->lanes)); spirv::Value ptr = builder_->StructArrayAccess( ptr_type, buffer, MakeValue(vec_index)); diff --git a/src/te/autodiff/jacobian.cc b/src/te/autodiff/jacobian.cc index 1a324588537f..9ebb89ee95a6 100644 --- a/src/te/autodiff/jacobian.cc +++ b/src/te/autodiff/jacobian.cc @@ -24,9 +24,11 @@ * The result Jacobian shape will be (Y.shape, X.shape) */ #include +#include #include #include -#include +#include + #include #include "ad_util.h" @@ -264,7 +266,7 @@ class JacobianMutator : public ExprMutator { CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); // Also simplify the resulting combiner // (mostly to get rid of unused components, e.g., the original expressions) - return Simplify( + return analyzer_.Simplify( ReduceNode::make(new_combiner, new_source, new_op->axis, new_op->condition, new_op->value_index)); } @@ -302,6 +304,7 @@ class JacobianMutator : public ExprMutator { Tensor input_; Array indices_; Var input_var_; + arith::Analyzer analyzer_; }; PrimExpr Derivative(const PrimExpr& expr, const Var& var) { @@ -341,11 +344,11 @@ Tensor Jacobian(const Tensor& output, const Tensor& input) { // Differentiate wrt input[input_indices] input_indices.push_back(new_v); } - + arith::Analyzer analzyer; // Compute Jacobian PrimExpr new_body = Jacobian( Substitute(op->body[output->value_index], vmap), input, input_indices); - new_body = Simplify(new_body); + new_body = analzyer.Simplify(new_body); int value_index = 0; Array new_bodies; diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index 2ee5b273d4f6..1916b4a4823e 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -39,10 +39,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(ScanOpNode); -inline bool prove_equal(PrimExpr lhs, PrimExpr rhs) { - return is_zero(tir::Simplify(lhs - rhs)); -} - int ScanOpNode::num_outputs() const { return static_cast(update.size()); } @@ -77,6 +73,10 @@ Operation ScanOpNode::make(std::string name, auto n = make_object(); CHECK_EQ(init.size(), update.size()); CHECK_EQ(init.size(), state_placeholder.size()); + arith::Analyzer analyzer; + auto prove_equal = [&](PrimExpr lhs, PrimExpr rhs) { + return is_zero(analyzer.Simplify(lhs - rhs)); + }; for (size_t i = 0; i < init.size(); ++i) { CHECK_EQ(init[i]->dtype, state_placeholder[i]->dtype); @@ -232,10 +232,11 @@ void ScanOpNode::GatherBound( time_dom.insert(time_dom.end(), d.data[0].begin(), d.data[0].end()); } CHECK(!out_dom_map->count(this->scan_axis)); + arith::Analyzer analyzer; Range sdom = this->scan_axis->dom; Range r = arith::Union(time_dom).cover_range(sdom); (*out_dom_map)[this->scan_axis] = Range::make_by_min_extent( - sdom->min, tir::Simplify(r->extent + r->min - sdom->min)); + sdom->min, analyzer.Simplify(r->extent + r->min - sdom->min)); Map fix_pt = ScanFixPointAnalysis(self); // Update for spatial axis. size_t sp_idx = 0; @@ -260,10 +261,11 @@ Stmt ScanOpNode::BuildRealize( const Stage& stage, const std::unordered_map& dom_map, const Stmt& body) const { + arith::Analyzer analyzer; CHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); Range tdom = Range::make_by_min_extent( - 0, tir::Simplify(sdom->extent + sdom->min)); + 0, analyzer.Simplify(sdom->extent + sdom->min)); Stmt ret = body; size_t sp_idx = 0; for (size_t i = 0; i < update.size(); ++i) { diff --git a/src/te/operation/tensorize.cc b/src/te/operation/tensorize.cc index 6064f5c4e008..b66406969c76 100644 --- a/src/te/operation/tensorize.cc +++ b/src/te/operation/tensorize.cc @@ -222,6 +222,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { compute_intrin_iter_space->Set(iv->var, vrange); } } + analyzer_.Bind(*compute_intrin_iter_space); // input remap. Array inputs = self->InputTensors(); @@ -234,7 +235,7 @@ class TensorIntrinMatcher final : public StmtExprMutator { // Enable fuzzy matching, to match [1, n, m] to [n, m] e.start = e.region.size() - e.tensor.ndim(); for (size_t j = 0; j < e.start; ++j) { - auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space); + auto canonical_extent = analyzer_.Simplify(e.region[j]->extent); CHECK(is_one(canonical_extent)) << "Tensorize " << intrin->name << ":" << " Input dimension mismatch with tensor intrin " @@ -304,6 +305,8 @@ class TensorIntrinMatcher final : public StmtExprMutator { std::unordered_map var_remap_; // IterVar remap. std::unordered_map axis_remap_; + // arith analyzer + arith::Analyzer analyzer_; }; // Try to match tensor dataflow of the stage with the intrinsic @@ -339,11 +342,12 @@ void VerifyTensorizeBody( CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK_EQ(body.size(), intrin_compute->body.size()) << "Tensorize failed: body size mismatch"; + arith::Analyzer ana; + ana.Bind(compute_intrin_iter_space); + for (size_t i = 0; i < body.size(); ++i) { - PrimExpr lhs = Simplify(body[i], compute_intrin_iter_space); - lhs = CanonicalSimplify(lhs, compute_intrin_iter_space); - PrimExpr rhs = Simplify(intrin_compute->body[i], compute_intrin_iter_space); - rhs = CanonicalSimplify(rhs, compute_intrin_iter_space); + PrimExpr lhs = ana.Simplify(body[i]); + PrimExpr rhs = ana.Simplify(intrin_compute->body[i]); if (lhs.dtype() != rhs.dtype()) { LOG(FATAL) << "Failed to match the data type with TensorIntrin " diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index 4ff8586bccdf..6ed9438ec90f 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -324,6 +324,7 @@ void PassUpDomain(const FuseNode* s, CHECK(dom_map.count(s->outer)); CHECK(dom_map.count(s->inner)); CHECK(dom_map.count(s->fused)); + arith::Analyzer ana; if (fused.match_range(dom_map.at(s->fused))) { *outer = IntSet::range(dom_map.at(s->outer)); @@ -348,15 +349,15 @@ void PassUpDomain(const FuseNode* s, *outer = IntSet::interval( outer_min + indexdiv(fused.min(), inner_extent), outer_min + indexdiv(fused.max(), inner_extent)); - if (is_zero(Simplify(indexmod(inner_extent, fused_extent))) && - is_zero(Simplify(indexmod(fused.min(), fused_extent)))) { + if (is_zero(ana.Simplify(indexmod(inner_extent, fused_extent))) && + is_zero(ana.Simplify(indexmod(fused.min(), fused_extent)))) { // fused never spans multiple rows, make a tight bounding box // there may be other cases when bounding box could be tightened *inner = IntSet::interval(inner_min + indexmod(fused.min(), inner_extent), inner_min + indexmod(fused.max(), inner_extent)); } else { // fused may span multiple rows, use full row widths - if (!is_zero(Simplify(indexmod(fused_extent, inner_extent))) || - !is_zero(Simplify(indexmod(fused.min(), inner_extent)))) { + if (!is_zero(ana.Simplify(indexmod(fused_extent, inner_extent))) || + !is_zero(ana.Simplify(indexmod(fused.min(), inner_extent)))) { LOG(WARNING) << "fused and original axes are not aligned, this may cause redundant computations"; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 57b637df0570..c818218fa65e 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -181,7 +181,7 @@ class SchedulePostProc : public StmtExprMutator { // delete duplicated thread extent attr auto it = thread_extent_scope_.find(op->node.get()); if (it != thread_extent_scope_.end()) { - CHECK(is_zero(tir::Simplify(it->second - op->value))); + CHECK(is_zero(analyzer_.Simplify(it->second - op->value))); return this->VisitStmt(op->body); } else { thread_extent_scope_[op->node.get()] = op->value; @@ -335,6 +335,8 @@ class SchedulePostProc : public StmtExprMutator { std::unordered_map replace_realize_; // replace producer consumer. std::unordered_map replace_op_; + // integer analyzer + arith::Analyzer analyzer_; }; Stmt ScheduleOps( diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index 56235591f722..797a8b2b7b88 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -508,7 +508,7 @@ class BufferAnalyser : public StmtExprVisitor { return; } auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); + auto simplified_index = analyzer_.Simplify(index); index_visitor(simplified_index); } @@ -611,7 +611,7 @@ class BufferAnalyser : public StmtExprVisitor { index_visitor.scaling_factor_ = shape->value; } auto index = rel_index[i]; - auto simplified_index = tir::Simplify(index); + auto simplified_index = analyzer_.Simplify(index); index_visitor(simplified_index); } } @@ -645,7 +645,7 @@ class BufferAnalyser : public StmtExprVisitor { PrimExpr offset = make_const(stride.dtype(), avec[dim].align_offset); stride = stride + \ indexmod(factor + offset - indexmod(stride, factor), factor); - stride = tir::Simplify(stride); + stride = analyzer_.Simplify(stride); } rstrides.push_back(stride); stride = stride * shape[dim]; @@ -773,6 +773,7 @@ class BufferAnalyser : public StmtExprVisitor { IndexVisitor index_visitor; Tile warp_tile_; Tile thread_tile_; + arith::Analyzer analyzer_; int warp_threads_y_{-1}; bool invalid_{false}; }; @@ -1148,7 +1149,7 @@ class TensorCoreIRMutator : public StmtExprMutator { buffer_node->strides = strides; buffer_node->shape = shape; buffer_node->data_alignment = 1; - buffer_node->elem_offset = Simplify(elem_offset); + buffer_node->elem_offset = analyzer_.Simplify(elem_offset); buffer_node->offset_factor = 1; Buffer buffer(buffer_node); @@ -1184,6 +1185,7 @@ class TensorCoreIRMutator : public StmtExprMutator { std::unordered_map frag_load_; std::unordered_map frag_store_; std::unordered_map bounds_; + arith::Analyzer analyzer_; Tile warp_tile_; int warp_threads_y_{-1}; }; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 6bbf6451b7ac..a7bc822fdd30 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -37,9 +38,9 @@ namespace tir { using IndexMod = tir::FloorModNode; using IndexDiv = tir::FloorDivNode; -Array SimplifyArray(Array array) { +Array SimplifyArray(arith::Analyzer* ana, Array array) { for (size_t i = 0; i < array.size(); ++i) { - array.Set(i, tir::Simplify(array[i])); + array.Set(i, ana->Simplify(array[i])); } return array; } @@ -185,14 +186,14 @@ inline void MergeMulModInsertElements(const std::vector& eles, // The search will be performed repeatively until no pattern is found. // Return: a pair with (false, Expr()) if cannot be optimized. // a pair with (true, optimized_expr) if can be optimized -inline PrimExpr MergeMulMod(const PrimExpr &base) { +inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr &base) { using namespace tir; // 1. Prepare the lists. // We store two lists, a list that contain all the elements that match Mul and // a list that contain all the elements that match Mod. // The elements in the Mod will be used to match against the elements in Mul. // The result will then be split and pushed back to these two lists. - PrimExpr simplified_base = Simplify(base); + PrimExpr simplified_base = analyzer->Simplify(base); std::vector eles = ExprSplitAddition(simplified_base); std::list mult_exprs; std::list > mod_exprs; @@ -254,6 +255,7 @@ inline PrimExpr MergeMulMod(const PrimExpr &base) { // We also perform optimization to simplify the indexing expression. inline PrimExpr ElemOffset(const BufferNode* n, Array index) { PrimExpr base = n->elem_offset; + arith::Analyzer ana; if (n->strides.size() == 0) { // Scalar case if (n->shape.size() == 0 && index.size() == 1) { @@ -265,7 +267,7 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { if (index.size() > 0) { PrimExpr offset = index[0]; for (size_t i = 1; i < index.size(); ++i) { - offset = MergeMulMod(offset * n->shape[i] + index[i]); + offset = MergeMulMod(&ana, offset * n->shape[i] + index[i]); } base = base + offset; } @@ -273,12 +275,12 @@ inline PrimExpr ElemOffset(const BufferNode* n, Array index) { } else { CHECK_EQ(n->strides.size(), index.size()); if (is_zero(base)) { - base = MergeMulMod(index[0] * n->strides[0]); + base = MergeMulMod(&ana, index[0] * n->strides[0]); } else { - base = MergeMulMod(base + index[0] * n->strides[0]); + base = MergeMulMod(&ana, base + index[0] * n->strides[0]); } for (size_t i = 1; i < index.size(); ++i) { - base = MergeMulMod(base + index[i] * n->strides[i]); + base = MergeMulMod(&ana, base + index[i] * n->strides[i]); } } return base; @@ -353,8 +355,9 @@ Buffer Buffer::MakeStrideView() const { Buffer Buffer::MakeSlice(Array begins, Array extents) const { const BufferNode* n = operator->(); - begins = SimplifyArray(begins); - PrimExpr elem_offset = tir::Simplify(ElemOffset(n, begins)); + arith::Analyzer ana; + begins = SimplifyArray(&ana, begins); + PrimExpr elem_offset = ana.Simplify(ElemOffset(n, begins)); Array strides = n->strides; if (strides.size() == 0) { bool can_relax = true; @@ -363,7 +366,7 @@ Buffer Buffer::MakeSlice(Array begins, Array extents) const for (size_t i = 0; i < extents.size(); ++i) { if (!can_relax) { if (!is_zero(begins[i]) || - !is_zero(tir::Simplify(extents[i] - n->shape[i]))) { + !is_zero(ana.Simplify(extents[i] - n->shape[i]))) { need_stride = true; } } diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 85842a0b9dcf..fb63fed623cf 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -24,6 +24,8 @@ #include #include #include +#include + #include namespace tvm { @@ -253,15 +255,16 @@ inline bool GetStoreRule(Array* rule, } inline Array TransformIndex(const Array& src_index, - const Array& src_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& transform_rule) { + arith::Analyzer ana; Array result; std::unordered_map bind_map; for (size_t i = 0; i < src_index.size(); ++i) { bind_map[src_axis[i]->var.get()] = src_index[i]; } for (PrimExpr rule : transform_rule) { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } return result; } @@ -284,9 +287,10 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) } inline Array TransformShape(const Array& src_shape, - const Array& src_axis, - const Array& target_axis, - const Array& transform_rule) { + const Array& src_axis, + const Array& target_axis, + const Array& transform_rule) { + arith::Analyzer ana; CHECK_EQ(src_shape.size(), src_axis.size()); // bind variables for original axes // for major-axis, bind the corresponding size @@ -329,7 +333,7 @@ inline Array TransformShape(const Array& src_shape, if (symbolic_var_set.count(i)) { result.push_back(tir::AnyNode::make()); } else { - result.push_back(tir::Simplify(tir::Substitute(rule, bind_map))); + result.push_back(ana.Simplify(tir::Substitute(rule, bind_map))); } } } diff --git a/src/tir/pass/arg_binder.cc b/src/tir/pass/arg_binder.cc index c684b9e68038..51a6d8bf5fed 100644 --- a/src/tir/pass/arg_binder.cc +++ b/src/tir/pass/arg_binder.cc @@ -31,10 +31,11 @@ namespace tvm { namespace tir { -void BinderAddAssert(PrimExpr cond, +void BinderAddAssert(arith::Analyzer* ana, + PrimExpr cond, const std::string& arg_name, std::vector* asserts) { - PrimExpr scond = Simplify(cond); + PrimExpr scond = ana->Simplify(cond); if (is_zero(scond)) { LOG(FATAL) << "Bind have an unmet assertion: " << cond << ", " << " on argument " << arg_name; @@ -65,10 +66,10 @@ bool ArgBinder::Bind_(const PrimExpr& arg, } return true; } else { - BinderAddAssert(it->second == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, it->second == value, arg_name, &asserts_); } } else { - BinderAddAssert(arg == value, arg_name, &asserts_); + BinderAddAssert(&analyzer_, arg == value, arg_name, &asserts_); } return false; } @@ -121,7 +122,8 @@ void ArgBinder::BindBuffer(const Buffer& arg, PrimExpr offset = value->elem_offset; PrimExpr factor = make_const(offset.dtype(), arg->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, + BinderAddAssert(&analyzer_, + truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); } } @@ -130,7 +132,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; size_t diff = value->shape.size() - arg->shape.size(); for (size_t i = 0; i < diff; ++i) { - CHECK(is_one(Simplify(value->shape[i]))) + CHECK(is_one(analyzer_.Simplify(value->shape[i]))) << "Argument " << arg_name << " shape mismatch" << arg->shape << " vs " << value->shape; } @@ -269,7 +271,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, value = tvm::if_then_else(is_null, stride, value); value = tvm::if_then_else(buffer->shape[k] == 1, 0, value); Bind_(buffer->strides[k], value, field_name.str(), true); - stride = Simplify(stride * buffer->shape[k]); + stride = analyzer_.Simplify(stride * buffer->shape[k]); } } else { std::ostringstream stride_null_err_msg; @@ -304,7 +306,9 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, PrimExpr offset = buffer->elem_offset; PrimExpr factor = make_const(offset.dtype(), buffer->offset_factor); PrimExpr zero = make_zero(offset.dtype()); - BinderAddAssert(truncmod(offset, factor) == zero, arg_name + ".elem_offset", &asserts_); + BinderAddAssert(&analyzer_, + truncmod(offset, factor) == zero, + arg_name + ".elem_offset", &asserts_); } } } diff --git a/src/tir/pass/arg_binder.h b/src/tir/pass/arg_binder.h index dfeb82853529..0ff51e8c98f1 100644 --- a/src/tir/pass/arg_binder.h +++ b/src/tir/pass/arg_binder.h @@ -26,6 +26,8 @@ #include #include +#include + #include #include #include @@ -153,6 +155,8 @@ class ArgBinder { Map def_handle_dtype_; /*! \brief asserts generated */ std::vector asserts_; + /*! \brief internal analyzer. */ + arith::Analyzer analyzer_; }; } // namespace tir } // namespace tvm diff --git a/src/tir/pass/ffi_api.cc b/src/tir/pass/ffi_api.cc index 60b5bd9a7f9c..ea762453cc9c 100644 --- a/src/tir/pass/ffi_api.cc +++ b/src/tir/pass/ffi_api.cc @@ -32,40 +32,6 @@ namespace tvm { namespace tir { -TVM_REGISTER_GLOBAL("ir_pass.Simplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = Simplify(args[0].operator Stmt(), args[1]); - } else { - *ret = Simplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = Simplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = Simplify(args[0].operator PrimExpr()); - } - } - }); - -TVM_REGISTER_GLOBAL("ir_pass.CanonicalSimplify") -.set_body([](TVMArgs args, TVMRetValue *ret) { - if (args[0].IsObjectRef()) { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator Stmt(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator Stmt()); - } - } else { - if (args.size() > 1) { - *ret = CanonicalSimplify(args[0].operator PrimExpr(), args[1]); - } else { - *ret = CanonicalSimplify(args[0].operator PrimExpr()); - } - } - }); - TVM_REGISTER_GLOBAL("ir_pass.Substitute") .set_body([](TVMArgs args, TVMRetValue *ret) { if (args[0].IsObjectRef()) { diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 5e40eb2d9025..d409ffc4a15d 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include "../../arith/pattern_match.h" @@ -125,7 +126,7 @@ class CopyIntrinInjector : public StmtMutator { DataType t = loop_vars[i].dtype(); PrimExpr svalue = src_shape[i]; if (min_value.defined()) { - PrimExpr pbefore = Simplify(MaxNode::make(min_value, make_zero(t))); + PrimExpr pbefore = analyzer_.Simplify(MaxNode::make(min_value, make_zero(t))); src_elem_offset = src_elem_offset + pbefore * load_strides[i]; svalue = svalue - pbefore; pad_before.push_back(pbefore); @@ -133,16 +134,16 @@ class CopyIntrinInjector : public StmtMutator { pad_before.push_back(make_zero(t)); } if (max_value.defined()) { - PrimExpr pafter = Simplify(MaxNode::make(loops[i]->extent - max_value - make_const(t, 1), - make_zero(t))); + PrimExpr pafter = analyzer_.Simplify( + max(loops[i]->extent - max_value - make_const(t, 1), make_zero(t))); svalue = svalue - pafter; pad_after.push_back(pafter); } else { pad_after.push_back(make_zero(t)); } - src_shape.Set(i, Simplify(svalue)); + src_shape.Set(i, analyzer_.Simplify(svalue)); } - src_elem_offset = Simplify(src_elem_offset); + src_elem_offset = analyzer_.Simplify(src_elem_offset); } CHECK_EQ(load_strides.size(), store_strides.size()); CHECK_EQ(load_strides.size(), loop_var_size + 1); @@ -189,6 +190,8 @@ class CopyIntrinInjector : public StmtMutator { const PackedFunc& flower_copy_fromto_; // Storage scope std::unordered_map storage_scope_; + // arith analyzer + arith::Analyzer analyzer_; }; Stmt InjectCopyIntrin(Stmt stmt, diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 9fa72303e2d8..a77d529e7764 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -24,11 +24,9 @@ #include #include #include +#include #include #include - -#include - #include "../pass/ir_util.h" #include "../../runtime/thread_storage_scope.h" @@ -123,8 +121,8 @@ class StorageAccessInfoLower : public StmtExprMutator { int dtype_bits = dtype.bits() * dtype.lanes(); CHECK_EQ(info->unit_bits % dtype_bits, 0); return cast(ptr_type, - tir::Simplify(offset / make_const( - offset.dtype(), info->unit_bits / dtype_bits))); + analyzer_.Simplify(offset / make_const( + offset.dtype(), info->unit_bits / dtype_bits))); } // The storage entry. struct StorageEntry { @@ -137,6 +135,8 @@ class StorageAccessInfoLower : public StmtExprMutator { }; // The storage scope of each buffer std::unordered_map storage_info_; + // analyzer + arith::Analyzer analyzer_; }; Stmt LowerStorageAccessInfo(Stmt stmt) { diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 655a0074c7fd..85744d11a9e5 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -313,6 +313,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } return ret; } + // The local buffer index. + PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { + if (!is_zero(group_index)) { + return analyzer_.Simplify(group_index * reduce_extent + reduce_index); + } else { + return reduce_index; + } + } // sync thread op. static Stmt SyncThread(const std::string& sync) { return EvaluateNode::make( @@ -320,14 +328,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { {StringImmNode::make(sync)}, CallNode::Intrinsic)); } - // The local buffer index. - static PrimExpr BufIndex(PrimExpr reduce_index, PrimExpr group_index, int reduce_extent) { - if (!is_zero(group_index)) { - return tir::Simplify(group_index * reduce_extent + reduce_index); - } else { - return reduce_index; - } - } // The warp size of the device. int warp_size_{1}; @@ -338,6 +338,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { std::unordered_map load_remap_; // Allocate remap std::unordered_map alloc_remap_; + // Internal analyzer + arith::Analyzer analyzer_; }; namespace transform { diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 0aee3c284422..ac08e6fd07a4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -371,7 +371,6 @@ class WarpMemoryRewriter : private StmtMutator { BindVarBoundInfo binder(&analyzer_); binder(stmt); stmt = operator()(std::move(stmt)); - stmt = CanonicalSimplify(stmt); return stmt; } diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index ecfa25e28975..1e4fd73f6d0c 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -98,36 +98,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { - -Stmt CanonicalSimplify(Stmt stmt, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return arith::StmtSimplifier(&analyzer).Simplify(std::move(stmt)); -} - -PrimExpr CanonicalSimplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - return analyzer.canonical_simplify(expr); -} - -PrimExpr Simplify(PrimExpr expr, Map vrange) { - arith::Analyzer analyzer; - for (auto kv : vrange) { - analyzer.Bind(kv.first, kv.second); - } - expr = analyzer.Simplify(expr); - return expr; -} - -Stmt Simplify(Stmt stmt, Map vrange) { - return CanonicalSimplify(std::move(stmt), vrange); -} - namespace transform { Pass Simplify() { diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c13879c31c64..f960306f2ee8 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -625,7 +625,7 @@ class StoragePlanRewriter : public StmtExprMutator { if (!divided) { combo_size = combo_size + make_const(DataType::Int(32), 1); } - combo_size = tir::Simplify(combo_size); + combo_size = analyzer_.Simplify(combo_size); e->new_alloc = AllocateNode::make( e->alloc_var, alloc_type, {combo_size}, const_true(), EvaluateNode::make(0)); diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 27c39d4c18aa..9ff5429329c5 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -25,9 +25,10 @@ #include #include #include -#include #include #include +#include +#include #include #include #include @@ -160,7 +161,7 @@ class LoopUnroller : public StmtExprMutator { // returns the extent of the loop if it's a constant integer, otherwise return -1 int GetExtent(const ForNode* op) { // constant folding. - PrimExpr extent = tir::Simplify(op->extent); + PrimExpr extent = analyzer_.Simplify(op->extent); const IntImmNode *v1 = extent.as(); int value = -1; // integers that do not fit in int32_t are treated as symbolic, @@ -184,6 +185,8 @@ class LoopUnroller : public StmtExprMutator { int unroll_depth_{0}; // Number of total steps unrolled int step_count_{0}; + // analyzer + arith::Analyzer analyzer_; }; diff --git a/tests/cpp/ir_simplify_test.cc b/tests/cpp/arith_simplify_test.cc similarity index 79% rename from tests/cpp/ir_simplify_test.cc rename to tests/cpp/arith_simplify_test.cc index 69cf1298a320..f4c259fca342 100644 --- a/tests/cpp/ir_simplify_test.cc +++ b/tests/cpp/arith_simplify_test.cc @@ -19,35 +19,38 @@ #include #include -#include +#include #include -TEST(IRSIMPLIFY, MinMax) { +TEST(Simplify, MinMax) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e1 = (tvm::max(x, 1) - tvm::max(x, 1)) ; - auto e1s = tvm::tir::CanonicalSimplify(e1); + auto e1s = ana.canonical_simplify(e1); CHECK(tvm::tir::is_zero(e1s)); auto e2 = (x * tvm::min(x, 1)) - (x * tvm::min(x, 1)); - auto e2s = tvm::tir::CanonicalSimplify(e2); + auto e2s = ana.canonical_simplify(e2); CHECK(tvm::tir::is_zero(e2s)); } -TEST(IRSIMPLIFY, Mul) { +TEST(Simplify, Mul) { + tvm::arith::Analyzer ana; auto x = tvm::te::var("x"); auto e = (x * x) - (x * x) ; - auto es = tvm::tir::CanonicalSimplify(e); + auto es = ana.canonical_simplify(e); CHECK(tvm::tir::is_zero(es)); } -TEST(IRSIMPLIFY, Mod) { +TEST(Simplify, Mod) { + tvm::arith::Analyzer ana; auto x = tvm::Integer(10); auto y = tvm::Integer(12); // Mod::make is used instead of % to avoid constant folding during // calling operator%(x,y). Mod::make doesn't try constant folding, // and therefore, the constant folding will be attempted in CanonicalSimplify - auto mod = tvm::tir::CanonicalSimplify(tvm::tir::ModNode::make(x, y)); - auto es = tvm::tir::CanonicalSimplify(mod - x); + auto mod = ana.canonical_simplify(tvm::tir::ModNode::make(x, y)); + auto es = ana.canonical_simplify(mod - x); CHECK(tvm::tir::is_zero(es)); } int main(int argc, char ** argv) { diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 5baabd16c615..6efb67b19bad 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -18,13 +18,6 @@ from tvm import te -def assert_expr_equal(a, b): - res = tvm.tir.ir_pass.Simplify(a - b) - equal = isinstance(res, tvm.tir.IntImm) and res.value == 0 - if not equal: - raise ValueError("{} and {} are not equal".format(a, b)) - - def test_deduce(): a = te.var('a') b = te.var('b') @@ -41,32 +34,32 @@ def test_deduce(): e0 = (-b)*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d - c, b*-1) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e0 = d*a+c-d res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {}) ans0 = fdiv(d-c, d) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) # expression containing variable a is on rhs res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res0.max_value, ans0) + tvm.testing.assert_prim_expr_equal(res0.max_value, ans0) e1 = (a*4+b < c) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) ans1 = fdiv(c-1-b, 4) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) # expression containing variable a is on rhs e1 = (c > a*4+b) res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res1.max_value, ans1) + tvm.testing.assert_prim_expr_equal(res1.max_value, ans1) e2 = (tvm.te.max(5, a * 4) < 0) @@ -83,15 +76,15 @@ def test_deduce(): e3 = (-b)+a*c-d res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) ans3 = fdiv(2,c)+1 - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: b_s, d: d_s}) - assert str(tvm.tir.ir_pass.Simplify(res3.min_value)) == str(ans3) + tvm.testing.assert_prim_expr_equal(res3.min_value, ans3) # tests for `EQ` op res4 = tvm.arith.deduce_bound(a, a == b, {}, {}) - assert_expr_equal(res4.max_value, b) - assert_expr_equal(res4.min_value, b) + tvm.testing.assert_prim_expr_equal(res4.max_value, b) + tvm.testing.assert_prim_expr_equal(res4.min_value, b) # Unsatisfiable `EQ`, variable as one of the Operand res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s}) @@ -100,20 +93,20 @@ def test_deduce(): # variable `a` on the RHS side res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {}) - assert_expr_equal(res6.max_value, 10) - assert_expr_equal(res6.min_value, 10) + tvm.testing.assert_prim_expr_equal(res6.max_value, 10) + tvm.testing.assert_prim_expr_equal(res6.min_value, 10) # Add, Sub in `EQ` e4 = ((a - c) == (b + d)) ans4 = (b + d + c) res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {}) - assert_expr_equal(res7.max_value, ans4) - assert_expr_equal(res7.min_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.max_value, ans4) + tvm.testing.assert_prim_expr_equal(res7.min_value, ans4) # Satisfiable Mul in `EQ` with negative sign res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {}) - assert_expr_equal(res8.max_value, -2) - assert_expr_equal(res8.min_value, -2) + tvm.testing.assert_prim_expr_equal(res8.max_value, -2) + tvm.testing.assert_prim_expr_equal(res8.min_value, -2) # Unsatisfiable Mul in `EQ` e5 = (4 * a == b) @@ -158,21 +151,22 @@ def test_basic(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) < 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32") < e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) > 17, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(17, "int32")>= e0, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1 + + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) <= 17, True) res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s}) [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1 + tvm.testing.assert_prim_expr_equal((x * coff + 3 + y) >= 17, True) test_basic(0, 4, 4) test_basic(1, 5, 4) @@ -190,21 +184,21 @@ def test_complex(a1, a2, coff): res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) < 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32")>= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) <= 63, True) res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) > 63, True) # expression containing variable a is on rhs res1 = tvm.arith.deduce_bound(a, tvm.tir.const(63, "int32") <= e0, {b: b_s}, {b: b_s}) [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else [res1.min_value, b_s.min_value] - assert (tvm.tir.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1 + tvm.testing.assert_prim_expr_equal(((x*3 + t* coff) * 4) >= 63, True) test_complex(0, 4, 4) test_complex(0, 4, -4) diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py b/tests/python/unittest/test_arith_detect_clip_bound.py index d6953713f14b..129237a8c58b 100644 --- a/tests/python/unittest/test_arith_detect_clip_bound.py +++ b/tests/python/unittest/test_arith_detect_clip_bound.py @@ -23,15 +23,15 @@ def test_basic(): c = te.var("c") m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a]) - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 - 1) assert m[0].value == 2 m = tvm.arith.detect_clip_bound(tvm.tir.all(a * 1 < b * 6, a - 1 > 0), [a, b]) assert len(m) == 0 m = tvm.arith.detect_clip_bound(tvm.tir.all(a + 10 * c <= 20, b - 1 > 0), [a, b]) - assert tvm.tir.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0 - assert tvm.tir.ir_pass.Simplify(m[2] - 2).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], 20 - 10 * c) + tvm.testing.assert_prim_expr_equal(m[2], 2) if __name__ == "__main__": diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py b/tests/python/unittest/test_arith_detect_linear_equation.py index 278581d0cacd..82153ab5207e 100644 --- a/tests/python/unittest/test_arith_detect_linear_equation.py +++ b/tests/python/unittest/test_arith_detect_linear_equation.py @@ -22,14 +22,14 @@ def test_basic(): b = te.var("b") m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a]) assert m[0].value == 4 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7) m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a]) assert len(m) == 0 m = tvm.arith.detect_linear_equation(a * 4 + (a+1) + b * 6 + 7, [a]) assert m[0].value == 5 - assert tvm.tir.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0 + tvm.testing.assert_prim_expr_equal(m[1], b * 6 + 7 + 1) m = tvm.arith.detect_linear_equation(a * b + 7, [a]) assert m[0] == b @@ -39,13 +39,15 @@ def test_basic(): m = tvm.arith.detect_linear_equation(b * 7, []) assert len(m) == 1 - assert tvm.tir.ir_pass.Simplify(m[0] - b * 7).value == 0 + tvm.testing.assert_prim_expr_equal(m[0], b * 7) def test_multivariate(): v = [te.var("v%d" % i) for i in range(4)] b = te.var("b") m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v) - assert(tvm.tir.analysis.expr_deep_equal(tvm.tir.ir_pass.Simplify(m[0]), b + 5)) + + tvm.testing.assert_prim_expr_equal(m[0], b + 5) + assert(m[1].value == 8) m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[2], v) @@ -61,11 +63,12 @@ def test_multivariate(): m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]]) assert(m[0].value == 0) - assert(tvm.tir.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0) + + tvm.testing.assert_prim_expr_equal(m[1], v[0] - v[1]) m = tvm.arith.detect_linear_equation((v[0] - v[1]), []) assert(len(m) == 1) - assert(tvm.tir.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0) + tvm.testing.assert_prim_expr_equal(m[0], v[0] - v[1]) if __name__ == "__main__": test_basic() diff --git a/tests/python/unittest/test_arith_solve_linear_system.py b/tests/python/unittest/test_arith_solve_linear_system.py index 45f8fc10aaf0..645b1a2b537c 100644 --- a/tests/python/unittest/test_arith_solve_linear_system.py +++ b/tests/python/unittest/test_arith_solve_linear_system.py @@ -55,12 +55,13 @@ def check_bruteforce(bool_expr, vranges, cond=None): counterex = ", ".join([v + " = " + str(i) for v, i in counterex]) raise AssertionError("Expression {}\nis not true on {}\n" "Counterexample: {}" - .format(tir.ir_pass.CanonicalSimplify(bool_expr), vranges, counterex)) + .format(tir.arith.Analyzer().simplify(bool_expr), vranges, counterex)) def check_solution(solution, vranges={}): """Check that solution is a bijective transformation""" def _check_forward(constraints1, constraints2, varmap, backvarmap): + ana = tvm.arith.Analyzer() all_vranges = vranges.copy() all_vranges.update({v: r for v, r in constraints1.ranges.items()}) @@ -68,7 +69,7 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): cond_on_vars = tir.const(1, 'bool') for v in constraints1.variables: # variable mapping is consistent - v_back = tir.ir_pass.Simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) + v_back = ana.simplify(tir.ir_pass.Substitute(varmap[v], backvarmap)) cond_on_vars = te.all(cond_on_vars, v == v_back) # Also we have to check that the new relations are true when old relations are true cond_subst = tir.ir_pass.Substitute( @@ -80,7 +81,7 @@ def _check_forward(constraints1, constraints2, varmap, backvarmap): range_cond = te.all(v >= r.min, v < r.min + r.extent) range_cond = tir.ir_pass.Substitute(range_cond, backvarmap) cond_subst = te.all(cond_subst, range_cond) - cond_subst = tir.ir_pass.Simplify(cond_subst) + cond_subst = ana.simplify(cond_subst) check_bruteforce(te.all(cond_subst, cond_on_vars), all_vranges, cond=te.all(tir.const(1, 'bool'), *constraints1.relations)) diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 5b4a1c92a7e4..8afd65330f98 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -25,7 +25,7 @@ def run_and_check(func, args, var_dict={}, target='llvm', sch=None, outs=None): def tvm_val_2_py_val(val): val = tvm.tir.ir_pass.Substitute(val, var_dict) - val = tvm.tir.ir_pass.Simplify(val) + val = tvm.arith.Analyzer().simplify(val) assert isinstance(val, (tvm.tir.IntImm,)) return val.value diff --git a/tests/python/unittest/test_te_schedule_bound_inference.py b/tests/python/unittest/test_te_schedule_bound_inference.py index edae527c0183..6b6c519c8fa3 100644 --- a/tests/python/unittest/test_te_schedule_bound_inference.py +++ b/tests/python/unittest/test_te_schedule_bound_inference.py @@ -139,19 +139,20 @@ def test_bound_fusesplit1(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) idxdiv = tvm.tir.indexdiv - assert(tvm.tir.ir_pass.Simplify( - bounds[A1.op.axis[0]].min - idxdiv(xo * split1, l)).value == 0) + tvm.testing.assert_prim_expr_equal( + bounds[A1.op.axis[0]].min, idxdiv(xo * split1, l)) expected_extent = (idxdiv((xo + 1) * split1 - 1, l) - idxdiv(xo * split1, l) + 1) for i in range(1, 6): for j in range(1, 6): for k in range(1, 6): vars = tvm.runtime.convert({split1: tvm.tir.const(i, "int32"), l: tvm.tir.const(j, "int32"), xo.var: tvm.tir.const(k, "int32")}) - comp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value - exp_ext = tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(expected_extent, vars)).value - assert(comp_ext == exp_ext) + tvm.testing.assert_prim_expr_equal( + tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), + tvm.tir.ir_pass.Substitute(expected_extent, vars) + ) - assert(tvm.tir.ir_pass.Simplify(bounds[A1.op.axis[1]].extent - l).value == 0) + tvm.testing.assert_prim_expr_equal(bounds[A1.op.axis[1]].extent, l) def test_bound_fusesplit2(): m = te.var("m") @@ -169,10 +170,10 @@ def test_bound_fusesplit2(): bounds = tvm.te.schedule.InferBound(s) assert isinstance(bounds, tvm.container.Map) vars = tvm.runtime.convert({xo.var: tvm.tir.const(5, "int32")}) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars)).value == 2) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars)).value == 3) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars)).value == 1) - assert(tvm.tir.ir_pass.Simplify(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars)).value == 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].min, vars), 2) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].min, vars), 3) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[0]].extent, vars), 1) + tvm.testing.assert_prim_expr_equal(tvm.tir.ir_pass.Substitute(bounds[A1.op.axis[1]].extent, vars), 3) def test_bound_warp(): diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index dafffed9bd44..ef5b3fd3a44e 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -105,9 +105,10 @@ def check(factor): assert tvm.ir.structural_equal(in_dom.items()[0][1][0].extent, factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[z], out_dom, in_dom, vadd) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(vadd.op.body[0])) + ana.simplify(body[0]), + ana.simplify(vadd.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [x, y, z]) @@ -139,9 +140,11 @@ def check(factor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() + assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -164,9 +167,10 @@ def check_rfactor(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -188,9 +192,10 @@ def check_rfactor_no_reset(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) @@ -213,9 +218,10 @@ def check_rfactor_no_reset_multi_reduction(factor, rfactor): assert tvm.ir.structural_equal(out_dom[y].min, yo * factor) fmatch = tvm.get_global_func("test.op.MatchTensorizeBody") body = fmatch(s[C], out_dom, in_dom, gemv) + ana = tvm.arith.Analyzer() assert tvm.ir.structural_equal( - tvm.tir.ir_pass.CanonicalSimplify(body[0]), - tvm.tir.ir_pass.CanonicalSimplify(gemv.op.body[0])) + ana.simplify(body[0]), + ana.simplify(gemv.op.body[0])) stmt = tvm.te.schedule.ScheduleOps(s, dom_map) tvm.lower(s, [A, B, C]) diff --git a/tests/python/unittest/test_tir_buffer.py b/tests/python/unittest/test_tir_buffer.py index fe23955017a0..7ee1e539204b 100644 --- a/tests/python/unittest/test_tir_buffer.py +++ b/tests/python/unittest/test_tir_buffer.py @@ -48,17 +48,14 @@ def test_buffer_access_ptr_offset(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32") aptr = Ab.access_ptr("rw", offset=100) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 100) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 100) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE v = te.size_var('int32') aptr = Ab.access_ptr("rw", offset=100 + 100 + v) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, 200 + v) + tvm.testing.assert_prim_expr_equal(aptr.args[2], 200 + v) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE aptr = Ab.access_ptr("rw", offset=tvm.tir.call_extern('int32', "test_call", 100 + 100 + v)) - offset = tvm.tir.ir_pass.Simplify(aptr.args[2]) - assert tvm.ir.structural_equal(offset, tvm.tir.call_extern('int32', "test_call", 200 + v)) + tvm.testing.assert_prim_expr_equal(aptr.args[2], tvm.tir.call_extern('int32', "test_call", 200 + v)) assert aptr.args[4].value == Buffer.READ | Buffer.WRITE @@ -80,8 +77,7 @@ def test_buffer_vload(): n = te.size_var('n') Ab = tvm.tir.decl_buffer((m, n), "float32", elem_offset=100) load = Ab.vload([2, 3]) - offset = tvm.tir.ir_pass.Simplify(load.index) - assert tvm.ir.structural_equal(offset, n * 2 + 103) + tvm.testing.assert_prim_expr_equal(load.index, n * 2 + 103) def test_buffer_index_merge_mult_mod(): diff --git a/tests/python/unittest/test_tir_pass_basic.py b/tests/python/unittest/test_tir_pass_basic.py index 228e0c52c435..23982873075f 100644 --- a/tests/python/unittest/test_tir_pass_basic.py +++ b/tests/python/unittest/test_tir_pass_basic.py @@ -17,16 +17,6 @@ import tvm from tvm import te -def test_simplify(): - tdiv = tvm.tir.truncdiv - tmod = tvm.tir.truncmod - x = te.var('x') - e1 = tvm.tir.ir_pass.Simplify(x + 2 + 1) - assert(tvm.ir.structural_equal(e1, x + 3)) - e2 = tvm.tir.ir_pass.Simplify(x * 3 + 5 * x) - assert(tvm.ir.structural_equal(e2, x * 8)) - e3 = tvm.tir.ir_pass.Simplify(x - tdiv(x, 3) * 3) - assert(tvm.ir.structural_equal(e3, tmod(x, 3))) def test_verify_ssa(): diff --git a/tests/python/unittest/test_tir_pass_decorate_device_scope.py b/tests/python/unittest/test_tir_pass_decorate_device_scope.py index 327cfd9ed548..9c58431158b9 100644 --- a/tests/python/unittest/test_tir_pass_decorate_device_scope.py +++ b/tests/python/unittest/test_tir_pass_decorate_device_scope.py @@ -31,8 +31,7 @@ def test_decorate_device(): s[A1].set_scope("shared") bounds = tvm.te.schedule.InferBound(s) - stmt = tvm.te.schedule.ScheduleOps(s, bounds) - stmt1 = tvm.tir.ir_pass.Simplify(stmt) + stmt1 = tvm.te.schedule.ScheduleOps(s, bounds) stmt2 = tvm.tir.ir_pass.DecorateDeviceScope(stmt1) assert isinstance(stmt2, tvm.tir.AttrStmt) assert stmt2.attr_key == "device_scope" diff --git a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py index 9d1641366d7d..887b8b0c2b75 100644 --- a/tests/python/unittest/test_tir_transform_inject_copy_intrin.py +++ b/tests/python/unittest/test_tir_transform_inject_copy_intrin.py @@ -57,7 +57,7 @@ def test_copy_pad(): mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) assert pad_before[0].value == 1 assert pad_before[1].value == 0 assert pad_after[0].value == 1 @@ -82,18 +82,15 @@ def test_single_point_test(): mod = tvm.tir.transform.StorageFlatten(64)(mod) def cb(src, dst, pad_before, pad_after, pad_value): - assert tvm.tir.ir_pass.Simplify(src.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(dst.elem_offset).value == 0 - assert tvm.tir.ir_pass.Simplify(src.strides[0]).value == 1 - assert tvm.tir.ir_pass.Simplify(dst.strides[0]).value == 1 + tvm.testing.assert_prim_expr_equal(src.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(dst.elem_offset, 0) + tvm.testing.assert_prim_expr_equal(src.strides[0], 1) + tvm.testing.assert_prim_expr_equal(dst.strides[0], 1) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body -def assert_expr_equal(a, b): - assert tvm.tir.ir_pass.Simplify(a - b).value == 0 - def test_copy_pad_split(): m = 4 * 3 A = te.placeholder((m, ), name="A") @@ -115,13 +112,13 @@ def test_copy_pad_split(): def cb(src, dst, pad_before, pad_after, pad_value): assert(dst.elem_offset.value == 0) - assert_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) + tvm.testing.assert_prim_expr_equal(src.elem_offset, tvm.te.max(xo * 4, 1) - 1) rpad_before = tvm.te.max(1 - xo * 4, 0) rpad_after = tvm.te.max(xo * 4 - 7, 0) - assert_expr_equal(pad_before[0], rpad_before) - assert_expr_equal(pad_after[0], rpad_after) - assert_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) + tvm.testing.assert_prim_expr_equal(pad_before[0], rpad_before) + tvm.testing.assert_prim_expr_equal(pad_after[0], rpad_after) + tvm.testing.assert_prim_expr_equal(src.shape[0], 6 - rpad_before - rpad_after) return tvm.tir.Evaluate(0) stmt = tvm.tir.transform.InjectCopyIntrin("memcpy", cb)(mod)["main"].body diff --git a/tests/python/unittest/test_tir_transform_lower_intrin.py b/tests/python/unittest/test_tir_transform_lower_intrin.py index b2e984aae6fd..98c79e414226 100644 --- a/tests/python/unittest/test_tir_transform_lower_intrin.py +++ b/tests/python/unittest/test_tir_transform_lower_intrin.py @@ -22,10 +22,13 @@ def lower_intrin(params, stmt): """wrapper to call transformation in stmt""" lower_expr = isinstance(stmt, tvm.tir.PrimExpr) stmt = tvm.tir.Evaluate(stmt) if lower_expr else stmt - stmt = tvm.tir.ir_pass.CanonicalSimplify(stmt) - func = tvm.tir.PrimFunc(params, stmt).with_attr( - "target", tvm.target.create("llvm")) - func = tvm.tir.transform.LowerIntrin()(tvm.IRModule.from_expr(func))["main"] + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc(params, stmt).with_attr( + "target", tvm.target.create("llvm"))) + mod = tvm.transform.Sequential([ + tvm.tir.transform.Simplify(), + tvm.tir.transform.LowerIntrin() + ])(mod) + func = mod["main"] stmt = func.body return stmt.value if lower_expr else stmt.body diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index e2bfeb009a11..57eb349f18df 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -51,9 +51,10 @@ def test_flatten_prefetch(): [_A], stmt, {A: _A}) mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) stmt = mod["main"].body - stmt = tvm.tir.ir_pass.Simplify(stmt) assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -74,9 +75,11 @@ def test_flatten_storage_align(): func = tvm.te.schedule.SchedulePostProcToPrimFunc([A, A2], stmt, None) mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.StorageFlatten(64)(mod) + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify()])(mod) + stmt = mod["main"].body - stmt = tvm.tir.ir_pass.Simplify(stmt) assert(stmt.body.extents[0].value == 17 * 8) @@ -103,11 +106,12 @@ def test_flatten_double_buffer(): mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([A, C], stmt)) - mod = tvm.tir.transform.StorageFlatten(64)(mod) - stmt = mod["main"].body + mod = tvm.transform.Sequential([ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.InjectDoubleBuffer(2), + tvm.tir.transform.Simplify()])(mod) - stmt = tvm.tir.ir_pass.InjectDoubleBuffer(stmt, 2) - stmt = tvm.tir.ir_pass.Simplify(stmt) + stmt = mod["main"].body assert isinstance(stmt.body.body, tvm.tir.Allocate) assert stmt.body.body.extents[0].value == 2 diff --git a/topi/include/topi/detail/constant_utils.h b/topi/include/topi/detail/constant_utils.h index 74be9453ae61..afa8833b3490 100644 --- a/topi/include/topi/detail/constant_utils.h +++ b/topi/include/topi/detail/constant_utils.h @@ -25,8 +25,9 @@ #define TOPI_DETAIL_CONSTANT_UTILS_H_ #include -#include +#include #include +#include #include #include @@ -119,7 +120,7 @@ inline bool EqualCheck(PrimExpr lhs, PrimExpr rhs) { bool result = expr_equal(lhs, rhs); if (!result) { PrimExpr zero(0); - result = expr_equal(tvm::tir::CanonicalSimplify(lhs-rhs), zero); + result = expr_equal(tvm::arith::Analyzer().Simplify(lhs-rhs), zero); } return result; } diff --git a/topi/include/topi/nn.h b/topi/include/topi/nn.h index a1ee8c1a5901..7569bb0a7ba2 100644 --- a/topi/include/topi/nn.h +++ b/topi/include/topi/nn.h @@ -26,8 +26,8 @@ #include #include +#include #include -#include #include #include @@ -184,6 +184,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, pad_after.push_back(pad_before[i]); } } + arith::Analyzer analyzer; CHECK_GE(pad_before.size(), 1); CHECK_EQ(pad_before.size(), pad_after.size()); tvm::Array output_shape; @@ -200,13 +201,14 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, output_shape.push_back(t->shape[i]); } else { output_shape.push_back( - tvm::tir::Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); + analyzer.Simplify(t->shape[i] + pad_before_int32[i] + pad_after_int32[i])); } } if (!pad_value.defined()) { pad_value = tvm::tir::make_const(t->dtype, 0); } + auto l = [&](tvm::Array ovars) { tvm::Array indices; tvm::Array sel; @@ -223,7 +225,7 @@ inline tvm::te::Tensor pad(const tvm::te::Tensor& t, indices.push_back(ovars[i]); } if (!topi::detail::EqualCheck(pad_after_int32[i], 0)) { - sel.push_back(tvm::tir::Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); + sel.push_back(analyzer.Simplify(ovars[i] < pad_before_int32[i] + t->shape[i])); } if (pad_mode == "edge") { pad_idx.push_back(tvm::if_then_else( diff --git a/topi/include/topi/nn/bnn.h b/topi/include/topi/nn/bnn.h index 6bda65317706..c69fc5406e33 100644 --- a/topi/include/topi/nn/bnn.h +++ b/topi/include/topi/nn/bnn.h @@ -25,7 +25,7 @@ #define TOPI_NN_BNN_H_ #include -#include +#include #include #include @@ -55,11 +55,12 @@ inline tvm::te::Tensor binarize_pack(const tvm::te::Tensor& data, CHECK_EQ(GetConstInt(ishape[axis]) % 32, 0) << "binarize_pack: axis size must be a multiple of 32"; + arith::Analyzer analyzer; auto n = ishape.size(); Array oshape; for (size_t i = 0; i < n; ++i) { oshape.push_back(i == static_cast(axis) ? - tvm::tir::Simplify(indexdiv(ishape[i], 32)) : + analyzer.Simplify(indexdiv(ishape[i], 32)) : ishape[i]); } diff --git a/topi/include/topi/nn/dilate.h b/topi/include/topi/nn/dilate.h index a67bf3a300b2..32ee1392e846 100644 --- a/topi/include/topi/nn/dilate.h +++ b/topi/include/topi/nn/dilate.h @@ -25,7 +25,7 @@ #define TOPI_NN_DILATE_H_ #include -#include +#include #include #include @@ -75,8 +75,9 @@ inline Tensor dilate(const Tensor& x, << ") must match dimension of x (" << n << ")"; Array out_shape; + arith::Analyzer analyzer; for (size_t i = 0; i < n; ++i) { - out_shape.push_back(tvm::tir::Simplify( + out_shape.push_back(analyzer.Simplify( (x->shape[i] - 1) * cast(DataType::Int(32), strides[i] + 1))); } diff --git a/topi/include/topi/nn/pooling.h b/topi/include/topi/nn/pooling.h index 20b7b246317b..324ecadfe95a 100644 --- a/topi/include/topi/nn/pooling.h +++ b/topi/include/topi/nn/pooling.h @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -102,10 +102,10 @@ inline Tensor pool_impl(const Tensor& x, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - - auto out_height = tvm::tir::Simplify( + arith::Analyzer analyzer; + auto out_height = analyzer.Simplify( indexdiv(height - kernel_height + pad_top + pad_bottom, stride_height) + 1); - auto out_width = tvm::tir::Simplify( + auto out_width = analyzer.Simplify( indexdiv(width - kernel_width + pad_left + pad_right, stride_width) + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); @@ -212,11 +212,11 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, Array pad_after(std::vector(x->shape.size(), 0)); pad_after.Set(height_axis, pad_bottom); pad_after.Set(width_axis, pad_right); - + arith::Analyzer analyzer; auto out_height = - tvm::tir::Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); + analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1); auto out_width = - tvm::tir::Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); + analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1); auto dheight = tvm::te::reduce_axis(Range(0, kernel_height)); auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width)); @@ -711,7 +711,8 @@ inline Tensor pool_impl_nd(const Tensor& x, pad_before.Set(ii, pad_head[i]); pad_after.Set(ii, pad_tail[i]); - auto out_dim = tvm::tir::Simplify( + arith::Analyzer analyzer; + auto out_dim = analyzer.Simplify( indexdiv(x->shape[ii] - kernel[i] + pad_head[i] + pad_tail[i], stride[i]) + 1); out_shape.Set(ii, out_dim); diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index 431ace5bc11e..0609020b5c81 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -375,12 +375,12 @@ inline Tensor concatenate(const Array& inputs, for (auto t : inputs) { axis_sizes.push_back(t->shape[axis]); } - + arith::Analyzer analyzer; PrimExpr join_size = axis_sizes[0]; for (size_t i = 1; i < axis_sizes.size(); ++i) { join_size += axis_sizes[i]; } - join_size = tvm::tir::Simplify(join_size); + join_size = analyzer.Simplify(join_size); Array out_shape; for (size_t i = 0; i < inputs[0]->shape.size(); ++i) { out_shape.push_back(i == static_cast(axis) ? join_size : inputs[0]->shape[i]); diff --git a/topi/python/topi/cuda/depthwise_conv2d.py b/topi/python/topi/cuda/depthwise_conv2d.py index db9da844e3af..b7cb32d58d01 100644 --- a/topi/python/topi/cuda/depthwise_conv2d.py +++ b/topi/python/topi/cuda/depthwise_conv2d.py @@ -167,7 +167,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/intel_graphics/depthwise_conv2d.py b/topi/python/topi/intel_graphics/depthwise_conv2d.py index a54941315a1a..650809985279 100644 --- a/topi/python/topi/intel_graphics/depthwise_conv2d.py +++ b/topi/python/topi/intel_graphics/depthwise_conv2d.py @@ -168,7 +168,7 @@ def _schedule(temp, Filter, DepthwiseConv2d): b, h, w, c = s[Output].op.axis # num_thread here could be 728, it is larger than cuda.max_num_threads - num_thread = tvm.tir.ir_pass.Simplify(temp.shape[3]).value + num_thread = tvm.arith.Analyzer().simplify(temp.shape[3]).value target = tvm.target.Target.current() if target and (target.target_name not in ["cuda", "nvptx"]): num_thread = target.max_num_threads diff --git a/topi/python/topi/nn/dilate.py b/topi/python/topi/nn/dilate.py index f628fadee96e..ebcf478033fb 100644 --- a/topi/python/topi/nn/dilate.py +++ b/topi/python/topi/nn/dilate.py @@ -45,9 +45,9 @@ def dilate(data, strides, name="DilatedInput"): if len(strides) != n: raise ValueError("data dimension and strides size dismatch : %d vs %d" % ( n, len(strides))) - + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) + ana.simplify((data.shape[i] - 1) * strides[i] + 1) for i in range(n)) def _dilate(*indices): not_zero = [] diff --git a/topi/python/topi/nn/pad.py b/topi/python/topi/nn/pad.py index 8fe53374f2b5..b298a0a2bb95 100644 --- a/topi/python/topi/nn/pad.py +++ b/topi/python/topi/nn/pad.py @@ -55,9 +55,9 @@ def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % ( n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify( - (data.shape[i] + pad_before[i] + pad_after[i])) for i in range(n)) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) pad_value = (pad_value if isinstance(pad_value, tvm.tir.PrimExpr) else tvm.tir.const(pad_value, data.dtype)) def _pad(*indices): @@ -115,8 +115,9 @@ def mirror_pad(data, if len(pad_after) != n: raise ValueError("Input dimension and pad_after dismatch : %d vs %d" % (n, len(pad_before))) + ana = tvm.arith.Analyzer() out_shape = tuple( - tvm.tir.ir_pass.Simplify((data.shape[i] + pad_before[i] + pad_after[i])) + ana.simplify(data.shape[i] + pad_before[i] + pad_after[i]) for i in range(n)) assert mode in ('SYMMETRIC', 'REFLECT') mode = int(mode == 'SYMMETRIC') diff --git a/topi/python/topi/util.py b/topi/python/topi/util.py index 50a6a36edc46..cc437325e0d6 100644 --- a/topi/python/topi/util.py +++ b/topi/python/topi/util.py @@ -101,7 +101,8 @@ def get_const_int(expr): if isinstance(expr, Integral): return expr if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): raise ValueError("Expect value to be constant int") return int(expr.value) @@ -123,7 +124,8 @@ def get_const_float(expr): if isinstance(expr, float): return float(expr) if not isinstance(expr, tvm.tir.FloatImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.FloatImm): raise ValueError("Expect value to be constant float") return float(expr.value) @@ -145,7 +147,8 @@ def equal_const_int(expr, value): if isinstance(expr, Integral): return expr == value if not isinstance(expr, tvm.tir.IntImm): - expr = tvm.tir.ir_pass.Simplify(expr) + ana = tvm.arith.Analyzer() + expr = ana.simplify(expr) if not isinstance(expr, tvm.tir.IntImm): return False return expr.value == value @@ -165,11 +168,13 @@ def get_const_tuple(in_tuple): The output. """ ret = [] + ana = None for elem in in_tuple: if isinstance(elem, (tvm.tir.Var, tvm.tir.expr.Any)): ret.append(elem) elif not isinstance(elem, (tvm.tir.IntImm, int)): - elem = tvm.tir.ir_pass.Simplify(elem) + ana = tvm.arith.Analyzer() if ana is None else ana + elem = ana.simplify(elem) if not isinstance(elem, tvm.tir.IntImm): ret.append(elem) else: @@ -208,7 +213,7 @@ def simplify(expr): out : Expr or int The simplified output """ - return tvm.tir.ir_pass.Simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr + return tvm.arith.Analyzer().simplify(expr) if isinstance(expr, tvm.tir.PrimExpr) else expr def ravel_index(indices, shape): diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index c2684d159b98..9836d133ceb7 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -364,6 +364,7 @@ def _fold_buffer_dim(buf, scope, elem_block): shape.append(1) strides.append(elem_block) + analyzer = tvm.arith.Analyzer() while base < ndim + 1: x_size = 1 x_stride = buf.strides[ndim - base] @@ -378,7 +379,7 @@ def _fold_buffer_dim(buf, scope, elem_block): break x_size = x_size * buf.shape[k] next_base = i + 1 - shape.append(tvm.tir.ir_pass.Simplify(x_size)) + shape.append(analyzer.simplify(x_size)) strides.append(x_stride) assert next_base != base base = next_base @@ -769,10 +770,11 @@ def inject_alu_intrin(stmt_in): """ env = get_env() idxm = tvm.tir.indexmod + analyzer = tvm.arith.Analyzer() def _do_fold(stmt): def _equal(x, y): - return tvm.ir.structural_equal(tvm.tir.ir_pass.Simplify(x - y), 0) + return tvm.ir.structural_equal(analyzer.simplify(x - y), 0) def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff = list(src_coeff) @@ -791,7 +793,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): next_ext = extents.pop() if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext): - vext = tvm.tir.ir_pass.Simplify(vext * next_ext) + vext = analyzer.simplify(vext * next_ext) else: rev_src_coeff.append(vsrc) rev_dst_coeff.append(vdst) @@ -851,7 +853,7 @@ def _flatten_loop(src_coeff, dst_coeff, extents): if loop_body.value.name == 'shift_left': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] - rhs = tvm.tir.ir_pass.Simplify(-loop_body.value.args[1]) + rhs = analyzer.simplify(-loop_body.value.args[1]) elif loop_body.value.name == 'shift_right': alu_opcode = env.dev.ALU_OPCODE_SHR lhs = loop_body.value.args[0] @@ -914,10 +916,10 @@ def _flatten_loop(src_coeff, dst_coeff, extents): assert len(dst_coeff) > 1 assert len(extents) != 0 assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( + analyzer.simplify( idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir.structural_equal( - tvm.tir.ir_pass.Simplify( + analyzer.simplify( idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0) assert tvm.ir.structural_equal(src_coeff[-2], 1) assert tvm.ir.structural_equal(dst_coeff[-2], 1) @@ -942,9 +944,9 @@ def _flatten_loop(src_coeff, dst_coeff, extents): src_coeff.append(src_offset) dst_coeff.append(dst_offset) src_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff] dst_coeff = [ - tvm.tir.ir_pass.Simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] + analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff] # Flatten the outer loops if extents: