From 7fdf1f8df338ba93afed6d43412b98dd7bf3ff32 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 10 Mar 2025 19:55:06 -0400 Subject: [PATCH] [IR] Compact Functor vtable This PR add a finalize routine to optionally compact functor vtable dynamically. Also updates child_slots for key types to make sure the IR node type index stay within range and such compact happens. --- include/tvm/arith/iter_affine_map.h | 2 +- include/tvm/ir/expr.h | 4 +-- include/tvm/ir/type_functor.h | 1 + include/tvm/node/functor.h | 28 +++++++++++++++++++- include/tvm/relax/dataflow_pattern.h | 2 ++ include/tvm/relax/dataflow_pattern_functor.h | 2 +- include/tvm/relax/expr.h | 4 +-- include/tvm/relax/expr_functor.h | 1 + include/tvm/relax/struct_info_functor.h | 1 + include/tvm/tir/expr_functor.h | 1 + include/tvm/tir/stmt_functor.h | 1 + src/ir/attr_functor.h | 1 + src/relax/ir/py_expr_functor.cc | 3 +++ src/runtime/object.cc | 10 ++++++- 14 files changed, 53 insertions(+), 8 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 53c5b32dd25d..d2a6f9a745b4 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -69,7 +69,7 @@ class IterMapExprNode : public PrimExprNode { void VisitAttrs(tvm::AttrVisitor* v) {} static constexpr const char* _type_key = "arith.IterMapExpr"; - static constexpr const uint32_t _type_child_slots = 3; + static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b3b4e8ab32fd..53af26975648 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -58,7 +58,7 @@ class BaseExprNode : public Object { static constexpr const char* _type_key = "BaseExpr"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const uint32_t _type_child_slots = 62; + static constexpr const uint32_t _type_child_slots = 64; TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object); }; @@ -104,7 +104,7 @@ class PrimExprNode : public BaseExprNode { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "PrimExpr"; - static constexpr const uint32_t _type_child_slots = 38; + static constexpr const uint32_t _type_child_slots = 40; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); }; diff --git a/include/tvm/ir/type_functor.h b/include/tvm/ir/type_functor.h index 2c145e480b84..858226354c66 100644 --- a/include/tvm/ir/type_functor.h +++ b/include/tvm/ir/type_functor.h @@ -93,6 +93,7 @@ class TypeFunctor { TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode); TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode); + vtable.Finalize(); return vtable; } }; diff --git a/include/tvm/node/functor.h b/include/tvm/node/functor.h index 58d59c81cb16..82ea37566eb5 100644 --- a/include/tvm/node/functor.h +++ b/include/tvm/node/functor.h @@ -26,6 +26,7 @@ #include #include +#include #include #include #include @@ -72,6 +73,8 @@ class NodeFunctor { using TSelf = NodeFunctor; /*! \brief internal function table */ std::vector func_; + /*! \brief start range of func index */ + uint32_t begin_type_index_{0}; public: /*! \brief the result type of this functor */ @@ -83,6 +86,8 @@ class NodeFunctor { */ bool can_dispatch(const ObjectRef& n) const { uint32_t type_index = n->type_index(); + if (type_index < begin_type_index_) return false; + type_index -= begin_type_index_; return type_index < func_.size() && func_[type_index] != nullptr; } /*! @@ -94,7 +99,7 @@ class NodeFunctor { R operator()(const ObjectRef& n, Args... args) const { ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type " << n->GetTypeKey(); - return (*func_[n->type_index()])(n, std::forward(args)...); + return (*func_[n->type_index() - begin_type_index_])(n, std::forward(args)...); } /*! * \brief set the dispatcher for type TNode @@ -109,6 +114,7 @@ class NodeFunctor { func_.resize(tindex + 1, nullptr); } ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set"; + ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after calling Finalize"; func_[tindex] = f; return *this; } @@ -122,9 +128,29 @@ class NodeFunctor { TSelf& clear_dispatch() { // NOLINT(*) uint32_t tindex = TNode::RuntimeTypeIndex(); ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range"; + ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after calling Finalize"; func_[tindex] = nullptr; return *this; } + /*! + * \brief Finalize the functor after calling sequence of set_dispatch + * This function will attempt to find the min type index that is not null + * and optimize the space of the func table so it is more compact + */ + void Finalize() { + ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once"; + while (begin_type_index_ < func_.size() && func_[begin_type_index_] == nullptr) { + ++begin_type_index_; + } + // shift up the function value + size_t new_ftable_size = func_.size() - begin_type_index_; + if (begin_type_index_ != 0) { + std::memmove(func_.data(), func_.data() + begin_type_index_, + new_ftable_size * sizeof(FPointer)); + } + func_.resize(new_ftable_size); + func_.shrink_to_fit(); + } }; #define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName diff --git a/include/tvm/relax/dataflow_pattern.h b/include/tvm/relax/dataflow_pattern.h index df9fdcad9759..b3bbebd0e06c 100644 --- a/include/tvm/relax/dataflow_pattern.h +++ b/include/tvm/relax/dataflow_pattern.h @@ -91,6 +91,7 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs); class DFPatternNode : public Object { public: static constexpr const char* _type_key = "DFPatternNode"; + static constexpr const uint32_t _type_child_slots = 21; TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object); }; @@ -373,6 +374,7 @@ class VarPatternNode : public DFPatternNode { void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); } static constexpr const char* _type_key = "relax.dpl.VarPattern"; + static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode); }; diff --git a/include/tvm/relax/dataflow_pattern_functor.h b/include/tvm/relax/dataflow_pattern_functor.h index bbdda4421399..fb67f3cc4aca 100644 --- a/include/tvm/relax/dataflow_pattern_functor.h +++ b/include/tvm/relax/dataflow_pattern_functor.h @@ -135,12 +135,12 @@ class DFPatternFunctor { RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); - RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode); RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode); + vtable.Finalize(); return vtable; } }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index fb6f0e40b130..330ff7e8dab0 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -119,7 +119,7 @@ class StructInfoNode : public Object { static constexpr const char* _type_key = "StructInfo"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const uint32_t _type_child_slots = 5; + static constexpr const uint32_t _type_child_slots = 7; TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object); }; @@ -416,7 +416,7 @@ class VarNode : public LeafExprNode { static constexpr const char* _type_key = "relax.expr.Var"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const uint32_t _type_child_slots = 2; + static constexpr const uint32_t _type_child_slots = 1; TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode); }; diff --git a/include/tvm/relax/expr_functor.h b/include/tvm/relax/expr_functor.h index 9c867129fdd2..a6419687ee16 100644 --- a/include/tvm/relax/expr_functor.h +++ b/include/tvm/relax/expr_functor.h @@ -176,6 +176,7 @@ class ExprFunctor { RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode); RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode); RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode); + vtable.Finalize(); return vtable; } }; diff --git a/include/tvm/relax/struct_info_functor.h b/include/tvm/relax/struct_info_functor.h index 8418b48dc182..2ce562754791 100644 --- a/include/tvm/relax/struct_info_functor.h +++ b/include/tvm/relax/struct_info_functor.h @@ -108,6 +108,7 @@ class StructInfoFunctor { TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode); TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode); TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode); + vtable.Finalize(); return vtable; } }; diff --git a/include/tvm/tir/expr_functor.h b/include/tvm/tir/expr_functor.h index 3f66164b42c0..7a9cf91a65af 100644 --- a/include/tvm/tir/expr_functor.h +++ b/include/tvm/tir/expr_functor.h @@ -193,6 +193,7 @@ class ExprFunctor { IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode); IR_EXPR_FUNCTOR_DISPATCH(StringImmNode); IR_EXPR_FUNCTOR_DISPATCH(AnyNode); + vtable.Finalize(); return vtable; } }; diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index c5b20f8ec00d..e9a41468d310 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -126,6 +126,7 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); IR_STMT_FUNCTOR_DISPATCH(BlockNode); IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); + vtable.Finalize(); return vtable; } }; diff --git a/src/ir/attr_functor.h b/src/ir/attr_functor.h index 12b4f6f65b11..008e63fffc7e 100644 --- a/src/ir/attr_functor.h +++ b/src/ir/attr_functor.h @@ -139,6 +139,7 @@ class AttrFunctor { ATTR_FUNCTOR_DISPATCH(CastNode); ATTR_FUNCTOR_DISPATCH(CallNode); ATTR_FUNCTOR_DISPATCH(SelectNode); + vtable.Finalize(); return vtable; } }; diff --git a/src/relax/ir/py_expr_functor.cc b/src/relax/ir/py_expr_functor.cc index a7ac2456107f..eb286b4ef63c 100644 --- a/src/relax/ir/py_expr_functor.cc +++ b/src/relax/ir/py_expr_functor.cc @@ -161,6 +161,7 @@ class PyExprVisitorNode : public Object, public ExprVisitor { PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_); PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_); PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + vtable.Finalize(); return vtable; } }; @@ -414,6 +415,7 @@ class PyExprMutatorNode : public Object, public ExprMutator { PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_); PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_); PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_); + vtable.Finalize(); return vtable; } @@ -437,6 +439,7 @@ class PyExprMutatorNode : public Object, public ExprMutator { PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode); PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode); + post_order_vtable.Finalize(); return post_order_vtable; } }; diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 05bfd6d1cf80..85ec4f036020 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -170,10 +170,17 @@ class TypeContext { void Dump(int min_children_count) { std::vector num_children(type_table_.size(), 0); + // expected child slots compute the expected slots + // based on the current child slot setting + std::vector expected_child_slots(type_table_.size(), 0); // reverse accumulation so we can get total counts in a bottom-up manner. for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) { if (it->index != 0) { num_children[it->parent_index] += num_children[it->index] + 1; + if (static_cast(expected_child_slots[it->index] + 1) < it->num_slots) { + expected_child_slots[it->index] = it->num_slots - 1; + } + expected_child_slots[it->parent_index] += expected_child_slots[it->index] + 1; } } @@ -182,7 +189,8 @@ class TypeContext { std::cerr << '[' << info.index << "] " << info.name << "\tparent=" << type_table_[info.parent_index].name << "\tnum_child_slots=" << info.num_slots - 1 - << "\tnum_children=" << num_children[info.index] << std::endl; + << "\tnum_children=" << num_children[info.index] + << "\texpected_child_slots=" << expected_child_slots[info.index] << std::endl; } } }