From b3b75eeca92046f9197d78b66454d38e5e71938f Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 7 Apr 2023 10:47:52 -0500 Subject: [PATCH 1/7] [Arith] Implement statistics counters for RewriteSimplifier Previously, so long as `RewriteSimplifier` produces the same output, unit tests of its behavior would pass. This could have severe performance regressions, such as the one resolved in https://github.com/apache/tvm/pull/14528, which caused the runtime of two test to increase from ~1.5 seconds to ~10 minutes each. This commit implements statistics counts in RewriteSimplifier, which are exposed through both the C++ and Python APIs, and uses these to guard against the known performance regression from https://github.com/apache/tvm/pull/14528. --- include/tvm/arith/analyzer.h | 8 +++ python/tvm/arith/analyzer.py | 9 ++++ src/arith/analyzer.cc | 7 +++ src/arith/rewrite_simplify.cc | 35 ++++++++++++ src/arith/rewrite_simplify.h | 54 +++++++++++++++++++ src/tir/analysis/control_flow_graph.cc | 7 ++- src/tir/analysis/control_flow_graph.h | 6 ++- src/tir/transforms/remove_no_op.cc | 18 +++++-- .../test_tir_transform_remove_no_op.py | 16 ++++++ 9 files changed, 154 insertions(+), 6 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 885c23f49186..8ca56a2eac48 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -330,6 +330,14 @@ class RewriteSimplifier { /*! \brief Return the currently enabled extensions */ TVM_DLL Extension GetEnabledExtensions() const; + /*! \brief Return the statistics counters */ + TVM_DLL ObjectRef GetStatsCounters() const; + + /*! \brief Reset the statistics counters */ + TVM_DLL void ResetStatsCounters(); + + TVM_DLL void SetMaximumRewriteSteps(int maximum); + private: friend class Analyzer; friend class ConstraintContext; diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 28adbe9d815f..53c28df25d31 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -87,6 +87,8 @@ def __init__(self): self._modular_set = _mod("modular_set") self._simplify = _mod("Simplify") self._rewrite_simplify = _mod("rewrite_simplify") + self._get_rewrite_simplify_stats = _mod("get_rewrite_simplify_stats") + self._reset_rewrite_simplify_stats = _mod("reset_rewrite_simplify_stats") self._canonical_simplify = _mod("canonical_simplify") self._int_set = _mod("int_set") self._enter_constraint_context = _mod("enter_constraint_context") @@ -157,6 +159,13 @@ def rewrite_simplify(self, expr): """ return self._rewrite_simplify(expr) + @property + def rewrite_simplify_stats(self): + return self._get_rewrite_simplify_stats() + + def reset_rewrite_simplify_stats(self): + self._reset_rewrite_simplify_stats() + def canonical_simplify(self, expr): """Simplify expression via canonicalization. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 4714cf1df59f..48aca7c69947 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -175,6 +175,13 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu } else if (name == "rewrite_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->rewrite_simplify(args[0]); }); + } else if (name == "get_rewrite_simplify_stats") { + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + *ret = self->rewrite_simplify.GetStatsCounters(); + }); + } else if (name == "reset_rewrite_simplify_stats") { + return PackedFunc( + [self](TVMArgs args, TVMRetValue* ret) { self->rewrite_simplify.ResetStatsCounters(); }); } else if (name == "canonical_simplify") { return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index e44ef31da1dd..c6eab073de6c 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -58,25 +58,33 @@ using namespace tir; // macro for doing simple rewrite #define TVM_TRY_REWRITE(SrcExpr, ResExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret)) { \ + RecordRewrite(); \ return (ResExpr).Eval(); \ } // macro for rewrite + recursively rewrite ResExpr #define TVM_TRY_RECURSIVE_REWRITE(SrcExpr, ResExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret)) { \ + RecordRewrite(); \ return RecursiveRewrite((ResExpr).Eval()); \ } // macro rewrite only if CondExor is true after match. #define TVM_TRY_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \ + RecordRewrite(); \ return (ResExpr).Eval(); \ } // macro rewrite + recursive_rewrite only if CondExor is true after match. #define TVM_TRY_RECURSIVE_REWRITE_IF(SrcExpr, ResExpr, CondExpr) \ + RecordAttemptedRewrite(); \ if ((SrcExpr).Match(ret, [&]() { return (CondExpr); })) { \ + RecordRewrite(); \ return RecursiveRewrite((ResExpr).Eval()); \ } @@ -203,6 +211,11 @@ CompareResult RewriteSimplifier::Impl::TryCompare(const PrimExpr& x, int64_t val return CompareResult::kUnknown; } +PrimExpr RewriteSimplifier::Impl::VisitExpr(const PrimExpr& e) { + stats_.nodes_visited++; + return IRMutatorWithAnalyzer::VisitExpr(e); +} + void RewriteSimplifier::Impl::Update(const Var& var, const PrimExpr& info, bool can_override) { if (!can_override) { auto it = var_map_.find(var); @@ -342,6 +355,7 @@ std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& c literal_constraints_.push_back(Not(negation)); } } + stats_.constraints_entered++; size_t new_literal_size = literal_constraints_.size(); auto frecover = [old_literal_size, new_literal_size, this]() { ICHECK_EQ(literal_constraints_.size(), new_literal_size); @@ -2133,9 +2147,30 @@ RewriteSimplifier::Extension RewriteSimplifier::GetEnabledExtensions() const { return impl_->GetEnabledExtensions(); } +ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCounters(); } + +void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); } + +void RewriteSimplifier::SetMaximumRewriteSteps(int maximum) { + impl_->SetMaximumRewriteSteps(maximum); +} + RewriteSimplifier::RewriteSimplifier(Analyzer* parent) : impl_(new Impl(parent)) {} RewriteSimplifier::~RewriteSimplifier() { delete impl_; } +TVM_REGISTER_NODE_TYPE(RewriteSimplifierStatsNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* ptr = node.as(); + p->stream << "RewriteSimplifierStats(nodes_visited = " << ptr->nodes_visited + << ", constraints_entered = " << ptr->constraints_entered + << ", rewrites_attempted = " << ptr->rewrites_attempted + << ", rewrites_performed = " << ptr->rewrites_performed + << ", max_recursive_depth = " << ptr->max_recursive_depth + << ", num_recursive_rewrites = " << ptr->num_recursive_rewrites << ")"; + }); + } // namespace arith } // namespace tvm diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b8e7fcdd9433..2c35ddda1607 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -39,6 +39,36 @@ namespace arith { using namespace tir; +struct RewriteSimplifierStatsNode : Object { + int nodes_visited{0}; + int constraints_entered{0}; + int rewrites_attempted{0}; + int rewrites_performed{0}; + int max_recursive_depth{0}; + int num_recursive_rewrites{0}; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("nodes_visited", &nodes_visited); + v->Visit("constraints_entered", &constraints_entered); + v->Visit("rewrites_attempted", &rewrites_attempted); + v->Visit("rewrites_performed", &rewrites_performed); + v->Visit("max_recursive_depth", &max_recursive_depth); + v->Visit("num_recursive_rewrites", &num_recursive_rewrites); + } + + static constexpr const char* _type_key = "arith.RewriteSimplifierStats"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteSimplifierStatsNode, Object); +}; + +struct RewriteSimplifierStats : ObjectRef { + RewriteSimplifierStats(RewriteSimplifierStatsNode data) { + data_ = make_object(data); + } + + TVM_DEFINE_OBJECT_REF_METHODS(RewriteSimplifierStats, ObjectRef, RewriteSimplifierStatsNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(RewriteSimplifierStatsNode); +}; + /*! * \brief Rewrite-based simplifier. * @@ -50,6 +80,8 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { explicit Impl(Analyzer* parent) : IRMutatorWithAnalyzer(parent) {} + PrimExpr VisitExpr(const PrimExpr& e) override; + void Update(const Var& var, const PrimExpr& info, bool override_info); PrimExpr VisitExpr_(const AddNode* op) override; PrimExpr VisitExpr_(const SubNode* op) override; @@ -87,7 +119,27 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { /*! \brief Return the currently enabled extensions */ Extension GetEnabledExtensions() const; + RewriteSimplifierStats GetStatsCounters() const { return RewriteSimplifierStats(stats_); } + + void ResetStatsCounters() { stats_ = {}; } + + void SetMaximumRewriteSteps(int maximum) { maximum_rewrite_steps_ = maximum; }; + protected: + int maximum_rewrite_steps_{0}; + RewriteSimplifierStatsNode stats_; + + void RecordAttemptedRewrite() { stats_.rewrites_attempted++; } + void RecordRewrite() { + stats_.rewrites_performed++; + + ICHECK(maximum_rewrite_steps_ <= 0 || stats_.rewrites_performed <= maximum_rewrite_steps_) + << "RewriteSimplifier exceeded maximum number of rewrites allowed (" + << maximum_rewrite_steps_ << ")"; + } + + bool is_currently_visiting_{false}; + // counter to record recursive rewrite depth. int recur_depth_{0}; // internal variable map @@ -178,8 +230,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // we limit maximum depth of recursive rewrite allowed to // avoid infinite loop PrimExpr RecursiveRewrite(const PrimExpr& x) { + stats_.num_recursive_rewrites++; if (recur_depth_ >= kMaxRecurDepth) return x; ++recur_depth_; + stats_.max_recursive_depth = std::max(recur_depth_, stats_.max_recursive_depth); PrimExpr res = this->VisitExpr(x); --recur_depth_; return res; diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 86ce4e21351f..22d0e8e4ad08 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -820,8 +820,9 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph return buffer_touch; } -ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, size_t max_revisits) - : max_revisits_(max_revisits) { +ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int max_simplification_steps, + size_t max_revisits) + : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { ControlFlowGraphBuilder::Build(this, stmt); ForwardPropagateKnownValues(); BackwardPropagateUnusedValues(); @@ -1377,6 +1378,7 @@ void ControlFlowGraph::ForwardPropagateKnownValues(std::optional flow_fr std::unordered_map visit_count_lookup; Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( arith::RewriteSimplifier::kTransitivelyProveInequalities | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | @@ -1510,6 +1512,7 @@ void ControlFlowGraph::BackwardPropagateUnusedValues(std::optional flow_ std::unordered_map visit_count_lookup; Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(max_simplification_steps_); analyzer.rewrite_simplify.SetEnabledExtensions(arith::RewriteSimplifier::Extension( arith::RewriteSimplifier::kTransitivelyProveInequalities | arith::RewriteSimplifier::kConvertBooleanToAndOfOrs | diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index f2e46b2478a3..12a0301f864f 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -399,7 +399,8 @@ class ControlFlowGraph { public: /* \brief Extract the touch pattern from a TIR statement */ - explicit ControlFlowGraph(const Stmt& stmt, size_t max_revisits = 5); + explicit ControlFlowGraph(const Stmt& stmt, int max_simplification_steps = 0, + size_t max_revisits = 5); /* \brief Check if a write is overwritten without impacting final results * @@ -655,6 +656,9 @@ class ControlFlowGraph { /*! \brief The maximum number of revisits while flowing constraints */ size_t max_revisits_; + + /*! \brief The maximum number of revisits while flowing constraints */ + int max_simplification_steps_; }; } // namespace tir diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index d35cf8b8d602..996e93285870 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -42,6 +42,7 @@ namespace tir { struct RemoveNoOpConfigNode : public tvm::AttrsNode { bool use_dataflow_analysis; + int max_simplification_steps; TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") { TVM_ATTR_FIELD(use_dataflow_analysis) @@ -49,6 +50,12 @@ struct RemoveNoOpConfigNode : public tvm::AttrsNode { "If true, known buffer values are propagated and used " "to statically prove statements as no-ops.") .set_default(false); + TVM_ATTR_FIELD(max_simplification_steps) + .describe( + "If non-zero, RewriteSimplifier will throw an error " + "after the number of steps specified. " + "For use in debug and testing purposes.") + .set_default(0); } }; @@ -316,14 +323,19 @@ Pass RemoveNoOp() { RemoveNoOpConfig config = ctx->GetConfig("tir.RemoveNoOp") .value_or(AttrsWithDefaultValues()); + if (config->use_dataflow_analysis) { - touch_pattern.emplace(f->body); + touch_pattern.emplace(f->body, config->max_simplification_steps); } arith::Analyzer analyzer; + analyzer.rewrite_simplify.SetMaximumRewriteSteps(config->max_simplification_steps); - auto* n = f.CopyOnWrite(); - n->body = NoOpRemover::Apply(std::move(n->body), &analyzer, std::move(touch_pattern), nullptr); + { + auto* write_ptr = f.CopyOnWrite(); + write_ptr->body = NoOpRemover::Apply(std::move(write_ptr->body), &analyzer, + std::move(touch_pattern), nullptr); + } return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.RemoveNoOp", {}); diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 15c5a577f9f5..189423fcbf79 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -86,12 +86,14 @@ def main(A: T.Buffer((16), "int32"), B: T.Buffer((16), "int32")) -> None: class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): use_dataflow_analysis = False + max_simplification_steps = 0 def transform(self): def inner(mod): config = { "tir.RemoveNoOp": { "use_dataflow_analysis": self.use_dataflow_analysis, + "max_simplification_steps": self.max_simplification_steps, } } with tvm.transform.PassContext(config=config): @@ -319,9 +321,16 @@ class TestRemoveOverwrittenPredicatedLoopWithIdenticalCondition(BaseBeforeAfter) Similar to TestKeepPartiallyOverwrittenLoop, except the first loop has the same predicate as the second, and can therefore be removed. + + In the past, this test has had performance regressions in which + the runtime increased from a few seconds to nearly ten minutes. + The "max_simplification_steps" parameter is set at twice the + current number of steps required, in order to prevent similar + performance regression. """ use_dataflow_analysis = True + max_simplification_steps = 200000 def before(A: T.Buffer(16, "int32")): for i in T.serial(16): @@ -347,9 +356,16 @@ class TestRemoveOverwrittenPredicatedLoopWithProvableCondition(BaseBeforeAfter): loop's predicate. So long as the regions written in the first loop are a subset of those written in the second loop, they can be removed. + + In the past, this test has had performance regressions in which + the runtime increased from a few seconds to nearly ten minutes. + The "max_simplification_steps" parameter is set at twice the + current number of steps required, in order to prevent similar + performance regression. """ use_dataflow_analysis = True + max_simplification_steps = 200000 def before(A: T.Buffer(16, "int32")): for i in T.serial(16): From 82d1068f87436ba2b8f6371b1c6218db35cd5ed9 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 8 Apr 2023 12:49:15 -0500 Subject: [PATCH 2/7] lint fixes --- src/arith/rewrite_simplify.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 2c35ddda1607..5010a503e7ef 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -61,7 +62,7 @@ struct RewriteSimplifierStatsNode : Object { }; struct RewriteSimplifierStats : ObjectRef { - RewriteSimplifierStats(RewriteSimplifierStatsNode data) { + explicit RewriteSimplifierStats(RewriteSimplifierStatsNode data) { data_ = make_object(data); } @@ -123,7 +124,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { void ResetStatsCounters() { stats_ = {}; } - void SetMaximumRewriteSteps(int maximum) { maximum_rewrite_steps_ = maximum; }; + void SetMaximumRewriteSteps(int maximum) { maximum_rewrite_steps_ = maximum; } protected: int maximum_rewrite_steps_{0}; From f5c3c8da6ebcc21bb841b20703e5b586c1c6f84b Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 10 Apr 2023 08:40:30 -0500 Subject: [PATCH 3/7] Updates based on review comments --- include/tvm/arith/analyzer.h | 15 ++++++++++++++- src/arith/rewrite_simplify.cc | 2 +- src/arith/rewrite_simplify.h | 21 +++++++++++++-------- src/tir/transforms/remove_no_op.cc | 2 +- 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 8ca56a2eac48..a1beab9a7a23 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -336,7 +336,20 @@ class RewriteSimplifier { /*! \brief Reset the statistics counters */ TVM_DLL void ResetStatsCounters(); - TVM_DLL void SetMaximumRewriteSteps(int maximum); + /*! \brief Set the maximum allowed number of rewrite steps + * + * By default, the simplifier may perform as many steps as are + * required. If a positive limit is set, then the simplifier will + * throw an exception when exceeding that number of rewrite steps. + * This allows tests to guard against performance regressions. + * + * Note: To maintain accurate usage counters, `Analyzer` instances + * should be re-used wherever possible. For example, TIR + * transformations should declare a single `Analyzer` that is used + * throughout the pass, and utility functions should receive an + * `Analyzer*` from their calling scope. + */ + TVM_DLL void SetMaximumRewriteSteps(int64_t maximum); private: friend class Analyzer; diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index c6eab073de6c..64a46a9daa7a 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -2151,7 +2151,7 @@ ObjectRef RewriteSimplifier::GetStatsCounters() const { return impl_->GetStatsCo void RewriteSimplifier::ResetStatsCounters() { impl_->ResetStatsCounters(); } -void RewriteSimplifier::SetMaximumRewriteSteps(int maximum) { +void RewriteSimplifier::SetMaximumRewriteSteps(int64_t maximum) { impl_->SetMaximumRewriteSteps(maximum); } diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 5010a503e7ef..7f06e323a11c 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -40,13 +40,18 @@ namespace arith { using namespace tir; +/* Record of + * + * These are intended for debug and testing purposes, to ensure that + * PrimExpr simplifications and TIR passes do not require an excessive + */ struct RewriteSimplifierStatsNode : Object { - int nodes_visited{0}; - int constraints_entered{0}; - int rewrites_attempted{0}; - int rewrites_performed{0}; - int max_recursive_depth{0}; - int num_recursive_rewrites{0}; + int64_t nodes_visited{0}; + int64_t constraints_entered{0}; + int64_t rewrites_attempted{0}; + int64_t rewrites_performed{0}; + int64_t max_recursive_depth{0}; + int64_t num_recursive_rewrites{0}; void VisitAttrs(AttrVisitor* v) { v->Visit("nodes_visited", &nodes_visited); @@ -124,10 +129,10 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { void ResetStatsCounters() { stats_ = {}; } - void SetMaximumRewriteSteps(int maximum) { maximum_rewrite_steps_ = maximum; } + void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = maximum; } protected: - int maximum_rewrite_steps_{0}; + int64_t maximum_rewrite_steps_{0}; RewriteSimplifierStatsNode stats_; void RecordAttemptedRewrite() { stats_.rewrites_attempted++; } diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index 996e93285870..e58a49108910 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -42,7 +42,7 @@ namespace tir { struct RemoveNoOpConfigNode : public tvm::AttrsNode { bool use_dataflow_analysis; - int max_simplification_steps; + int64_t max_simplification_steps; TVM_DECLARE_ATTRS(RemoveNoOpConfigNode, "tir.transform.RemoveNoOpConfig") { TVM_ATTR_FIELD(use_dataflow_analysis) From 60fbf7754898e43e1e28142b0f9d68993382e8df Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 10 Apr 2023 09:24:05 -0500 Subject: [PATCH 4/7] Consistent int64_t with kMaxRecurDepth --- src/arith/rewrite_simplify.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 8c2caef418ff..aedea1205d0d 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -147,7 +147,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool is_currently_visiting_{false}; // counter to record recursive rewrite depth. - int recur_depth_{0}; + int64_t recur_depth_{0}; // internal variable map std::unordered_map var_map_; @@ -161,7 +161,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { bool recursively_visiting_boolean_{false}; // maximum number of recursion allowed during a single pass. - static const constexpr int kMaxRecurDepth = 5; + static const constexpr int64_t kMaxRecurDepth = 5; /*! * \brief try to compare x against val. * \param x The expression to be evaluated. From 02a02495cf0be7a378cf34f3a7bf0eaa98ee6c58 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 17 Apr 2023 08:34:01 -0500 Subject: [PATCH 5/7] Removed unused is_currently_visiting_ --- src/arith/rewrite_simplify.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index aedea1205d0d..f435732d21ca 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -144,8 +144,6 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { << maximum_rewrite_steps_ << ")"; } - bool is_currently_visiting_{false}; - // counter to record recursive rewrite depth. int64_t recur_depth_{0}; // internal variable map From 405eab39cd5d89a1b2b0ebec2fa23e205481420d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 17 Apr 2023 08:34:16 -0500 Subject: [PATCH 6/7] Add missing \brief for RewriteSimplifierStatsNode --- src/arith/rewrite_simplify.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index f435732d21ca..7c4b0eab2224 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -40,7 +40,7 @@ namespace arith { using namespace tir; -/* Record of +/* \brief Usage counters for RewriteSimplifier * * These are intended for debug and testing purposes, to ensure that * PrimExpr simplifications and TIR passes do not require an excessive From 2a8f181025745e4d0d699c9fd1b3ab95f6c5ea78 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 17 Apr 2023 08:34:40 -0500 Subject: [PATCH 7/7] Use int64_t in ControlFlowGraph for max simplification steps --- src/tir/analysis/control_flow_graph.cc | 2 +- src/tir/analysis/control_flow_graph.h | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/analysis/control_flow_graph.cc b/src/tir/analysis/control_flow_graph.cc index 22d0e8e4ad08..59f53ff64bc8 100644 --- a/src/tir/analysis/control_flow_graph.cc +++ b/src/tir/analysis/control_flow_graph.cc @@ -820,7 +820,7 @@ BufferTouch ControlFlowGraph::ControlFlowBlock::MakeBufferTouch(ControlFlowGraph return buffer_touch; } -ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int max_simplification_steps, +ControlFlowGraph::ControlFlowGraph(const tir::Stmt& stmt, int64_t max_simplification_steps, size_t max_revisits) : max_revisits_(max_revisits), max_simplification_steps_(max_simplification_steps) { ControlFlowGraphBuilder::Build(this, stmt); diff --git a/src/tir/analysis/control_flow_graph.h b/src/tir/analysis/control_flow_graph.h index 12a0301f864f..35934351dce0 100644 --- a/src/tir/analysis/control_flow_graph.h +++ b/src/tir/analysis/control_flow_graph.h @@ -399,7 +399,7 @@ class ControlFlowGraph { public: /* \brief Extract the touch pattern from a TIR statement */ - explicit ControlFlowGraph(const Stmt& stmt, int max_simplification_steps = 0, + explicit ControlFlowGraph(const Stmt& stmt, int64_t max_simplification_steps = 0, size_t max_revisits = 5); /* \brief Check if a write is overwritten without impacting final results @@ -658,7 +658,7 @@ class ControlFlowGraph { size_t max_revisits_; /*! \brief The maximum number of revisits while flowing constraints */ - int max_simplification_steps_; + int64_t max_simplification_steps_; }; } // namespace tir