Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/runtime/object.h>

#include <algorithm>
#include <functional>
#include <limits>
#include <string>
#include <type_traits>
Expand Down Expand Up @@ -821,4 +822,32 @@ struct PackedFuncValueConverter<tvm::Bool> {

} // namespace runtime
} // namespace tvm

/* \brief Allow tvm.GLobalVar as key in STL tables
*
* For most IR expressions, it would be ambiguous whether the
* expression should follow reference equality or structural equality.
* This is not the case for variables, which do not contain nested
* internal structure, and are frequently used as keys in lookup
* tables.
*
* Providing `std::hash` and `std::equal_to` specializations for
* `tvm::GlobalVar` allows it to be used as a key in STL tables. For
* other IR expressions, the user must specify the type of equality
* used (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>`
* or `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
*/
template <>
struct std::hash<tvm::GlobalVar> {
std::size_t operator()(const tvm::GlobalVar& var) const {
return tvm::runtime::ObjectPtrHash()(var);
}
};

template <>
struct std::equal_to<tvm::GlobalVar> {
bool operator()(const tvm::GlobalVar& var_a, const tvm::GlobalVar& var_b) const {
return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
}
};
#endif // TVM_IR_EXPR_H_
30 changes: 30 additions & 0 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>

#include <functional>

namespace tvm {
namespace relax {

Expand Down Expand Up @@ -1111,4 +1113,32 @@ TVM_DLL Expr GetShapeOf(const Expr& expr);
} // namespace relax
} // namespace tvm

/* \brief Allow relax.Var as key in STL tables
*
* For most Relax expressions, it would be ambiguous whether the
* expression should follow reference equality or structural equality.
* This is not the case for variables, which do not contain nested
* internal structure, and are frequently used as keys in lookup
* tables.
*
* Providing `std::hash` and `std::equal_to` specializations for
* `relax::Var` allows it to be used as a key in STL tables. For
* `relax::Expr`, the user must specify the type of equality used
* (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
* `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
*/
template <>
struct std::hash<tvm::relax::Var> {
std::size_t operator()(const tvm::relax::Var& var) const {
return tvm::runtime::ObjectPtrHash()(var);
}
};

template <>
struct std::equal_to<tvm::relax::Var> {
bool operator()(const tvm::relax::Var& var_a, const tvm::relax::Var& var_b) const {
return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
}
};

#endif // TVM_RELAX_EXPR_H_
30 changes: 30 additions & 0 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>

#include <functional>
#include <string>

namespace tvm {
Expand Down Expand Up @@ -352,4 +353,33 @@ inline const char* IterVarType2String(IterVarType t) {
}
} // namespace tir
} // namespace tvm

/* \brief Allow tir.Var as key in STL tables
*
* For most TIR expressions, it would be ambiguous whether the
* expression should follow reference equality or structural equality.
* This is not the case for variables, which do not contain nested
* internal structure, and are frequently used as keys in lookup
* tables.
*
* Providing `std::hash` and `std::equal_to` specializations for
* `tir::Var` allows it to be used as a key in STL tables. For
* `PrimExpr`, the user must specify the type of equality used
* (e.g. `std::unordered_set<T, StructuralHash, StructuralEqual>` or
* `std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>`).
*/
template <>
struct std::hash<tvm::tir::Var> {
std::size_t operator()(const tvm::tir::Var& var) const {
return tvm::runtime::ObjectPtrHash()(var);
}
};

template <>
struct std::equal_to<tvm::tir::Var> {
bool operator()(const tvm::tir::Var& var_a, const tvm::tir::Var& var_b) const {
return tvm::runtime::ObjectPtrEqual()(var_a, var_b);
}
};

