diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index e5b2c2b6957c..bd590cd144d8 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -141,6 +141,12 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); */ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + * \param expr The expr to be calculated. + */ +TVM_DLL size_t CalculateExprComplexity(const PrimExpr& expr); + // Pass variants of verification analysis // directly throws RuntimeError when verification fails. namespace transform { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index ba549959ac98..732b18debe79 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -544,7 +544,12 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { explicit Impl(Analyzer* parent) : Rewriter(parent) {} PrimExpr CanonicalSimplify(PrimExpr expr) { - expr = operator()(expr); + // same as kMaxFusedOps. + // avoid long compile time of tflite quantized model + constexpr static size_t kMaxPrimOps = 256; + if (CalculateExprComplexity(expr) < kMaxPrimOps) { + expr = operator()(expr); + } return expr; } diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index a58e4433dadd..491351622cce 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1598,10 +1598,15 @@ PrimExpr RewriteSimplifier::operator()(const PrimExpr& expr) { // Run simplification in post order PrimExpr res = expr; int max_iter = 2; - for (int i = 0; i < max_iter; ++i) { - PrimExpr new_expr = impl_->operator()(res); - if (new_expr.same_as(res)) return res; - res = new_expr; + // same as kMaxFusedOps. + // avoid long compile time of tflite quantized model + constexpr static size_t kMaxPrimOps = 256; + if (CalculateExprComplexity(expr) < kMaxPrimOps) { + for (int i = 0; i < max_iter; ++i) { + PrimExpr new_expr = impl_->operator()(res); + if (new_expr.same_as(res)) return res; + res = new_expr; + } } return res; } diff --git a/src/arith/solve_linear_inequality.cc b/src/arith/solve_linear_inequality.cc index dd9044833546..6aad5b7b0a25 100644 --- a/src/arith/solve_linear_inequality.cc +++ b/src/arith/solve_linear_inequality.cc @@ -39,58 +39,9 @@ namespace arith { using namespace tvm::runtime; using namespace tvm::tir; -#define PLUS_ONE(OP) \ - void VisitExpr_(const OP* op) final { num_symbols_++; } - -#define PLUS_ONE_BINARY(OP) \ - void VisitExpr_(const OP* op) final { \ - num_symbols_++; \ - VisitExpr(op->a); \ - VisitExpr(op->b); \ - } - -/*! - * \brief Calculate the expresion complexity based on number of symbols it contains. - */ -class ExprComplexity : public ExprVisitor { - public: - size_t Eval(const PrimExpr& expr) { - VisitExpr(expr); - return num_symbols_; - } - - PLUS_ONE_BINARY(AddNode) - PLUS_ONE_BINARY(SubNode) - PLUS_ONE_BINARY(MulNode) - PLUS_ONE_BINARY(DivNode) - PLUS_ONE_BINARY(ModNode) - PLUS_ONE_BINARY(FloorDivNode) - PLUS_ONE_BINARY(FloorModNode) - PLUS_ONE_BINARY(MinNode) - PLUS_ONE_BINARY(MaxNode) - PLUS_ONE_BINARY(EQNode) - PLUS_ONE_BINARY(NENode) - PLUS_ONE_BINARY(LTNode) - PLUS_ONE_BINARY(LENode) - PLUS_ONE_BINARY(GTNode) - PLUS_ONE_BINARY(GENode) - PLUS_ONE_BINARY(AndNode) - PLUS_ONE_BINARY(OrNode) - PLUS_ONE(VarNode) - PLUS_ONE(FloatImmNode) - PLUS_ONE(IntImmNode) - void VisitExpr_(const NotNode* op) final { - num_symbols_++; - VisitExpr(op->a); - } - - private: - size_t num_symbols_{0}; -}; - struct ExprLess { bool operator()(const PrimExpr& l, const PrimExpr& r) const { - return ExprComplexity().Eval(l) < ExprComplexity().Eval(r); + return CalculateExprComplexity(l) < CalculateExprComplexity(r); } }; diff --git a/src/tir/analysis/expr_complexity.cc b/src/tir/analysis/expr_complexity.cc new file mode 100644 index 000000000000..e39316718e37 --- /dev/null +++ b/src/tir/analysis/expr_complexity.cc @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/expr_complexity.cc + * \brief Calcute expr complexity. + */ +#include +#include + +namespace tvm { +namespace tir { + +#define PLUS_ONE(OP) \ + void VisitExpr_(const OP* op) final { num_symbols_++; } + +#define PLUS_ONE_BINARY(OP) \ + void VisitExpr_(const OP* op) final { \ + num_symbols_++; \ + VisitExpr(op->a); \ + VisitExpr(op->b); \ + } + +/*! + * \brief Calculate the expresion complexity based on number of symbols it contains. + */ +class ExprComplexity : public ExprVisitor { + public: + size_t Eval(const PrimExpr& expr) { + VisitExpr(expr); + return num_symbols_; + } + + PLUS_ONE_BINARY(AddNode) + PLUS_ONE_BINARY(SubNode) + PLUS_ONE_BINARY(MulNode) + PLUS_ONE_BINARY(DivNode) + PLUS_ONE_BINARY(ModNode) + PLUS_ONE_BINARY(FloorDivNode) + PLUS_ONE_BINARY(FloorModNode) + PLUS_ONE_BINARY(MinNode) + PLUS_ONE_BINARY(MaxNode) + PLUS_ONE_BINARY(EQNode) + PLUS_ONE_BINARY(NENode) + PLUS_ONE_BINARY(LTNode) + PLUS_ONE_BINARY(LENode) + PLUS_ONE_BINARY(GTNode) + PLUS_ONE_BINARY(GENode) + PLUS_ONE_BINARY(AndNode) + PLUS_ONE_BINARY(OrNode) + PLUS_ONE(VarNode) + PLUS_ONE(FloatImmNode) + PLUS_ONE(IntImmNode) + void VisitExpr_(const NotNode* op) final { + num_symbols_++; + VisitExpr(op->a); + } + + private: + size_t num_symbols_{0}; +}; + +size_t CalculateExprComplexity(const PrimExpr& expr) { return ExprComplexity().Eval(expr); } + +} // namespace tir +} // namespace tvm