Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class IterMapExprNode : public PrimExprNode {
void VisitAttrs(tvm::AttrVisitor* v) {}

static constexpr const char* _type_key = "arith.IterMapExpr";
static constexpr const uint32_t _type_child_slots = 3;
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
};

Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class BaseExprNode : public Object {
static constexpr const char* _type_key = "BaseExpr";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 62;
static constexpr const uint32_t _type_child_slots = 64;
TVM_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

Expand Down Expand Up @@ -104,7 +104,7 @@ class PrimExprNode : public BaseExprNode {
TVM_OBJECT_ENABLE_SCRIPT_PRINTER();

static constexpr const char* _type_key = "PrimExpr";
static constexpr const uint32_t _type_child_slots = 38;
static constexpr const uint32_t _type_child_slots = 40;
TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(TupleTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
28 changes: 27 additions & 1 deletion include/tvm/node/functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/object.h>

#include <cstring>
#include <type_traits>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -72,6 +73,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
using TSelf = NodeFunctor<R(const ObjectRef& n, Args...)>;
/*! \brief internal function table */
std::vector<FPointer> func_;
/*! \brief start range of func index */
uint32_t begin_type_index_{0};

public:
/*! \brief the result type of this functor */
Expand All @@ -83,6 +86,8 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
*/
bool can_dispatch(const ObjectRef& n) const {
uint32_t type_index = n->type_index();
if (type_index < begin_type_index_) return false;
type_index -= begin_type_index_;
return type_index < func_.size() && func_[type_index] != nullptr;
}
/*!
Expand All @@ -94,7 +99,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
R operator()(const ObjectRef& n, Args... args) const {
ICHECK(can_dispatch(n)) << "NodeFunctor calls un-registered function on type "
<< n->GetTypeKey();
return (*func_[n->type_index()])(n, std::forward<Args>(args)...);
return (*func_[n->type_index() - begin_type_index_])(n, std::forward<Args>(args)...);
}
/*!
* \brief set the dispatcher for type TNode
Expand All @@ -109,6 +114,7 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
func_.resize(tindex + 1, nullptr);
}
ICHECK(func_[tindex] == nullptr) << "Dispatch for " << TNode::_type_key << " is already set";
ICHECK_EQ(begin_type_index_, 0) << " Cannot call set_dispatch after calling Finalize";
func_[tindex] = f;
return *this;
}
Expand All @@ -122,9 +128,29 @@ class NodeFunctor<R(const ObjectRef& n, Args...)> {
TSelf& clear_dispatch() { // NOLINT(*)
uint32_t tindex = TNode::RuntimeTypeIndex();
ICHECK_LT(tindex, func_.size()) << "clear_dispatch: index out of range";
ICHECK_EQ(begin_type_index_, 0) << " Cannot call clear_dispatch after calling Finalize";
func_[tindex] = nullptr;
return *this;
}
/*!
* \brief Finalize the functor after calling sequence of set_dispatch
* This function will attempt to find the min type index that is not null
* and optimize the space of the func table so it is more compact
*/
void Finalize() {
ICHECK_EQ(begin_type_index_, 0) << "Can only call Finalize once";
while (begin_type_index_ < func_.size() && func_[begin_type_index_] == nullptr) {
++begin_type_index_;
}
// shift up the function value
size_t new_ftable_size = func_.size() - begin_type_index_;
if (begin_type_index_ != 0) {
std::memmove(func_.data(), func_.data() + begin_type_index_,
new_ftable_size * sizeof(FPointer));
}
func_.resize(new_ftable_size);
func_.shrink_to_fit();
}
};

#define TVM_REG_FUNC_VAR_DEF(ClsName) static TVM_ATTRIBUTE_UNUSED auto& __make_functor##_##ClsName
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/relax/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ TVM_DLL PatternSeq operator>>(const PatternSeq& lhs, const PatternSeq& rhs);
class DFPatternNode : public Object {
public:
static constexpr const char* _type_key = "DFPatternNode";
static constexpr const uint32_t _type_child_slots = 21;
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternNode, Object);
};

Expand Down Expand Up @@ -373,6 +374,7 @@ class VarPatternNode : public DFPatternNode {
void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name", &name); }

static constexpr const char* _type_key = "relax.dpl.VarPattern";
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarPatternNode, DFPatternNode);
};

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relax/dataflow_pattern_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAX_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);