#endif // TVM_TIR_VAR_H_
2 changes: 1 addition & 1 deletion src/arith/const_int_bound.cc
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ class ConstIntBoundAnalyzer::Impl
private:
friend class ConstIntBoundAnalyzer;
// internal variable map
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Var, Entry> var_map_;
// additional bound info
std::vector<BoundInfo> additional_info_;
// look up table for memorization
Expand Down
10 changes: 5 additions & 5 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class IterMapRewriter : public ExprMutator {
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Var, PrimExpr> var_map_;
// input iter marks
std::vector<IterMark> input_marks_;

Expand Down Expand Up @@ -1419,7 +1419,7 @@ bool MatchBoundConstraints(PrimExpr pred, Map<Var, Range>* input_iters,
}

bool IterRangeSanityCheck(const Map<Var, Range>& iter_ranges) {
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> iters;
std::unordered_set<Var> iters;
for (const auto& it : iter_ranges) iters.insert(it.first);
auto f = [&](const VarNode* var) { return iters.count(GetRef<Var>(var)); };
for (const auto& it : iter_ranges) {
Expand Down Expand Up @@ -2187,7 +2187,7 @@ TVM_REGISTER_GLOBAL("arith.IterMapSimplify")
class SubspaceDivider {
public:
explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
const std::unordered_set<Var>& sub_iters)
: analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}

size_t unresolved_count() const { return unresolved_count_; }
Expand Down Expand Up @@ -2455,7 +2455,7 @@ class SubspaceDivider {
// collector that collects the outgoing split reference of each IterMark
const IterMarkSplitCollector collector_;
// the set of subspace iters
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters_;
const std::unordered_set<Var>& sub_iters_;
// map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X)
std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
// predicate of outer space and inner space;
Expand All @@ -2473,7 +2473,7 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Array<IterSumExpr>& maps = res->indices;
if (maps.empty()) return {};

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
std::unordered_set<Var> inner_iter_set;
for (const Var& inner_iter : sub_iters) {
inner_iter_set.insert(inner_iter);
}
Expand Down
2 changes: 1 addition & 1 deletion src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
/*! \brief pointer to parent. */
Analyzer* parent_{nullptr};
// internal variable map
std::unordered_map<Var, Entry, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Var, Entry> var_map_;
/*!
* \brief Update var by intersecting entry with var's current set.
* \param var The variable.
Expand Down
2 changes: 1 addition & 1 deletion src/arith/rewrite_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer {
// counter to record recursive rewrite depth.
int64_t recur_depth_{0};
// internal variable map
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Var, PrimExpr> var_map_;

std::vector<PrimExpr> literal_constraints_;

Expand Down
2 changes: 1 addition & 1 deletion src/contrib/msc/core/transform/set_expr_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ class LayoutInfer : public ExprVisitor {
bool infered_;
Map<Var, Expr> var_map_;
Array<Expr> ordered_exprs_;
std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual> var_layout_map_;
std::unordered_map<Var, NLayout> var_layout_map_;
Map<Expr, Function> local_funcs_;
}; // class LayoutInfer

Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/feature_extractor/per_store_feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ Pass SimplifyForFeatureExtraction() {
}
}

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> unit_vars_;
std::unordered_set<Var> unit_vars_;
};
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
PrimFuncNode* n = f.CopyOnWrite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,9 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
const tir::IndexMap& index_map = mapping_info->mappings[0];

// Find the correspondence between block iters and the iters in the index map.
std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
std::unordered_map<tir::Var, tir::Var> lhs_to_index_map_src;
std::unordered_map<tir::Var, PrimExpr> rhs_to_index_map_tgt;
std::unordered_set<tir::Var> unmapped_index_map_src;
ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size());
for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) {
lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i];
Expand Down
2 changes: 1 addition & 1 deletion src/relax/analysis/computable_at_compile_time.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class CompileTimeCollector : ExprVisitor {
}

support::OrderedSet<Var> known_relax_vars_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> known_tir_vars_;
std::unordered_set<tir::Var> known_tir_vars_;
};
} // namespace

Expand Down
4 changes: 2 additions & 2 deletions src/relax/analysis/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ static bool AreIdenticalSpatialAccess(const SpatialLayout& s0, const SpatialLayo
* (ignoring reduction dimensions). It checks that the order of spatial iter vars in spatial layout
* of a buffer access is same as the order of spatial iter vars in block domain.
*/
using VarToBlockIndexMap = std::unordered_map<tir::Var, int, ObjectPtrHash, ObjectPtrEqual>;
using VarToBlockIndexMap = std::unordered_map<tir::Var, int>;
static bool IsSequentialAccess(const SpatialLayout& iterators,
const VarToBlockIndexMap& iter_to_block_index) {
int last_value = -1;
Expand Down Expand Up @@ -210,7 +210,7 @@ static bool AreIdenticalTransforms(const IndexMap& t0, const IndexMap& t1) {
* source spatial layout.
* target transformation = lambda dim, C, H, W -> (dim, H, W, C // 4, C %4)
*/
using VarSet = std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>;
using VarSet = std::unordered_set<tir::Var>;
static Optional<IndexMap> InferLayoutTransformation(const SpatialLayout& src_spatial_layout,
const IndexMap& src_transformation,
const SpatialLayout& tgt_spatial_layout) {
Expand Down
4 changes: 2 additions & 2 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1356,9 +1356,9 @@ class SymbolicVarCollector : public relax::ExprVisitor,
/*! \brief The current visit mode. */
VisitMode mode_ = VisitMode::kRequireDefinition;
/*! \brief The set of defined symbolic vars. */
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> defined_symbolic_var_;
std::unordered_set<tir::Var> defined_symbolic_var_;
/*! \brief The set of free/undefined symbolic vars. */
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> free_symbolic_var_;
std::unordered_set<tir::Var> free_symbolic_var_;
};

Array<tir::Var> DefinedSymbolicVars(const Expr& expr) {
Expand Down
2 changes: 1 addition & 1 deletion src/relax/analysis/udchain.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class UDChain : relax::ExprVisitor {

private:
Map<Var, Expr> bound_values;
std::unordered_map<Var, support::OrderedSet<Var>, ObjectPtrHash, ObjectPtrEqual> usage_map;
std::unordered_map<Var, support::OrderedSet<Var>> usage_map;
support::OrderedSet<Var> outputs;

Optional<Var> cur_user_{nullptr};
Expand Down
16 changes: 7 additions & 9 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -364,9 +364,8 @@ class WellFormedChecker : public relax::ExprVisitor,
Malformed(Diagnostic::Error(op) << "The condition for an if node must be a leaf expression.");
}

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> previous_var_set = var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> previous_symbolic_var_set =
symbolic_var_set_;
std::unordered_set<Var> previous_var_set = var_set_;
std::unordered_set<tir::Var> previous_symbolic_var_set = symbolic_var_set_;
this->VisitSeqExpr(op->true_branch.get());
var_set_ = previous_var_set;
symbolic_var_set_ = previous_symbolic_var_set;
Expand Down Expand Up @@ -567,13 +566,12 @@ class WellFormedChecker : public relax::ExprVisitor,
// Current visit mode.
VisitMode mode_ = VisitMode::kDefault;
// set of context variables.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> var_set_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> recur_vars_;
std::unordered_set<Var> var_set_;
std::unordered_set<Var> recur_vars_;
std::unordered_set<DataflowVar, ObjectPtrHash, ObjectPtrEqual> dataflow_var_set_;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> symbolic_var_set_;
std::unordered_map<Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual> param_var_func_map_;
std::unordered_map<tir::Var, const FunctionNode*, ObjectPtrHash, ObjectPtrEqual>
symbolic_var_func_map_;
std::unordered_set<tir::Var> symbolic_var_set_;
std::unordered_map<Var, const FunctionNode*> param_var_func_map_;
std::unordered_map<tir::Var, const FunctionNode*> symbolic_var_func_map_;

tvm::OpAttrMap<FNormalize> op_map_normalize_ = Op::GetAttrMap<FNormalize>("FNormalize");
};
Expand Down
2 changes: 1 addition & 1 deletion src/relax/backend/vm/codegen_vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> {
*/
size_t registers_num_ = 0;
/*! \brief Map from var to register number. */
std::unordered_map<Var, Instruction::Arg, ObjectPtrHash, ObjectPtrEqual> var_arg_map_;
std::unordered_map<Var, Instruction::Arg> var_arg_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
/*! \brief Cache ops that need to be frequently used later to reduce lookup overhead. */
Expand Down
2 changes: 1 addition & 1 deletion src/relax/backend/vm/codegen_vm_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class CodeGenVMTIR : public ExprFunctor<Optional<PrimExpr>(const Expr&)> {
/*! \brief Stack to build up statements */
std::vector<std::vector<tir::Stmt>> stmt_stack_;
/*! \brief Map from var to Expr. */
std::unordered_map<Var, Optional<PrimExpr>, ObjectPtrHash, ObjectPtrEqual> var_map_;
std::unordered_map<Var, Optional<PrimExpr>> var_map_;
/*! \brief the context module. */
IRModule ctx_mod_;
/*! \brief system lib prefix */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,8 +337,8 @@ class DistributedBufferCompactor : StmtExprMutator {
return new_loop;
}

std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> iter_var_shards_;
std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> loop_var_shards_;
std::unordered_map<Var, int> iter_var_shards_;
std::unordered_map<Var, int> loop_var_shards_;
Array<Buffer> allocated_buffer_under_root;
BufferAxisGraphExtractor extractor_;
std::vector<ShardingSpec> sharding_specs_;
Expand Down
3 changes: 1 addition & 2 deletions src/relax/transform/adjust_matmul_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ namespace {
std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> CreatePatterns(
const Function& func) {
auto compile_time_arr = ComputableAtCompileTime(func);
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> compile_time_lookup(
compile_time_arr.begin(), compile_time_arr.end());
std::unordered_set<Var> compile_time_lookup(compile_time_arr.begin(), compile_time_arr.end());

TypedPackedFunc<bool(Expr)> is_compile_time = [compile_time_lookup](Expr arg) -> bool {
if (auto as_var = arg.as<Var>()) {
Expand Down
4 changes: 2 additions & 2 deletions src/relax/transform/canonicalize_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ class CanonicalizePlanner : public ExprVisitor {
Map<Var, Var> trivial_bindings_;
Map<Var, Expr> known_bindings_;
Map<Var, Constant> known_bound_to_constant_;
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> defined_inside_dataflow_;
std::unordered_set<Var> defined_inside_dataflow_;
// Set of vars either used outside a dataflow block altogether or outside their
// home dataflow block (the one where they were defined)
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> used_outside_home_dataflow_;
std::unordered_set<Var> used_outside_home_dataflow_;
};

/*! \brief The mutator class to apply a CanonicalizationPlan */
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class LayoutConvertMutator : public ExprMutator {
}
}

std::unordered_map<Var, NLayout, ObjectPtrHash, ObjectPtrEqual> var_layout_map_;
std::unordered_map<Var, NLayout> var_layout_map_;
Map<String, Array<String>> desired_layouts_;
}; // namespace relax

Expand Down
Loading