From cc88b491cb4feb053e72d8476382612d6d2e8662 Mon Sep 17 00:00:00 2001 From: tqchen Date: Wed, 18 Jun 2025 15:38:58 -0400 Subject: [PATCH] [REFACTOR] Phase out LegacyReprPrinter and improve CommonSubExprElim This PR phases out LegacyReprPrinter. Previously common subexpr elim relies on sorting on legacy repr for determism, which is hacky. This PR introduces an ordered_map impl in support to ensure determinism and migrates the CSE pass to use that instead. --- include/tvm/node/repr_printer.h | 32 - src/ir/analysis.cc | 2 +- src/node/repr_printer.cc | 35 - src/node/script_printer.cc | 5 +- .../analysis/computable_at_compile_time.cc | 2 +- src/relax/analysis/udchain.cc | 4 +- src/relax/ir/binding_rewrite.cc | 3 +- src/relax/transform/inline_functions.cc | 2 +- src/relax/transform/run_codegen.cc | 2 +- src/script/printer/legacy_repr.cc | 894 ------------------ src/script/printer/utils.h | 14 +- src/support/ordered_map.h | 145 +++ src/support/ordered_set.h | 57 +- src/tir/transforms/common_subexpr_elim.cc | 42 +- src/tir/transforms/common_subexpr_elim.h | 3 +- .../transforms/common_subexpr_elim_tools.cc | 28 +- .../transforms/common_subexpr_elim_tools.h | 8 +- .../test_tir_transform_common_subexpr_elim.py | 58 +- ...est_tir_transform_inject_ptx_async_copy.py | 20 +- .../test_tir_transform_lower_tvm_builtin.py | 4 +- .../tvmscript/test_tvmscript_roundtrip.py | 10 +- 21 files changed, 249 insertions(+), 1121 deletions(-) delete mode 100644 src/script/printer/legacy_repr.cc create mode 100644 src/support/ordered_map.h diff --git a/include/tvm/node/repr_printer.h b/include/tvm/node/repr_printer.h index 30bfe8e95193..e3baf397f25f 100644 --- a/include/tvm/node/repr_printer.h +++ b/include/tvm/node/repr_printer.h @@ -52,32 +52,6 @@ class ReprPrinter { TVM_DLL static FType& vtable(); }; -/*! \brief Legacy behavior of ReprPrinter. */ -class ReprLegacyPrinter { - public: - /*! \brief The indentation level. */ - int indent{0}; - - explicit ReprLegacyPrinter(std::ostream& stream) // NOLINT(*) - : stream(stream) {} - - /*! \brief The node to be printed. */ - TVM_DLL void Print(const ObjectRef& node); - /*! \brief Print indent to the stream */ - TVM_DLL void PrintIndent(); - /*! \brief Could the LegacyPrinter dispatch the node */ - TVM_DLL static bool CanDispatch(const ObjectRef& node); - /*! \brief Return the ostream it maintains */ - TVM_DLL std::ostream& Stream() const; - // Allow registration to be printer. - using FType = NodeFunctor; - TVM_DLL static FType& vtable(); - - private: - /*! \brief The output stream */ - std::ostream& stream; -}; - /*! * \brief Dump the node to stderr, used for debug purposes. * \param node The input node @@ -113,12 +87,6 @@ inline std::ostream& operator<<(std::ostream& os, const Variant& n) { // return os; } -inline std::string AsLegacyRepr(const ObjectRef& n) { - std::ostringstream os; - ReprLegacyPrinter(os).Print(n); - return os.str(); -} } // namespace ffi -using ffi::AsLegacyRepr; } // namespace tvm #endif // TVM_NODE_REPR_PRINTER_H_ diff --git a/src/ir/analysis.cc b/src/ir/analysis.cc index 3a54085c2290..26a348bceee1 100644 --- a/src/ir/analysis.cc +++ b/src/ir/analysis.cc @@ -31,7 +31,7 @@ namespace ir { Map> CollectCallMap(const IRModule& mod) { struct CalleeCollectorImpl : CalleeCollector { void Mark(GlobalVar gvar) override { gvars.push_back(gvar); } - support::OrderedSet gvars; + support::OrderedSet gvars; }; Map> call_map; diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index aa999655c03d..69cb05c12106 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -97,38 +97,6 @@ ReprPrinter::FType& ReprPrinter::vtable() { return inst; } -void ReprLegacyPrinter::Print(const ObjectRef& node) { - static const FType& f = vtable(); - if (!node.defined()) { - stream << "(nullptr)"; - } else if (f.can_dispatch(node)) { - f(node, this); - } else { - try { - stream << node; // Use ReprPrinter - } catch (const tvm::Error& e) { - LOG(WARNING) << "ReprPrinter fails"; - stream << node->GetTypeKey() << '(' << node.get() << ')'; - } - } -} - -bool ReprLegacyPrinter::CanDispatch(const ObjectRef& node) { - static const FType& f = vtable(); - return !node.defined() || f.can_dispatch(node); -} - -void ReprLegacyPrinter::PrintIndent() { - for (int i = 0; i < indent; ++i) { - stream << ' '; - } -} - -ReprLegacyPrinter::FType& ReprLegacyPrinter::vtable() { - static FType inst; - return inst; -} - void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } @@ -138,7 +106,4 @@ TVM_FFI_REGISTER_GLOBAL("node.AsRepr").set_body_typed([](ffi::Any obj) { os << obj; return os.str(); }); - -TVM_FFI_REGISTER_GLOBAL("node.AsLegacyRepr").set_body_typed(ffi::AsLegacyRepr); - } // namespace tvm diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index ee7880f4485a..c81543579655 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -32,7 +32,10 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { if (!TVMScriptPrinter::vtable().can_dispatch(node)) { - return AsLegacyRepr(node); + std::ostringstream os; + ReprPrinter printer(os); + printer.Print(node); + return os.str(); } return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); } diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index ba163b51d6c9..5825895db7d6 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -83,7 +83,7 @@ class CompileTimeCollector : ExprVisitor { } } - support::OrderedSet known_relax_vars_; + support::OrderedSet known_relax_vars_; std::unordered_set known_tir_vars_; }; } // namespace diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index f62254b6959d..2f04d8659405 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -56,8 +56,8 @@ class UDChain : relax::ExprVisitor { private: Map bound_values; std::unordered_set forward_declarations; - std::unordered_map> usage_map; - support::OrderedSet outputs; + std::unordered_map> usage_map; + support::OrderedSet outputs; Optional cur_user_; diff --git a/src/relax/ir/binding_rewrite.cc b/src/relax/ir/binding_rewrite.cc index f35b443b5b39..11a0fd29a92f 100644 --- a/src/relax/ir/binding_rewrite.cc +++ b/src/relax/ir/binding_rewrite.cc @@ -321,7 +321,8 @@ Expr RemoveAllUnused(Expr expr) { auto var_usage = CollectVarUsage(expr); // For the purpose of - support::OrderedSet externally_exposed(var_usage.outputs.begin(), var_usage.outputs.end()); + support::OrderedSet externally_exposed( + var_usage.outputs.begin(), var_usage.outputs.end()); for (const auto& [var, expr] : var_usage.bound_values) { if (ContainsImpureCall(expr)) { externally_exposed.insert(var); diff --git a/src/relax/transform/inline_functions.cc b/src/relax/transform/inline_functions.cc index 26b106373ff0..e295226e9e72 100644 --- a/src/relax/transform/inline_functions.cc +++ b/src/relax/transform/inline_functions.cc @@ -138,7 +138,7 @@ class FunctionInliner : public ExprMutator { } const Map, Function>& replacements_; - support::OrderedSet inline_stack_; + std::unordered_set inline_stack_; }; } // namespace diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index d29bdaacb9b0..33d3f485a5e0 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -44,7 +44,7 @@ class CodeGenRunner : ExprMutator { Array entry_function_names) { IRModule mod = builder_->GetContextIRModule(); - support::OrderedSet entry_functions; + support::OrderedSet entry_functions; // Any user-provided functions are treated as entry functions. for (const auto& name : entry_function_names) { entry_functions.insert(mod->GetGlobalVar(name)); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc deleted file mode 100644 index 57dd691b8897..000000000000 --- a/src/script/printer/legacy_repr.cc +++ /dev/null @@ -1,894 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include -#include -#include -#include -#include - -#include - -#include "../../support/str_escape.h" - -namespace tvm { - -#define TVM_LEGACY_REPR_PRINTER_DEF_OP(Type) \ - ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, Type value) { \ - p.Stream() << value; \ - return p; \ - } - -TVM_LEGACY_REPR_PRINTER_DEF_OP(int); -TVM_LEGACY_REPR_PRINTER_DEF_OP(int64_t); -TVM_LEGACY_REPR_PRINTER_DEF_OP(float); -TVM_LEGACY_REPR_PRINTER_DEF_OP(double); -TVM_LEGACY_REPR_PRINTER_DEF_OP(char); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const char*); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const std::string&); -TVM_LEGACY_REPR_PRINTER_DEF_OP(runtime::DataType); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const void*); -TVM_LEGACY_REPR_PRINTER_DEF_OP(const String&); - -std::ostream& ReprLegacyPrinter::Stream() const { return stream; } - -ReprLegacyPrinter& operator<<(ReprLegacyPrinter& p, const ObjectRef& value) { - p.Stream() << AsLegacyRepr(value); - return p; -} - -ReprLegacyPrinter& operator<<(ReprLegacyPrinter& out, tir::ForKind type) { // NOLINT(*) - using tvm::tir::ForKind; - switch (type) { - case ForKind::kSerial: - out << "for"; - break; - case ForKind::kParallel: - out << "parallel"; - break; - case ForKind::kUnrolled: - out << "unrolled"; - break; - case ForKind::kVectorized: - out << "vectorized"; - break; - case ForKind::kThreadBinding: - out << "launch_thread"; - break; - } - return out; -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '['; - for (size_t i = 0; i < op->size(); ++i) { - if (i != 0) { - (*p) << ", "; - } - p->Print(op->at(i).cast()); - } - (*p) << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '{'; - for (auto it = op->begin(); it != op->end(); ++it) { - if (it != op->begin()) { - (*p) << ", "; - } - if (it->first.as()) { - (*p) << '\"' << Downcast(it->first) << "\": "; - } else { - p->Print(it->first.cast()); - (*p) << ": "; - } - p->Print(it->second.cast()); - } - (*p) << '}'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '['; - for (size_t i = 0; i < op->size; ++i) { - if (i != 0) { - (*p) << ", "; - } - (*p) << op->data[i]; - } - (*p) << ']'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - if (op->dtype == DataType::Int(32)) { - (*p) << op->value; - } else { - (*p) << "(" << op->dtype << ")" << op->value; - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - switch (op->dtype.bits()) { - case 64: - (*p) << op->value; - break; - case 32: - (*p) << op->value << 'f'; - break; - case 16: - (*p) << op->value << 'h'; - break; - default: - LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits(); - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "range(min=" << op->min << ", ext=" << op->extent << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << node->dtype; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - if (!node->storage_scope.empty()) { - (*p) << node->storage_scope << " "; - } - p->Print(node->element_type); - (*p) << '*'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "TupleTypeNode(" << node->fields << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->dict; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "GlobalVar(" << node->name_hint << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "IRModule(" << node->functions << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "FuncType(" << node->arg_types << ", " << node->ret_type << ")"; - }); - -} // namespace tvm - -namespace tvm { -namespace tir { - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "buffer(" << op->name << ", " << op << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - // omit the type - // stream << op->name << "." << op->type; - (*p) << op->name_hint; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "{" << op->name_hint << "|" << op->name_hint << ">=0}"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "iter_var("; - if (op->var->name_hint.length() != 0) { - (*p) << op->var->name_hint << ", "; - } - if (op->dom.defined()) { - (*p) << op->dom; - } - if (op->thread_tag.length() != 0) { - (*p) << ", " << op->thread_tag; - } - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '\"' << support::StrEscape(op->value) << '\"'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->dtype << '('; - p->Print(op->value); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " + "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " - "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << "*"; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << "/"; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " % "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "floordiv(" << op->a << ", " << op->b << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "floormod(" << op->a << ", " << op->b << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "min("; - p->Print(op->a); - (*p) << ", "; - p->Print(op->b); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "max("; - p->Print(op->a); - (*p) << ", "; - p->Print(op->b); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " == "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " != "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " < "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " <= "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " > "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " >= "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " && "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '('; - p->Print(op->a); - (*p) << " || "; - p->Print(op->b); - (*p) << ')'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << '!'; - p->Print(op->a); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "select("; - p->Print(op->condition); - (*p) << ", "; - p->Print(op->true_value); - (*p) << ", "; - p->Print(op->false_value); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "ramp("; - p->Print(op->base); - (*p) << ", "; - p->Print(op->stride); - (*p) << ", " << op->lanes << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "x" << op->lanes << "("; - p->Print(op->value); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "(let " << op->var << " = "; - p->Print(op->value); - (*p) << " in "; - p->Print(op->body); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - if (auto* ptr_op = op->op.as()) { - (*p) << ptr_op->name << "("; - } else { - auto* ptr_gvar = op->op.as(); - ICHECK(ptr_gvar != nullptr); - (*p) << "@" << ptr_gvar->name_hint << "("; - } - for (size_t i = 0; i < op->args.size(); ++i) { - p->Print(op->args[i]); - if (i < op->args.size() - 1) { - (*p) << ", "; - } - } - (*p) << ")"; - }); - -template -void PrintList(const Array& exprs, ReprLegacyPrinter* p) { - for (size_t i = 0; i < exprs.size(); ++i) { - p->Print(exprs[i]); - if (i < exprs.size() - 1) { - (*p) << ", "; - } - } -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "shuffle("; - PrintList(op->vectors, p); - (*p) << ", "; - PrintList(op->indices, p); - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "comm_reducer(result=" << op->result << ", lhs=" << op->lhs << ", rhs=" << op->rhs - << ", identity_element=" << op->identity_element << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << "reduce(combiner=" << op->combiner; - (*p) << ", source=" << op->source; - (*p) << ", init=" << op->init; - (*p) << ", axis=" << op->axis; - (*p) << ", where=" << op->condition; - (*p) << ", value_index=" << op->value_index; - (*p) << ")"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - (*p) << ", "; - } - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) { - (*p) << ", "; - } - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprLegacyPrinter* p) { - auto* node = static_cast(ref.get()); - (*p) << "PrimFunc(" << node->params << ") "; - if (node->attrs.defined()) { - (*p) << "attrs=" << node->attrs; - } - (*p) << " {\n"; - p->indent += 2; - p->Print(node->body); - p->indent -= 2; - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "let " << op->var << " = "; - p->Print(op->value); - (*p) << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "// attr ["; - p->Print(op->node); - (*p) << "] " << op->attr_key << " = "; - p->Print(op->value); - (*p) << '\n'; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "assert("; - p->Print(op->condition); - (*p) << ", "; - p->Print(op->message); - (*p) << ")\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->kind << " (" << op->loop_var << ", "; - p->Print(op->min); - (*p) << ", "; - p->Print(op->extent); - (*p) << ") {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "while(" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - const auto* ptr_type = op->buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - p->PrintIndent(); - (*p) << "allocate " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - (*p) << " * "; - p->Print(op->extents[i]); - } - (*p) << "], storage_scope = " << ptr_type->storage_scope; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "constant " << op->buffer_var << "[" << op->dtype; - for (size_t i = 0; i < op->extents.size(); ++i) { - (*p) << " * "; - p->Print(op->extents[i]); - } - (*p) << "]"; - (*p) << "\n"; - p->Print(op->body); - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "decl_buffer " << op->buffer << "\n"; - (*p) << op->body; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - for (Stmt stmt : op->seq) { - p->Print(stmt); - } - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - while (true) { - (*p) << "if (" << op->condition << ") {\n"; - p->indent += 2; - p->Print(op->then_case); - p->indent -= 2; - - if (!op->else_case) { - break; - } - - if (const IfThenElseNode* nested_if = op->else_case.as()) { - p->PrintIndent(); - (*p) << "} else "; - op = nested_if; - } else { - p->PrintIndent(); - (*p) << "} else {\n"; - p->indent += 2; - p->Print(op->else_case); - p->indent -= 2; - break; - } - } - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - p->Print(op->value); - (*p) << "\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->buffer->name << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) (*p) << ", "; - } - (*p) << "]"; - (*p) << " = "; - p->Print(op->value); - (*p) << '\n'; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "buffer_realize " << op->buffer->name << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - (*p) << op->buffer->name; - (*p) << "["; - for (size_t i = 0; i < op->region.size(); ++i) { - const auto& range = op->region[i]; - p->Print(range->min); - if (!is_one(range->extent)) { - (*p) << ":"; - p->Print(range->min + range->extent); - } - if (i != op->region.size() - 1) (*p) << ", "; - } - (*p) << "]"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->buffer->name << " = match_buffer("; - p->Print(op->source); - (*p) << ")\n"; - }); - -void PrintBlockTitle(const BlockNode* op, ReprLegacyPrinter* p) { - (*p) << "block " << op->name_hint << "("; - for (size_t i = 0; i < op->iter_vars.size(); i++) { - p->Print(op->iter_vars[i]); - if (i < op->iter_vars.size() - 1) (*p) << ", "; - } - (*p) << ")"; -} - -void PrintBlockSignature(const BlockNode* op, ReprLegacyPrinter* p) { - // print read/write regions - p->PrintIndent(); - (*p) << "reads("; - p->Print(op->reads); - (*p) << ")\n"; - p->PrintIndent(); - (*p) << "writes("; - p->Print(op->writes); - (*p) << ")\n"; - // Print alloc_buffers - for (const auto& alloc_buf : op->alloc_buffers) { - p->PrintIndent(); - (*p) << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; - for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { - if (i > 0) (*p) << ", "; - p->Print(alloc_buf->shape[i]); - } - (*p) << "])\n"; - } - // Print match_buffer_regions - for (const auto& match_buf : op->match_buffers) { - p->Print(match_buf); - } - if (!op->annotations.empty()) { - p->PrintIndent(); - (*p) << "annotations(" << op->annotations << ")\n"; - } -} - -void PrintBlockBody(const BlockNode* op, ReprLegacyPrinter* p) { - // Print init - if (op->init.defined()) { - p->PrintIndent(); - (*p) << "with init() {\n"; - p->indent += 2; - p->Print(op->init.value()); - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - } - // Print body - p->Print(op->body); -} - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - PrintBlockTitle(op, p); - (*p) << " {\n"; - p->indent += 2; - - // Print block elements (e.g. reads/writes, etc) - PrintBlockSignature(op, p); - // Print block init and body - PrintBlockBody(op, p); - - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - auto* block_op = op->block.get(); - p->PrintIndent(); - PrintBlockTitle(block_op, p); - (*p) << " {\n"; - p->indent += 2; - - // Print binding iter_values - for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { - p->PrintIndent(); - (*p) << "bind("; - p->Print(block_op->iter_vars[i]->var); - (*p) << ", "; - p->Print(op->iter_values[i]); - (*p) << ")\n"; - } - // Print predicate - if (!is_one(op->predicate)) { - p->PrintIndent(); - (*p) << "where("; - p->Print(op->predicate); - (*p) << ")\n"; - } - // Print block elements (e.g. reads/writes, etc) - PrintBlockSignature(block_op, p); - // Print block init and body - PrintBlockBody(block_op, p); - - p->indent -= 2; - p->PrintIndent(); - (*p) << "}\n"; - }); - -} // namespace tir -} // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 03341c4cd90f..95d24c91c41e 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -42,18 +42,8 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { try { p->stream << TVMScriptPrinter::Script(obj, std::nullopt); } catch (const tvm::Error& e) { - if (ReprLegacyPrinter::CanDispatch(obj)) { - LOG(WARNING) << "TVMScript printer falls back to the legacy ReprPrinter with the error:\n" - << e.what(); - try { - p->stream << AsLegacyRepr(obj); - } catch (const tvm::Error& e) { - LOG(WARNING) << "AsLegacyRepr fails. Falling back to the basic address printer"; - } - } else { - LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" - << e.what(); - } + LOG(WARNING) << "TVMScript printer falls back to the basic address printer with the error:\n" + << e.what(); p->stream << obj->GetTypeKey() << '(' << obj.get() << ')'; } } diff --git a/src/support/ordered_map.h b/src/support/ordered_map.h new file mode 100644 index 000000000000..81b0fd38a7a4 --- /dev/null +++ b/src/support/ordered_map.h @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file support/ordered_map.h + * \brief An STL-like map that preserves insertion order. + */ +#ifndef TVM_SUPPORT_ORDERED_MAP_H_ +#define TVM_SUPPORT_ORDERED_MAP_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace support { + +/** + * \brief An STL-like map that preserves insertion order. + * + * \tparam K The key type. + * \tparam V The value type. + * \tparam Hash The hash function. + * \tparam KeyEqual The key equality function. + * \note we don't support erase since it is less needed and vector backing is more efficient. + */ +template , + typename KeyEqual = std::equal_to> +class OrderedMap { + public: + OrderedMap() = default; + + /* \brief Explicit copy constructor + * + * The default copy constructor would copy both `elements_` and + * `elem_to_iter_`. While this is the correct behavior for + * `elements_`, the copy of `elem_to_iter_` would contain references + * to the original's `element_`, rather than to its own + */ + OrderedMap(const OrderedMap& other) : elements_(other.elements_) { + InitElementToIter(); + } + + /* \brief Explicit copy assignment + * + * Implemented in terms of the copy constructor, and the default + * move assignment. + */ + OrderedMap& operator=(const OrderedMap& other) { + return *this = OrderedMap(other); + } + + OrderedMap(OrderedMap&&) = default; + OrderedMap& operator=(OrderedMap&&) = default; + + template + OrderedMap(Iter begin, Iter end) : elements_(begin, end) { + InitElementToIter(); + } + + auto find(const K& k) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_.begin() + it->second; + } + return elements_.end(); + } + + auto find(const K& k) const { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_.begin() + it->second; + } + return elements_.end(); + } + + V& operator[](const K& k) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + return elements_[it->second].second; + } + elements_.emplace_back(k, V()); + elem_to_index_[k] = elements_.size() - 1; + return elements_.back().second; + } + + void insert(const K& k, V v) { + auto it = elem_to_index_.find(k); + if (it != elem_to_index_.end()) { + elements_[it->second].second = std::move(v); + } else { + elements_.emplace_back(k, v); + elem_to_index_[k] = elements_.size() - 1; + } + } + + void clear() { + elements_.clear(); + elem_to_index_.clear(); + } + + size_t count(const K& k) const { return elem_to_index_.count(k); } + + auto begin() const { return elements_.begin(); } + auto end() const { return elements_.end(); } + auto begin() { return elements_.begin(); } + auto end() { return elements_.end(); } + + size_t size() const { return elements_.size(); } + bool empty() const { return elements_.empty(); } + + void reserve(size_t n) { elem_to_index_.reserve(n); } + + private: + void InitElementToIter() { + for (size_t i = 0; i < elements_.size(); i++) { + elem_to_index_[elements_[i].first] = i; + } + } + + std::vector> elements_; + std::unordered_map elem_to_index_; +}; + +} // namespace support +} // namespace tvm + +#endif // TVM_SUPPORT_ORDERED_MAP_H_ diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h index 11acb8c3fef5..169f738e700d 100644 --- a/src/support/ordered_set.h +++ b/src/support/ordered_set.h @@ -26,30 +26,14 @@ #include -#include +#include #include +#include namespace tvm { namespace support { -namespace detail { -/* \brief Utility to allow use for standard and ObjectRef types - * - * \tparam T The type held by the OrderedSet - */ -template -struct OrderedSetLookupType { - using MapType = std::unordered_map::iterator>; -}; - -template -struct OrderedSetLookupType>> { - using MapType = std::unordered_map::iterator, runtime::ObjectPtrHash, - runtime::ObjectPtrEqual>; -}; -} // namespace detail - -template +template , typename KeyEqual = std::equal_to> class OrderedSet { public: OrderedSet() = default; @@ -61,17 +45,21 @@ class OrderedSet { * `elements_`, the copy of `elem_to_iter_` would contain references * to the original's `element_`, rather than to its own */ - OrderedSet(const OrderedSet& other) : elements_(other.elements_) { InitElementToIter(); } + OrderedSet(const OrderedSet& other) : elements_(other.elements_) { + InitElementToIter(); + } /* \brief Explicit copy assignment * * Implemented in terms of the copy constructor, and the default * move assignment. */ - OrderedSet& operator=(const OrderedSet& other) { return *this = OrderedSet(other); } + OrderedSet& operator=(const OrderedSet& other) { + return *this = OrderedSet(other); + } - OrderedSet(OrderedSet&&) = default; - OrderedSet& operator=(OrderedSet&&) = default; + OrderedSet(OrderedSet&&) = default; + OrderedSet& operator=(OrderedSet&&) = default; template OrderedSet(Iter begin, Iter end) : elements_(begin, end) { @@ -79,27 +67,20 @@ class OrderedSet { } void push_back(const T& t) { - if (!elem_to_iter_.count(t)) { + if (!elem_to_index_.count(t)) { elements_.push_back(t); - elem_to_iter_[t] = std::prev(elements_.end()); + elem_to_index_[t] = elements_.size() - 1; } } void insert(const T& t) { push_back(t); } - void erase(const T& t) { - if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) { - elements_.erase(it->second); - elem_to_iter_.erase(it); - } - } - void clear() { elements_.clear(); - elem_to_iter_.clear(); + elem_to_index_.clear(); } - size_t count(const T& t) const { return elem_to_iter_.count(t); } + size_t count(const T& t) const { return elem_to_index_.count(t); } auto begin() const { return elements_.begin(); } auto end() const { return elements_.end(); } @@ -108,13 +89,13 @@ class OrderedSet { private: void InitElementToIter() { - for (auto it = elements_.begin(); it != elements_.end(); it++) { - elem_to_iter_[*it] = it; + for (size_t i = 0; i < elements_.size(); ++i) { + elem_to_index_[elements_[i]] = i; } } - std::list elements_; - typename detail::OrderedSetLookupType::MapType elem_to_iter_; + std::vector elements_; + std::unordered_map elem_to_index_; }; } // namespace support diff --git a/src/tir/transforms/common_subexpr_elim.cc b/src/tir/transforms/common_subexpr_elim.cc index 42409efb0bd1..3fd78a523301 100644 --- a/src/tir/transforms/common_subexpr_elim.cc +++ b/src/tir/transforms/common_subexpr_elim.cc @@ -43,8 +43,7 @@ #include // For the algorithm std::find #include #include -#include // For the hashtable datatype -#include // For std::pair and std::move +#include #include #include "../analysis/check_contains.h" // For the visitor CheckContains @@ -131,41 +130,24 @@ bool CommonSubexpressionEliminator::CanContainEligibleComputations(const PrimExp * they appeared in the hashtable was based on some runtime addresses, so it can potentially * change with every execution. */ -bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(std::pair a, - std::pair b) { +bool CommonSubexpressionEliminator::OrderOnExprAndFrequency(const std::pair& a, + const std::pair& b) { size_t a_size = CalculateExprComplexity(a.first); size_t b_size = CalculateExprComplexity(b.first); - - // Criteria 1 - Size of the expression comes first - // `a` comes before `b` if the size of `a` is bigger - if (a_size > b_size) { - return true; - } - // `a` does NOT come before `b` if the size of `b` is bigger - if (b_size > a_size) { - return false; - } - - // Criteria 2 - If they had the same size, use the lexicographic order as a last resort - // as we need a deterministic order - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << AsLegacyRepr(a.first); - b_stream << AsLegacyRepr(b.first); - return (a_stream.str().compare(b_stream.str()) < 0); + return a_size > b_size; } /*! - * \brief Generates a new fresh variable, whose name will be cse_var_i. + * \brief Generates a new fresh variable, whose name will be cse_vi. * \param type_annotation The type of the new variable to generate - * \return A new variable of type `type_annotation` called cse_var_i where i is the first available + * \return A new variable of type `type_annotation` called cse_vi where i is the first available integer. */ Var CommonSubexpressionEliminator::GenerateNewVar(DataType type_annotation) { // Increase `num_last_try_` for this new attempt num_last_try_++; - // Builds the variable name, which is sce_var_i where i will go up from 1 - std::string prefix = "cse_var_"; + // Builds the variable name, which is cse_vi where i will go up from 1 + std::string prefix = "cse_v"; std::string name = prefix.append(std::to_string(num_last_try_)); // Builds a String using the std::string String string_name(name); @@ -241,8 +223,8 @@ PrimExpr CommonSubexpressionEliminator::VisitExpr(const PrimExpr& expr) { SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size - std::sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), - OrderOnExprAndFrequency); + std::stable_sort(semantic_comp_done_by_expr.begin(), semantic_comp_done_by_expr.end(), + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) { @@ -421,8 +403,8 @@ Stmt CommonSubexpressionEliminator::VisitStmt(const Stmt& stmt) { SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, identify_equiv_terms_); // Sort the vector of semantic entities by decreasing size - std::sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), - OrderOnExprAndFrequency); + std::stable_sort(semantic_comp_done_by_stmt.begin(), semantic_comp_done_by_stmt.end(), + OrderOnExprAndFrequency); // For each computation done (considering them from biggest to smallest) for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { diff --git a/src/tir/transforms/common_subexpr_elim.h b/src/tir/transforms/common_subexpr_elim.h index 5c14caf1a6e3..12a71458e13f 100644 --- a/src/tir/transforms/common_subexpr_elim.h +++ b/src/tir/transforms/common_subexpr_elim.h @@ -83,7 +83,8 @@ class CommonSubexpressionEliminator : public StmtExprMutator { static bool ForbiddenComputation(const PrimExpr& expr); static bool IsEligibleComputation(const PrimExpr& expr); static bool CanContainEligibleComputations(const PrimExpr& expr); - static bool OrderOnExprAndFrequency(std::pair a, std::pair b); + static bool OrderOnExprAndFrequency(const std::pair& a, + const std::pair& b); Var GenerateNewVar(DataType type_annotation); }; diff --git a/src/tir/transforms/common_subexpr_elim_tools.cc b/src/tir/transforms/common_subexpr_elim_tools.cc index ce8aef4587dd..f71d2cf42a02 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.cc +++ b/src/tir/transforms/common_subexpr_elim_tools.cc @@ -797,7 +797,7 @@ std::vector> SyntacticToSemanticComputations( // normalized. This normalized table will keep the count for each set of equivalent terms // (i.e. each equivalence class), together with a term that did appear in this equivalence class // (in practice, the first term of the equivalence class that was encoutered). - std::unordered_map, StructuralHash, ExprDeepEqual> + support::OrderedMap, StructuralHash, ExprDeepEqual> norm_table; // In order to avoid frequent rehashing if the norm_table becomes big, we immediately ask for @@ -806,23 +806,7 @@ std::vector> SyntacticToSemanticComputations( // equivalence classes as there are elements) norm_table.reserve(table.size()); - // Transform the input hashtable to a vector and sort it according to some order, as we will be - // iterating through its items soon, and the order of appearance will be used to determine the - // individual representant for each class of equivalence, which we want to be deterministic - // (otherwise {x+y, y+x} could be both replaced by x+y, and on another run by y+x). - std::vector> sorted_items_of_table(table.begin(), table.end()); - - // We do the ordering by comparing the string repr of each expr to get a determinstic ordering - sort(sorted_items_of_table.begin(), sorted_items_of_table.end(), - [](std::pair a, std::pair b) { - std::stringstream a_stream; - std::stringstream b_stream; - a_stream << AsLegacyRepr(a.first); - b_stream << AsLegacyRepr(b.first); - return a_stream.str().compare(b_stream.str()) < 0; - }); - - for (const auto& elem : sorted_items_of_table) { + for (const auto& elem : table) { PrimExpr norm_elem = NormalizeTerm(elem.first, identify_equiv_terms); // If the normalized term is not already a key in the normalized table auto it_found = norm_table.find(norm_elem); @@ -831,7 +815,7 @@ std::vector> SyntacticToSemanticComputations( // (i.e. `norm_elem` has been seen `elem`.second many times so far, and the chosen element // to represent the equivalence class will be `elem`.first as it's the first element of the // class that we see) - norm_table[norm_elem] = elem; + norm_table.insert(norm_elem, elem); } else { // Otherwise, it's not the first time we see a term in this equivalence class, so we just // increase the count of this equivalence class as we now have `elem`.second additional items @@ -850,10 +834,8 @@ std::vector> SyntacticToSemanticComputations( // Careful : the pairs will never change (the canonical represantants chosen will always be the // same), but the order in which the pairs are produced can vary as we are iterating through the // hashtable `norm_table`. It is not an issue as the called will be sorting the result anyway. - std::unordered_map, StructuralHash, - ExprDeepEqual>::const_iterator it_norm_table; - for (it_norm_table = norm_table.begin(); it_norm_table != norm_table.end(); ++it_norm_table) { - result.push_back(it_norm_table->second); + for (const auto& kv : norm_table) { + result.push_back(kv.second); } return result; diff --git a/src/tir/transforms/common_subexpr_elim_tools.h b/src/tir/transforms/common_subexpr_elim_tools.h index 58014e6a406d..31a81dabdbf2 100644 --- a/src/tir/transforms/common_subexpr_elim_tools.h +++ b/src/tir/transforms/common_subexpr_elim_tools.h @@ -34,10 +34,12 @@ #include // For the class StmtExprVisitor #include -#include // For the hashtable datatype -#include // For pairs datatype +#include +#include // For pairs datatype #include +#include "../../support/ordered_map.h" + namespace tvm { namespace tir { @@ -50,7 +52,7 @@ namespace tir { not do variables remapping), so it is compatible with StructuralHash (intended to be used with StructuralEqual). */ -using ComputationTable = std::unordered_map; +using ComputationTable = support::OrderedMap; /*! * \brief A cache of computations is made of a pair of two hashtables, which respectively associate diff --git a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py index e7e64d89168e..1be5e57ba15a 100644 --- a/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py +++ b/tests/python/tir-transform/test_tir_transform_common_subexpr_elim.py @@ -93,14 +93,14 @@ def test_cse(): assert body.var.name == "z2" assert body.value == 2 - # This is the let-in for the first variable generated cse_var_1 + # This is the let-in for the first variable generated cse_v1 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable - cse_var_1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_1" + cse_v1 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v1" tvm.ir.assert_structural_equal(body.value, z1 + z2) assert isinstance(body.body, tvm.tir.SeqStmt) @@ -118,27 +118,27 @@ def test_cse(): assert body.var.name == "y" assert body.value == 1 - # This is the let-in for the second variable generated cse_var_2 + # This is the let-in for the second variable generated cse_v2 assert isinstance(body.body, tvm.tir.LetStmt) body = body.body # And this is the name and value of this variable - cse_var_2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_2" + cse_v2 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v2" tvm.ir.assert_structural_equal(body.value, x + y) body = body.body body.var.name == "a" # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_var_2 + cse_var_1) + tvm.ir.assert_structural_equal(body.value, cse_v2 + cse_v1) body = body.body body.var.name == "b" # Check that the replacement has been done correctly! - tvm.ir.assert_structural_equal(body.value, cse_var_2 + z3) + tvm.ir.assert_structural_equal(body.value, cse_v2 + z3) assert isinstance(body.body, tvm.tir.BufferStore) @@ -199,7 +199,7 @@ def test_cse_ifNode_1(): body = body.then_case # The let-in introduced by the CSE should appear now, inside the Then branch of the If node - assert body.var.name == "cse_var_1" + assert body.var.name == "cse_v1" # and it should contain the expression (y+z) that was redundant tvm.ir.assert_structural_equal(body.value, y + z) @@ -250,7 +250,7 @@ def test_cse_ifNode_2(): assert isinstance(body, tvm.tir.LetStmt) # The let-in introduced by the CSE should appear now, at the toplevel (i.e. before the If) - assert body.var.name == "cse_var_1" + assert body.var.name == "cse_v1" # and it should contain the expression (y+z) that was redundant tvm.ir.assert_structural_equal(body.value, y + z) @@ -291,8 +291,8 @@ def test_cse_cascade(): assert isinstance(body, tvm.tir.LetStmt) # The second let-in (by order introduced) introduced by the CSE should appear first - cse_var_2 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_2" + cse_v2 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v2" # and it should contain the expression (x+y) tvm.ir.assert_structural_equal(body.value, (x + y)) @@ -301,10 +301,10 @@ def test_cse_cascade(): assert isinstance(body, tvm.tir.LetStmt) # The first let-in (by order introduced) introduced by the CSE should appear now, after the 2nd - cse_var_1 = body.var # Keep the variable accessible for later checking the replacements - assert body.var.name == "cse_var_1" - # and it should contain the expression cse_var_2+z - tvm.ir.assert_structural_equal(body.value, cse_var_2 + z) + cse_v1 = body.var # Keep the variable accessible for later checking the replacements + assert body.var.name == "cse_v1" + # and it should contain the expression cse_v2+z + tvm.ir.assert_structural_equal(body.value, cse_v2 + z) body = body.body @@ -317,9 +317,9 @@ def test_cse_cascade(): store2 = body[1] store3 = body[2] - tvm.ir.assert_structural_equal(store1.value, cse_var_1) - tvm.ir.assert_structural_equal(store2.value, cse_var_1) - tvm.ir.assert_structural_equal(store3.value, cse_var_2) + tvm.ir.assert_structural_equal(store1.value, cse_v1) + tvm.ir.assert_structural_equal(store2.value, cse_v1) + tvm.ir.assert_structural_equal(store3.value, cse_v2) # ----------------------------------------------------------------------------------------- @@ -360,9 +360,9 @@ def func_distributivity( def func_distributivity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt(x * y + x * z) as cse_var_1: - B[i1] = cse_var_1 - B[i2] = cse_var_1 + with T.LetStmt((y + z) * x) as cse_v1: + B[i1] = cse_v1 + B[i2] = cse_v1 @T.prim_func @@ -377,9 +377,9 @@ def func_associativity( def func_associativity_expected( B: T.Buffer((50,), "int32"), i1: T.int32, i2: T.int32, x: T.int32, y: T.int32, z: T.int32 ) -> None: - with T.LetStmt((x + y) + z) as cse_var_1: - B[i1] = cse_var_1 - B[i2] = cse_var_1 + with T.LetStmt(x + y + z) as cse_v1: + B[i1] = cse_v1 + B[i2] = cse_v1 def _check(original, transformed): @@ -410,10 +410,10 @@ def test_deterministic_cse(): result = (x + 1) + (x + 2) + (x + 3) + (x + 1) + (x + 2) + (x + 3) --> - cse_var_3 = (x + 1) - cse_var_2 = (x + 2) - cse_var_1 = (x + 3) - result = cse_var_3 + cse_var_2 + cse_var_1 + cse_var_3 + cse_var_2 + cse_var_1 + cse_v3 = (x + 1) + cse_v2 = (x + 2) + cse_v1 = (x + 3) + result = cse_v3 + cse_v2 + cse_v1 + cse_v3 + cse_v2 + cse_v1 """ NUM_TERMS = 10 REPEATS = 10 diff --git a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py index da079f46e38e..13487b42f00f 100644 --- a/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/tir-transform/test_tir_transform_inject_ptx_async_copy.py @@ -329,11 +329,11 @@ def test_inject_async_copy_barrier(): __asm__ __volatile__("cp.async.commit_group;"); for (int i = 0; i < 13; ++i) { - bool cse_var_1 = (i < 12); + bool cse_v1 = (i < 12); { unsigned int addr = cast_smem_ptr_to_int(A_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_var_1; + int pred_guard = (int)cse_v1; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -356,7 +356,7 @@ def test_inject_async_copy_barrier(): { unsigned int addr = cast_smem_ptr_to_int(B_shared + ((((i + 3) & 3) * 16) + ((int)threadIdx.x))); - int pred_guard = (int)cse_var_1; + int pred_guard = (int)cse_v1; __asm__ __volatile__( "{ .reg .pred p;" " setp.ne.b32 p, %0, 0;" @@ -954,10 +954,10 @@ def before(A: T.Buffer((32, 128), "float16")): T.attr("default", "async_scope", 1) for i in range(16): - cse_var_1: T.int64 = T.Cast("int64", i) - A_shared[ - T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8) - ] = A_flattened[T.Ramp(tx * T.int64(128) + cse_var_1 * T.int64(8), T.int64(1), 8)] + cse_v1: T.int64 = T.Cast("int64", i) + A_shared[T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8), T.int64(1), 8)] = A_flattened[ + T.Ramp(tx * T.int64(128) + cse_v1 * T.int64(8), T.int64(1), 8) + ] T.ptx_commit_group() T.ptx_wait_group(0) @@ -965,13 +965,13 @@ def expected(A: T.Buffer((32, 128), "float16")): tx = T.launch_thread("threadIdx.x", T.int64(32)) A_shared = T.decl_buffer((4096,), "float16", scope="shared") for i in range(16): - cse_var_1: T.int64 = T.Cast("int64", i) + cse_v1: T.int64 = T.Cast("int64", i) T.ptx_cp_async( "float16", A_shared.data, - tx * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_v1 * T.int64(8), A.data, - tx * T.int64(128) + cse_var_1 * T.int64(8), + tx * T.int64(128) + cse_v1 * T.int64(8), 16, ) T.ptx_commit_group() diff --git a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py index c63d2f8a4137..299c19314654 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/tir-transform/test_tir_transform_lower_tvm_builtin.py @@ -154,8 +154,8 @@ def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112 rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data) T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract) for ax1, ax2 in T.grid(32, 25690112): - cse_var_1: T.int32 = ax1 * 25690112 + ax2 - T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1] + cse_v1: T.int32 = ax1 * 25690112 + ax2 + T_subtract_1[cse_v1] = rxplaceholder_1[cse_v1] - rxplaceholder_red_1[ax1] func = variance4 tvm.compile(func, target="llvm") # should not crash diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index af2db34415f8..0e1b328844be 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -3679,6 +3679,7 @@ def merge_shape_var_def(): # uninitialized vars @T.prim_func(check_well_formed=False) def main(A: T.handle, B: T.handle): + # fmt: off T.func_attr({"global_symbol": "main", "tir.noalias": True}) m, n = T.int32(), T.int32() A_1 = T.match_buffer(A, (m, n), strides=("A_1_s0", "A_1_s1"), buffer_type="auto") @@ -3687,8 +3688,8 @@ def main(A: T.handle, B: T.handle): if T.likely(i_outer * 10 + i_inner < m): for j_inner in range(5): if T.likely(j_outer * 5 + j_inner < n): - cse_var_2: T.int32 = j_outer * 5 + j_inner - cse_var_1: T.int32 = i_outer * 10 + i_inner + cse_v2: T.int32 = j_outer * 5 + j_inner + cse_v1: T.int32 = i_outer * 10 + i_inner B_2 = T.Buffer( (B_1.strides[0] * m,), data=B_1.data, @@ -3701,9 +3702,10 @@ def main(A: T.handle, B: T.handle): strides=("A_2_s0",), buffer_type="auto", ) - B_2[cse_var_1 * B_1.strides[0] + cse_var_2 * B_1.strides[1]] = A_2[ - cse_var_1 * A_1.strides[0] + cse_var_2 * A_1.strides[1] + B_2[cse_v1 * B_1.strides[0] + cse_v2 * B_1.strides[1]] = A_2[ + cse_v1 * A_1.strides[0] + cse_v2 * A_1.strides[1] ] + # fmt: on return main