From eb17944f489907382886825eae21b55bf3479b66 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 20 Apr 2024 14:50:37 -0500 Subject: [PATCH 1/2] [QoL][IR] Provide std::hash and std::equal_to for IR Variable types For most IR types, neither `std::hash` nor `std::equal_to` are provided, as it would be ambiguous whether comparisons should be performed with reference equality or structural equality. While this avoids ambiguity in the general case of nested structures, IR variables follow reference equality and are frequently used as lookup keys. This commit implements a specialization of `std::hash` and `std::equal_to` for `tvm::GlobalVar`, `tvm::tir::Var`, and `tvm::relax::Var`. This allows them to be used as lookup keys for `std::unordered_set` and `std::unordered_map` without explicitly specifying explicit `ObjectPtrHash` and `ObjectPtrEqual`. --- include/tvm/ir/expr.h | 28 ++++++++++++++++++ include/tvm/relax/expr.h | 28 ++++++++++++++++++ include/tvm/tir/var.h | 29 +++++++++++++++++++ src/arith/const_int_bound.cc | 2 +- src/arith/iter_affine_map.cc | 10 +++---- src/arith/modular_set.cc | 2 +- src/arith/rewrite_simplify.h | 2 +- .../msc/core/transform/set_expr_layout.cc | 2 +- .../feature_extractor/per_store_feature.cc | 2 +- .../multi_level_tiling_tensor_core.cc | 6 ++-- .../analysis/computable_at_compile_time.cc | 2 +- src/relax/analysis/layout_transformation.cc | 4 +-- src/relax/analysis/struct_info_analysis.cc | 4 +-- src/relax/analysis/udchain.cc | 2 +- src/relax/analysis/well_formed.cc | 16 +++++----- src/relax/backend/vm/codegen_vm.cc | 2 +- src/relax/backend/vm/codegen_vm_tir.cc | 2 +- .../lower_global_view_to_local_view.cc | 4 +-- src/relax/transform/adjust_matmul_order.cc | 3 +- src/relax/transform/canonicalize_bindings.cc | 4 +-- src/relax/transform/convert_layout.cc | 2 +- src/relax/transform/dataflow_inplace.cc | 24 +++++++-------- src/relax/transform/dead_code_elimination.cc | 7 ++--- src/relax/transform/expand_matmul_of_sum.cc | 3 +- src/relax/transform/fuse_tir.cc | 9 +++--- src/relax/transform/infer_amp_utils.h | 2 +- src/relax/transform/lambda_lift.cc | 4 +-- src/relax/transform/lazy_transform_params.cc | 12 ++++---- src/relax/transform/lift_transform_params.cc | 21 +++++--------- .../transform/merge_composite_functions.cc | 2 +- .../transform/split_call_tir_by_pattern.cc | 2 +- src/relax/transform/topological_sort.cc | 2 +- .../transform/update_param_struct_info.cc | 4 +-- src/relay/analysis/call_graph.h | 6 ++-- src/target/llvm/codegen_llvm.h | 2 +- src/target/source/codegen_c.h | 4 +-- src/target/source/codegen_webgpu.cc | 2 +- src/target/spirv/codegen_spirv.h | 2 +- src/tir/analysis/is_pure_function.cc | 2 +- src/tir/analysis/verify_ssa.cc | 2 +- src/tir/analysis/verify_well_formed.cc | 6 ++-- src/tir/ir/specialize.cc | 2 +- src/tir/ir/tir_visitor_with_path.cc | 2 +- src/tir/schedule/analysis/analysis.cc | 4 +-- .../schedule/primitive/cache_read_write.cc | 12 ++++---- src/tir/schedule/primitive/reduction.cc | 4 +-- src/tir/transforms/compact_buffer_region.cc | 20 +++++-------- src/tir/transforms/inject_permuted_layout.cc | 2 +- .../transforms/inject_software_pipeline.cc | 2 +- src/tir/transforms/ir_utils.cc | 9 +++--- src/tir/transforms/ir_utils.h | 3 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_opaque_block.cc | 4 +-- src/tir/transforms/storage_flatten.cc | 2 +- src/tir/transforms/texture_flatten.cc | 2 +- src/tir/transforms/thread_storage_sync.cc | 2 +- .../transforms/transform_mma_buffer_layout.cc | 2 +- src/tir/transforms/unroll_loop.cc | 7 ++--- .../transforms/unsupported_dtype_legalize.cc | 17 +++++------ src/tir/transforms/vectorize_loop.cc | 2 +- src/tir/usmp/analysis/extract_buffer_info.cc | 2 +- src/tir/usmp/transform/create_io_allocates.cc | 8 ++--- 62 files changed, 220 insertions(+), 164 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 594e2b86e9f9..91280c2dad36 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -821,4 +821,32 @@ struct PackedFuncValueConverter { } // namespace runtime } // namespace tvm + +/* \brief Allow tvm.GLobalVar as key in STL tables + * + * For most IR expressions, it would be ambiguous whether the + * expression should follow reference equality or structural equality. + * This is not the case for variables, which do not contain nested + * internal structure, and are frequently used as keys in lookup + * tables. + * + * Providing `std::hash` and `std::equal_to` specializations for + * `tvm::GlobalVar` allows it to be used as a key in STL tables. For + * other IR expressions, the user must specify the type of equality + * used (e.g. `std::unordered_set` + * or `std::unordered_set`). + */ +template <> +struct std::hash { + std::size_t operator()(const tvm::GlobalVar& var) const { + return tvm::runtime::ObjectPtrHash()(var); + } +}; + +template <> +struct std::equal_to { + bool operator()(const tvm::GlobalVar& var_a, const tvm::GlobalVar& var_b) const { + return tvm::runtime::ObjectPtrEqual()(var_a, var_b); + } +}; #endif // TVM_IR_EXPR_H_ diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 0ca92a01a74b..c41cc7ee877c 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -1111,4 +1111,32 @@ TVM_DLL Expr GetShapeOf(const Expr& expr); } // namespace relax } // namespace tvm +/* \brief Allow relax.Var as key in STL tables + * + * For most Relax expressions, it would be ambiguous whether the + * expression should follow reference equality or structural equality. + * This is not the case for variables, which do not contain nested + * internal structure, and are frequently used as keys in lookup + * tables. + * + * Providing `std::hash` and `std::equal_to` specializations for + * `relax::Var` allows it to be used as a key in STL tables. For + * `relax::Expr`, the user must specify the type of equality used + * (e.g. `std::unordered_set` or + * `std::unordered_set`). + */ +template <> +struct std::hash { + std::size_t operator()(const tvm::relax::Var& var) const { + return tvm::runtime::ObjectPtrHash()(var); + } +}; + +template <> +struct std::equal_to { + bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const { + return tvm::runtime::ObjectPtrEqual()(var_a, var_b); + } +}; + #endif // TVM_RELAX_EXPR_H_ diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 6c2c6dd5fc86..4d99a09d427e 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -352,4 +352,33 @@ inline const char* IterVarType2String(IterVarType t) { } } // namespace tir } // namespace tvm + +/* \brief Allow tir.Var as key in STL tables + * + * For most TIR expressions, it would be ambiguous whether the + * expression should follow reference equality or structural equality. + * This is not the case for variables, which do not contain nested + * internal structure, and are frequently used as keys in lookup + * tables. + * + * Providing `std::hash` and `std::equal_to` specializations for + * `tir::Var` allows it to be used as a key in STL tables. For + * `PrimExpr`, the user must specify the type of equality used + * (e.g. `std::unordered_set` or + * `std::unordered_set`). + */ +template <> +struct std::hash { + std::size_t operator()(const tvm::tir::Var& var) const { + return tvm::runtime::ObjectPtrHash()(var); + } +}; + +template <> +struct std::equal_to { + bool operator()(const tvm::tir::Var& var_a, const tvm::tir::Var& var_b) const { + return tvm::runtime::ObjectPtrEqual()(var_a, var_b); + } +}; + #endif // TVM_TIR_VAR_H_ diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 8d41f0f2c6e7..4a829393d140 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -447,7 +447,7 @@ class ConstIntBoundAnalyzer::Impl private: friend class ConstIntBoundAnalyzer; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; // additional bound info std::vector additional_info_; // look up table for memorization diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index f90df9941766..77b20fcdf203 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -440,7 +440,7 @@ class IterMapRewriter : public ExprMutator { // Error messages for each unresolved expression. Array& errors_; // The var map - std::unordered_map var_map_; + std::unordered_map var_map_; // input iter marks std::vector input_marks_; @@ -1419,7 +1419,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map* input_iters, } bool IterRangeSanityCheck(const Map& iter_ranges) { - std::unordered_set iters; + std::unordered_set iters; for (const auto& it : iter_ranges) iters.insert(it.first); auto f = [&](const VarNode* var) { return iters.count(GetRef(var)); }; for (const auto& it : iter_ranges) { @@ -2187,7 +2187,7 @@ TVM_REGISTER_GLOBAL("arith.IterMapSimplify") class SubspaceDivider { public: explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector, - const std::unordered_set& sub_iters) + const std::unordered_set& sub_iters) : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} size_t unresolved_count() const { return unresolved_count_; } @@ -2455,7 +2455,7 @@ class SubspaceDivider { // collector that collects the outgoing split reference of each IterMark const IterMarkSplitCollector collector_; // the set of subspace iters - const std::unordered_set& sub_iters_; + const std::unordered_set& sub_iters_; // map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X) std::unordered_map split_map_; // predicate of outer space and inner space; @@ -2473,7 +2473,7 @@ Array> SubspaceDivide(const Array& bindings, const Array& maps = res->indices; if (maps.empty()) return {}; - std::unordered_set inner_iter_set; + std::unordered_set inner_iter_set; for (const Var& inner_iter : sub_iters) { inner_iter_set.insert(inner_iter); } diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index ac6bf94b1198..197e5ec8b868 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -302,7 +302,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor var_map_; + std::unordered_map var_map_; /*! * \brief Update var by intersecting entry with var's current set. * \param var The variable. diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index e488024ec348..26dee062c4d2 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -147,7 +147,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { // counter to record recursive rewrite depth. int64_t recur_depth_{0}; // internal variable map - std::unordered_map var_map_; + std::unordered_map var_map_; std::vector literal_constraints_; diff --git a/src/contrib/msc/core/transform/set_expr_layout.cc b/src/contrib/msc/core/transform/set_expr_layout.cc index 76775a5ba322..56517fdae8d6 100644 --- a/src/contrib/msc/core/transform/set_expr_layout.cc +++ b/src/contrib/msc/core/transform/set_expr_layout.cc @@ -1298,7 +1298,7 @@ class LayoutInfer : public ExprVisitor { bool infered_; Map var_map_; Array ordered_exprs_; - std::unordered_map var_layout_map_; + std::unordered_map var_layout_map_; Map local_funcs_; }; // class LayoutInfer diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 5ade69101f22..82bc7c2de078 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -288,7 +288,7 @@ Pass SimplifyForFeatureExtraction() { } } - std::unordered_set unit_vars_; + std::unordered_set unit_vars_; }; auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { PrimFuncNode* n = f.CopyOnWrite(); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index d519187d303f..e3b51dda154a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -775,9 +775,9 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( const tir::IndexMap& index_map = mapping_info->mappings[0]; // Find the correspondence between block iters and the iters in the index map. - std::unordered_map lhs_to_index_map_src; - std::unordered_map rhs_to_index_map_tgt; - std::unordered_set unmapped_index_map_src; + std::unordered_map lhs_to_index_map_src; + std::unordered_map rhs_to_index_map_tgt; + std::unordered_set unmapped_index_map_src; ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size()); for (int i = 0; i < static_cast(mapping_info->lhs_iters.size()); ++i) { lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i]; diff --git a/src/relax/analysis/computable_at_compile_time.cc b/src/relax/analysis/computable_at_compile_time.cc index 5ee336ff008f..37bbf3a9775e 100644 --- a/src/relax/analysis/computable_at_compile_time.cc +++ b/src/relax/analysis/computable_at_compile_time.cc @@ -84,7 +84,7 @@ class CompileTimeCollector : ExprVisitor { } support::OrderedSet known_relax_vars_; - std::unordered_set known_tir_vars_; + std::unordered_set known_tir_vars_; }; } // namespace diff --git a/src/relax/analysis/layout_transformation.cc b/src/relax/analysis/layout_transformation.cc index 8f4b91ef55f9..2e850fa9dee3 100644 --- a/src/relax/analysis/layout_transformation.cc +++ b/src/relax/analysis/layout_transformation.cc @@ -150,7 +150,7 @@ static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayo * (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout * of a buffer access is same as the order of spatial iter vars in block domain. */ -using VarToBlockIndexMap = std::unordered_map; +using VarToBlockIndexMap = std::unordered_map; static bool IsSequentialAccess(const SpatialLayout& iterators, const VarToBlockIndexMap& iter_to_block_index) { int last_value = -1; @@ -210,7 +210,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) { * source spatial layout. * target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4) */ -using VarSet = std::unordered_set; +using VarSet = std::unordered_set; static Optional InferLayoutTransformation(const SpatialLayout& src_spatial_layout, const IndexMap& src_transformation, const SpatialLayout& tgt_spatial_layout) { diff --git a/src/relax/analysis/struct_info_analysis.cc b/src/relax/analysis/struct_info_analysis.cc index 08e2acfbd069..a1077287a0f7 100644 --- a/src/relax/analysis/struct_info_analysis.cc +++ b/src/relax/analysis/struct_info_analysis.cc @@ -1356,9 +1356,9 @@ class SymbolicVarCollector : public relax::ExprVisitor, /*! \brief The current visit mode. */ VisitMode mode_ = VisitMode::kRequireDefinition; /*! \brief The set of defined symbolic vars. */ - std::unordered_set defined_symbolic_var_; + std::unordered_set defined_symbolic_var_; /*! \brief The set of free/undefined symbolic vars. */ - std::unordered_set free_symbolic_var_; + std::unordered_set free_symbolic_var_; }; Array DefinedSymbolicVars(const Expr& expr) { diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc index 95af8f43c982..d7ab4f1031b4 100644 --- a/src/relax/analysis/udchain.cc +++ b/src/relax/analysis/udchain.cc @@ -55,7 +55,7 @@ class UDChain : relax::ExprVisitor { private: Map bound_values; - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> usage_map; + std::unordered_map> usage_map; support::OrderedSet outputs; Optional cur_user_{nullptr}; diff --git a/src/relax/analysis/well_formed.cc b/src/relax/analysis/well_formed.cc index a73e6fb233bf..626fadda273d 100644 --- a/src/relax/analysis/well_formed.cc +++ b/src/relax/analysis/well_formed.cc @@ -364,9 +364,8 @@ class WellFormedChecker : public relax::ExprVisitor, Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression."); } - std::unordered_set previous_var_set = var_set_; - std::unordered_set previous_symbolic_var_set = - symbolic_var_set_; + std::unordered_set previous_var_set = var_set_; + std::unordered_set previous_symbolic_var_set = symbolic_var_set_; this->VisitSeqExpr(op->true_branch.get()); var_set_ = previous_var_set; symbolic_var_set_ = previous_symbolic_var_set; @@ -567,13 +566,12 @@ class WellFormedChecker : public relax::ExprVisitor, // Current visit mode. VisitMode mode_ = VisitMode::kDefault; // set of context variables. - std::unordered_set var_set_; - std::unordered_set recur_vars_; + std::unordered_set var_set_; + std::unordered_set recur_vars_; std::unordered_set dataflow_var_set_; - std::unordered_set symbolic_var_set_; - std::unordered_map param_var_func_map_; - std::unordered_map - symbolic_var_func_map_; + std::unordered_set symbolic_var_set_; + std::unordered_map param_var_func_map_; + std::unordered_map symbolic_var_func_map_; tvm::OpAttrMap op_map_normalize_ = Op::GetAttrMap("FNormalize"); }; diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 329da67e84ec..334e6e5c9a62 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -424,7 +424,7 @@ class CodeGenVM : public ExprFunctor { */ size_t registers_num_ = 0; /*! \brief Map from var to register number. */ - std::unordered_map var_arg_map_; + std::unordered_map var_arg_map_; /*! \brief the context module. */ IRModule ctx_mod_; /*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */ diff --git a/src/relax/backend/vm/codegen_vm_tir.cc b/src/relax/backend/vm/codegen_vm_tir.cc index ec1678e9e0f3..dd34bc63bb31 100644 --- a/src/relax/backend/vm/codegen_vm_tir.cc +++ b/src/relax/backend/vm/codegen_vm_tir.cc @@ -511,7 +511,7 @@ class CodeGenVMTIR : public ExprFunctor(const Expr&)> { /*! \brief Stack to build up statements */ std::vector> stmt_stack_; /*! \brief Map from var to Expr. */ - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> var_map_; + std::unordered_map> var_map_; /*! \brief the context module. */ IRModule ctx_mod_; /*! \brief system lib prefix */ diff --git a/src/relax/distributed/transform/lower_global_view_to_local_view.cc b/src/relax/distributed/transform/lower_global_view_to_local_view.cc index 69c9c3bf2f87..793b9cbe248b 100644 --- a/src/relax/distributed/transform/lower_global_view_to_local_view.cc +++ b/src/relax/distributed/transform/lower_global_view_to_local_view.cc @@ -337,8 +337,8 @@ class DistributedBufferCompactor : StmtExprMutator { return new_loop; } - std::unordered_map iter_var_shards_; - std::unordered_map loop_var_shards_; + std::unordered_map iter_var_shards_; + std::unordered_map loop_var_shards_; Array allocated_buffer_under_root; BufferAxisGraphExtractor extractor_; std::vector sharding_specs_; diff --git a/src/relax/transform/adjust_matmul_order.cc b/src/relax/transform/adjust_matmul_order.cc index 399860987c01..d6bbb669a365 100644 --- a/src/relax/transform/adjust_matmul_order.cc +++ b/src/relax/transform/adjust_matmul_order.cc @@ -41,8 +41,7 @@ namespace { std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); - std::unordered_set compile_time_lookup( - compile_time_arr.begin(), compile_time_arr.end()); + std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); TypedPackedFunc is_compile_time = [compile_time_lookup](Expr arg) -> bool { if (auto as_var = arg.as()) { diff --git a/src/relax/transform/canonicalize_bindings.cc b/src/relax/transform/canonicalize_bindings.cc index 6b88446893cf..12eb81ac675d 100644 --- a/src/relax/transform/canonicalize_bindings.cc +++ b/src/relax/transform/canonicalize_bindings.cc @@ -226,10 +226,10 @@ class CanonicalizePlanner : public ExprVisitor { Map trivial_bindings_; Map known_bindings_; Map known_bound_to_constant_; - std::unordered_set defined_inside_dataflow_; + std::unordered_set defined_inside_dataflow_; // Set of vars either used outside a dataflow block altogether or outside their // home dataflow block (the one where they were defined) - std::unordered_set used_outside_home_dataflow_; + std::unordered_set used_outside_home_dataflow_; }; /*! \brief The mutator class to apply a CanonicalizationPlan */ diff --git a/src/relax/transform/convert_layout.cc b/src/relax/transform/convert_layout.cc index 6530d0d2cf0c..cae845e07141 100644 --- a/src/relax/transform/convert_layout.cc +++ b/src/relax/transform/convert_layout.cc @@ -281,7 +281,7 @@ class LayoutConvertMutator : public ExprMutator { } } - std::unordered_map var_layout_map_; + std::unordered_map var_layout_map_; Map> desired_layouts_; }; // namespace relax diff --git a/src/relax/transform/dataflow_inplace.cc b/src/relax/transform/dataflow_inplace.cc index 091298177595..aee2c015fc81 100644 --- a/src/relax/transform/dataflow_inplace.cc +++ b/src/relax/transform/dataflow_inplace.cc @@ -41,9 +41,8 @@ namespace relax { // pairs of indices (the liveness interval, from the starting index to the end index). // A starting index of -1 means the var is defined before the block starts and an end index // of block->bindings.size() (one past the last index) means it is live after the block ends. -std::unordered_map, ObjectPtrHash, ObjectPtrEqual> AnalyzeLiveness( - const DataflowBlock& block) { - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> ret; +std::unordered_map> AnalyzeLiveness(const DataflowBlock& block) { + std::unordered_map> ret; for (int i = block->bindings.size() - 1; i >= 0; i--) { Binding b = block->bindings[i]; Var defined_var = b->var; @@ -103,7 +102,7 @@ class AliasAnalyzer { // that correspond to tuples (this maps to sets of memory locations for each tuple element). // Note: inputs are values that should be assumed not to be aliased and are therefore // (in the case of in-place ops) safe to overwrite. This may not be true of function args. - std::pair, ObjectPtrHash, ObjectPtrEqual>, + std::pair>, std::unordered_map>>> Analyze(const DataflowBlock& block, const Array& inputs) { for (auto input : inputs) { @@ -296,7 +295,7 @@ class AliasAnalyzer { return ret; } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> alias_map_; + std::unordered_map> alias_map_; std::unordered_map>> tuple_map_; int mem_idx_; }; @@ -415,8 +414,7 @@ std::pair SizeMatches(const StructInfo& target_info, const StructInf // Return false if the alias set contains -1, meaning a reference to an unknown or // possibly dangerous value (no checking we can do for that). bool GatherSetsToCheckForLiveness( - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - alias_sets, + const std::unordered_map>& alias_sets, const std::unordered_map>>& tuple_map, std::vector>* sets_to_check, int alias_idx) { if (tuple_map.count(alias_idx)) { @@ -443,12 +441,10 @@ bool GatherSetsToCheckForLiveness( // Check that the target is not live past the index and that no alias of it is live past the // binding index (if the target is a tuple, check the conditions recursively for the members) bool InplaceConditionsMet( - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& live_ranges, - const std::unordered_map, ObjectPtrHash, ObjectPtrEqual>& - alias_sets, + const std::unordered_map>& live_ranges, + const std::unordered_map>& alias_sets, const std::unordered_map>>& tuple_map, - const std::unordered_set& currently_live, - const Expr& target, int binding_idx) { + const std::unordered_set& currently_live, const Expr& target, int binding_idx) { if (auto* var_node = target.as()) { auto current_var = GetRef(var_node); // if the var is live past this point, we can't use it for in-place computations anyway @@ -586,7 +582,7 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, return live_ranges[var1].first < live_ranges[var2].first; }); - std::unordered_set currently_live; + std::unordered_set currently_live; int last_live = 0; for (size_t i = 0; i < block->bindings.size(); i++) { @@ -602,7 +598,7 @@ FindInplaceOpportunities(const DataflowBlock& block, const Array& inputs, } // remove vars whose range has come to an end // (keep a separate set to avoid changing the set while iterating on it) - std::unordered_set remove; + std::unordered_set remove; for (auto var : currently_live) { auto live_range = live_ranges[var]; if (live_range.second < static_cast(i)) { diff --git a/src/relax/transform/dead_code_elimination.cc b/src/relax/transform/dead_code_elimination.cc index 876c714c61e3..9591b45595f9 100644 --- a/src/relax/transform/dead_code_elimination.cc +++ b/src/relax/transform/dead_code_elimination.cc @@ -106,14 +106,13 @@ class CallTracer : public ExprVisitor { bool all_callees_found_{true}; // Record the names of all encountered functions. - std::unordered_set called_funcs_; + std::unordered_set called_funcs_; // Record the expressions that are being visited. std::unordered_set visiting_; }; -IRModule RemoveUnusedFunctions( - IRModule mod, const std::unordered_set& entry_funcs) { +IRModule RemoveUnusedFunctions(IRModule mod, const std::unordered_set& entry_funcs) { CallTracer tracer(mod); for (const auto& gvar : entry_funcs) { tracer.VisitExpr(gvar); @@ -144,7 +143,7 @@ IRModule DeadCodeElimination(const IRModule& arg_mod, Array ent // S0: Make a list of all user-specified entry functions and // externally-visible entry functions. - std::unordered_set entry_functions; + std::unordered_set entry_functions; for (const auto& name : entry_function_names) { entry_functions.insert(mod->GetGlobalVar(name)); } diff --git a/src/relax/transform/expand_matmul_of_sum.cc b/src/relax/transform/expand_matmul_of_sum.cc index 906620563450..e20f9c59b28b 100644 --- a/src/relax/transform/expand_matmul_of_sum.cc +++ b/src/relax/transform/expand_matmul_of_sum.cc @@ -43,8 +43,7 @@ namespace { std::tuple)>> CreatePatterns( const Function& func) { auto compile_time_arr = ComputableAtCompileTime(func); - std::unordered_set compile_time_lookup( - compile_time_arr.begin(), compile_time_arr.end()); + std::unordered_set compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end()); auto pat_lhs = WildcardPattern(); diff --git a/src/relax/transform/fuse_tir.cc b/src/relax/transform/fuse_tir.cc index cb8d340f7d09..e712b5022a7d 100644 --- a/src/relax/transform/fuse_tir.cc +++ b/src/relax/transform/fuse_tir.cc @@ -447,7 +447,7 @@ class FusedTIRConstructor : public ExprVisitor { // map of input buffers to indices (helpful for detecting in-place inputs) std::unordered_map buffer_to_idx; - std::unordered_map input_to_idx; + std::unordered_map input_to_idx; for (size_t i = 0; i < func_info_.params.size(); i++) { input_to_idx[func_info_.params[i]] = i; } @@ -979,7 +979,7 @@ class TIRFuseMutator : public ExprMutator { mod.CopyOnWrite(); IRModule updates; - std::unordered_map replacements; + std::unordered_map replacements; // Since TIRFuseMutator will delete bunch of PrimFunc, we create an empty block builder. @@ -1024,8 +1024,7 @@ class TIRFuseMutator : public ExprMutator { Array inplace_indices; }; - explicit TIRFuseMutator( - std::unordered_map replacements) + explicit TIRFuseMutator(std::unordered_map replacements) : replacements_(replacements) {} using ExprMutator::VisitExpr_; @@ -1129,7 +1128,7 @@ class TIRFuseMutator : public ExprMutator { * * Has one entry for each primitive relax function in the IRModule. */ - std::unordered_map replacements_; + std::unordered_map replacements_; }; IRModule FuseTIR(IRModule mod) { diff --git a/src/relax/transform/infer_amp_utils.h b/src/relax/transform/infer_amp_utils.h index 3c98af6db965..8d759d204cf1 100644 --- a/src/relax/transform/infer_amp_utils.h +++ b/src/relax/transform/infer_amp_utils.h @@ -69,7 +69,7 @@ NType NTypeFrom(const Expr& expr, DataType dtype = DataType::Void()); NType NTypeMerge(const NType& a, const NType& b); // The map that notes the NType message of each var -using VarDTypeMap = std::unordered_map; +using VarDTypeMap = std::unordered_map; // Call is a call node, out_dtype is the expected output_dtype using FInferMixedPrecision = diff --git a/src/relax/transform/lambda_lift.cc b/src/relax/transform/lambda_lift.cc index 16bd8bfc9110..f45d82129db6 100644 --- a/src/relax/transform/lambda_lift.cc +++ b/src/relax/transform/lambda_lift.cc @@ -482,8 +482,8 @@ class LambdaLifter : public ExprMutator { } private: - std::unordered_map nested_closure_map_; - std::unordered_map rebind_map_; + std::unordered_map nested_closure_map_; + std::unordered_map rebind_map_; std::unordered_set, ObjectPtrHash, ObjectPtrEqual> closures_; Optional current_lambda_var_ = NullOpt; IRModule mod_; diff --git a/src/relax/transform/lazy_transform_params.cc b/src/relax/transform/lazy_transform_params.cc index 37827fbe0e6c..fb401e1b6787 100644 --- a/src/relax/transform/lazy_transform_params.cc +++ b/src/relax/transform/lazy_transform_params.cc @@ -59,7 +59,7 @@ class LazyInputMutator : public ExprMutator { int64_t num_input_params = GetNumInputParams(func).value_or(0); - std::unordered_map param_lookup; + std::unordered_map param_lookup; for (size_t i = num_input_params; i < func->params.size(); i++) { param_lookup.insert({func->params[i], i - num_input_params}); } @@ -73,8 +73,8 @@ class LazyInputMutator : public ExprMutator { auto array_externally_visible_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(new_params.Map(GetStructInfo))); - std::unordered_set externally_visible_vars( - array_externally_visible_vars.begin(), array_externally_visible_vars.end()); + std::unordered_set externally_visible_vars(array_externally_visible_vars.begin(), + array_externally_visible_vars.end()); StructInfo new_ret_struct_info = EraseToWellDefined(func->ret_struct_info, [&](const tir::Var& var) -> Optional { if (externally_visible_vars.count(var)) { @@ -115,7 +115,7 @@ class LazyInputMutator : public ExprMutator { private: struct FunctionPlan { - std::unordered_map param_lookup; + std::unordered_map param_lookup; Expr fget_param; }; std::optional plan_; @@ -128,7 +128,7 @@ class LazyOutputMutator : public ExprMutator { return ExprMutator::VisitExpr_(func); } - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> output_lookup; + std::unordered_map> output_lookup; std::vector> inline_outputs; auto define_lookup = [&](size_t output_index, Expr output_value) { if (auto var = output_value.as()) { @@ -220,7 +220,7 @@ class LazyOutputMutator : public ExprMutator { } struct FunctionPlan { - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> output_lookup; + std::unordered_map> output_lookup; Expr fset_output; }; std::optional plan_; diff --git a/src/relax/transform/lift_transform_params.cc b/src/relax/transform/lift_transform_params.cc index 7607d690d4cd..937cb8702952 100644 --- a/src/relax/transform/lift_transform_params.cc +++ b/src/relax/transform/lift_transform_params.cc @@ -136,8 +136,7 @@ struct GlobalCollectInfo : public BaseCollectInfo { Array GetPropagatedSymbolicVariables() const { auto vars_from_original_params = DefinableTIRVarsInStructInfo(TupleStructInfo(params.Map(GetStructInfo))); - auto vars_from_transformed_params = - [&]() -> std::unordered_set { + auto vars_from_transformed_params = [&]() -> std::unordered_set { auto tir_vars = DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); return {tir_vars.begin(), tir_vars.end()}; @@ -179,15 +178,13 @@ struct LocalCollectInfo : public BaseCollectInfo { auto vars_from_any_param = DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo))); - auto vars_from_runtime_params = - [&]() -> std::unordered_set { + auto vars_from_runtime_params = [&]() -> std::unordered_set { auto tir_var_vec = DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo))); return {tir_var_vec.begin(), tir_var_vec.end()}; }(); - auto vars_from_transformed_params = - [&]() -> std::unordered_set { + auto vars_from_transformed_params = [&]() -> std::unordered_set { auto tir_var_vec = DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo))); return {tir_var_vec.begin(), tir_var_vec.end()}; @@ -287,7 +284,7 @@ struct LocalCollectInfo : public BaseCollectInfo { // Any binding that is computable at compile-time should be // suppressed at run-time. - std::unordered_set to_suppress; + std::unordered_set to_suppress; for (const auto& binding : computable_at_compile_time) { if (requires_compile_time_param.count(binding->var)) { to_suppress.insert(binding->var); @@ -296,8 +293,7 @@ struct LocalCollectInfo : public BaseCollectInfo { class SuppressCompileTime : public ExprMutator { public: - explicit SuppressCompileTime( - const std::unordered_set& to_suppress) + explicit SuppressCompileTime(const std::unordered_set& to_suppress) : to_suppress_(to_suppress) {} void VisitBinding(const Binding& binding) override { @@ -317,7 +313,7 @@ struct LocalCollectInfo : public BaseCollectInfo { } private: - const std::unordered_set& to_suppress_; + const std::unordered_set& to_suppress_; }; Expr body = SuppressCompileTime(to_suppress)(orig_func->body); body = SeqExpr({DataflowBlock(bindings)}, body); @@ -769,8 +765,7 @@ Pass PartitionTransformParams(Variant> shared_transform) { global_collect_info = MakeGlobalLiftPlan(mod, functions); } - std::unordered_map - local_collect_info; + std::unordered_map local_collect_info; for (const auto& [gvar, func] : target_functions) { auto info = LocalLiftableBindingCollector::Collect( func, global_collect_info.has_value() ? &global_collect_info.value() : nullptr); @@ -814,7 +809,7 @@ Pass LiftTransformParams(Variant> shared_transform) { // 3. Post-proc: Expose the compile-time and run-time functions for // external use, replacing the end-to-end functions. auto post_proc_func = [=](IRModule mod, PassContext pc) { - std::unordered_map to_add; + std::unordered_map to_add; for (const auto& [gvar, base_func] : mod->functions) { if (auto opt = base_func.as()) { auto func = opt.value(); diff --git a/src/relax/transform/merge_composite_functions.cc b/src/relax/transform/merge_composite_functions.cc index 9d9d9aa64447..0dd14f5bb1af 100644 --- a/src/relax/transform/merge_composite_functions.cc +++ b/src/relax/transform/merge_composite_functions.cc @@ -376,7 +376,7 @@ class CompositeFunctionAnnotator : public ExprMutator { private: IRModule mod_; CompositeInliner inliner; - std::unordered_map var_map_; + std::unordered_map var_map_; }; } // namespace diff --git a/src/relax/transform/split_call_tir_by_pattern.cc b/src/relax/transform/split_call_tir_by_pattern.cc index 7fcc2cb34a76..4f934916f5ca 100644 --- a/src/relax/transform/split_call_tir_by_pattern.cc +++ b/src/relax/transform/split_call_tir_by_pattern.cc @@ -48,7 +48,7 @@ using relax::TIRPattern; /*! \brief helper to match a for stmt to a pattern*/ class ForMatcher : public TensorizeComparator { public: - using SymbolMap = std::unordered_map; + using SymbolMap = std::unordered_map; explicit ForMatcher(const tir::PrimFunc& pattern, const Array& pattern_vars) : TensorizeComparator(IRModule({{GlobalVar(""), pattern}}), false), pattern_(pattern) { for (const auto& pattern_var : pattern_vars) { diff --git a/src/relax/transform/topological_sort.cc b/src/relax/transform/topological_sort.cc index a366ff4d1271..24ed53948e71 100644 --- a/src/relax/transform/topological_sort.cc +++ b/src/relax/transform/topological_sort.cc @@ -188,7 +188,7 @@ class TopologicalSorter : public ExprMutator { // A map from not-yet-defined variables to the binding that will // define the variable. Items are removed from this map as they // are collected into `new_bindings`. - std::unordered_map to_emit; + std::unordered_map to_emit; for (const auto& binding : block->bindings) { to_emit.insert({binding->var, binding}); } diff --git a/src/relax/transform/update_param_struct_info.cc b/src/relax/transform/update_param_struct_info.cc index b3fa0464bead..eefcf3ba1b64 100644 --- a/src/relax/transform/update_param_struct_info.cc +++ b/src/relax/transform/update_param_struct_info.cc @@ -73,8 +73,8 @@ Pass UpdateParamStructInfo(TypedPackedFunc(Var)> sinfo_func auto pass_func = [=](IRModule mod, PassContext pc) { ParamStructInfoMutator mutator(sinfo_func); - std::unordered_set to_remove; - std::unordered_map to_add; + std::unordered_set to_remove; + std::unordered_map to_add; for (const auto& [gvar, base_func] : mod->functions) { if (auto func = base_func.as()) { diff --git a/src/relay/analysis/call_graph.h b/src/relay/analysis/call_graph.h index 7cc813ebbff1..091891acd414 100644 --- a/src/relay/analysis/call_graph.h +++ b/src/relay/analysis/call_graph.h @@ -47,8 +47,7 @@ class CallGraphEntry; class CallGraph; class CallGraphNode : public Object { - using CallGraphMap = - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; + using CallGraphMap = std::unordered_map>; // Create iterator alias for a CallGraphNode object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; @@ -195,8 +194,7 @@ class CallGraphNode : public Object { * a call graph. */ class CallGraph : public ObjectRef { - using CallGraphMap = - std::unordered_map, ObjectPtrHash, ObjectPtrEqual>; + using CallGraphMap = std::unordered_map>; // Create iterator alias for a CallGraph object. using iterator = CallGraphMap::iterator; using const_iterator = CallGraphMap::const_iterator; diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 0f7aa847ecb8..4c19ed9b203b 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -560,7 +560,7 @@ class CodeGenLLVM : public ExprFunctor, // deep comparison of PrimExpr ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; // debug info for function being compiled llvm::DISubprogram* di_subprogram_{nullptr}; // Cache potential common path ops to slightly improve lookup time. diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 9a20566d5b3e..e739df0ca1c0 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -328,7 +328,7 @@ class CodeGenC : public ExprFunctor, ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; /* \brief Map of GlobalVar to their symbol. * @@ -337,7 +337,7 @@ class CodeGenC : public ExprFunctor, * functions, this is the name of the function's GlobalVar, possibly * altered to prevent duplicate names. */ - std::unordered_map internal_functions_; + std::unordered_map internal_functions_; /* \brief Name supply to generate unique function names */ NameSupply func_name_supply_{""}; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index a9a23fb999d8..ba925056a379 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -47,7 +47,7 @@ struct WebGPUWorkGroupInfo { // whether we have ref to block index z is used. bool has_block_index_z{false}; // set of handles that have write access - std::unordered_set write_access_set; + std::unordered_set write_access_set; }; class WebGPUWorkgroupInfoCollector : public StmtExprVisitor { diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 8ea90a9c4b80..e5fde107f452 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -227,7 +227,7 @@ class CodeGenSPIRV : public ExprFunctor, ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value - std::unordered_map let_binding_; + std::unordered_map let_binding_; // Running total of the number of bytes of shared memory used. // Checked against the max_shared_memory_per_group diff --git a/src/tir/analysis/is_pure_function.cc b/src/tir/analysis/is_pure_function.cc index c9934c4bcf6f..ee893987c91e 100644 --- a/src/tir/analysis/is_pure_function.cc +++ b/src/tir/analysis/is_pure_function.cc @@ -83,7 +83,7 @@ class PurityChecker : TIRVisitorWithPath { bool assert_on_error_{false}; bool is_pure_{true}; - std::unordered_set internal_allocations_; + std::unordered_set internal_allocations_; }; } // namespace diff --git a/src/tir/analysis/verify_ssa.cc b/src/tir/analysis/verify_ssa.cc index e04dcf90aa79..068f252de3f0 100644 --- a/src/tir/analysis/verify_ssa.cc +++ b/src/tir/analysis/verify_ssa.cc @@ -130,7 +130,7 @@ class SSAVerifier final : public StmtExprVisitor { // deep equal ExprDeepEqual deep_equal_; // def map, for let, maps to the bind value, for others maps to self. - std::unordered_map def_map_; + std::unordered_map def_map_; }; bool VerifySSA(const PrimFunc& func) { diff --git a/src/tir/analysis/verify_well_formed.cc b/src/tir/analysis/verify_well_formed.cc index c001d35054f3..cfdc2f35515a 100644 --- a/src/tir/analysis/verify_well_formed.cc +++ b/src/tir/analysis/verify_well_formed.cc @@ -291,14 +291,14 @@ class UndefinedVarVerifier : public Verifier { } // Variables that are defined in the currently-visited scope. - std::unordered_map currently_defined_; + std::unordered_map currently_defined_; // Variables that were previously defined, and are now out of scope. - std::unordered_map previously_defined_; + std::unordered_map previously_defined_; // Special variables that are allowed to be re-defined, so long as // that re-definition occurs within the same PrimFunc. For example - std::unordered_set redefine_allowed_within_function_; + std::unordered_set redefine_allowed_within_function_; }; /* \brief Verify unique tir::Var for each environment thread diff --git a/src/tir/ir/specialize.cc b/src/tir/ir/specialize.cc index 924ef9a0cdde..b30d0caf6af3 100644 --- a/src/tir/ir/specialize.cc +++ b/src/tir/ir/specialize.cc @@ -35,7 +35,7 @@ namespace tvm { namespace tir { -using VarMap = std::unordered_map; +using VarMap = std::unordered_map; /**************** Helper functions ****************/ diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 37b3ce55a2ca..e0318b21bee3 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -37,7 +37,7 @@ void TIRVisitorWithPath::Visit(const IRModule& mod, ObjectPath path) { // To ensure deterministic order of visits, sort the GlobalVar first // by visibility (public then private), then alphabetically by name. std::vector gvars; - std::unordered_set externally_exposed; + std::unordered_set externally_exposed; for (const auto& [gvar, func] : mod->functions) { gvars.push_back(gvar); if (func->GetAttr(tvm::attr::kGlobalSymbol).defined()) { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3f79fed8d25a..b60e60c3cfc9 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1914,7 +1914,7 @@ class AutoTensorizeMappingProposer { arith::Analyzer* analyzer) : extractor_(extractor), analyzer_(analyzer) {} - using VarSet = std::unordered_set; + using VarSet = std::unordered_set; void CollectFeasibleSet() { // Collect the set of potential iter var mapping between the workload and the tensor intrin. @@ -2076,7 +2076,7 @@ class AutoTensorizeMappingProposer { // The arithmetic analyzer. arith::Analyzer* analyzer_; /*! \brief Potential mappings on RHS for each variable on LHS */ - std::unordered_map lhs_feasible_vars_; + std::unordered_map lhs_feasible_vars_; }; bool CheckAutoTensorizeApplicable(const ScheduleState& state, const tir::StmtSRef& block_sref, diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index eac5500a19b3..b0cb56af4ed4 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -343,7 +343,7 @@ Block MakeCacheStage(const BufferRegion& cache_region, CacheStageInfo* info, * \return The reindex block. */ Block MakeReIndexStage(const Block& block, CacheStageInfo* info, - const std::unordered_set& covered, + const std::unordered_set& covered, const Array& original_indices, int buffer_index, BufferIndexType buffer_index_type) { // iters of the reindex block @@ -1397,7 +1397,7 @@ class ReindexCacheWriteRewriter : public CacheWriteRewriter { * \return The new buffer with target shape. */ Buffer CreateReindexBuffer(const Buffer& buffer, const Array& block_iters, - const std::unordered_set& covered) { + const std::unordered_set& covered) { ObjectPtr new_buffer = make_object(*buffer.get()); ObjectPtr new_var = make_object(*buffer->data.get()); std::vector new_shape; @@ -1541,14 +1541,14 @@ class ReIndexCollector : public StmtExprVisitor { class ReIndexRewriter : public StmtExprMutator { public: static Stmt Rewrite(const StmtSRef& scope_sref, const StmtSRef& block_sref, CacheStageInfo* info, - const std::unordered_set& covered) { + const std::unordered_set& covered) { ReIndexRewriter rewriter(block_sref, info, covered); return rewriter(GetRef(scope_sref->stmt)); } private: explicit ReIndexRewriter(const StmtSRef& block_sref, CacheStageInfo* info, - const std::unordered_set& covered) + const std::unordered_set& covered) : block_sref_(block_sref), info_(info), covered_(covered) { new_buffer_ = info->alloc.value(); old_buffer_ = info->read_buffer.same_as(new_buffer_) ? info->write_buffer : info->read_buffer; @@ -1624,7 +1624,7 @@ class ReIndexRewriter : public StmtExprMutator { /*! \brief The info for inserting reindex stage. */ CacheStageInfo* info_; /*! \brief Whether old block var is covered in the indices */ - const std::unordered_set& covered_; + const std::unordered_set& covered_; /*! \brief Whether the current block is scope block */ bool is_scope_{true}; /*! \brief The buffer to be replaced */ @@ -2253,7 +2253,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde [&analyzer](const PrimExpr& expr) { return SimplifyNonTrivialExpr(expr, &analyzer); }); // Collect block iters appearing in the original_indices - std::unordered_set covered; + std::unordered_set covered; for (const PrimExpr& index : original_indices) { PreOrderVisit(index, [&](const ObjectRef& obj) -> bool { if (auto var = obj.as()) { diff --git a/src/tir/schedule/primitive/reduction.cc b/src/tir/schedule/primitive/reduction.cc index e1c90cc645fb..c294f7092516 100644 --- a/src/tir/schedule/primitive/reduction.cc +++ b/src/tir/schedule/primitive/reduction.cc @@ -210,7 +210,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, init_realize->block = Block(init_block); // Step 1. Create new block vars and their bindings // Maps an old block var to the new corresponding block var - std::unordered_map block_var_map; + std::unordered_map block_var_map; block_var_map.reserve(block->iter_vars.size()); for (int i = 0, n = block->iter_vars.size(); i < n; ++i) { const IterVar& iter_var = block->iter_vars[i]; @@ -263,7 +263,7 @@ StmtSRef DecomposeReduction(ScheduleState self, const StmtSRef& block_sref, // We discard predicate that is related to discarded loops init_realize->predicate = RemakePredicate(realize->predicate, discarded_loops); // Step 5. Create new loops above init block - std::unordered_map loop_var_map; + std::unordered_map loop_var_map; Stmt body = BlockRealize(init_realize); for (int i : chosen_loops) { const ForNode* old_loop = TVM_SREF_TO_FOR(loops[i]); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index c7706212c519..f562a057e595 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -65,9 +65,7 @@ NDIntSet NDIntSetEval(Region region, PrimExpr predicate, class Var2BufferCollector : public StmtExprVisitor { public: /*! \brief Map the buffer var to all aliased buffers. */ - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - var2buffer_; + std::unordered_map> var2buffer_; private: void VisitStmt_(const BufferStoreNode* op) final { @@ -465,12 +463,10 @@ class BufferAccessRegionCollector : public StmtExprVisitor { * define point. ancestor_loops_[0: n_ancester_loop] should not be relaxed when * we evaluate this buffer's access regions. */ - std::unordered_map buffer_scope_depth_; + std::unordered_map buffer_scope_depth_; /*! \brief Map the buffer var to all aliased buffers. */ - std::unordered_map, ObjectPtrHash, - ObjectPtrEqual> - var2buffer_; + std::unordered_map> var2buffer_; /*! \brief The map from loop vars to their iter range. */ std::unordered_map dom_map_; @@ -518,8 +514,7 @@ struct BufferAllocInfo { /*! \brief Reallocate the buffers with minimal region. */ class BufferCompactor : public StmtExprMutator { public: - explicit BufferCompactor( - std::unordered_map buffer_info) + explicit BufferCompactor(std::unordered_map buffer_info) : buffer_info_(std::move(buffer_info)) {} Stmt VisitStmt_(const BufferStoreNode* _op) final { @@ -649,7 +644,7 @@ class BufferCompactor : public StmtExprMutator { } /*! \brief Map buffer var to the allocation information about each buffer. */ - std::unordered_map buffer_info_; + std::unordered_map buffer_info_; }; Array CalcStrides(const BufferAllocInfo& alloc_info, const Array& shape) { @@ -678,10 +673,9 @@ Array CalcStrides(const BufferAllocInfo& alloc_info, const Array& regions, - const std::unordered_map& - storage_align) { + const std::unordered_map& storage_align) { // collect buffer allocation info for no-alias buffers - std::unordered_map buffer_info; + std::unordered_map buffer_info; for (const auto& kv : regions) { const Buffer& buffer = kv.first; // set dim alignment info diff --git a/src/tir/transforms/inject_permuted_layout.cc b/src/tir/transforms/inject_permuted_layout.cc index cccf2c505a51..d9479256c527 100644 --- a/src/tir/transforms/inject_permuted_layout.cc +++ b/src/tir/transforms/inject_permuted_layout.cc @@ -280,7 +280,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { static constexpr size_t BANK_SIZE_BYTES = 128; // Mapping from data Var of a Buffer to Buffer, for lookup - std::unordered_map buffer_map_; + std::unordered_map buffer_map_; bool permute_ = false; }; diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 21de2d86070f..c14c2cf4d6ac 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -971,7 +971,7 @@ void BuildDependencyGraph( const Array& blocks, std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_src2dst, std::unordered_map, ObjectPtrHash, ObjectPtrEqual>* dep_dst2src) { - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> buffer_writers; + std::unordered_map> buffer_writers; for (const Block& block : blocks) { for (const BufferRegion& read : block->reads) { diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 584b3cbf58f4..f2d5d54a58c4 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -703,8 +703,8 @@ std::pair GetAsyncWaitAttributes(const AttrStmtNode* op) { /*! \brief Collect storage alignment information from annotations. */ class StorageAlignCollector : public StmtVisitor { private: - friend std::unordered_map - CollectStorageAlignAnnotation(const Stmt& body); + friend std::unordered_map CollectStorageAlignAnnotation( + const Stmt& body); /*! \brief For s-stir, the alignment annotations reside in block annotations. */ void VisitStmt_(const BlockNode* op) final { @@ -737,11 +737,10 @@ class StorageAlignCollector : public StmtVisitor { } /*! \brief The map from buffer var to its storage alignment information. */ - std::unordered_map storage_align_; + std::unordered_map storage_align_; }; -std::unordered_map -CollectStorageAlignAnnotation(const Stmt& body) { +std::unordered_map CollectStorageAlignAnnotation(const Stmt& body) { StorageAlignCollector collector; collector(body); return std::move(collector.storage_align_); diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index a03ad3beb400..423b0ca92237 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -342,8 +342,7 @@ using StorageAlignAnnotation = Array; * \param body The stmt to collect. * \return The result dict from buffer var to storage align annotations. */ -std::unordered_map -CollectStorageAlignAnnotation(const Stmt& body); +std::unordered_map CollectStorageAlignAnnotation(const Stmt& body); /*! * \brief Split string separated by "," to get wmma fragment dimension size. * \param shape_str The string to split. diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 273d37829dcb..3e2dc130e7dd 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -231,7 +231,7 @@ class CustomDatatypesLowerer : public StmtExprMutator { private: std::string target_; // remap buffer vars - std::unordered_map var_remap_; + std::unordered_map var_remap_; std::unordered_map buf_remap_; }; diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index 86892433b42d..08642a598b74 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -190,13 +190,13 @@ class OpaqueBlockLower : public StmtExprMutator { } /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ - std::unordered_map unit_loop_vars_; + std::unordered_map unit_loop_vars_; /*! \brief Attr keys to preserve into loop annotations. */ std::unordered_set preserved_annotations_; /*! \brief The map from buffer var to its storage alignment information. */ - std::unordered_map storage_align_; + std::unordered_map storage_align_; }; PrimFunc LowerOpaqueBlock(PrimFunc f) { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9c1244838173..c51dfd7913e4 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -788,7 +788,7 @@ class ThreadScopePropagate : public StmtExprMutator { } } - std::unordered_map buf_remap_; + std::unordered_map buf_remap_; std::unordered_set external_buffers_; // The current thread scope. diff --git a/src/tir/transforms/texture_flatten.cc b/src/tir/transforms/texture_flatten.cc index 3f8f0efd1f20..91e1121ea130 100644 --- a/src/tir/transforms/texture_flatten.cc +++ b/src/tir/transforms/texture_flatten.cc @@ -184,7 +184,7 @@ class TextureFlattener : public TextureLoweringBase { } // Bindings to new texture vars with texture pointer scope - std::unordered_map let_binding_; + std::unordered_map let_binding_; }; PrimFunc TextureFlatten(PrimFunc func) { diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index d92986e51a9c..fd772863f780 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -440,7 +440,7 @@ class ThreadSyncInserter : public StmtExprMutator { StorageScope sync_scope_; const std::unordered_set& syncs_; // The read write statistics of storage - std::unordered_map rw_stats_; + std::unordered_map rw_stats_; // The statistics for global barrier bool in_thread_env_{false}; // memorized results diff --git a/src/tir/transforms/transform_mma_buffer_layout.cc b/src/tir/transforms/transform_mma_buffer_layout.cc index abe0bc3a3d12..899f292b8fe3 100644 --- a/src/tir/transforms/transform_mma_buffer_layout.cc +++ b/src/tir/transforms/transform_mma_buffer_layout.cc @@ -169,7 +169,7 @@ class MmaBufferLayoutTransformer : public StmtExprMutator { private: std::unordered_map buffer_map_; - std::unordered_map buffer_var_map_; + std::unordered_map buffer_var_map_; arith::Analyzer analyzer; }; diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 0c448d8e31f8..a68ebe7e02ff 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -75,14 +75,13 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.UnrollLoop", UnrollLoopConfig); class VarLocalAccessMarker : public ExprVisitor { public: - explicit VarLocalAccessMarker( - std::unordered_set* var_touched_local) + explicit VarLocalAccessMarker(std::unordered_set* var_touched_local) : var_touched_local_(var_touched_local) {} void VisitExpr_(const VarNode* op) final { var_touched_local_->insert(GetRef(op)); } private: - std::unordered_set* var_touched_local_; + std::unordered_set* var_touched_local_; }; // The Visitor is used to check whether var is used as write index in a local memory @@ -259,7 +258,7 @@ class LoopUnroller : public StmtExprMutator { // Number of total steps unrolled int step_count_{0}; // set of indices touched during visit local memory - std::unordered_set var_touched_local_; + std::unordered_set var_touched_local_; // analyzer arith::Analyzer analyzer_; }; diff --git a/src/tir/transforms/unsupported_dtype_legalize.cc b/src/tir/transforms/unsupported_dtype_legalize.cc index 5537c8a409a0..5a14beb6dc4c 100644 --- a/src/tir/transforms/unsupported_dtype_legalize.cc +++ b/src/tir/transforms/unsupported_dtype_legalize.cc @@ -45,8 +45,7 @@ class ComputeLegalizePlanner : public StmtExprVisitor { public: ComputeLegalizePlanner( std::unordered_map* buffer_remap, - std::unordered_map* var_remap, - DataType promote_dtype) + std::unordered_map* var_remap, DataType promote_dtype) : buffer_remap_(buffer_remap), var_remap_(var_remap), promote_dtype_(promote_dtype) {} // run planning to populate buffer remap and var remap. @@ -124,8 +123,8 @@ class ComputeLegalizePlanner : public StmtExprVisitor { } std::unordered_map* buffer_remap_; - std::unordered_map* var_remap_; - std::unordered_set opaque_var_access_; + std::unordered_map* var_remap_; + std::unordered_set opaque_var_access_; DataType promote_dtype_; }; @@ -133,8 +132,7 @@ class BF16ComputeLegalizePlanner : public ComputeLegalizePlanner { public: explicit BF16ComputeLegalizePlanner( std::unordered_map* buffer_remap, - std::unordered_map* var_remap, - DataType promote_dtype) + std::unordered_map* var_remap, DataType promote_dtype) : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {} bool MatchDType(DataType dtype) const { return dtype.is_bfloat16(); } }; @@ -143,8 +141,7 @@ class FP8ComputeLegalizePlanner : public ComputeLegalizePlanner { public: explicit FP8ComputeLegalizePlanner( std::unordered_map* buffer_remap, - std::unordered_map* var_remap, - DataType promote_dtype) + std::unordered_map* var_remap, DataType promote_dtype) : ComputeLegalizePlanner(buffer_remap, var_remap, promote_dtype) {} bool MatchDType(DataType dtype) const { return dtype.is_float8(); } }; @@ -446,7 +443,7 @@ class ComputeLegalizer : public StmtExprMutator { protected: DataType promote_dtype_; std::unordered_map buffer_remap_; - std::unordered_map var_remap_; + std::unordered_map var_remap_; }; class BF16ComputeLegalizer : public ComputeLegalizer { @@ -678,7 +675,7 @@ class StorageLegalizer : public StmtExprMutator { } std::unordered_map buffer_remap_; - std::unordered_map var_remap_; + std::unordered_map var_remap_; }; class BF16StorageLegalizer : public StorageLegalizer { diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index a9cc4975801a..c02b66a33f97 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -653,7 +653,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor let_binding_; + std::unordered_map let_binding_; // vectorizable property OpAttrMap op_vectorizable_ = Op::GetAttrMap("TVectorizable"); diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index f512bfaffa97..5abfe24f434d 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -125,7 +125,7 @@ class BufferInfoExtractor : public StmtExprVisitor { * \brief Maintains the mapping of buffer variable to their allocate nodes to ensure * that only one BufferInfo object is created. */ - std::unordered_map allocate_infos; + std::unordered_map allocate_infos; /*! * \brief Indicates a count of stmts visited so far to use as a metric of liveness */ diff --git a/src/tir/usmp/transform/create_io_allocates.cc b/src/tir/usmp/transform/create_io_allocates.cc index 0afdacd48fd7..ca06095f0bdc 100644 --- a/src/tir/usmp/transform/create_io_allocates.cc +++ b/src/tir/usmp/transform/create_io_allocates.cc @@ -64,14 +64,14 @@ class IOAllocateCreator : public StmtExprVisitor { /*! \brief The main function that calls into operator subgraphs */ PrimFunc main_func_; /*! \brief The input Vars of the main function */ - std::unordered_set inputs_; + std::unordered_set inputs_; /*! \brief The output Vars of the main function */ - std::unordered_set outputs_; + std::unordered_set outputs_; /*! \brief The buffer vars associated with the I/O Vars */ - std::unordered_set io_buffer_vars_; + std::unordered_set io_buffer_vars_; /*! \brief The aliases that buffer vars inside the primfunc refer * to in terms call arguments */ - std::unordered_map aliases_; + std::unordered_map aliases_; /*! * \brief The TIR main function calls by name to PrimFuncs to be able to * support BYOC. Therefore, this Map records functions that are present From ce22674ae68cbbb2a5e4716ade6853a7f18478a5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Sat, 20 Apr 2024 19:05:05 -0500 Subject: [PATCH 2/2] lint fix --- include/tvm/ir/expr.h | 1 + include/tvm/relax/expr.h | 2 ++ include/tvm/tir/var.h | 1 + 3 files changed, 4 insertions(+) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 91280c2dad36..9b522389227a 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -31,6 +31,7 @@ #include #include +#include #include #include #include diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index c41cc7ee877c..401aaa9248ce 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -29,6 +29,8 @@ #include #include +#include + namespace tvm { namespace relax { diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 4d99a09d427e..0918d12821e1 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -28,6 +28,7 @@ #include #include +#include #include namespace tvm {