RELAX_DFPATTERN_FUNCTOR_DISPATCH(DataflowVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(GlobalVarPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(ExternFuncPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(PrimArrPatternNode);
RELAX_DFPATTERN_FUNCTOR_DISPATCH(UnorderedTuplePatternNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class StructInfoNode : public Object {
static constexpr const char* _type_key = "StructInfo";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 5;
static constexpr const uint32_t _type_child_slots = 7;
TVM_DECLARE_BASE_OBJECT_INFO(StructInfoNode, Object);
};

Expand Down Expand Up @@ -416,7 +416,7 @@ class VarNode : public LeafExprNode {
static constexpr const char* _type_key = "relax.expr.Var";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
static constexpr const uint32_t _type_child_slots = 2;
static constexpr const uint32_t _type_child_slots = 1;
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, LeafExprNode);
};

Expand Down
1 change: 1 addition & 0 deletions include/tvm/relax/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAX_EXPR_FUNCTOR_DISPATCH(PrimValueNode);
RELAX_EXPR_FUNCTOR_DISPATCH(StringImmNode);
RELAX_EXPR_FUNCTOR_DISPATCH(DataTypeImmNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relax/struct_info_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ class StructInfoFunctor<R(const StructInfo& n, Args...)> {
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(distributed::DTensorStructInfoNode);
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(TupleStructInfoNode);
TVM_STRUCT_INFO_FUNCTOR_DISPATCH(FuncStructInfoNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(FloatImmNode);
IR_EXPR_FUNCTOR_DISPATCH(StringImmNode);
IR_EXPR_FUNCTOR_DISPATCH(AnyNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
1 change: 1 addition & 0 deletions src/ir/attr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class AttrFunctor<R(const ObjectRef& n, Args...)> {
ATTR_FUNCTOR_DISPATCH(CastNode);
ATTR_FUNCTOR_DISPATCH(CallNode);
ATTR_FUNCTOR_DISPATCH(SelectNode);
vtable.Finalize();
return vtable;
}
};
Expand Down
3 changes: 3 additions & 0 deletions src/relax/ir/py_expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ class PyExprVisitorNode : public Object, public ExprVisitor {
PY_EXPR_VISITOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
PY_EXPR_VISITOR_DISPATCH(StringImmNode, f_visit_string_imm_);
PY_EXPR_VISITOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
vtable.Finalize();
return vtable;
}
};
Expand Down Expand Up @@ -414,6 +415,7 @@ class PyExprMutatorNode : public Object, public ExprMutator {
PY_EXPR_MUTATOR_DISPATCH(PrimValueNode, f_visit_prim_value_);
PY_EXPR_MUTATOR_DISPATCH(StringImmNode, f_visit_string_imm_);
PY_EXPR_MUTATOR_DISPATCH(DataTypeImmNode, f_visit_data_type_imm_);
vtable.Finalize();
return vtable;
}

Expand All @@ -437,6 +439,7 @@ class PyExprMutatorNode : public Object, public ExprMutator {
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(PrimValueNode);
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(StringImmNode);
PY_EXPR_MUTATOR_VISIT_EXPR_POST_ORDER_DISPATCH(DataTypeImmNode);
post_order_vtable.Finalize();
return post_order_vtable;
}
};
Expand Down
10 changes: 9 additions & 1 deletion src/runtime/object.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,17 @@ class TypeContext {

void Dump(int min_children_count) {
std::vector<int> num_children(type_table_.size(), 0);
// expected child slots compute the expected slots
// based on the current child slot setting
std::vector<int> expected_child_slots(type_table_.size(), 0);
// reverse accumulation so we can get total counts in a bottom-up manner.
for (auto it = type_table_.rbegin(); it != type_table_.rend(); ++it) {
if (it->index != 0) {
num_children[it->parent_index] += num_children[it->index] + 1;
if (static_cast<uint32_t>(expected_child_slots[it->index] + 1) < it->num_slots) {
expected_child_slots[it->index] = it->num_slots - 1;
}
expected_child_slots[it->parent_index] += expected_child_slots[it->index] + 1;
}
}

Expand All @@ -182,7 +189,8 @@ class TypeContext {
std::cerr << '[' << info.index << "] " << info.name
<< "\tparent=" << type_table_[info.parent_index].name
<< "\tnum_child_slots=" << info.num_slots - 1
<< "\tnum_children=" << num_children[info.index] << std::endl;
<< "\tnum_children=" << num_children[info.index]
<< "\texpected_child_slots=" << expected_child_slots[info.index] << std::endl;
}
}
}
Expand Down
Loading