From 35795600fdf1b59e72d8408d0e92cbe5407afb55 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Jun 2025 16:20:12 -0400 Subject: [PATCH 1/6] [FFI] Support IntEnum conversion with Any and allow base class field registration --- ffi/include/tvm/ffi/reflection/reflection.h | 14 ++++---- ffi/include/tvm/ffi/type_traits.h | 40 +++++++++++++++++++++ ffi/tests/cpp/test_any.cc | 18 ++++++++++ 3 files changed, 66 insertions(+), 6 deletions(-) diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index 04d96857cb9a..5f34a65e4e47 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -162,8 +162,9 @@ class ObjectDef : public ReflectionDefBase { * * \return The reflection definition. */ - template - TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T Class::*field_ptr, Extra&&... extra) { + template >> + TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { RegisterField(name, field_ptr, false, std::forward(extra)...); return *this; } @@ -181,8 +182,9 @@ class ObjectDef : public ReflectionDefBase { * * \return The reflection definition. */ - template - TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T Class::*field_ptr, Extra&&... extra) { + template >> + TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); RegisterField(name, field_ptr, true, std::forward(extra)...); return *this; @@ -239,8 +241,8 @@ class ObjectDef : public ReflectionDefBase { TVM_FFI_CHECK_SAFE_CALL(TVMFFITypeRegisterExtraInfo(type_index_, &info)); } - template - void RegisterField(const char* name, T Class::*field_ptr, bool writable, + template + void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable, ExtraArgs&&... extra_args) { TVMFFIFieldInfo info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; diff --git a/ffi/include/tvm/ffi/type_traits.h b/ffi/include/tvm/ffi/type_traits.h index b33a6dad6f26..0cb03a1e7f4d 100644 --- a/ffi/include/tvm/ffi/type_traits.h +++ b/ffi/include/tvm/ffi/type_traits.h @@ -274,6 +274,46 @@ struct TypeTraits>> : public TypeT static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } }; +// Enum Integer POD values +template +struct TypeTraits && + std::is_integral_v>>> + : public TypeTraitsBase { + static constexpr int32_t field_static_type_index = TypeIndex::kTVMFFIInt; + + static TVM_FFI_INLINE void CopyToAnyView(const IntEnum& src, TVMFFIAny* result) { + result->type_index = TypeIndex::kTVMFFIInt; + result->v_int64 = static_cast(src); + } + + static TVM_FFI_INLINE void MoveToAny(IntEnum src, TVMFFIAny* result) { + CopyToAnyView(src, result); + } + + static TVM_FFI_INLINE bool CheckAnyStrict(const TVMFFIAny* src) { + // NOTE: CheckAnyStrict is always strict and should be consistent with MoveToAny + return src->type_index == TypeIndex::kTVMFFIInt; + } + + static TVM_FFI_INLINE IntEnum CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + return static_cast(src->v_int64); + } + + static TVM_FFI_INLINE IntEnum MoveFromAnyAfterCheck(TVMFFIAny* src) { + // POD type, we can just copy the value + return CopyFromAnyViewAfterCheck(src); + } + + static TVM_FFI_INLINE std::optional TryCastFromAnyView(const TVMFFIAny* src) { + if (src->type_index == TypeIndex::kTVMFFIInt || src->type_index == TypeIndex::kTVMFFIBool) { + return static_cast(src->v_int64); + } + return std::nullopt; + } + + static TVM_FFI_INLINE std::string TypeStr() { return StaticTypeKey::kTVMFFIInt; } +}; + // Float POD values template struct TypeTraits>> diff --git a/ffi/tests/cpp/test_any.cc b/ffi/tests/cpp/test_any.cc index eea18c7c644f..a1a2b4514a17 100644 --- a/ffi/tests/cpp/test_any.cc +++ b/ffi/tests/cpp/test_any.cc @@ -60,6 +60,24 @@ TEST(Any, Int) { EXPECT_EQ(view0.CopyToTVMFFIAny().v_int64, 2); } +TEST(Any, Enum) { + enum class ENum : int { + A = 1, + B = 2, + }; + + AnyView view0; + Optional opt_v0 = view0.as(); + EXPECT_TRUE(!opt_v0.has_value()); + + AnyView view1 = ENum::A; + EXPECT_EQ(view1.CopyToTVMFFIAny().type_index, TypeIndex::kTVMFFIInt); + EXPECT_EQ(view1.CopyToTVMFFIAny().v_int64, 1); + + ENum v1 = view1.cast(); + EXPECT_EQ(v1, ENum::A); +} + TEST(Any, bool) { AnyView view0; Optional opt_v0 = view0.as(); From 97e88fad8a23acf9f9ffa133bd0de0934c99df10 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Jun 2025 15:29:17 -0400 Subject: [PATCH 2/6] Upgrade the analyer and tir structures to use new reflection mechanism --- ffi/include/tvm/ffi/reflection/reflection.h | 7 +- include/tvm/arith/analyzer.h | 21 +- include/tvm/arith/int_solver.h | 37 +-- include/tvm/arith/iter_affine_map.h | 55 +++-- include/tvm/ir/expr.h | 67 ++++-- include/tvm/tir/expr.h | 205 ++++++++-------- include/tvm/tir/stmt.h | 245 ++++++++++++-------- include/tvm/tir/var.h | 29 ++- src/arith/canonical_simplify.cc | 3 - src/arith/const_int_bound.cc | 2 + src/arith/int_constraints.cc | 6 + src/arith/iter_affine_map.cc | 7 + src/arith/modular_set.cc | 2 + src/ir/expr.cc | 14 ++ src/node/serialization.cc | 83 ++++--- src/script/printer/ir_docsifier.cc | 25 +- src/tir/ir/expr.cc | 36 +++ src/tir/ir/stmt.cc | 21 ++ 18 files changed, 542 insertions(+), 323 deletions(-) diff --git a/ffi/include/tvm/ffi/reflection/reflection.h b/ffi/include/tvm/ffi/reflection/reflection.h index 5f34a65e4e47..0a5e836e1aa6 100644 --- a/ffi/include/tvm/ffi/reflection/reflection.h +++ b/ffi/include/tvm/ffi/reflection/reflection.h @@ -162,8 +162,7 @@ class ObjectDef : public ReflectionDefBase { * * \return The reflection definition. */ - template >> + template TVM_FFI_INLINE ObjectDef& def_ro(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { RegisterField(name, field_ptr, false, std::forward(extra)...); return *this; @@ -182,8 +181,7 @@ class ObjectDef : public ReflectionDefBase { * * \return The reflection definition. */ - template >> + template TVM_FFI_INLINE ObjectDef& def_rw(const char* name, T BaseClass::*field_ptr, Extra&&... extra) { static_assert(Class::_type_mutable, "Only mutable classes are supported for writable fields"); RegisterField(name, field_ptr, true, std::forward(extra)...); @@ -244,6 +242,7 @@ class ObjectDef : public ReflectionDefBase { template void RegisterField(const char* name, T BaseClass::*field_ptr, bool writable, ExtraArgs&&... extra_args) { + static_assert(std::is_base_of_v, "BaseClass must be a base class of Class"); TVMFFIFieldInfo info; info.name = TVMFFIByteArray{name, std::char_traits::length(name)}; info.field_static_type_index = TypeToFieldStaticTypeIndex::value; diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 7358394d3a25..a1c098a3f61f 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -86,11 +87,15 @@ class ConstIntBoundNode : public Object { int64_t min_value; int64_t max_value; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("min_value", &min_value); - v->Visit("max_value", &max_value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("min_value", &ConstIntBoundNode::min_value) + .def_ro("max_value", &ConstIntBoundNode::max_value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ConstIntBoundNode* other, SEqualReducer equal) const { return equal(min_value, other->min_value) && equal(max_value, other->max_value); } @@ -208,11 +213,15 @@ class ModularSetNode : public Object { /*! \brief The base */ int64_t base; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("coeff", &coeff); - v->Visit("base", &base); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("coeff", &ModularSetNode::coeff) + .def_ro("base", &ModularSetNode::base); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ModularSetNode* other, SEqualReducer equal) const { return equal(coeff, other->coeff) && equal(base, other->base); } diff --git a/include/tvm/arith/int_solver.h b/include/tvm/arith/int_solver.h index 0ef74ce0d5ce..4716dc7aa274 100644 --- a/include/tvm/arith/int_solver.h +++ b/include/tvm/arith/int_solver.h @@ -62,11 +62,13 @@ class IntGroupBoundsNode : public Object { Array equal; Array upper; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("coef", &coef); - v->Visit("lower", &lower); - v->Visit("equal", &equal); - v->Visit("upper", &upper); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("coef", &IntGroupBoundsNode::coef) + .def_ro("lower", &IntGroupBoundsNode::lower) + .def_ro("equal", &IntGroupBoundsNode::equal) + .def_ro("upper", &IntGroupBoundsNode::upper); } bool SEqualReduce(const IntGroupBoundsNode* other, SEqualReducer eq) const { @@ -81,6 +83,7 @@ class IntGroupBoundsNode : public Object { hash_reduce(upper); } + static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntGroupBounds"; TVM_DECLARE_FINAL_OBJECT_INFO(IntGroupBoundsNode, Object); @@ -152,10 +155,12 @@ class IntConstraintsNode : public Object { // e.g., A \alpha = \beta or A \alpha <= \beta Array relations; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("variables", &variables); - v->Visit("ranges", &ranges); - v->Visit("relations", &relations); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("variables", &IntConstraintsNode::variables) + .def_ro("ranges", &IntConstraintsNode::ranges) + .def_ro("relations", &IntConstraintsNode::relations); } bool SEqualReduce(const IntConstraintsNode* other, SEqualReducer equal) const { @@ -169,6 +174,7 @@ class IntConstraintsNode : public Object { hash_reduce(relations); } + static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraints"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsNode, Object); @@ -213,11 +219,13 @@ class IntConstraintsTransformNode : public Object { Map src_to_dst; Map dst_to_src; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("src", &src); - v->Visit("dst", &dst); - v->Visit("src_to_dst", &src_to_dst); - v->Visit("dst_to_src", &dst_to_src); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("src", &IntConstraintsTransformNode::src) + .def_ro("dst", &IntConstraintsTransformNode::dst) + .def_ro("src_to_dst", &IntConstraintsTransformNode::src_to_dst) + .def_ro("dst_to_src", &IntConstraintsTransformNode::dst_to_src); } bool SEqualReduce(const IntConstraintsTransformNode* other, SEqualReducer equal) const { @@ -232,6 +240,7 @@ class IntConstraintsTransformNode : public Object { hash_reduce(dst_to_src); } + static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const char* _type_key = "arith.IntConstraintsTransform"; TVM_DECLARE_FINAL_OBJECT_INFO(IntConstraintsTransformNode, Object); diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index d2a6f9a745b4..0b6b8e4ba77f 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -52,6 +52,7 @@ #include #include #include +#include namespace tvm { namespace arith { @@ -65,9 +66,7 @@ namespace arith { */ class IterMapExprNode : public PrimExprNode { public: - // overrides - void VisitAttrs(tvm::AttrVisitor* v) {} - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "arith.IterMapExpr"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); @@ -100,12 +99,15 @@ class IterMarkNode : public Object { */ PrimExpr extent; - // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source", &source); - v->Visit("extent", &extent); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("source", &IterMarkNode::source) + .def_ro("extent", &IterMarkNode::extent); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { equal->MarkGraphNode(); return equal(source, other->source) && equal(extent, other->extent); @@ -156,14 +158,17 @@ class IterSplitExprNode : public IterMapExprNode { /*! \brief Additional scale. */ PrimExpr scale; - // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("source", &source); - v->Visit("lower_factor", &lower_factor); - v->Visit("extent", &extent); - v->Visit("scale", &scale); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("source", &IterSplitExprNode::source) + .def_ro("lower_factor", &IterSplitExprNode::lower_factor) + .def_ro("extent", &IterSplitExprNode::extent) + .def_ro("scale", &IterSplitExprNode::scale); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const { return equal(source, other->source) && equal(lower_factor, other->lower_factor) && equal(extent, other->extent) && equal(scale, other->scale); @@ -223,12 +228,15 @@ class IterSumExprNode : public IterMapExprNode { /*! \brief The base offset. */ PrimExpr base; - // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("args", &args); - v->Visit("base", &base); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("args", &IterSumExprNode::args) + .def_ro("base", &IterSumExprNode::base); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const { return equal(args, other->args) && equal(base, other->base); } @@ -291,13 +299,16 @@ class IterMapResultNode : public Object { */ PrimExpr padding_predicate; - // overrides - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("errors", &errors); - v->Visit("indices", &indices); - v->Visit("padding_predicate", &padding_predicate); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("indices", &IterMapResultNode::indices) + .def_ro("errors", &IterMapResultNode::errors) + .def_ro("padding_predicate", &IterMapResultNode::padding_predicate); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "arith.IterMapResult"; TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); }; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 381be8514916..8541983212ca 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -55,6 +56,11 @@ class BaseExprNode : public Object { */ mutable Span span; + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("span", &BaseExprNode::span); + } + static constexpr const char* _type_key = "BaseExpr"; static constexpr const bool _type_has_method_visit_attrs = true; static constexpr const bool _type_has_method_sequal_reduce = true; @@ -102,6 +108,14 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); + } + + static constexpr const bool _type_has_method_visit_attrs = false; + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "PrimExpr"; @@ -130,6 +144,12 @@ class PrimExpr : public BaseExpr { DataType dtype() const { return static_cast(get())->dtype; } TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); + + /*! + * \brief construct from string to form a StringImm. + * \param value The value to be constructed. + */ + TVM_DLL static PrimExpr ConvertFallbackValue(String value); // NOLINT(*) }; /*! @@ -168,7 +188,9 @@ struct TypeTraits static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(StrictBool value); static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(int64_t value); static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(double value); - static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(String value); + static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(String value) { + return PrimExpr::ConvertFallbackValue(value); + } static TVM_FFI_INLINE PrimExpr ConvertFallbackValue(PrimExprConvertible value) { return value->ToPrimExpr(); } @@ -407,6 +429,11 @@ class RelaxExprNode : public BaseExprNode { */ mutable Optional struct_info_ = Optional(); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("struct_info_", &RelaxExprNode::struct_info_); + } + static constexpr const char* _type_key = "RelaxExpr"; static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode); @@ -435,10 +462,11 @@ class GlobalVarNode : public RelaxExprNode { /*! \brief The name of the variable, this only acts as a hint. */ String name_hint; - void VisitAttrs(AttrVisitor* v) { - v->Visit("name_hint", &name_hint); - v->Visit("span", &span); - v->Visit("struct_info_", &struct_info_); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("name_hint", &GlobalVarNode::name_hint); } bool SEqualReduce(const GlobalVarNode* other, SEqualReducer equal) const { @@ -467,8 +495,6 @@ class GlobalVar : public RelaxExpr { TVM_DEFINE_OBJECT_REF_COW_METHOD(GlobalVarNode); }; -// PrimExprs that are useful as runtime containers. -// /*! * \brief Constant integer literals in the program. * \sa IntImm @@ -478,10 +504,9 @@ class IntImmNode : public PrimExprNode { /*! \brief the Internal value. */ int64_t value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &IntImmNode::value); } bool SEqualReduce(const IntImmNode* other, SEqualReducer equal) const { @@ -525,10 +550,11 @@ class FloatImmNode : public PrimExprNode { /*! \brief The constant value content. */ double value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &FloatImmNode::value); } bool SEqualReduce(const FloatImmNode* other, SEqualReducer equal) const { @@ -675,10 +701,13 @@ class RangeNode : public Object { RangeNode(PrimExpr min, PrimExpr extent, Span span = Span()) : min(min), extent(extent), span(span) {} - void VisitAttrs(AttrVisitor* v) { - v->Visit("min", &min); - v->Visit("extent", &extent); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("min", &RangeNode::min) + .def_ro("extent", &RangeNode::extent); } bool SEqualReduce(const RangeNode* other, SEqualReducer equal) const { diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 5f058f7d5e4c..fb02c9147e96 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -55,10 +55,11 @@ class StringImmNode : public PrimExprNode { /*! \brief The constant value content. */ String value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &StringImmNode::value); } bool SEqualReduce(const StringImmNode* other, SEqualReducer equal) const { @@ -91,10 +92,11 @@ class CastNode : public PrimExprNode { /*! \brief Original data type. */ PrimExpr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &CastNode::value); } bool SEqualReduce(const CastNode* other, SEqualReducer equal) const { @@ -133,11 +135,13 @@ class BinaryOpNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->dtype)); - v->Visit("a", &a); - v->Visit("b", &b); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &T::a) + .def_ro("b", &T::b); } bool SEqualReduce(const T* other, SEqualReducer equal) const { @@ -325,11 +329,13 @@ class CmpOpNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->dtype)); - v->Visit("a", &a); - v->Visit("b", &b); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &T::a) + .def_ro("b", &T::b); } bool SEqualReduce(const T* other, SEqualReducer equal) const { @@ -455,11 +461,13 @@ class AndNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->dtype)); - v->Visit("a", &a); - v->Visit("b", &b); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &AndNode::a) + .def_ro("b", &AndNode::b); } bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { @@ -495,11 +503,13 @@ class OrNode : public PrimExprNode { /*! \brief The right operand. */ PrimExpr b; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("a", &a); - v->Visit("b", &b); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("a", &OrNode::a) + .def_ro("b", &OrNode::b); } bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { @@ -533,10 +543,11 @@ class NotNode : public PrimExprNode { /*! \brief The input operand. */ PrimExpr a; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("a", &a); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("a", &NotNode::a); } bool SEqualReduce(const NotNode* other, SEqualReducer equal) const { @@ -579,12 +590,14 @@ class SelectNode : public PrimExprNode { /*! \brief value to be returned when condition is false. */ PrimExpr false_value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("condition", &condition); - v->Visit("true_value", &true_value); - v->Visit("false_value", &false_value); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &SelectNode::condition) + .def_ro("true_value", &SelectNode::true_value) + .def_ro("false_value", &SelectNode::false_value); } bool SEqualReduce(const SelectNode* other, SEqualReducer equal) const { @@ -634,12 +647,12 @@ class BufferLoadNode : public PrimExprNode { /*! \brief The predicate mask for loading values. */ Optional predicate; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->dtype)); - v->Visit("buffer", &buffer); - v->Visit("indices", &indices); - v->Visit("predicate", &predicate); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &BufferLoadNode::buffer) + .def_ro("indices", &BufferLoadNode::indices) + .def_ro("predicate", &BufferLoadNode::predicate); } bool SEqualReduce(const BufferLoadNode* other, SEqualReducer equal) const { @@ -702,11 +715,11 @@ class ProducerLoadNode : public PrimExprNode { /*! \brief The location arguments. */ Array indices; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &(this->dtype)); - v->Visit("producer", &producer); - v->Visit("indices", &indices); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("producer", &ProducerLoadNode::producer) + .def_ro("indices", &ProducerLoadNode::indices); } bool SEqualReduce(const ProducerLoadNode* other, SEqualReducer equal) const { @@ -754,12 +767,12 @@ class RampNode : public PrimExprNode { /*! \brief Total number of lanes. */ PrimExpr lanes; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("base", &base); - v->Visit("stride", &stride); - v->Visit("lanes", &lanes); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("base", &RampNode::base) + .def_ro("stride", &RampNode::stride) + .def_ro("lanes", &RampNode::lanes); } bool SEqualReduce(const RampNode* other, SEqualReducer equal) const { @@ -797,11 +810,11 @@ class BroadcastNode : public PrimExprNode { /*! \brief The number of lanes. */ PrimExpr lanes; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("value", &value); - v->Visit("lanes", &lanes); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("value", &BroadcastNode::value) + .def_ro("lanes", &BroadcastNode::lanes); } bool SEqualReduce(const BroadcastNode* other, SEqualReducer equal) const { @@ -841,12 +854,12 @@ class LetNode : public PrimExprNode { /*! \brief The result expression. */ PrimExpr body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("var", &var); - v->Visit("value", &value); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &LetNode::var) + .def_ro("value", &LetNode::value) + .def_ro("body", &LetNode::body); } bool SEqualReduce(const LetNode* other, SEqualReducer equal) const { @@ -891,11 +904,12 @@ class CallNode : public PrimExprNode { /*! \brief The arguments. */ Array args; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("op", &op); - v->Visit("args", &args); - v->Visit("span", &span); + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("op", &CallNode::op) + .def_ro("args", &CallNode::args); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { @@ -935,11 +949,11 @@ class ShuffleNode : public PrimExprNode { /*! \brief The indices of each element. */ Array indices; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("vectors", &vectors); - v->Visit("indices", &indices); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("vectors", &ShuffleNode::vectors) + .def_ro("indices", &ShuffleNode::indices); } bool SEqualReduce(const ShuffleNode* other, SEqualReducer equal) const { @@ -998,12 +1012,16 @@ class CommReducerNode : public Object { */ mutable Span span; - void VisitAttrs(AttrVisitor* v) { - v->Visit("lhs", &lhs); - v->Visit("rhs", &rhs); - v->Visit("result", &result); - v->Visit("identity_element", &identity_element); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("lhs", &CommReducerNode::lhs) + .def_ro("rhs", &CommReducerNode::rhs) + .def_ro("result", &CommReducerNode::result) + .def_ro("identity_element", &CommReducerNode::identity_element) + .def_ro("span", &CommReducerNode::span); } bool SEqualReduce(const CommReducerNode* other, SEqualReducer equal) const { @@ -1055,15 +1073,15 @@ class ReduceNode : public PrimExprNode { /*! \brief the index of this reduce node */ int value_index; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("combiner", &combiner); - v->Visit("source", &source); - v->Visit("init", &init); - v->Visit("axis", &axis); - v->Visit("condition", &condition); - v->Visit("value_index", &value_index); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("combiner", &ReduceNode::combiner) + .def_ro("source", &ReduceNode::source) + .def_ro("init", &ReduceNode::init) + .def_ro("axis", &ReduceNode::axis) + .def_ro("condition", &ReduceNode::condition) + .def_ro("value_index", &ReduceNode::value_index); } bool SEqualReduce(const ReduceNode* other, SEqualReducer equal) const { @@ -1131,11 +1149,6 @@ struct TypeTraits return tvm::tir::StringImm(value); } }; - -// auto convert String to PrimExpr -TVM_FFI_INLINE PrimExpr TypeTraits::ConvertFallbackValue(String value) { - return TypeTraits::ConvertFallbackValue(value); -} } // namespace ffi } // namespace tvm diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cb5db7e44f8a..ff8ff12e2379 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -24,6 +24,7 @@ #ifndef TVM_TIR_STMT_H_ #define TVM_TIR_STMT_H_ +#include #include #include @@ -45,6 +46,11 @@ class StmtNode : public Object { StmtNode() = default; explicit StmtNode(Span span) : span(span) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("span", &StmtNode::span); + } + TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); static constexpr const char* _type_key = "tir.Stmt"; @@ -72,13 +78,16 @@ class LetStmtNode : public StmtNode { /*! \brief The body block. */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("var", &var); - v->Visit("value", &value); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &LetStmtNode::var) + .def_ro("value", &LetStmtNode::value) + .def_ro("body", &LetStmtNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const LetStmtNode* other, SEqualReducer equal) const { return equal.DefEqual(var, other->var) && equal(value, other->value) && equal(body, other->body); @@ -127,14 +136,17 @@ class AttrStmtNode : public StmtNode { /*! \brief The body statement to be executed */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("node", &node); - v->Visit("attr_key", &attr_key); - v->Visit("value", &value); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("node", &AttrStmtNode::node) + .def_ro("attr_key", &AttrStmtNode::attr_key) + .def_ro("value", &AttrStmtNode::value) + .def_ro("body", &AttrStmtNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const AttrStmtNode* other, SEqualReducer equal) const { return equal(node, other->node) && equal(attr_key, other->attr_key) && equal(value, other->value) && equal(body, other->body); @@ -178,13 +190,16 @@ class AssertStmtNode : public StmtNode { */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("condition", &condition); - v->Visit("message", &message); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &AssertStmtNode::condition) + .def_ro("message", &AssertStmtNode::message) + .def_ro("body", &AssertStmtNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const AssertStmtNode* other, SEqualReducer equal) const { return equal(condition, other->condition) && equal(message, other->message) && equal(body, other->body); @@ -233,14 +248,17 @@ class BufferStoreNode : public StmtNode { /*! \brief The predicate mask for storing values. */ Optional predicate; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("value", &value); - v->Visit("indices", &indices); - v->Visit("predicate", &predicate); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &BufferStoreNode::buffer) + .def_ro("value", &BufferStoreNode::value) + .def_ro("indices", &BufferStoreNode::indices) + .def_ro("predicate", &BufferStoreNode::predicate); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(value, other->value) && equal(indices, other->indices); @@ -292,14 +310,17 @@ class BufferRealizeNode : public StmtNode { /*! \brief The body of realization. */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &BufferRealizeNode::buffer) + .def_ro("bounds", &BufferRealizeNode::bounds) + .def_ro("condition", &BufferRealizeNode::condition) + .def_ro("body", &BufferRealizeNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BufferRealizeNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(bounds, other->bounds) && equal(condition, other->condition) && equal(body, other->body); @@ -357,16 +378,19 @@ class AllocateNode : public StmtNode { */ Map annotations; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - v->Visit("dtype", &dtype); - v->Visit("extents", &extents); - v->Visit("condition", &condition); - v->Visit("body", &body); - v->Visit("annotations", &annotations); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer_var", &AllocateNode::buffer_var) + .def_ro("dtype", &AllocateNode::dtype) + .def_ro("extents", &AllocateNode::extents) + .def_ro("condition", &AllocateNode::condition) + .def_ro("body", &AllocateNode::body) + .def_ro("annotations", &AllocateNode::annotations); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const AllocateNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && equal(extents, other->extents) && equal(condition, other->condition) && @@ -445,17 +469,20 @@ class AllocateConstNode : public StmtNode { */ Map annotations; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer_var", &buffer_var); - v->Visit("data", &data); - v->Visit("irmod_storage_idx", &irmod_storage_idx); - v->Visit("dtype", &dtype); - v->Visit("extents", &extents); - v->Visit("body", &body); - v->Visit("annotations", &annotations); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer_var", &AllocateConstNode::buffer_var) + .def_ro("data", &AllocateConstNode::data) + .def_ro("irmod_storage_idx", &AllocateConstNode::irmod_storage_idx) + .def_ro("dtype", &AllocateConstNode::dtype) + .def_ro("extents", &AllocateConstNode::extents) + .def_ro("body", &AllocateConstNode::body) + .def_ro("annotations", &AllocateConstNode::annotations); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const AllocateConstNode* other, SEqualReducer equal) const { return equal.DefEqual(buffer_var, other->buffer_var) && equal(dtype, other->dtype) && equal(extents, other->extents) && equal(data, other->data) && equal(body, other->body) && @@ -517,12 +544,15 @@ class DeclBufferNode : public StmtNode { /*! \brief The body to be executed */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &DeclBufferNode::buffer) + .def_ro("body", &DeclBufferNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const DeclBufferNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(body, other->body); } @@ -560,11 +590,13 @@ class SeqStmtNode : public StmtNode { */ Stmt operator[](size_t index) const { return seq[index]; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("seq", &seq); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("seq", &SeqStmtNode::seq); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const SeqStmtNode* other, SEqualReducer equal) const { return equal(seq, other->seq); } @@ -586,11 +618,13 @@ class EvaluateNode : public StmtNode { /*! \brief The expression to be evaluated. */ PrimExpr value; - void VisitAttrs(AttrVisitor* v) { - v->Visit("value", &value); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("value", &EvaluateNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const EvaluateNode* other, SEqualReducer equal) const { return equal(value, other->value); } @@ -778,13 +812,16 @@ class IfThenElseNode : public StmtNode { /*! \brief The branch to be executed when condition is false, can be null. */ Optional else_case; - void VisitAttrs(AttrVisitor* v) { - v->Visit("condition", &condition); - v->Visit("then_case", &then_case); - v->Visit("else_case", &else_case); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &IfThenElseNode::condition) + .def_ro("then_case", &IfThenElseNode::then_case) + .def_ro("else_case", &IfThenElseNode::else_case); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const IfThenElseNode* other, SEqualReducer equal) const { return equal(condition, other->condition) && equal(then_case, other->then_case) && equal(else_case, other->else_case); @@ -878,17 +915,20 @@ class ForNode : public StmtNode { */ Map annotations; - void VisitAttrs(AttrVisitor* v) { - v->Visit("loop_var", &loop_var); - v->Visit("min", &min); - v->Visit("extent", &extent); - v->Visit("kind", &kind); - v->Visit("body", &body); - v->Visit("thread_binding", &thread_binding); - v->Visit("annotations", &annotations); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("loop_var", &ForNode::loop_var) + .def_ro("min", &ForNode::min) + .def_ro("extent", &ForNode::extent) + .def_ro("kind", &ForNode::kind) + .def_ro("body", &ForNode::body) + .def_ro("thread_binding", &ForNode::thread_binding) + .def_ro("annotations", &ForNode::annotations); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && @@ -940,12 +980,15 @@ class WhileNode : public StmtNode { /*! \brief The body of the while loop. */ Stmt body; - void VisitAttrs(AttrVisitor* v) { - v->Visit("condition", &condition); - v->Visit("body", &body); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &WhileNode::condition) + .def_ro("body", &WhileNode::body); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const WhileNode* other, SEqualReducer equal) const { return equal(condition, other->condition) && equal(body, other->body); } @@ -981,11 +1024,15 @@ class BufferRegionNode : public PrimExprConvertibleNode { /*! \brief The region array of the buffer region. */ Array region; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("region", ®ion); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &BufferRegionNode::buffer) + .def_ro("region", &BufferRegionNode::region); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(region, other->region); } @@ -1046,11 +1093,15 @@ class MatchBufferRegionNode : public Object { /*! \brief The source buffer region. */ BufferRegion source; - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("source", &source); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &MatchBufferRegionNode::buffer) + .def_ro("source", &MatchBufferRegionNode::source); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const { return equal(buffer, other->buffer) && equal(source, other->source); } @@ -1126,18 +1177,22 @@ class BlockNode : public StmtNode { /*! \brief The annotation of the block. */ Map annotations; - void VisitAttrs(AttrVisitor* v) { - v->Visit("iter_vars", &iter_vars); - v->Visit("reads", &reads); - v->Visit("writes", &writes); - v->Visit("name_hint", &name_hint); - v->Visit("body", &body); - v->Visit("init", &init); - v->Visit("alloc_buffers", &alloc_buffers); - v->Visit("match_buffers", &match_buffers); - v->Visit("annotations", &annotations); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("iter_vars", &BlockNode::iter_vars) + .def_ro("reads", &BlockNode::reads) + .def_ro("writes", &BlockNode::writes) + .def_ro("name_hint", &BlockNode::name_hint) + .def_ro("body", &BlockNode::body) + .def_ro("init", &BlockNode::init) + .def_ro("alloc_buffers", &BlockNode::alloc_buffers) + .def_ro("match_buffers", &BlockNode::match_buffers) + .def_ro("annotations", &BlockNode::annotations); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const { // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars return equal.DefEqual(iter_vars, other->iter_vars) && @@ -1195,12 +1250,16 @@ class BlockRealizeNode : public StmtNode { /*! \brief The block to be realized. */ Block block; - void VisitAttrs(AttrVisitor* v) { - v->Visit("iter_values", &iter_values); - v->Visit("predicate", &predicate); - v->Visit("block", &block); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("iter_values", &BlockRealizeNode::iter_values) + .def_ro("predicate", &BlockRealizeNode::predicate) + .def_ro("block", &BlockRealizeNode::block); } + static constexpr bool _type_has_method_visit_attrs = false; + bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const { return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) && equal(block, other->block); diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 9682206c3e1d..b40a82e1cfdc 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -61,11 +61,11 @@ class VarNode : public PrimExprNode { */ Type type_annotation; - void VisitAttrs(AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("name", &name_hint); - v->Visit("type_annotation", &type_annotation); - v->Visit("span", &span); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &VarNode::name_hint) + .def_ro("type_annotation", &VarNode::type_annotation); } bool SEqualReduce(const VarNode* other, SEqualReducer equal) const { @@ -143,6 +143,10 @@ class Var : public PrimExpr { */ class SizeVarNode : public VarNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } static constexpr const char* _type_key = "tir.SizeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); }; @@ -282,12 +286,15 @@ class IterVarNode : public PrimExprConvertibleNode { PrimExpr ToPrimExpr() const final { return var; } - void VisitAttrs(AttrVisitor* v) { - v->Visit("dom", &dom); - v->Visit("var", &var); - v->Visit("iter_type", &iter_type); - v->Visit("thread_tag", &thread_tag); - v->Visit("span", &span); + static constexpr const bool _type_has_method_visit_attrs = false; + + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dom", &IterVarNode::dom) + .def_ro("var", &IterVarNode::var) + .def_ro("iter_type", &IterVarNode::iter_type) + .def_ro("thread_tag", &IterVarNode::thread_tag); } bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 1b82e93eacf7..0f4c773e0d7e 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -52,9 +52,6 @@ class CanonicalExprNode : public PrimExprNode { */ virtual PrimExpr Normalize() const = 0; - // overrides - void VisitAttrs(tvm::AttrVisitor* v) {} - static constexpr const char* _type_key = "arith.CanonicalExpr"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(CanonicalExprNode, PrimExprNode); diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index 5078e5013865..b57c04752ff2 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -38,6 +38,8 @@ namespace arith { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ ConstIntBoundNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(ConstIntBoundNode); ConstIntBound::ConstIntBound(int64_t min_value, int64_t max_value) { diff --git a/src/arith/int_constraints.cc b/src/arith/int_constraints.cc index 01e7a3096927..afe7f09676b6 100644 --- a/src/arith/int_constraints.cc +++ b/src/arith/int_constraints.cc @@ -38,6 +38,12 @@ namespace tvm { namespace arith { +TVM_FFI_STATIC_INIT_BLOCK({ + IntGroupBoundsNode::RegisterReflection(); + IntConstraintsNode::RegisterReflection(); + IntConstraintsTransformNode::RegisterReflection(); +}); + Array AsConditions(const Array& variables, const Map& bounds, const Array& relations) { Array res; diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 2aa0ca6b6425..01aeba305027 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -40,6 +40,13 @@ namespace arith { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ + IterMarkNode::RegisterReflection(); + IterSplitExprNode::RegisterReflection(); + IterSumExprNode::RegisterReflection(); + IterMapResultNode::RegisterReflection(); +}); + IterMark::IterMark(PrimExpr source, PrimExpr extent) { auto n = make_object(); n->source = std::move(source); diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc index fa4891d5a00b..e4170f6c3c68 100644 --- a/src/arith/modular_set.cc +++ b/src/arith/modular_set.cc @@ -38,6 +38,8 @@ namespace arith { using namespace tir; +TVM_FFI_STATIC_INIT_BLOCK({ ModularSetNode::RegisterReflection(); }); + TVM_REGISTER_NODE_TYPE(ModularSetNode); ModularSet::ModularSet(int64_t coeff, int64_t base) { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index b45bcd968421..386fc32e8d07 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -32,10 +32,24 @@ namespace tvm { +TVM_FFI_STATIC_INIT_BLOCK({ + BaseExprNode::RegisterReflection(); + PrimExprNode::RegisterReflection(); + RelaxExprNode::RegisterReflection(); + GlobalVarNode::RegisterReflection(); + IntImmNode::RegisterReflection(); + FloatImmNode::RegisterReflection(); + RangeNode::RegisterReflection(); +}); + PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) {} PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} +PrimExpr PrimExpr::ConvertFallbackValue(String value) { + return tir::StringImm(value); +} + IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype << " was supplied."; diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 08fc32ad3aae..026398b99a75 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -67,8 +67,6 @@ class NodeIndexer : private AttrVisitor { public: std::unordered_map node_index_{{Any(nullptr), 0}}; std::vector node_list_{Any(nullptr)}; - std::unordered_map tensor_index_; - std::vector tensor_list_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); void Visit(const char* key, double* value) final {} @@ -81,11 +79,7 @@ class NodeIndexer : private AttrVisitor { void Visit(const char* key, DataType* value) final {} void Visit(const char* key, runtime::NDArray* value) final { - DLTensor* ptr = const_cast((*value).operator->()); - if (tensor_index_.count(ptr)) return; - ICHECK_EQ(tensor_index_.size(), tensor_list_.size()); - tensor_index_[ptr] = tensor_list_.size(); - tensor_list_.push_back(ptr); + MakeIndex(Any(*value)); } void Visit(const char* key, Optional* value) final {} @@ -109,6 +103,7 @@ class NodeIndexer : private AttrVisitor { if (node_index_.count(node)) { return; } + MakeNodeIndex(node); if (auto opt_array = node.as()) { const ffi::ArrayObj* n = opt_array.value(); @@ -231,7 +226,6 @@ struct JSONNode { class JSONAttrGetter : private AttrVisitor { public: const std::unordered_map* node_index_; - const std::unordered_map* tensor_index_; JSONNode* node_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); @@ -252,8 +246,7 @@ class JSONAttrGetter : private AttrVisitor { } void Visit(const char* key, DataType* value) final { node_->attrs[key] = Type2String(*value); } void Visit(const char* key, runtime::NDArray* value) final { - node_->attrs[key] = - std::to_string(tensor_index_->at(const_cast((*value).operator->()))); + Visit(key, static_cast(value)); } void Visit(const char* key, Optional* value) final { @@ -273,7 +266,11 @@ class JSONAttrGetter : private AttrVisitor { } void Visit(const char* key, ObjectRef* value) final { - node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); + if (value->defined()) { + node_->attrs[key] = std::to_string(node_index_->at(Any(*value))); + } else { + node_->attrs[key] = "null"; + } } // Get the node @@ -416,6 +413,18 @@ class FieldDependencyFinder : private AttrVisitor { LOG(FATAL) << "Wrong value format for field " << key; } } + + template + void ParseOptionalValue(const char* key, Optional* value) const { + std::string value_str = GetValue(key); + if (value_str == "null") { + *value = std::nullopt; + } else { + T temp; + ParseValue(key, &temp); + *value = temp; + } + } void Visit(const char* key, double* value) final {} void Visit(const char* key, int64_t* value) final {} void Visit(const char* key, uint64_t* value) final {} @@ -428,9 +437,11 @@ class FieldDependencyFinder : private AttrVisitor { void Visit(const char* key, Optional* value) final {} void Visit(const char* key, Optional* value) final {} void Visit(const char* key, ObjectRef* value) final { - size_t index; - ParseValue(key, &index); - jnode_->fields.push_back(index); + Optional index; + ParseOptionalValue(key, &index); + if (index.has_value()) { + jnode_->fields.push_back(*index); + } } void Find(Any node, JSONNode* jnode) { // Skip None @@ -445,8 +456,9 @@ class FieldDependencyFinder : private AttrVisitor { reflection_->GetReprBytes(node.cast(), nullptr)) { return; } - // Skip containers - if (jnode->type_key == ffi::ArrayObj::_type_key || jnode->type_key == ffi::MapObj::_type_key) { + // Skip special handling containers + if (jnode->type_key == ffi::ArrayObj::_type_key || jnode->type_key == ffi::MapObj::_type_key || + jnode->type_key == ffi::NDArrayObj::_type_key) { return; } jnode_ = jnode; @@ -479,7 +491,6 @@ class FieldDependencyFinder : private AttrVisitor { class JSONAttrSetter : private AttrVisitor { public: const std::vector* node_list_; - const std::vector* tensor_list_; JSONNode* jnode_; ReflectionVTable* reflection_ = ReflectionVTable::Global(); @@ -550,16 +561,14 @@ class JSONAttrSetter : private AttrVisitor { *value = String2Type(stype); } void Visit(const char* key, runtime::NDArray* value) final { - size_t index; - ParseValue(key, &index); - ICHECK_LE(index, tensor_list_->size()); - *value = tensor_list_->at(index); + Visit(key, static_cast(value)); } void Visit(const char* key, ObjectRef* value) final { - size_t index; - ParseValue(key, &index); - ICHECK_LE(index, node_list_->size()); - *value = node_list_->at(index).cast(); + Optional index; + ParseOptionalValue(key, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); + if (index.has_value()) { + *value = node_list_->at(*index).cast(); + } } static Any CreateInitAny(ReflectionVTable* reflection, JSONNode* jnode) { @@ -674,13 +683,13 @@ class JSONAttrSetter : private AttrVisitor { break; } default: { - if (field_info->field_static_type_index >= ffi::TypeIndex::kTVMFFIStaticObjectBegin) { - ObjectRef value; - this->Visit(field_info->name.data, &value); + Optional index; + ParseOptionalValue(field_info->name.data, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); + if (index.has_value()) { + Any value = node_list_->at(*index).cast(); setter(obj, value); - break; - } else { - LOG(FATAL) << "Unsupported type: " << field_info->field_static_type_index; + } else{ + setter(obj, Any()); } } } @@ -725,7 +734,6 @@ struct JSONGraph { indexer.MakeIndex(root); JSONAttrGetter getter; getter.node_index_ = &indexer.node_index_; - getter.tensor_index_ = &indexer.tensor_index_; for (Any n : indexer.node_list_) { JSONNode jnode; getter.node_ = &jnode; @@ -733,16 +741,8 @@ struct JSONGraph { g.nodes.emplace_back(std::move(jnode)); } g.attrs["tvm_version"] = TVM_VERSION; + ICHECK(indexer.node_index_.count(root)); g.root = indexer.node_index_.at(root); - // serialize tensor - for (DLTensor* tensor : indexer.tensor_list_) { - std::string blob; - dmlc::MemoryStringStream mstrm(&blob); - support::Base64OutStream b64strm(&mstrm); - runtime::SaveDLTensor(&b64strm, tensor); - b64strm.Finish(); - g.b64ndarrays.emplace_back(std::move(blob)); - } return g; } @@ -830,7 +830,6 @@ Any LoadJSON(std::string json_str) { { JSONAttrSetter setter; setter.node_list_ = &nodes; - setter.tensor_list_ = &tensors; for (size_t i : topo_order) { setter.SetAttrs(&nodes[i], &jgraph.nodes[i]); } diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index d906c1baf54d..b4bd64ada6f5 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -109,18 +109,7 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, class Visitor : private AttrVisitor { public: void operator()(ObjectRef obj) { - const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); - if (tinfo->extra_info != nullptr) { - // visit fields with the new reflection - ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { - Any field_value = ffi::reflection::FieldGetter(field_info)(obj); - this->RecursiveVisitAny(&field_value); - }); - } else { - // NOTE: legacy VisitAttrs mechanism - // TODO(tvm-team): remove this once all objects are transitioned to the new reflection - this->Visit("", &obj); - } + this->Visit("", &obj); } private: @@ -165,7 +154,17 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, this->RecursiveVisitAny(&kv.second); } } else { - vtable_->VisitAttrs(const_cast(obj), this); + const TVMFFITypeInfo* tinfo = TVMFFIGetTypeInfo(obj->type_index()); + if (tinfo->extra_info != nullptr) { + // visit fields with the new reflection + ffi::reflection::ForEachFieldInfo(tinfo, [&](const TVMFFIFieldInfo* field_info) { + Any field_value = ffi::reflection::FieldGetter(field_info)(obj); + this->RecursiveVisitAny(&field_value); + }); + } else { + // legacy VisitAttrs mechanism + vtable_->VisitAttrs(const_cast(obj), this); + } } if (is_var(GetRef(obj))) { HandleVar(obj); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 0ac59b160200..f6657451e511 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -35,6 +35,42 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + VarNode::RegisterReflection(); + SizeVarNode::RegisterReflection(); + IterVarNode::RegisterReflection(); + StringImmNode::RegisterReflection(); + CastNode::RegisterReflection(); + AddNode::RegisterReflection(); + SubNode::RegisterReflection(); + MulNode::RegisterReflection(); + DivNode::RegisterReflection(); + ModNode::RegisterReflection(); + FloorDivNode::RegisterReflection(); + FloorModNode::RegisterReflection(); + MinNode::RegisterReflection(); + MaxNode::RegisterReflection(); + EQNode::RegisterReflection(); + NENode::RegisterReflection(); + LTNode::RegisterReflection(); + LENode::RegisterReflection(); + GTNode::RegisterReflection(); + GENode::RegisterReflection(); + AndNode::RegisterReflection(); + OrNode::RegisterReflection(); + NotNode::RegisterReflection(); + SelectNode::RegisterReflection(); + BufferLoadNode::RegisterReflection(); + ProducerLoadNode::RegisterReflection(); + RampNode::RegisterReflection(); + BroadcastNode::RegisterReflection(); + LetNode::RegisterReflection(); + CallNode::RegisterReflection(); + ShuffleNode::RegisterReflection(); + CommReducerNode::RegisterReflection(); + ReduceNode::RegisterReflection(); +}); + /* \brief Convert an object to a PrimExpr * * All conversions to a PrimExpr are performed as part of the FFI, diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index f400ca0d507e..6be07368972d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -31,6 +31,27 @@ namespace tvm { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + StmtNode::RegisterReflection(); + LetStmtNode::RegisterReflection(); + AttrStmtNode::RegisterReflection(); + AssertStmtNode::RegisterReflection(); + BufferStoreNode::RegisterReflection(); + BufferRealizeNode::RegisterReflection(); + AllocateNode::RegisterReflection(); + AllocateConstNode::RegisterReflection(); + DeclBufferNode::RegisterReflection(); + SeqStmtNode::RegisterReflection(); + EvaluateNode::RegisterReflection(); + IfThenElseNode::RegisterReflection(); + ForNode::RegisterReflection(); + WhileNode::RegisterReflection(); + BufferRegionNode::RegisterReflection(); + MatchBufferRegionNode::RegisterReflection(); + BlockNode::RegisterReflection(); + BlockRealizeNode::RegisterReflection(); +}); + // LetStmt LetStmt::LetStmt(Var var, PrimExpr value, Stmt body, Span span) { ICHECK(value.defined()); From bc2ea5c91c3f333bf822859af805a0b9b51a66ac Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Jun 2025 10:58:17 -0400 Subject: [PATCH 3/6] Upgrades all parser frames --- include/tvm/script/ir_builder/base.h | 19 +- include/tvm/script/ir_builder/ir/frame.h | 16 +- include/tvm/script/ir_builder/relax/frame.h | 88 +++++--- include/tvm/script/ir_builder/tir/frame.h | 224 +++++++++++++------- src/script/ir_builder/base.cc | 5 + src/script/ir_builder/ir/frame.cc | 4 + src/script/ir_builder/relax/frame.cc | 10 + src/script/ir_builder/tir/frame.cc | 20 ++ 8 files changed, 265 insertions(+), 121 deletions(-) diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 85d6dcce5e1b..0dd1d6d805b8 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -22,6 +22,7 @@ #include #include #include +#include #include @@ -66,10 +67,14 @@ class IRBuilderFrameNode : public runtime::Object { /*! \brief A list of callbacks used when exiting the frame. */ std::vector> callbacks; - void VisitAttrs(tvm::AttrVisitor* v) { - // `callbacks` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + // `callbacks` is not registered as it's not visited. } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.IRBuilderFrame"; TVM_DECLARE_BASE_OBJECT_INFO(IRBuilderFrameNode, runtime::Object); @@ -158,11 +163,15 @@ class IRBuilderNode : public runtime::Object { /*! \brief The outcome of IR construction */ Optional result; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("frames", &frames); - v->Visit("result", &result); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("frames", &IRBuilderNode::frames) + .def_ro("result", &IRBuilderNode::result); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.IRBuilder"; TVM_DECLARE_FINAL_OBJECT_INFO(IRBuilderNode, runtime::Object); diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index 23a7d2ca0394..eca7908d1a5a 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -24,6 +24,7 @@ #include #include #include +#include #include @@ -51,14 +52,17 @@ class IRModuleFrameNode : public IRBuilderFrameNode { /*! \brief IRModule's global_infos */ Map> global_infos; - void VisitAttrs(tvm::AttrVisitor* v) { - IRBuilderFrameNode::VisitAttrs(v); - v->Visit("global_vars", &global_var_map); - v->Visit("functions", &functions); - v->Visit("attrs", &attrs); - v->Visit("global_infos", &global_infos); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("global_vars", &IRModuleFrameNode::global_var_map) + .def_ro("functions", &IRModuleFrameNode::functions) + .def_ro("attrs", &IRModuleFrameNode::attrs) + .def_ro("global_infos", &IRModuleFrameNode::global_infos); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.IRModuleFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleFrameNode, IRBuilderFrameNode); diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 98a51fcb7829..0d9f4031b153 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace tvm { namespace script { @@ -33,7 +34,12 @@ namespace relax { /*! \brief The base ir_builder frame for the relax dialect. */ class RelaxFrameNode : public IRBuilderFrameNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { IRBuilderFrameNode::VisitAttrs(v); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "script.ir_builder.relax.RelaxFrame"; TVM_DECLARE_BASE_OBJECT_INFO(RelaxFrameNode, IRBuilderFrameNode); @@ -57,12 +63,15 @@ class SeqExprFrameNode : public RelaxFrameNode { /*! \brief The frame output expr. `std::nullopt` when undefined. */ Optional output; - void VisitAttrs(tvm::AttrVisitor* v) { - RelaxFrameNode::VisitAttrs(v); - v->Visit("binding_blocks", &binding_blocks); - v->Visit("output", &output); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("binding_blocks", &SeqExprFrameNode::binding_blocks) + .def_ro("output", &SeqExprFrameNode::output); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.SeqExprFrame"; TVM_DECLARE_BASE_OBJECT_INFO(SeqExprFrameNode, RelaxFrameNode); @@ -106,18 +115,21 @@ class FunctionFrameNode : public SeqExprFrameNode { /*! \brief The block builder to create Relax function. */ tvm::relax::BlockBuilder block_builder; - void VisitAttrs(tvm::AttrVisitor* v) { - SeqExprFrameNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("params", ¶ms); - v->Visit("ret_struct_info", &ret_struct_info); - v->Visit("is_pure", &is_pure); - v->Visit("attrs", &attrs); - v->Visit("binding_blocks", &binding_blocks); - v->Visit("output", &output); - // `block_builder` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &FunctionFrameNode::name) + .def_ro("params", &FunctionFrameNode::params) + .def_ro("ret_struct_info", &FunctionFrameNode::ret_struct_info) + .def_ro("is_pure", &FunctionFrameNode::is_pure) + .def_ro("attrs", &FunctionFrameNode::attrs) + .def_ro("binding_blocks", &FunctionFrameNode::binding_blocks) + .def_ro("output", &FunctionFrameNode::output); + // `block_builder` is not registered as it's not visited. } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.FunctionFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(FunctionFrameNode, SeqExprFrameNode); @@ -150,14 +162,17 @@ class BlockFrameNode : public RelaxFrameNode { */ Array output_vars; - void VisitAttrs(tvm::AttrVisitor* v) { - RelaxFrameNode::VisitAttrs(v); - v->Visit("is_dataflow", &is_dataflow); - v->Visit("emitted_vars", &emitted_vars); - v->Visit("output_vars", &output_vars); - // `block_ended` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("is_dataflow", &BlockFrameNode::is_dataflow) + .def_ro("emitted_vars", &BlockFrameNode::emitted_vars) + .def_ro("output_vars", &BlockFrameNode::output_vars); + // `block_ended` is not registered as it's not visited. } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.BlockFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, RelaxFrameNode); @@ -189,15 +204,18 @@ class IfFrameNode : public RelaxFrameNode { /*! \brief The binding var name. */ String var_name; - void VisitAttrs(tvm::AttrVisitor* v) { - RelaxFrameNode::VisitAttrs(v); - v->Visit("condition", &condition); - v->Visit("then_expr", &then_expr); - v->Visit("else_expr", &else_expr); - v->Visit("var", &var); - v->Visit("var_name", &var_name); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &IfFrameNode::condition) + .def_ro("then_expr", &IfFrameNode::then_expr) + .def_ro("else_expr", &IfFrameNode::else_expr) + .def_ro("var", &IfFrameNode::var) + .def_ro("var_name", &IfFrameNode::var_name); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.IfFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, RelaxFrameNode); @@ -231,6 +249,13 @@ class IfFrame : public RelaxFrame { */ class ThenFrameNode : public SeqExprFrameNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.ThenFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, SeqExprFrameNode); @@ -264,6 +289,13 @@ class ThenFrame : public SeqExprFrame { */ class ElseFrameNode : public SeqExprFrameNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.relax.ElseFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, SeqExprFrameNode); diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index 171ac019dd03..a931ae039f07 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -38,11 +38,13 @@ class TIRFrameNode : public IRBuilderFrameNode { /*! \brief The Stmt within in this frame. */ Array stmts; - void VisitAttrs(tvm::AttrVisitor* v) { - IRBuilderFrameNode::VisitAttrs(v); - v->Visit("stmts", &stmts); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("stmts", &TIRFrameNode::stmts); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.TIRFrame"; TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, IRBuilderFrameNode); }; @@ -84,18 +86,21 @@ class PrimFuncFrameNode : public TIRFrameNode { /*! \brief The buffer allocated in root block. */ Array root_alloc_buffers; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("args", &args); - v->Visit("is_private", &is_private); - v->Visit("ret_type", &ret_type); - v->Visit("buffer_map", &buffer_map); - v->Visit("attrs", &attrs); - v->Visit("env_threads", &env_threads); - v->Visit("root_alloc_buffers", &root_alloc_buffers); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &PrimFuncFrameNode::name) + .def_ro("args", &PrimFuncFrameNode::args) + .def_ro("is_private", &PrimFuncFrameNode::is_private) + .def_ro("ret_type", &PrimFuncFrameNode::ret_type) + .def_ro("buffer_map", &PrimFuncFrameNode::buffer_map) + .def_ro("attrs", &PrimFuncFrameNode::attrs) + .def_ro("env_threads", &PrimFuncFrameNode::env_threads) + .def_ro("root_alloc_buffers", &PrimFuncFrameNode::root_alloc_buffers); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.PrimFuncFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimFuncFrameNode, TIRFrameNode); @@ -150,21 +155,24 @@ class BlockFrameNode : public TIRFrameNode { /*! \brief The flag whether to construct BlockRealize or Block. */ bool no_realize; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("name", &name); - v->Visit("iter_vars", &iter_vars); - v->Visit("reads", &reads); - v->Visit("writes", &writes); - v->Visit("init", &init); - v->Visit("alloc_buffers", &alloc_buffers); - v->Visit("match_buffers", &match_buffers); - v->Visit("annotations", &annotations); - v->Visit("iter_values", &iter_values); - v->Visit("predicate", &predicate); - v->Visit("no_realize", &no_realize); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("name", &BlockFrameNode::name) + .def_ro("iter_vars", &BlockFrameNode::iter_vars) + .def_ro("reads", &BlockFrameNode::reads) + .def_ro("writes", &BlockFrameNode::writes) + .def_ro("init", &BlockFrameNode::init) + .def_ro("alloc_buffers", &BlockFrameNode::alloc_buffers) + .def_ro("match_buffers", &BlockFrameNode::match_buffers) + .def_ro("annotations", &BlockFrameNode::annotations) + .def_ro("iter_values", &BlockFrameNode::iter_values) + .def_ro("predicate", &BlockFrameNode::predicate) + .def_ro("no_realize", &BlockFrameNode::no_realize); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.BlockFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockFrameNode, TIRFrameNode); @@ -194,7 +202,12 @@ class BlockFrame : public TIRFrame { */ class BlockInitFrameNode : public TIRFrameNode { public: - void VisitAttrs(tvm::AttrVisitor* v) { TIRFrameNode::VisitAttrs(v); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "script.ir_builder.tir.BlockInitFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(BlockInitFrameNode, TIRFrameNode); @@ -245,13 +258,16 @@ class ForFrameNode : public TIRFrameNode { /*! \brief The for loop generating function. */ FMakeForLoop f_make_for_loop; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("vars", &vars); - v->Visit("doms", &doms); - // `f_make_for_loop` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("vars", &ForFrameNode::vars) + .def_ro("doms", &ForFrameNode::doms); + // `f_make_for_loop` is not registered as it's not visited. } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode); @@ -286,12 +302,15 @@ class AssertFrameNode : public TIRFrameNode { /*! \brief The output error message when the assertion failed. */ PrimExpr message; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("condition", &condition); - v->Visit("message", &message); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &AssertFrameNode::condition) + .def_ro("message", &AssertFrameNode::message); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.AssertFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(AssertFrameNode, TIRFrameNode); @@ -325,12 +344,15 @@ class LetFrameNode : public TIRFrameNode { /*! \brief The value we bind var to */ PrimExpr value; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("var", &var); - v->Visit("value", &value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("var", &LetFrameNode::var) + .def_ro("value", &LetFrameNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.LetFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(LetFrameNode, TIRFrameNode); @@ -365,13 +387,16 @@ class LaunchThreadFrameNode : public TIRFrameNode { /*! \brief The iteration variable. */ tvm::tir::IterVar iter_var; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("extent", &extent); - v->Visit("attr_key", &attr_key); - v->Visit("iter_var", &iter_var); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("extent", &LaunchThreadFrameNode::extent) + .def_ro("attr_key", &LaunchThreadFrameNode::attr_key) + .def_ro("iter_var", &LaunchThreadFrameNode::iter_var); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.LaunchThreadFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(LaunchThreadFrameNode, TIRFrameNode); @@ -408,13 +433,16 @@ class RealizeFrameNode : public TIRFrameNode { /*! \brief The condition expression. */ PrimExpr condition; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("buffer_slice", &buffer_slice); - v->Visit("storage_scope", &storage_scope); - v->Visit("condition", &condition); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer_slice", &RealizeFrameNode::buffer_slice) + .def_ro("storage_scope", &RealizeFrameNode::storage_scope) + .def_ro("condition", &RealizeFrameNode::condition); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.RealizeFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(RealizeFrameNode, TIRFrameNode); @@ -456,16 +484,19 @@ class AllocateFrameNode : public TIRFrameNode { /*! \brief The buffer var. */ tvm::tir::Var buffer_var; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("extents", &extents); - v->Visit("dtype", &dtype); - v->Visit("storage_scope", &storage_scope); - v->Visit("condition", &condition); - v->Visit("annotations", &annotations); - v->Visit("buffer_var", &buffer_var); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("extents", &AllocateFrameNode::extents) + .def_ro("dtype", &AllocateFrameNode::dtype) + .def_ro("storage_scope", &AllocateFrameNode::storage_scope) + .def_ro("condition", &AllocateFrameNode::condition) + .def_ro("annotations", &AllocateFrameNode::annotations) + .def_ro("buffer_var", &AllocateFrameNode::buffer_var); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateFrameNode, TIRFrameNode); @@ -505,15 +536,18 @@ class AllocateConstFrameNode : public TIRFrameNode { /*! \brief Additional annotations about the allocation. */ Map annotations; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("dtype", &dtype); - v->Visit("extents", &extents); - v->Visit("data", &data); - v->Visit("buffer_var", &buffer_var); - v->Visit("annotations", &annotations); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dtype", &AllocateConstFrameNode::dtype) + .def_ro("extents", &AllocateConstFrameNode::extents) + .def_ro("data", &AllocateConstFrameNode::data) + .def_ro("buffer_var", &AllocateConstFrameNode::buffer_var) + .def_ro("annotations", &AllocateConstFrameNode::annotations); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.AllocateConstFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(AllocateConstFrameNode, TIRFrameNode); @@ -549,13 +583,16 @@ class AttrFrameNode : public TIRFrameNode { /*! \brief The value of the attribute. */ PrimExpr value; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("node", &node); - v->Visit("attr_key", &attr_key); - v->Visit("value", &value); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("node", &AttrFrameNode::node) + .def_ro("attr_key", &AttrFrameNode::attr_key) + .def_ro("value", &AttrFrameNode::value); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.AttrFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFrameNode, TIRFrameNode); @@ -587,11 +624,14 @@ class WhileFrameNode : public TIRFrameNode { /*! \brief The termination condition of while. */ PrimExpr condition; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("condition", &condition); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &WhileFrameNode::condition); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.WhileFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(WhileFrameNode, TIRFrameNode); @@ -627,13 +667,16 @@ class IfFrameNode : public TIRFrameNode { /*! \brief The stetements in the false branch. */ Optional> else_stmts; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("condition", &condition); - v->Visit("then_stmts", &then_stmts); - v->Visit("else_stmts", &else_stmts); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("condition", &IfFrameNode::condition) + .def_ro("then_stmts", &IfFrameNode::then_stmts) + .def_ro("else_stmts", &IfFrameNode::else_stmts); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.IfFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(IfFrameNode, TIRFrameNode); @@ -662,6 +705,13 @@ class IfFrame : public TIRFrame { */ class ThenFrameNode : public TIRFrameNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.ThenFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ThenFrameNode, TIRFrameNode); @@ -695,6 +745,13 @@ class ThenFrame : public TIRFrame { */ class ElseFrameNode : public TIRFrameNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.ElseFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(ElseFrameNode, TIRFrameNode); @@ -728,12 +785,15 @@ class DeclBufferFrameNode : public TIRFrameNode { /*! \brief The buffer allocated or not. */ bool allocated; - void VisitAttrs(tvm::AttrVisitor* v) { - TIRFrameNode::VisitAttrs(v); - v->Visit("buffer", &buffer); - v->Visit("allocated", &allocated); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffer", &DeclBufferFrameNode::buffer) + .def_ro("allocated", &DeclBufferFrameNode::allocated); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame"; TVM_DECLARE_FINAL_OBJECT_INFO(DeclBufferFrameNode, TIRFrameNode); diff --git a/src/script/ir_builder/base.cc b/src/script/ir_builder/base.cc index 13f272d7c946..20808635ed0c 100644 --- a/src/script/ir_builder/base.cc +++ b/src/script/ir_builder/base.cc @@ -24,6 +24,11 @@ namespace tvm { namespace script { namespace ir_builder { +TVM_FFI_STATIC_INIT_BLOCK({ + IRBuilderFrameNode::RegisterReflection(); + IRBuilderNode::RegisterReflection(); +}); + void IRBuilderFrameNode::EnterWithScope() { IRBuilder::Current()->frames.push_back(GetRef(this)); } diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 6cb61147a96a..7006aa25f36f 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -25,6 +25,10 @@ namespace script { namespace ir_builder { namespace ir { +TVM_FFI_STATIC_INIT_BLOCK({ + IRModuleFrameNode::RegisterReflection(); +}); + void IRModuleFrameNode::ExitWithScope() { Map func_map; CHECK_EQ(functions.size(), global_var_map.size()) diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index faf6bd6466ad..0cde34879e64 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -30,6 +30,16 @@ namespace script { namespace ir_builder { namespace relax { +TVM_FFI_STATIC_INIT_BLOCK({ + RelaxFrameNode::RegisterReflection(); + SeqExprFrameNode::RegisterReflection(); + FunctionFrameNode::RegisterReflection(); + BlockFrameNode::RegisterReflection(); + IfFrameNode::RegisterReflection(); + ThenFrameNode::RegisterReflection(); + ElseFrameNode::RegisterReflection(); +}); + void SeqExprFrameNode::ExitWithScope() { // At this moment, there should be at most one BlockFrame which hasn't ended. In this case, call // its `ExitBlockFrame` and check if there is any more unended BlockFrame. diff --git a/src/script/ir_builder/tir/frame.cc b/src/script/ir_builder/tir/frame.cc index 6fe83946ce61..1eb46f70eb71 100644 --- a/src/script/ir_builder/tir/frame.cc +++ b/src/script/ir_builder/tir/frame.cc @@ -28,6 +28,26 @@ namespace script { namespace ir_builder { namespace tir { +TVM_FFI_STATIC_INIT_BLOCK({ + TIRFrameNode::RegisterReflection(); + PrimFuncFrameNode::RegisterReflection(); + BlockFrameNode::RegisterReflection(); + BlockInitFrameNode::RegisterReflection(); + ForFrameNode::RegisterReflection(); + AssertFrameNode::RegisterReflection(); + LetFrameNode::RegisterReflection(); + LaunchThreadFrameNode::RegisterReflection(); + RealizeFrameNode::RegisterReflection(); + AllocateFrameNode::RegisterReflection(); + AllocateConstFrameNode::RegisterReflection(); + AttrFrameNode::RegisterReflection(); + WhileFrameNode::RegisterReflection(); + IfFrameNode::RegisterReflection(); + ThenFrameNode::RegisterReflection(); + ElseFrameNode::RegisterReflection(); + DeclBufferFrameNode::RegisterReflection(); +}); + void PrimFuncFrameNode::ExitWithScope() { TIRFrameNode::ExitWithScope(); // if the prim func is not private and there isn't already a global symbol, From 148f6245c4f515f6f7122d51b1d4862b62c3f35b Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Jun 2025 11:09:35 -0400 Subject: [PATCH 4/6] Fix base span optional --- include/tvm/ir/expr.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 8541983212ca..3545a9d3a57d 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -58,7 +58,7 @@ class BaseExprNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef().def_ro("span", &BaseExprNode::span); + refl::ObjectDef().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span())); } static constexpr const char* _type_key = "BaseExpr"; From 1ae34fe991319efad793664897e4cb6ba18518d1 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Jun 2025 11:46:07 -0400 Subject: [PATCH 5/6] Upgrade meta-schedule --- include/tvm/meta_schedule/arg_info.h | 11 ++-- include/tvm/meta_schedule/builder.h | 31 +++++++---- include/tvm/meta_schedule/cost_model.h | 14 ++--- include/tvm/meta_schedule/database.h | 47 +++++++++------- include/tvm/meta_schedule/extracted_task.h | 16 +++--- include/tvm/meta_schedule/feature_extractor.h | 15 ++++-- include/tvm/meta_schedule/measure_callback.h | 15 ++++-- include/tvm/meta_schedule/measure_candidate.h | 11 ++-- include/tvm/meta_schedule/mutator.h | 19 ++++--- include/tvm/meta_schedule/postproc.h | 19 ++++--- include/tvm/meta_schedule/profiler.h | 9 ++-- include/tvm/meta_schedule/runner.h | 35 +++++++----- include/tvm/meta_schedule/schedule_rule.h | 19 ++++--- include/tvm/meta_schedule/search_strategy.h | 17 +++--- include/tvm/meta_schedule/space_generator.h | 24 +++++---- include/tvm/meta_schedule/task_scheduler.h | 53 +++++++++++-------- include/tvm/meta_schedule/tune_context.h | 24 +++++---- src/meta_schedule/builder/builder.cc | 6 +++ src/meta_schedule/database/database.cc | 6 +++ src/meta_schedule/database/json_database.cc | 15 ++++-- src/meta_schedule/database/memory_database.cc | 15 ++++-- .../database/ordered_union_database.cc | 13 ++++- .../database/schedule_fn_database.cc | 13 +++-- src/meta_schedule/database/union_database.cc | 13 ++++- src/meta_schedule/extracted_task.cc | 4 ++ .../feature_extractor/feature_extractor.cc | 5 ++ .../feature_extractor/per_store_feature.cc | 19 +++++-- .../measure_callback/measure_callback.cc | 5 ++ .../mutator/mutate_compute_location.cc | 13 ++++- src/meta_schedule/mutator/mutate_parallel.cc | 14 +++-- .../mutator/mutate_thread_binding.cc | 13 ++++- src/meta_schedule/mutator/mutate_tile_size.cc | 13 ++++- src/meta_schedule/mutator/mutate_unroll.cc | 13 ++++- src/meta_schedule/mutator/mutator.cc | 5 ++ src/meta_schedule/postproc/postproc.cc | 5 ++ .../postproc/rewrite_cooperative_fetch.cc | 13 ++++- .../postproc/rewrite_reduction_block.cc | 13 ++++- .../postproc/rewrite_tensorize.cc | 12 ++++- .../postproc/rewrite_unbound_block.cc | 13 +++-- src/meta_schedule/profiler.cc | 4 ++ src/meta_schedule/runner/runner.cc | 7 +++ .../schedule_rule/add_rfactor.cc | 15 ++++-- .../schedule_rule/apply_custom_rule.cc | 12 ++++- src/meta_schedule/schedule_rule/auto_bind.cc | 14 +++-- .../schedule_rule/auto_inline.cc | 33 ++++++++---- .../schedule_rule/cross_thread_reduction.cc | 16 ++++-- .../schedule_rule/multi_level_tiling.cc | 4 ++ .../schedule_rule/multi_level_tiling.h | 20 ++++--- .../parallel_vectorize_unroll.cc | 18 ++++--- .../schedule_rule/random_compute_location.cc | 12 ++++- .../schedule_rule/schedule_rule.cc | 5 ++ .../search_strategy/evolutionary_search.cc | 37 ++++++------- .../search_strategy/replay_func.cc | 11 +++- .../search_strategy/replay_trace.cc | 17 +++--- .../search_strategy/search_strategy.cc | 5 ++ .../space_generator/post_order_apply.cc | 13 +++-- .../space_generator/schedule_fn.cc | 12 +++-- .../space_generator/space_generator.cc | 5 ++ .../space_generator/space_generator_union.cc | 14 +++-- .../task_scheduler/gradient_based.cc | 19 ++++--- .../task_scheduler/round_robin.cc | 15 ++++-- .../task_scheduler/task_scheduler.cc | 6 +++ src/meta_schedule/tune_context.cc | 4 ++ 63 files changed, 669 insertions(+), 269 deletions(-) diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index 2768ed2737dc..34d8278286fb 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -26,6 +26,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -83,11 +84,15 @@ class TensorInfoNode : public ArgInfoNode { /*! \brief The shape of the tensor. */ ffi::Shape shape; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("dtype", &dtype); - v->Visit("shape", &shape); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("dtype", &TensorInfoNode::dtype) + .def_ro("shape", &TensorInfoNode::shape); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.TensorInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, ArgInfoNode); diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index 24e136f9d345..df132aa15033 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -43,12 +44,16 @@ class BuilderInputNode : public runtime::Object { /*! \brief Parameters for Relax build module. */ Optional> params; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("mod", &mod); - v->Visit("target", &target); - v->Visit("params", ¶ms); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("mod", &BuilderInputNode::mod) + .def_ro("target", &BuilderInputNode::target) + .def_ro("params", &BuilderInputNode::params); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.BuilderInput"; TVM_DECLARE_FINAL_OBJECT_INFO(BuilderInputNode, runtime::Object); }; @@ -78,11 +83,15 @@ class BuilderResultNode : public runtime::Object { /*! \brief The error message if any. */ Optional error_msg; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("artifact_path", &artifact_path); - v->Visit("error_msg", &error_msg); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("artifact_path", &BuilderResultNode::artifact_path) + .def_ro("error_msg", &BuilderResultNode::error_msg); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.BuilderResult"; TVM_DECLARE_FINAL_OBJECT_INFO(BuilderResultNode, runtime::Object); }; @@ -145,10 +154,14 @@ class PyBuilderNode : public BuilderNode { /*! \brief The packed function to the `Build` function. */ FBuild f_build; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_build` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("f_build", &PyBuilderNode::f_build); } + static constexpr bool _type_has_method_visit_attrs = false; + Array Build(const Array& build_inputs) final { ICHECK(f_build != nullptr) << "PyBuilder's Build method not implemented!"; return f_build(build_inputs); diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index 300f53e113bd..ba7da0ce46d5 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -29,6 +29,7 @@ #include #include #include +#include #include @@ -43,8 +44,6 @@ class CostModelNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~CostModelNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} - /*! * \brief Load the cost model from given file location. * \param path The file path. @@ -75,6 +74,7 @@ class CostModelNode : public runtime::Object { virtual std::vector Predict(const TuneContext& context, const Array& candidates) = 0; + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.CostModel"; TVM_DECLARE_BASE_OBJECT_INFO(CostModelNode, Object); }; @@ -126,13 +126,7 @@ class PyCostModelNode : public CostModelNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_load` is not visited - // `f_save` is not visited - // `f_update` is not visited - // `f_predict` is not visited - // `f_as_string` is not visited - } + void Load(const String& path); void Save(const String& path); @@ -141,6 +135,8 @@ class PyCostModelNode : public CostModelNode { std::vector Predict(const TuneContext& context, const Array& candidates); + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.PyCostModel"; TVM_DECLARE_FINAL_OBJECT_INFO(PyCostModelNode, CostModelNode); }; diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index 570da2cf0650..e56348322082 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -30,6 +30,7 @@ #include #include #include +#include #include @@ -48,12 +49,13 @@ class WorkloadNode : public runtime::Object { /*! \brief The workload's structural hash. */ THashCode shash; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("mod", &mod); - // `shash` is not visited because TVM FFI doesn't support uint64_t + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("mod", &WorkloadNode::mod); } static constexpr const char* _type_key = "meta_schedule.Workload"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(WorkloadNode, runtime::Object); /*! @@ -124,15 +126,18 @@ class TuningRecordNode : public runtime::Object { /*! \brief The argument information. */ Optional> args_info; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("trace", &trace); - v->Visit("workload", &workload); - v->Visit("run_secs", &run_secs); - v->Visit("target", &target); - v->Visit("args_info", &args_info); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("trace", &TuningRecordNode::trace) + .def_ro("workload", &TuningRecordNode::workload) + .def_ro("run_secs", &TuningRecordNode::run_secs) + .def_ro("target", &TuningRecordNode::target) + .def_ro("args_info", &TuningRecordNode::args_info); } static constexpr const char* _type_key = "meta_schedule.TuningRecord"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TuningRecordNode, runtime::Object); /*! \brief Construct the measure candidate given the initial IR module and trace @@ -377,21 +382,23 @@ class PyDatabaseNode : public DatabaseNode { /*! \brief The packed function to the `Size` function. */ FSize f_size; - void VisitAttrs(tvm::AttrVisitor* v) { - // ffi::Functions are all not visited, because the reflection system doesn't take care of them, + static void RegisterReflection() { + // ffi::Functions are all not registered, because the reflection system doesn't take care of them, // so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. - // `f_has_workload` is not visited - // `f_commit_workload` is not visited - // `f_commit_tuning_record` is not visited - // `f_get_top_k` is not visited - // `f_get_all_tuning_records` is not visited - // `f_query_tuning_record` is not visited - // `f_query_schedule` is not visited - // `f_query_ir_module` is not visited - // `f_size` is not visited + // `f_has_workload` is not registered + // `f_commit_workload` is not registered + // `f_commit_tuning_record` is not registered + // `f_get_top_k` is not registered + // `f_get_all_tuning_records` is not registered + // `f_query_tuning_record` is not registered + // `f_query_schedule` is not registered + // `f_query_ir_module` is not registered + // `f_size` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + bool HasWorkload(const IRModule& mod) final { ICHECK(f_has_workload != nullptr) << "PyDatabase's HasWorkload method not implemented!"; return f_has_workload(mod); diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index cfc1f29e8efb..0e78bdd4bf95 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace tvm { namespace tir { @@ -52,15 +53,18 @@ class ExtractedTaskNode : public runtime::Object { /*! \brief Weight of the task */ int weight; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("task_name", &task_name); - v->Visit("mod", &mod); - v->Visit("target", &target); - v->Visit("dispatched", &dispatched); - v->Visit("weight", &weight); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("task_name", &ExtractedTaskNode::task_name) + .def_ro("mod", &ExtractedTaskNode::mod) + .def_ro("target", &ExtractedTaskNode::target) + .def_ro("dispatched", &ExtractedTaskNode::dispatched) + .def_ro("weight", &ExtractedTaskNode::weight); } static constexpr const char* _type_key = "meta_schedule.ExtractedTask"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ExtractedTaskNode, runtime::Object); }; diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index e45cb4eab195..cdf510c8caf2 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -39,7 +40,11 @@ class FeatureExtractorNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~FeatureExtractorNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; /*! * \brief Extract features from the given measure candidate. @@ -76,11 +81,13 @@ class PyFeatureExtractorNode : public FeatureExtractorNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_extract_from` is not visited - // `f_as_string` is not visited + static void RegisterReflection() { + // `f_extract_from` is not registered + // `f_as_string` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + Array ExtractFrom(const TuneContext& context, const Array& candidates) final; diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index 3a3d83cbf996..c74ce9ec8fc1 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -30,6 +30,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -42,7 +43,11 @@ class MeasureCallbackNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~MeasureCallbackNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; /*! * \brief Apply a measure callback rule with given arguments. @@ -90,11 +95,13 @@ class PyMeasureCallbackNode : public MeasureCallbackNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_apply` is not visited - // `f_as_string` is not visited + static void RegisterReflection() { + // `f_apply` is not registered + // `f_as_string` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void Apply(const TaskScheduler& task_scheduler, // int task_id, // const Array& measure_candidates, // diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index 9bfc9d0da954..a2dbb6f943dd 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -37,11 +38,15 @@ class MeasureCandidateNode : public runtime::Object { /*! \brief The argument information, e.g., (shape, dtype) for tensors. */ Array args_info; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("sch", &sch); - v->Visit("args_info", &args_info); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("sch", &MeasureCandidateNode::sch) + .def_ro("args_info", &MeasureCandidateNode::args_info); } + static constexpr const bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MeasureCandidate"; TVM_DECLARE_FINAL_OBJECT_INFO(MeasureCandidateNode, Object); }; diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 0f8e446784f3..0e2dc066a06e 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -40,7 +41,11 @@ class MutatorNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~MutatorNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; /*! * \brief Initialize the design space generator with tuning context. @@ -157,13 +162,15 @@ class PyMutatorNode : public MutatorNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_clone` is not visited - // `f_as_string` is not visited + static void RegisterReflection() { + // `f_initialize_with_tune_context` is not registered + // `f_apply` is not registered + // `f_clone` is not registered + // `f_as_string` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final; Optional Apply(const tir::Trace& trace, support::LinearCongruentialEngine::TRandState* rand_state) final; diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index e8648f038e61..fded08949d89 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -24,6 +24,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -39,7 +40,11 @@ class PostprocNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~PostprocNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; /*! * \brief Initialize the design space generator with tuning context. @@ -190,13 +195,15 @@ class PyPostprocNode : public PostprocNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_clone` is not visited - // `f_as_string` is not visited + static void RegisterReflection() { + // `f_initialize_with_tune_context` is not registered + // `f_apply` is not registered + // `f_clone` is not registered + // `f_as_string` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final; bool Apply(const tir::Schedule& sch) final; Postproc Clone() const final; diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index 6f8072b3f367..6ea64e51abf4 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -59,11 +60,13 @@ class ProfilerNode : public runtime::Object { /*! \brief Counter for the total time used */ ffi::Function total_timer; - void VisitAttrs(tvm::AttrVisitor* v) { - // `stats_sec` is not visited. - // `total_timer` is not visited. + static void RegisterReflection() { + // `stats_sec` is not registered + // `total_timer` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.Profiler"; TVM_DECLARE_FINAL_OBJECT_INFO(ProfilerNode, runtime::Object); diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c8331a3a60e3..c1a4fb84669a 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -27,6 +27,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -41,13 +42,16 @@ class RunnerInputNode : public runtime::Object { /*! \brief The argument information. */ Array args_info; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("artifact_path", &artifact_path); - v->Visit("device_type", &device_type); - v->Visit("args_info", &args_info); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("artifact_path", &RunnerInputNode::artifact_path) + .def_ro("device_type", &RunnerInputNode::device_type) + .def_ro("args_info", &RunnerInputNode::args_info); } static constexpr const char* _type_key = "meta_schedule.RunnerInput"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(RunnerInputNode, runtime::Object); }; @@ -75,12 +79,15 @@ class RunnerResultNode : public runtime::Object { /*! \brief The error message, if any. */ Optional error_msg; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("run_secs", &run_secs); - v->Visit("error_msg", &error_msg); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("run_secs", &RunnerResultNode::run_secs) + .def_ro("error_msg", &RunnerResultNode::error_msg); } static constexpr const char* _type_key = "meta_schedule.RunnerResult"; + static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(RunnerResultNode, runtime::Object); }; @@ -122,11 +129,13 @@ class RunnerFutureNode : public runtime::Object { /*! \brief The packed function to fetch runner output if it is ready. */ FResult f_result; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_done` is not visited - // `f_result` is not visited + static void RegisterReflection() { + // `f_done` is not registered + // `f_result` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + /*! * \brief Check whether the runner has finished. * \return A boolean indicating whether the runner has finished. @@ -215,10 +224,12 @@ class PyRunnerNode : public RunnerNode { /*! \brief The packed function to run the built artifacts and get runner futures. */ FRun f_run; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_run` is not visited + static void RegisterReflection() { + // `f_run` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + Array Run(Array runner_inputs) final { ICHECK(f_run != nullptr) << "PyRunner's Run method not implemented!"; return f_run(runner_inputs); diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 1a759c1b50fc..4f15d3d74dd8 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -42,7 +43,11 @@ class ScheduleRuleNode : public runtime::Object { /*! \brief Virtual destructor. */ virtual ~ScheduleRuleNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; /*! * \brief Initialize the design space generator with tuning context. @@ -320,13 +325,15 @@ class PyScheduleRuleNode : public ScheduleRuleNode { /*! \brief The packed function to the `Clone` function. */ FClone f_clone; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_apply` is not visited - // `f_as_string` is not visited - // `f_clone` is not visited + static void RegisterReflection() { + // `f_initialize_with_tune_context` is not registered + // `f_apply` is not registered + // `f_as_string` is not registered + // `f_clone` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final; Array Apply(const tir::Schedule& sch, const tir::BlockRV& block) final; ScheduleRule Clone() const final; diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index c0b4677f84b5..923abd18d24c 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -30,6 +30,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -241,15 +242,17 @@ class PySearchStrategyNode : public SearchStrategyNode { /*! \brief The packed function to the `Clone` method. */ FClone f_clone; - void VisitAttrs(tvm::AttrVisitor* v) { - // `f_initialize_with_tune_context` is not visited - // `f_pre_tuning` is not visited - // `f_post_tuning` is not visited - // `f_generate_measure_candidates` is not visited - // `f_notify_runner_results` is not visited - // `f_clone` is not visited + static void RegisterReflection() { + // `f_initialize_with_tune_context` is not registered + // `f_pre_tuning` is not registered + // `f_post_tuning` is not registered + // `f_generate_measure_candidates` is not registered + // `f_notify_runner_results` is not registered + // `f_clone` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final; void PreTuning(int max_trials, int num_trials_per_iter, const Array& design_spaces, const Optional& database, const Optional& cost_model) final; diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 4ba3c0b089fc..5ce6c9473d88 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -29,6 +29,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -82,12 +83,16 @@ class SpaceGeneratorNode : public runtime::Object { /*! \brief The probability of using certain mutator. */ Optional> mutator_probs; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("sch_rules", &sch_rules); - v->Visit("postprocs", &postprocs); - v->Visit("mutator_probs", &mutator_probs); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("sch_rules", &SpaceGeneratorNode::sch_rules) + .def_ro("postprocs", &SpaceGeneratorNode::postprocs) + .def_ro("mutator_probs", &SpaceGeneratorNode::mutator_probs); } + static constexpr const bool _type_has_method_visit_attrs = false; + /*! \brief Default destructor */ virtual ~SpaceGeneratorNode() = default; @@ -212,13 +217,14 @@ class PySpaceGeneratorNode : public SpaceGeneratorNode { /*! \brief The packed function to the `Clone` function. */ FClone f_clone; - void VisitAttrs(tvm::AttrVisitor* v) { - SpaceGeneratorNode::VisitAttrs(v); - // `f_initialize_with_tune_context` is not visited - // `f_generate_design_space` is not visited - // `f_clone` is not visited + static void RegisterReflection() { + // `f_initialize_with_tune_context` is not registered + // `f_generate_design_space` is not registered + // `f_clone` is not registered } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final; Array GenerateDesignSpace(const IRModule& mod) final; SpaceGenerator Clone() const final; diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index 7bf36873b3ce..d88a2cc00c72 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -60,19 +61,22 @@ class TaskRecordNode : public runtime::Object { /*! \brief Packed functions to fetch the runner results asynchronously. */ Optional> runner_futures = std::nullopt; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("ctx", &ctx); - v->Visit("task_weight", &task_weight); - v->Visit("flop", &flop); - v->Visit("is_terminated", &is_terminated); - v->Visit("build_error_count", &build_error_count); - v->Visit("run_error_count", &run_error_count); - // `latency_ms` is not visited - v->Visit("measure_candidates", &measure_candidates); - v->Visit("builder_results", &builder_results); - v->Visit("runner_futures", &runner_futures); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("ctx", &TaskRecordNode::ctx) + .def_ro("task_weight", &TaskRecordNode::task_weight) + .def_ro("flop", &TaskRecordNode::flop) + .def_ro("is_terminated", &TaskRecordNode::is_terminated) + .def_ro("build_error_count", &TaskRecordNode::build_error_count) + .def_ro("run_error_count", &TaskRecordNode::run_error_count) + .def_ro("measure_candidates", &TaskRecordNode::measure_candidates) + .def_ro("builder_results", &TaskRecordNode::builder_results) + .def_ro("runner_futures", &TaskRecordNode::runner_futures); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.TaskRecord"; TVM_DECLARE_FINAL_OBJECT_INFO(TaskRecordNode, Object); }; @@ -143,15 +147,18 @@ class TaskSchedulerNode : public runtime::Object { /*! \brief The default destructor. */ virtual ~TaskSchedulerNode() = default; - void VisitAttrs(tvm::AttrVisitor* v) { - // `logger` is not visited - v->Visit("tasks_", &tasks_); - v->Visit("measure_callbacks_", &measure_callbacks_); - v->Visit("database_", &database_); - v->Visit("cost_model_", &cost_model_); - v->Visit("remaining_tasks_", &remaining_tasks_); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("tasks_", &TaskSchedulerNode::tasks_) + .def_ro("measure_callbacks_", &TaskSchedulerNode::measure_callbacks_) + .def_ro("database_", &TaskSchedulerNode::database_) + .def_ro("cost_model_", &TaskSchedulerNode::cost_model_) + .def_ro("remaining_tasks_", &TaskSchedulerNode::remaining_tasks_); } + static constexpr bool _type_has_method_visit_attrs = false; + /*! * \brief Fetch the next task id. * \return The next task id. @@ -237,13 +244,13 @@ class PyTaskSchedulerNode : public TaskSchedulerNode { /*! \brief The packed function to the `Tune` function. */ FTune f_tune; - void VisitAttrs(tvm::AttrVisitor* v) { - TaskSchedulerNode::VisitAttrs(v); - // `f_next_task_id` is not visited - // `f_join_running_task` is not visited - // `f_tune` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } + static constexpr bool _type_has_method_visit_attrs = false; + int NextTaskId() final; Array JoinRunningTask(int task_id) final; void Tune(Array tasks, Array task_weights, int max_trials_global, diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 9045d4188ac1..1742f9424523 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -34,6 +34,7 @@ #include #include #include +#include namespace tvm { namespace meta_schedule { @@ -64,16 +65,21 @@ class TuneContextNode : public runtime::Object { /*! \brief The tuning task's logging function. t*/ ffi::Function logger; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("mod", &mod); - v->Visit("target", &target); - v->Visit("space_generator", &space_generator); - v->Visit("search_strategy", &search_strategy); - v->Visit("task_name", &task_name); - v->Visit("num_threads", &num_threads); - v->Visit("rand_state", &rand_state); - // `logger` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("mod", &TuneContextNode::mod) + .def_ro("target", &TuneContextNode::target) + .def_ro("space_generator", &TuneContextNode::space_generator) + .def_ro("search_strategy", &TuneContextNode::search_strategy) + .def_ro("task_name", &TuneContextNode::task_name) + .def_ro("num_threads", &TuneContextNode::num_threads) + .def_ro("rand_state", &TuneContextNode::rand_state); + // `logger` is not registered } + + static constexpr const bool _type_has_method_visit_attrs = false; + /*! * \brief Initialize members that needs initialization with tune context. */ diff --git a/src/meta_schedule/builder/builder.cc b/src/meta_schedule/builder/builder.cc index 85e189e73228..68c5f4c9c1ab 100644 --- a/src/meta_schedule/builder/builder.cc +++ b/src/meta_schedule/builder/builder.cc @@ -47,6 +47,12 @@ Builder Builder::PyBuilder(BuilderNode::FBuild f_build) { /******** FFI ********/ +TVM_FFI_STATIC_INIT_BLOCK({ + BuilderInputNode::RegisterReflection(); + BuilderResultNode::RegisterReflection(); + PyBuilderNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(BuilderInputNode); TVM_REGISTER_NODE_TYPE(BuilderResultNode); TVM_REGISTER_OBJECT_TYPE(BuilderNode); diff --git a/src/meta_schedule/database/database.cc b/src/meta_schedule/database/database.cc index 034294eedcd3..a9c04409c530 100644 --- a/src/meta_schedule/database/database.cc +++ b/src/meta_schedule/database/database.cc @@ -278,6 +278,12 @@ Database Database::PyDatabase(PyDatabaseNode::FHasWorkload f_has_workload, /******** FFI ********/ +TVM_FFI_STATIC_INIT_BLOCK({ + WorkloadNode::RegisterReflection(); + TuningRecordNode::RegisterReflection(); + PyDatabaseNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(WorkloadNode); TVM_REGISTER_NODE_TYPE(TuningRecordNode); TVM_REGISTER_OBJECT_TYPE(DatabaseNode); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index 2a6b93f8cb3b..c25853eb7004 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -22,6 +22,7 @@ #include "../module_equality.h" #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -81,12 +82,13 @@ class JSONDatabaseNode : public DatabaseNode { /*! \brief All the tuning records in the database */ std::multiset tuning_records_; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("path_workload", &path_workload); - v->Visit("path_tuning_record", &path_tuning_record); - // `workloads2idx_` is not visited - // `tuning_records_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("path_workload", &JSONDatabaseNode::path_workload) + .def_ro("path_tuning_record", &JSONDatabaseNode::path_tuning_record); } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.JSONDatabase"; TVM_DECLARE_FINAL_OBJECT_INFO(JSONDatabaseNode, DatabaseNode); @@ -213,6 +215,9 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, return Database(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + JSONDatabaseNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase") .set_body_typed(Database::JSONDatabase); diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index cbc811752cad..8a7a2da09b17 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -18,6 +18,7 @@ */ #include "../module_equality.h" #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -29,11 +30,13 @@ class MemoryDatabaseNode : public DatabaseNode { Array records; Array workloads; - void VisitAttrs(AttrVisitor* v) { - v->Visit("records", &records); - v->Visit("workloads", &workloads); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("records", &MemoryDatabaseNode::records) + .def_ro("workloads", &MemoryDatabaseNode::workloads); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.MemoryDatabase"; TVM_DECLARE_FINAL_OBJECT_INFO(MemoryDatabaseNode, DatabaseNode); @@ -100,5 +103,9 @@ TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") .set_body_typed(Database::MemoryDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + MemoryDatabaseNode::RegisterReflection(); +}); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 87f5c03a71eb..99aecf5d8632 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -25,8 +26,12 @@ class OrderedUnionDatabaseNode : public DatabaseNode { public: Array databases; - void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); } - + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("databases", &OrderedUnionDatabaseNode::databases); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.OrderedUnionDatabase"; TVM_DECLARE_FINAL_OBJECT_INFO(OrderedUnionDatabaseNode, DatabaseNode); @@ -82,5 +87,9 @@ TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") .set_body_typed(Database::OrderedUnionDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + OrderedUnionDatabaseNode::RegisterReflection(); +}); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index c66ec5f4f0c1..2235d7768209 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -27,10 +28,12 @@ class ScheduleFnDatabaseNode : public DatabaseNode { ffi::TypedFunction schedule_fn; - void VisitAttrs(AttrVisitor* v) { - // `schedule_fn` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("schedule_fn", &ScheduleFnDatabaseNode::schedule_fn); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase"; TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleFnDatabaseNode, DatabaseNode); @@ -102,5 +105,9 @@ TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") .set_body_typed(Database::ScheduleFnDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + ScheduleFnDatabaseNode::RegisterReflection(); +}); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 2bc82b459cad..4b843aab2b8e 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -25,8 +26,12 @@ class UnionDatabaseNode : public DatabaseNode { public: Array databases; - void VisitAttrs(AttrVisitor* v) { v->Visit("databases", &databases); } - + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("databases", &UnionDatabaseNode::databases); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.UnionDatabase"; TVM_DECLARE_FINAL_OBJECT_INFO(UnionDatabaseNode, DatabaseNode); @@ -85,5 +90,9 @@ TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase") .set_body_typed(Database::UnionDatabase); +TVM_FFI_STATIC_INIT_BLOCK({ + UnionDatabaseNode::RegisterReflection(); +}); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index fb26e6eb693c..acfb29b8de30 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -38,6 +38,10 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } +TVM_FFI_STATIC_INIT_BLOCK({ + ExtractedTaskNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ExtractedTask") .set_body_typed([](String task_name, IRModule mod, Target target, Array dispatched, diff --git a/src/meta_schedule/feature_extractor/feature_extractor.cc b/src/meta_schedule/feature_extractor/feature_extractor.cc index 9a3cecf4ce26..eda856d1bdcf 100644 --- a/src/meta_schedule/feature_extractor/feature_extractor.cc +++ b/src/meta_schedule/feature_extractor/feature_extractor.cc @@ -45,6 +45,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + FeatureExtractorNode::RegisterReflection(); + PyFeatureExtractorNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(FeatureExtractorNode); TVM_REGISTER_NODE_TYPE(PyFeatureExtractorNode); diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 2fc8878546d8..9fdb9c9adc16 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -1367,12 +1368,16 @@ class PerStoreFeatureNode : public FeatureExtractorNode { bool extract_workload; int feature_vector_length; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("buffers_per_store", &buffers_per_store); - v->Visit("arith_intensity_curve_num_samples", &arith_intensity_curve_num_samples); - v->Visit("cache_line_bytes", &cache_line_bytes); - v->Visit("feature_vector_length", &feature_vector_length); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("buffers_per_store", &PerStoreFeatureNode::buffers_per_store) + .def_ro("arith_intensity_curve_num_samples", &PerStoreFeatureNode::arith_intensity_curve_num_samples) + .def_ro("cache_line_bytes", &PerStoreFeatureNode::cache_line_bytes) + .def_ro("extract_workload", &PerStoreFeatureNode::extract_workload) + .def_ro("feature_vector_length", &PerStoreFeatureNode::feature_vector_length); } + static constexpr bool _type_has_method_visit_attrs = false; void ExtractSingle(IRModule mod, bool is_gpu, std::vector>* results) { static transform::Sequential passes = tir::transform::PassListForPerStoreFeature(); @@ -1441,6 +1446,10 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, return FeatureExtractor(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + PerStoreFeatureNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") .set_body_typed(FeatureExtractor::PerStoreFeature); diff --git a/src/meta_schedule/measure_callback/measure_callback.cc b/src/meta_schedule/measure_callback/measure_callback.cc index 0ee49f2ab4f9..76e5fcf7276c 100644 --- a/src/meta_schedule/measure_callback/measure_callback.cc +++ b/src/meta_schedule/measure_callback/measure_callback.cc @@ -56,6 +56,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + MeasureCallbackNode::RegisterReflection(); + PyMeasureCallbackNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(MeasureCallbackNode); TVM_REGISTER_NODE_TYPE(PyMeasureCallbackNode); diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 8f8c077aa815..959b5a52850a 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -31,7 +32,13 @@ class MutateComputeLocationNode : public MutatorNode { /*! \brief JSON representation of the workload */ std::string json_mod_; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MutateComputeLocation"; TVM_DECLARE_FINAL_OBJECT_INFO(MutateComputeLocationNode, MutatorNode); @@ -126,6 +133,10 @@ Mutator Mutator::MutateComputeLocation() { return Mutator(make_object()); } +TVM_FFI_STATIC_INIT_BLOCK({ + MutateComputeLocationNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") .set_body_typed(Mutator::MutateComputeLocation); diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index a6a34e47a9d9..8d7e8884661c 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -20,6 +20,7 @@ #include #include "../utils.h" +#include namespace tvm { namespace tir { @@ -169,12 +170,12 @@ class MutateParallelNode : public MutatorNode { /*! \brief JSON representation of the workload */ std::string json_mod_; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("max_jobs_per_core", &max_jobs_per_core); - // `max_parallel_extent_` is not visited. - // `json_mod` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_jobs_per_core", &MutateParallelNode::max_jobs_per_core); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.MutateParallel"; TVM_DECLARE_FINAL_OBJECT_INFO(MutateParallelNode, MutatorNode); @@ -311,6 +312,9 @@ Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { return Mutator(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + MutateParallelNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(MutateParallelNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel") .set_body_typed(Mutator::MutateParallel); diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index 269b05240443..f62658ff79fd 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -31,7 +32,13 @@ class MutateThreadBindingNode : public MutatorNode { /*! \brief JSON representation of the workload */ std::string json_mod_; - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MutateThreadBinding"; TVM_DECLARE_FINAL_OBJECT_INFO(MutateThreadBindingNode, MutatorNode); @@ -164,6 +171,10 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } +TVM_FFI_STATIC_INIT_BLOCK({ + MutateThreadBindingNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") .set_body_typed(Mutator::MutateThreadBinding); diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index e8a728d05033..50cfe89a24df 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -20,6 +20,7 @@ #include #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -54,7 +55,13 @@ int64_t Product(const std::vector& array) { /*! \brief A mutator that mutates the tile size */ class MutateTileSizeNode : public MutatorNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MutateTileSize"; TVM_DECLARE_FINAL_OBJECT_INFO(MutateTileSizeNode, MutatorNode); @@ -268,6 +275,10 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } +TVM_FFI_STATIC_INIT_BLOCK({ + MutateTileSizeNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize") .set_body_typed(Mutator::MutateTileSize); diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 28fcf3668f27..812af8c447f8 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace tir { @@ -50,7 +51,13 @@ using tir::Trace; /*! \brief Create a Mutator that mutates auto unroll step */ class MutateUnrollNode : public MutatorNode { public: - void VisitAttrs(tvm::AttrVisitor* v) {} + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MutateUnroll"; TVM_DECLARE_FINAL_OBJECT_INFO(MutateUnrollNode, MutatorNode); @@ -137,6 +144,10 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } +TVM_FFI_STATIC_INIT_BLOCK({ + MutateUnrollNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MutateUnrollNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); diff --git a/src/meta_schedule/mutator/mutator.cc b/src/meta_schedule/mutator/mutator.cc index e415b3909f10..5f2ccf24fe77 100644 --- a/src/meta_schedule/mutator/mutator.cc +++ b/src/meta_schedule/mutator/mutator.cc @@ -85,6 +85,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + MutatorNode::RegisterReflection(); + PyMutatorNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(MutatorNode); TVM_REGISTER_NODE_TYPE(PyMutatorNode); diff --git a/src/meta_schedule/postproc/postproc.cc b/src/meta_schedule/postproc/postproc.cc index e29f9dd54c5a..8434cbf808e8 100644 --- a/src/meta_schedule/postproc/postproc.cc +++ b/src/meta_schedule/postproc/postproc.cc @@ -109,6 +109,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + PostprocNode::RegisterReflection(); + PyPostprocNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(PostprocNode); TVM_REGISTER_NODE_TYPE(PyPostprocNode); diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index d23e07795cad..e06b983a6a5f 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace tir { @@ -115,6 +116,12 @@ namespace meta_schedule { */ class RewriteCooperativeFetchNode : public PostprocNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; + // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { if (Optional v = context->target.value()->GetAttr("thread_warp_size")) { @@ -132,8 +139,6 @@ class RewriteCooperativeFetchNode : public PostprocNode { return Postproc(n); } - void VisitAttrs(tvm::AttrVisitor* v) {} - static constexpr const char* _type_key = "meta_schedule.RewriteCooperativeFetch"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteCooperativeFetchNode, PostprocNode); @@ -226,6 +231,10 @@ Postproc Postproc::RewriteCooperativeFetch() { return Postproc(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + RewriteCooperativeFetchNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") .set_body_typed(Postproc::RewriteCooperativeFetch); diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 3ffe0f9234d2..42570a595f80 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace tir { @@ -109,6 +110,12 @@ namespace meta_schedule { /*! \brief Rewrite reduction block by moving the init block out */ class RewriteReductionBlockNode : public PostprocNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; + // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final {} // Inherited from PostprocNode @@ -119,8 +126,6 @@ class RewriteReductionBlockNode : public PostprocNode { return Postproc(n); } - void VisitAttrs(tvm::AttrVisitor* v) {} - static constexpr const char* _type_key = "meta_schedule.RewriteReductionBlock"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteReductionBlockNode, PostprocNode); }; @@ -175,5 +180,9 @@ TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") .set_body_typed(Postproc::RewriteReductionBlock); +TVM_FFI_STATIC_INIT_BLOCK({ + RewriteReductionBlockNode::RegisterReflection(); +}); + } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 0f98484dd44e..69f5fa2ff8dd 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include @@ -62,12 +63,16 @@ void CollectTensorizationJobs( class RewriteTensorizeNode : public PostprocNode { public: + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final {} bool Apply(const tir::Schedule& sch) final; - void VisitAttrs(tvm::AttrVisitor* v) {} - Postproc Clone() const { ObjectPtr n = make_object(*this); return Postproc(n); @@ -106,6 +111,9 @@ Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { return Postproc(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + RewriteTensorizeNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") .set_body_typed(Postproc::RewriteTensorize); diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index a2c9d1364ab6..cdc9ff12db9a 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include "../utils.h" @@ -85,6 +86,7 @@ namespace meta_schedule { /*! \brief Add thread binding to unbound blocks */ class RewriteUnboundBlockNode : public PostprocNode { public: + // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; @@ -109,10 +111,11 @@ class RewriteUnboundBlockNode : public PostprocNode { /*! \brief The max number of threadblocks in the cuda device */ int max_threadblocks_ = -1; - void VisitAttrs(tvm::AttrVisitor* v) { - // `max_threads_per_block_` is not visited - // `max_threadblocks_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.RewriteUnboundBlock"; TVM_DECLARE_FINAL_OBJECT_INFO(RewriteUnboundBlockNode, PostprocNode); @@ -145,6 +148,10 @@ Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { return Postproc(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + RewriteUnboundBlockNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") .set_body_typed(Postproc::RewriteUnboundBlock); diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index 2a034a7be297..ca01e1003f76 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -120,6 +120,10 @@ Optional Profiler::Current() { } } +TVM_FFI_STATIC_INIT_BLOCK({ + ProfilerNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ProfilerNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { return Profiler(); diff --git a/src/meta_schedule/runner/runner.cc b/src/meta_schedule/runner/runner.cc index 38d4225f0fbd..009a2786a983 100644 --- a/src/meta_schedule/runner/runner.cc +++ b/src/meta_schedule/runner/runner.cc @@ -51,6 +51,13 @@ Runner Runner::PyRunner(Runner::FRun f_run) { /******** FFI ********/ +TVM_FFI_STATIC_INIT_BLOCK({ + RunnerInputNode::RegisterReflection(); + RunnerResultNode::RegisterReflection(); + RunnerFutureNode::RegisterReflection(); + PyRunnerNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(RunnerInputNode); TVM_REGISTER_NODE_TYPE(RunnerResultNode); TVM_REGISTER_NODE_TYPE(RunnerFutureNode); diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 48149ed871e4..2fb2a9c90d71 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -56,12 +57,13 @@ class AddRFactorNode : public ScheduleRuleNode { /*! \brief The number of cores. */ int max_parallel_basic_; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("max_jobs_per_core", &max_jobs_per_core); - v->Visit("max_innermost_factor", &max_innermost_factor); - // `max_parallel_extent_` is not visited - // `max_parallel_basic_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_jobs_per_core", &AddRFactorNode::max_jobs_per_core) + .def_ro("max_innermost_factor", &AddRFactorNode::max_innermost_factor); } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.AddRFactor"; TVM_DECLARE_FINAL_OBJECT_INFO(AddRFactorNode, ScheduleRuleNode); @@ -119,6 +121,9 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: return res; } +TVM_FFI_STATIC_INIT_BLOCK({ + AddRFactorNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(AddRFactorNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") .set_body_typed(ScheduleRule::AddRFactor); diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 92de19163af5..7cc70dfe4733 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -71,8 +72,12 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { public: Optional target_ = std::nullopt; - void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("target_", &target_); } - + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("target_", &ApplyCustomRuleNode::target_); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ApplyCustomRule"; TVM_DECLARE_FINAL_OBJECT_INFO(ApplyCustomRuleNode, ScheduleRuleNode); }; @@ -86,6 +91,9 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { return rule->IsInstance(); } +TVM_FFI_STATIC_INIT_BLOCK({ + ApplyCustomRuleNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") .set_body_typed(ScheduleRule::ApplyCustomRule); diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index 892a79ea926d..ddb92273da74 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include @@ -55,12 +56,11 @@ class AutoBindNode : public ScheduleRuleNode { /*! \brief thread_extents Candidates of thread axis extent. */ Array thread_extents_; - void VisitAttrs(tvm::AttrVisitor* v) { - // `max_threads_per_block_` is not visited - // `max_threadblocks_` is not visited - // `thread_extents_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.AutoBind"; TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode); }; @@ -81,6 +81,10 @@ ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_ return ScheduleRule(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + AutoBindNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(AutoBindNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind") .set_body_typed(ScheduleRule::AutoBind); diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index 948632e580e6..b30a82eb06b0 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -82,16 +83,18 @@ class AutoInlineNode : public ScheduleRuleNode { /*! \brief The operators that are disallowed in auto inline */ Array disallow_op; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("into_producer", &into_producer); - v->Visit("into_consumer", &into_consumer); - v->Visit("inline_const_tensor", &inline_const_tensor); - v->Visit("disallow_if_then_else", &disallow_if_then_else); - v->Visit("require_injective", &require_injective); - v->Visit("require_ordered", &require_ordered); - v->Visit("disallow_op", &disallow_op); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("into_producer", &AutoInlineNode::into_producer) + .def_ro("into_consumer", &AutoInlineNode::into_consumer) + .def_ro("inline_const_tensor", &AutoInlineNode::inline_const_tensor) + .def_ro("disallow_if_then_else", &AutoInlineNode::disallow_if_then_else) + .def_ro("require_injective", &AutoInlineNode::require_injective) + .def_ro("require_ordered", &AutoInlineNode::require_ordered) + .def_ro("disallow_op", &AutoInlineNode::disallow_op); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.AutoInline"; TVM_DECLARE_FINAL_OBJECT_INFO(AutoInlineNode, ScheduleRuleNode); }; @@ -190,6 +193,9 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // return ScheduleRule(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + AutoInlineNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(AutoInlineNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); @@ -222,6 +228,12 @@ class InlineConstantScalarsNode : public ScheduleRuleNode { return ScheduleRule(n); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars"; TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode); }; @@ -231,6 +243,9 @@ ScheduleRule ScheduleRule::InlineConstantScalars() { return ScheduleRule(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + InlineConstantScalarsNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") .set_body_typed(ScheduleRule::InlineConstantScalars); diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index e06817e37c4c..f418cef8346c 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -271,12 +272,14 @@ class CrossThreadReductionNode : public ScheduleRuleNode { /*! \brief Candidates of thread axis extent (values are required to be positive). */ Array thread_extents; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("max_threads_per_block", &max_threads_per_block); - v->Visit("warp_size", &warp_size); - v->Visit("thread_extents", &thread_extents); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_threads_per_block", &CrossThreadReductionNode::max_threads_per_block) + .def_ro("warp_size", &CrossThreadReductionNode::warp_size) + .def_ro("thread_extents", &CrossThreadReductionNode::thread_extents); } - + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.CrossThreadReduction"; TVM_DECLARE_FINAL_OBJECT_INFO(CrossThreadReductionNode, ScheduleRuleNode); }; @@ -290,6 +293,9 @@ ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { return ScheduleRule(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + CrossThreadReductionNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") .set_body_typed(ScheduleRule::CrossThreadReduction); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index f020c8efd08a..b702d04f45bc 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -54,6 +54,10 @@ using tir::IterVarType; using tir::LoopRV; using tir::Schedule; +TVM_FFI_STATIC_INIT_BLOCK({ + MultiLevelTilingNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(StateNode); State::State(tir::Schedule sch, tir::BlockRV block_rv, Array> tiles) { diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index b46eac23ad7e..025a24148a37 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -220,19 +221,16 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief The function to overwrite the default condition for applying MultiLevelTiling. */ Optional filter_fn_; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("structure", &structure); - v->Visit("tile_binds", &tile_binds); - v->Visit("max_innermost_factor", &max_innermost_factor); - // `vector_load_lens` is not visited - // `reuse_read_` is not visited - // `reuse_write_` is not visited - // `s_indices_` is not visited - // `r_indices_` is not visited - // `thread_warp_size_` is not visited - // `max_threads_per_block` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("structure", &MultiLevelTilingNode::structure) + .def_ro("tile_binds", &MultiLevelTilingNode::tile_binds) + .def_ro("max_innermost_factor", &MultiLevelTilingNode::max_innermost_factor); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.MultiLevelTiling"; TVM_DECLARE_BASE_OBJECT_INFO(MultiLevelTilingNode, ScheduleRuleNode); }; diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 905f8d8ce65f..44c217ea4969 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace tir { @@ -108,13 +109,15 @@ class ParallelizeVectorizeUnrollNode : public ScheduleRuleNode { /*! \brief The number of maximum available jobs in CPU. */ int64_t max_parallel_extent_; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("max_jobs_per_core", &max_jobs_per_core); - v->Visit("max_vectorize_extent", &max_vectorize_extent); - v->Visit("unroll_max_steps", &unroll_max_steps); - v->Visit("unroll_explicit", &unroll_explicit); - // `max_parallel_extent_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_jobs_per_core", &ParallelizeVectorizeUnrollNode::max_jobs_per_core) + .def_ro("max_vectorize_extent", &ParallelizeVectorizeUnrollNode::max_vectorize_extent) + .def_ro("unroll_max_steps", &ParallelizeVectorizeUnrollNode::unroll_max_steps) + .def_ro("unroll_explicit", &ParallelizeVectorizeUnrollNode::unroll_explicit); } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ParallelizeVectorizeUnroll"; TVM_DECLARE_FINAL_OBJECT_INFO(ParallelizeVectorizeUnrollNode, ScheduleRuleNode); @@ -133,6 +136,9 @@ ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, return ScheduleRule(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + ParallelizeVectorizeUnrollNode::RegisterReflection(); +}); TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index ed71baade06a..ce6fce57e816 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -111,8 +112,11 @@ class RandomComputeLocationNode : public ScheduleRuleNode { } public: - void VisitAttrs(tvm::AttrVisitor* v) {} - + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef(); + } + static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.RandomComputeLocation"; TVM_DECLARE_FINAL_OBJECT_INFO(RandomComputeLocationNode, ScheduleRuleNode); }; @@ -121,6 +125,10 @@ ScheduleRule ScheduleRule::RandomComputeLocation() { return ScheduleRule(make_object()); } +TVM_FFI_STATIC_INIT_BLOCK({ + RandomComputeLocationNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") .set_body_typed(ScheduleRule::RandomComputeLocation); diff --git a/src/meta_schedule/schedule_rule/schedule_rule.cc b/src/meta_schedule/schedule_rule/schedule_rule.cc index 3640694b4e5f..e72be72520e7 100644 --- a/src/meta_schedule/schedule_rule/schedule_rule.cc +++ b/src/meta_schedule/schedule_rule/schedule_rule.cc @@ -399,6 +399,11 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << f_as_string(); }); +TVM_FFI_STATIC_INIT_BLOCK({ + ScheduleRuleNode::RegisterReflection(); + PyScheduleRuleNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(ScheduleRuleNode); TVM_REGISTER_NODE_TYPE(PyScheduleRuleNode); diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 4872f3aa5f6e..5f0b9405431a 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -19,6 +19,7 @@ #include "../module_equality.h" #include "../utils.h" +#include #define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ @@ -378,26 +379,22 @@ class EvolutionarySearchNode : public SearchStrategyNode { /*! \brief The ratio of measurements to use randomly sampled states. */ double eps_greedy; - void VisitAttrs(tvm::AttrVisitor* v) { - // `context_` is not visited - // `rand_state_` is not visited - // `state_` is not visited - - /*** Configuration: global ***/ - v->Visit("population_size", &population_size); - v->Visit("num_empty_iters_before_early_stop", &num_empty_iters_before_early_stop); - /*** Configuration: the initial population ***/ - v->Visit("init_measured_ratio", &init_measured_ratio); - v->Visit("init_min_unmeasured", &init_min_unmeasured); - v->Visit("max_fail_count", &max_fail_count); - /*** Configuration: evolution ***/ - v->Visit("genetic_num_iters", &genetic_num_iters); - v->Visit("genetic_mutate_prob", &genetic_mutate_prob); - v->Visit("genetic_max_fail_count", &genetic_max_fail_count); - /*** Configuration: pick states for measurement ***/ - v->Visit("eps_greedy", &eps_greedy); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("population_size", &EvolutionarySearchNode::population_size) + .def_ro("num_empty_iters_before_early_stop", &EvolutionarySearchNode::num_empty_iters_before_early_stop) + .def_ro("init_measured_ratio", &EvolutionarySearchNode::init_measured_ratio) + .def_ro("init_min_unmeasured", &EvolutionarySearchNode::init_min_unmeasured) + .def_ro("max_fail_count", &EvolutionarySearchNode::max_fail_count) + .def_ro("genetic_num_iters", &EvolutionarySearchNode::genetic_num_iters) + .def_ro("genetic_mutate_prob", &EvolutionarySearchNode::genetic_mutate_prob) + .def_ro("genetic_max_fail_count", &EvolutionarySearchNode::genetic_max_fail_count) + .def_ro("eps_greedy", &EvolutionarySearchNode::eps_greedy); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.EvolutionarySearch"; TVM_DECLARE_FINAL_OBJECT_INFO(EvolutionarySearchNode, SearchStrategyNode); @@ -801,6 +798,10 @@ Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, return result; } +TVM_FFI_STATIC_INIT_BLOCK({ + EvolutionarySearchNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") .set_body_typed(SearchStrategy::EvolutionarySearch); diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 51cc40839195..5716816be7e1 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -60,7 +60,12 @@ class ReplayFuncNode : public SearchStrategyNode { /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - void VisitAttrs(tvm::AttrVisitor* v) {} + + static void RegisterReflection() { + // No fields to register + } + + static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ReplayFunc"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayFuncNode, SearchStrategyNode); @@ -156,6 +161,10 @@ SearchStrategy SearchStrategy::ReplayFunc() { return SearchStrategy(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + ReplayFuncNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ReplayFuncNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") .set_body_typed(SearchStrategy::ReplayFunc); diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index c9a7459fdf61..1eaee10aec19 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -75,15 +75,14 @@ class ReplayTraceNode : public SearchStrategyNode { /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("max_fail_count", &max_fail_count); - // `rand_state_` is not visited - // `mod_` is not visited - // `num_threads_` is not visited - // `postprocs_` is not visited - // `state_` is not visited + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("max_fail_count", &ReplayTraceNode::max_fail_count); } + static constexpr const bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.ReplayTrace"; TVM_DECLARE_FINAL_OBJECT_INFO(ReplayTraceNode, SearchStrategyNode); @@ -190,6 +189,10 @@ SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { return SearchStrategy(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + ReplayTraceNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ReplayTraceNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") .set_body_typed(SearchStrategy::ReplayTrace); diff --git a/src/meta_schedule/search_strategy/search_strategy.cc b/src/meta_schedule/search_strategy/search_strategy.cc index 8fc6538b59f5..b1ebfd784951 100644 --- a/src/meta_schedule/search_strategy/search_strategy.cc +++ b/src/meta_schedule/search_strategy/search_strategy.cc @@ -82,6 +82,11 @@ SearchStrategy SearchStrategy::PySearchStrategy( return SearchStrategy(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + MeasureCandidateNode::RegisterReflection(); + PySearchStrategyNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(MeasureCandidateNode); TVM_REGISTER_OBJECT_TYPE(SearchStrategyNode); TVM_REGISTER_NODE_TYPE(PySearchStrategyNode); diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 91d5ba53d551..0a30ff09ac50 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -34,12 +35,12 @@ class PostOrderApplyNode : public SpaceGeneratorNode { /*! \brief The random state. -1 means using random number. */ TRandState rand_state_ = -1; - void VisitAttrs(tvm::AttrVisitor* v) { - SpaceGeneratorNode::VisitAttrs(v); - // `rand_state_` is not visited - // `sch_rules_` is not visited + static void RegisterReflection() { + // No fields to register } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); @@ -115,6 +116,10 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, return SpaceGenerator(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + PostOrderApplyNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") .set_body_typed(SpaceGenerator::PostOrderApply); diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index f7f2a3ba19de..f533cc815913 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -29,11 +30,12 @@ class ScheduleFnNode : public SpaceGeneratorNode { /*! \brief The schedule function. */ ffi::Function schedule_fn_; - void VisitAttrs(tvm::AttrVisitor* v) { - SpaceGeneratorNode::VisitAttrs(v); - // `schedule_fn_` is not visited. + static void RegisterReflection() { + // `schedule_fn_` is not registered. } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); this->rand_state_ = ForkSeed(&context->rand_state); @@ -96,6 +98,10 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, return SpaceGenerator(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + ScheduleFnNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(ScheduleFnNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") .set_body_typed(SpaceGenerator::ScheduleFn); diff --git a/src/meta_schedule/space_generator/space_generator.cc b/src/meta_schedule/space_generator/space_generator.cc index 7306fffcb1af..bd94d6804f2c 100644 --- a/src/meta_schedule/space_generator/space_generator.cc +++ b/src/meta_schedule/space_generator/space_generator.cc @@ -187,6 +187,11 @@ SpaceGenerator SpaceGenerator::PySpaceGenerator( return SpaceGenerator(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + SpaceGeneratorNode::RegisterReflection(); + PySpaceGeneratorNode::RegisterReflection(); +}); + TVM_REGISTER_OBJECT_TYPE(SpaceGeneratorNode); TVM_REGISTER_NODE_TYPE(PySpaceGeneratorNode); diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 12bf75349430..5355010d1cd4 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -27,11 +28,14 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { /*! \brief The array of design space generators unioned, could be recursive. */ Array space_generators; - void VisitAttrs(tvm::AttrVisitor* v) { - SpaceGeneratorNode::VisitAttrs(v); - v->Visit("space_generators", &space_generators); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("space_generators", &SpaceGeneratorUnionNode::space_generators); } + static constexpr const bool _type_has_method_visit_attrs = false; + void InitializeWithTuneContext(const TuneContext& context) final { SpaceGeneratorNode::InitializeWithTuneContext(context); for (const SpaceGenerator& space_generator : space_generators) { @@ -81,6 +85,10 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g return SpaceGenerator(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + SpaceGeneratorUnionNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") .set_body_typed(SpaceGenerator::SpaceGeneratorUnion); diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 23d23e624394..667166ec3845 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -17,6 +17,7 @@ * under the License. */ #include "../utils.h" +#include namespace tvm { namespace meta_schedule { @@ -31,15 +32,15 @@ class GradientBasedNode final : public TaskSchedulerNode { int round_robin_rounds_; std::vector> best_latency_history_; - void VisitAttrs(tvm::AttrVisitor* v) { - TaskSchedulerNode::VisitAttrs(v); - v->Visit("alpha", &alpha); - v->Visit("window_size", &window_size); - // `rand_state` is not visited. - // `num_rounds_already_` is not visited. - // `best_latency_history_` is not visited. + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("alpha", &GradientBasedNode::alpha) + .def_ro("window_size", &GradientBasedNode::window_size); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.GradientBased"; TVM_DECLARE_FINAL_OBJECT_INFO(GradientBasedNode, TaskSchedulerNode); @@ -144,6 +145,10 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i return TaskScheduler(n); } +TVM_FFI_STATIC_INIT_BLOCK({ + GradientBasedNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(GradientBasedNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") .set_body_typed(TaskScheduler::GradientBased); diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 9792fa7e7c25..4b48e8c8a582 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include "../utils.h" namespace tvm { @@ -27,11 +28,14 @@ class RoundRobinNode final : public TaskSchedulerNode { /*! \brief The current task id processed. */ int task_id = -1; - void VisitAttrs(tvm::AttrVisitor* v) { - TaskSchedulerNode::VisitAttrs(v); - v->Visit("task_id", &task_id); + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef() + .def_ro("task_id", &RoundRobinNode::task_id); } + static constexpr bool _type_has_method_visit_attrs = false; + static constexpr const char* _type_key = "meta_schedule.RoundRobin"; TVM_DECLARE_FINAL_OBJECT_INFO(RoundRobinNode, TaskSchedulerNode); @@ -62,6 +66,11 @@ TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { return TaskScheduler(n); } + +TVM_FFI_STATIC_INIT_BLOCK({ + RoundRobinNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(RoundRobinNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") .set_body_typed(TaskScheduler::RoundRobin); diff --git a/src/meta_schedule/task_scheduler/task_scheduler.cc b/src/meta_schedule/task_scheduler/task_scheduler.cc index 85a406365377..a787c4456b82 100644 --- a/src/meta_schedule/task_scheduler/task_scheduler.cc +++ b/src/meta_schedule/task_scheduler/task_scheduler.cc @@ -21,6 +21,12 @@ namespace tvm { namespace meta_schedule { +TVM_FFI_STATIC_INIT_BLOCK({ + TaskRecordNode::RegisterReflection(); + TaskSchedulerNode::RegisterReflection(); + PyTaskSchedulerNode::RegisterReflection(); +}); + TaskRecord::TaskRecord(TuneContext ctx, double task_weight) { ObjectPtr n = ffi::make_object(); n->ctx = ctx; diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 31120ce45d4a..877e348f2148 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -62,6 +62,10 @@ void TuneContextNode::Initialize() { } } +TVM_FFI_STATIC_INIT_BLOCK({ + TuneContextNode::RegisterReflection(); +}); + TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContext") .set_body_typed([](Optional mod, Optional target, From 87be0e7ce3b765ce536c510a8e7b8967eadfa3c6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Jun 2025 12:18:12 -0400 Subject: [PATCH 6/6] cpplint --- include/tvm/arith/analyzer.h | 6 ++--- include/tvm/arith/iter_affine_map.h | 24 +++++++++---------- include/tvm/ir/expr.h | 3 +-- include/tvm/meta_schedule/arg_info.h | 2 +- include/tvm/meta_schedule/builder.h | 5 ++-- include/tvm/meta_schedule/cost_model.h | 4 +--- include/tvm/meta_schedule/database.h | 6 ++--- include/tvm/meta_schedule/extracted_task.h | 2 +- include/tvm/meta_schedule/feature_extractor.h | 2 +- include/tvm/meta_schedule/measure_callback.h | 2 +- include/tvm/meta_schedule/measure_candidate.h | 2 +- include/tvm/meta_schedule/mutator.h | 2 +- include/tvm/meta_schedule/postproc.h | 2 +- include/tvm/meta_schedule/profiler.h | 2 +- include/tvm/meta_schedule/runner.h | 2 +- include/tvm/meta_schedule/schedule_rule.h | 2 +- include/tvm/meta_schedule/search_strategy.h | 2 +- include/tvm/meta_schedule/space_generator.h | 2 +- include/tvm/meta_schedule/tune_context.h | 2 +- include/tvm/script/ir_builder/base.h | 2 +- include/tvm/script/ir_builder/ir/frame.h | 2 +- include/tvm/script/ir_builder/relax/frame.h | 2 +- include/tvm/script/ir_builder/tir/frame.h | 3 +-- include/tvm/tir/expr.h | 20 ++++------------ src/ir/expr.cc | 4 +--- src/meta_schedule/arg_info.cc | 1 + src/meta_schedule/database/json_database.cc | 7 +++--- src/meta_schedule/database/memory_database.cc | 11 ++++----- .../database/ordered_union_database.cc | 11 ++++----- .../database/schedule_fn_database.cc | 11 ++++----- src/meta_schedule/database/union_database.cc | 10 ++++---- src/meta_schedule/extracted_task.cc | 4 +--- .../feature_extractor/per_store_feature.cc | 9 ++++--- .../mutator/mutate_compute_location.cc | 7 +++--- src/meta_schedule/mutator/mutate_parallel.cc | 11 ++++----- .../mutator/mutate_thread_binding.cc | 7 +++--- src/meta_schedule/mutator/mutate_tile_size.cc | 7 +++--- src/meta_schedule/mutator/mutate_unroll.cc | 7 +++--- .../postproc/rewrite_cooperative_fetch.cc | 7 +++--- .../postproc/rewrite_reduction_block.cc | 7 +++--- .../postproc/rewrite_tensorize.cc | 6 ++--- .../postproc/rewrite_unbound_block.cc | 7 ++---- src/meta_schedule/profiler.cc | 4 +--- .../schedule_rule/add_rfactor.cc | 5 ++-- .../schedule_rule/apply_custom_rule.cc | 8 +++---- src/meta_schedule/schedule_rule/auto_bind.cc | 6 ++--- .../schedule_rule/auto_inline.cc | 9 +++---- .../schedule_rule/cross_thread_reduction.cc | 5 ++-- .../schedule_rule/multi_level_tiling.cc | 4 +--- .../schedule_rule/multi_level_tiling.h | 2 +- .../parallel_vectorize_unroll.cc | 7 +++--- .../schedule_rule/random_compute_location.cc | 5 ++-- .../search_strategy/evolutionary_search.cc | 10 ++++---- .../search_strategy/replay_func.cc | 5 +--- .../search_strategy/replay_trace.cc | 7 ++---- .../space_generator/post_order_apply.cc | 7 +++--- .../space_generator/schedule_fn.cc | 7 +++--- .../space_generator/space_generator_union.cc | 11 ++++----- .../task_scheduler/gradient_based.cc | 7 +++--- .../task_scheduler/round_robin.cc | 9 +++---- src/meta_schedule/tune_context.cc | 4 +--- src/node/serialization.cc | 12 +++++----- src/script/ir_builder/ir/frame.cc | 4 +--- src/script/printer/ir_docsifier.cc | 4 +--- 64 files changed, 154 insertions(+), 225 deletions(-) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index a1c098a3f61f..84e5fd94f861 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -25,9 +25,9 @@ #define TVM_ARITH_ANALYZER_H_ #include +#include #include #include -#include #include #include @@ -90,8 +90,8 @@ class ConstIntBoundNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("min_value", &ConstIntBoundNode::min_value) - .def_ro("max_value", &ConstIntBoundNode::max_value); + .def_ro("min_value", &ConstIntBoundNode::min_value) + .def_ro("max_value", &ConstIntBoundNode::max_value); } static constexpr bool _type_has_method_visit_attrs = false; diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 0b6b8e4ba77f..3e6226278580 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -49,10 +49,10 @@ #define TVM_ARITH_ITER_AFFINE_MAP_H_ #include +#include #include #include #include -#include namespace tvm { namespace arith { @@ -102,8 +102,8 @@ class IterMarkNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("source", &IterMarkNode::source) - .def_ro("extent", &IterMarkNode::extent); + .def_ro("source", &IterMarkNode::source) + .def_ro("extent", &IterMarkNode::extent); } static constexpr bool _type_has_method_visit_attrs = false; @@ -161,10 +161,10 @@ class IterSplitExprNode : public IterMapExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("source", &IterSplitExprNode::source) - .def_ro("lower_factor", &IterSplitExprNode::lower_factor) - .def_ro("extent", &IterSplitExprNode::extent) - .def_ro("scale", &IterSplitExprNode::scale); + .def_ro("source", &IterSplitExprNode::source) + .def_ro("lower_factor", &IterSplitExprNode::lower_factor) + .def_ro("extent", &IterSplitExprNode::extent) + .def_ro("scale", &IterSplitExprNode::scale); } static constexpr bool _type_has_method_visit_attrs = false; @@ -231,8 +231,8 @@ class IterSumExprNode : public IterMapExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("args", &IterSumExprNode::args) - .def_ro("base", &IterSumExprNode::base); + .def_ro("args", &IterSumExprNode::args) + .def_ro("base", &IterSumExprNode::base); } static constexpr bool _type_has_method_visit_attrs = false; @@ -302,9 +302,9 @@ class IterMapResultNode : public Object { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("indices", &IterMapResultNode::indices) - .def_ro("errors", &IterMapResultNode::errors) - .def_ro("padding_predicate", &IterMapResultNode::padding_predicate); + .def_ro("indices", &IterMapResultNode::indices) + .def_ro("errors", &IterMapResultNode::errors) + .def_ro("padding_predicate", &IterMapResultNode::padding_predicate); } static constexpr bool _type_has_method_visit_attrs = false; diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 3545a9d3a57d..bcdbea38e41f 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -24,12 +24,12 @@ #ifndef TVM_IR_EXPR_H_ #define TVM_IR_EXPR_H_ +#include #include #include #include #include #include -#include #include #include @@ -108,7 +108,6 @@ class PrimExprNode : public BaseExprNode { */ DataType dtype; - static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef().def_ro("dtype", &PrimExprNode::dtype); diff --git a/include/tvm/meta_schedule/arg_info.h b/include/tvm/meta_schedule/arg_info.h index 34d8278286fb..aa10bdf5e209 100644 --- a/include/tvm/meta_schedule/arg_info.h +++ b/include/tvm/meta_schedule/arg_info.h @@ -20,13 +20,13 @@ #define TVM_META_SCHEDULE_ARG_INFO_H_ #include +#include #include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/builder.h b/include/tvm/meta_schedule/builder.h index df132aa15033..a603aed158f2 100644 --- a/include/tvm/meta_schedule/builder.h +++ b/include/tvm/meta_schedule/builder.h @@ -23,13 +23,13 @@ #include #include #include +#include #include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { @@ -156,8 +156,7 @@ class PyBuilderNode : public BuilderNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("f_build", &PyBuilderNode::f_build); + refl::ObjectDef().def_ro("f_build", &PyBuilderNode::f_build); } static constexpr bool _type_has_method_visit_attrs = false; diff --git a/include/tvm/meta_schedule/cost_model.h b/include/tvm/meta_schedule/cost_model.h index ba7da0ce46d5..5e9a62274058 100644 --- a/include/tvm/meta_schedule/cost_model.h +++ b/include/tvm/meta_schedule/cost_model.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include #include @@ -126,8 +126,6 @@ class PyCostModelNode : public CostModelNode { /*! \brief The packed function to the `AsString` function. */ FAsString f_as_string; - - void Load(const String& path); void Save(const String& path); void Update(const TuneContext& context, const Array& candidates, diff --git a/include/tvm/meta_schedule/database.h b/include/tvm/meta_schedule/database.h index e56348322082..db8ef0a348ce 100644 --- a/include/tvm/meta_schedule/database.h +++ b/include/tvm/meta_schedule/database.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include #include @@ -383,8 +383,8 @@ class PyDatabaseNode : public DatabaseNode { FSize f_size; static void RegisterReflection() { - // ffi::Functions are all not registered, because the reflection system doesn't take care of them, - // so it cannot be accessible on the python side. If there is such need from the future, + // ffi::Functions are all not registered, because the reflection system doesn't take care of + // them, so it cannot be accessible on the python side. If there is such need from the future, // we can then add corresponding accessor methods to help access on python. // `f_has_workload` is not registered // `f_commit_workload` is not registered diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 0e78bdd4bf95..e3e1d8272327 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -20,12 +20,12 @@ #define TVM_META_SCHEDULE_EXTRACTED_TASK_H_ #include +#include #include #include #include #include #include -#include namespace tvm { namespace tir { diff --git a/include/tvm/meta_schedule/feature_extractor.h b/include/tvm/meta_schedule/feature_extractor.h index cdf510c8caf2..d04189700516 100644 --- a/include/tvm/meta_schedule/feature_extractor.h +++ b/include/tvm/meta_schedule/feature_extractor.h @@ -22,12 +22,12 @@ #include #include +#include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/measure_callback.h b/include/tvm/meta_schedule/measure_callback.h index c74ce9ec8fc1..862606640947 100644 --- a/include/tvm/meta_schedule/measure_callback.h +++ b/include/tvm/meta_schedule/measure_callback.h @@ -22,6 +22,7 @@ #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/measure_candidate.h b/include/tvm/meta_schedule/measure_candidate.h index a2dbb6f943dd..79feda757688 100644 --- a/include/tvm/meta_schedule/measure_candidate.h +++ b/include/tvm/meta_schedule/measure_candidate.h @@ -21,11 +21,11 @@ #define TVM_META_SCHEDULE_MEASURE_CANDIDATE_H_ #include +#include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/mutator.h b/include/tvm/meta_schedule/mutator.h index 0e2dc066a06e..7e00f9d72e3a 100644 --- a/include/tvm/meta_schedule/mutator.h +++ b/include/tvm/meta_schedule/mutator.h @@ -22,12 +22,12 @@ #include #include +#include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index fded08949d89..cfbf9c702e65 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -21,10 +21,10 @@ #define TVM_META_SCHEDULE_POSTPROC_H_ #include +#include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/profiler.h b/include/tvm/meta_schedule/profiler.h index 6ea64e51abf4..21b77109ecc1 100644 --- a/include/tvm/meta_schedule/profiler.h +++ b/include/tvm/meta_schedule/profiler.h @@ -22,12 +22,12 @@ #include #include #include +#include #include #include #include #include #include -#include #include #include diff --git a/include/tvm/meta_schedule/runner.h b/include/tvm/meta_schedule/runner.h index c1a4fb84669a..80d5816db031 100644 --- a/include/tvm/meta_schedule/runner.h +++ b/include/tvm/meta_schedule/runner.h @@ -22,12 +22,12 @@ #include #include #include +#include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 4f15d3d74dd8..e702e1d2cbfe 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -24,12 +24,12 @@ #include #include #include +#include #include #include #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 923abd18d24c..5133e9fd8a48 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index 5ce6c9473d88..68f26a6bfee9 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -29,7 +30,6 @@ #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/meta_schedule/tune_context.h b/include/tvm/meta_schedule/tune_context.h index 1742f9424523..65d48e000dcc 100644 --- a/include/tvm/meta_schedule/tune_context.h +++ b/include/tvm/meta_schedule/tune_context.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include #include #include -#include namespace tvm { namespace meta_schedule { diff --git a/include/tvm/script/ir_builder/base.h b/include/tvm/script/ir_builder/base.h index 0dd1d6d805b8..4fe5c519eedb 100644 --- a/include/tvm/script/ir_builder/base.h +++ b/include/tvm/script/ir_builder/base.h @@ -19,10 +19,10 @@ #ifndef TVM_SCRIPT_IR_BUILDER_BASE_H_ #define TVM_SCRIPT_IR_BUILDER_BASE_H_ +#include #include #include #include -#include #include diff --git a/include/tvm/script/ir_builder/ir/frame.h b/include/tvm/script/ir_builder/ir/frame.h index eca7908d1a5a..764ff0507b04 100644 --- a/include/tvm/script/ir_builder/ir/frame.h +++ b/include/tvm/script/ir_builder/ir/frame.h @@ -19,12 +19,12 @@ #ifndef TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ #define TVM_SCRIPT_IR_BUILDER_IR_FRAME_H_ +#include #include #include #include #include #include -#include #include diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 0d9f4031b153..a56e5305535e 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -19,12 +19,12 @@ #ifndef TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ #define TVM_SCRIPT_IR_BUILDER_RELAX_FRAME_H_ +#include #include #include #include #include #include -#include namespace tvm { namespace script { diff --git a/include/tvm/script/ir_builder/tir/frame.h b/include/tvm/script/ir_builder/tir/frame.h index a931ae039f07..c3d3d46e7f98 100644 --- a/include/tvm/script/ir_builder/tir/frame.h +++ b/include/tvm/script/ir_builder/tir/frame.h @@ -626,8 +626,7 @@ class WhileFrameNode : public TIRFrameNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("condition", &WhileFrameNode::condition); + refl::ObjectDef().def_ro("condition", &WhileFrameNode::condition); } static constexpr bool _type_has_method_visit_attrs = false; diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index fb02c9147e96..6e283575c67c 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -139,9 +139,7 @@ class BinaryOpNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &T::a) - .def_ro("b", &T::b); + refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } bool SEqualReduce(const T* other, SEqualReducer equal) const { @@ -333,9 +331,7 @@ class CmpOpNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &T::a) - .def_ro("b", &T::b); + refl::ObjectDef().def_ro("a", &T::a).def_ro("b", &T::b); } bool SEqualReduce(const T* other, SEqualReducer equal) const { @@ -465,9 +461,7 @@ class AndNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &AndNode::a) - .def_ro("b", &AndNode::b); + refl::ObjectDef().def_ro("a", &AndNode::a).def_ro("b", &AndNode::b); } bool SEqualReduce(const AndNode* other, SEqualReducer equal) const { @@ -507,9 +501,7 @@ class OrNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("a", &OrNode::a) - .def_ro("b", &OrNode::b); + refl::ObjectDef().def_ro("a", &OrNode::a).def_ro("b", &OrNode::b); } bool SEqualReduce(const OrNode* other, SEqualReducer equal) const { @@ -907,9 +899,7 @@ class CallNode : public PrimExprNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("op", &CallNode::op) - .def_ro("args", &CallNode::args); + refl::ObjectDef().def_ro("op", &CallNode::op).def_ro("args", &CallNode::args); } bool SEqualReduce(const CallNode* other, SEqualReducer equal) const { diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 386fc32e8d07..f5effa2ba522 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -46,9 +46,7 @@ PrimExpr::PrimExpr(int32_t value) : PrimExpr(IntImm(DataType::Int(32), value)) { PrimExpr::PrimExpr(float value) : PrimExpr(FloatImm(DataType::Float(32), value)) {} -PrimExpr PrimExpr::ConvertFallbackValue(String value) { - return tir::StringImm(value); -} +PrimExpr PrimExpr::ConvertFallbackValue(String value) { return tir::StringImm(value); } IntImm::IntImm(DataType dtype, int64_t value, Span span) { ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " << dtype diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 58c4a7b33c4f..eb1e52a17d2c 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -158,6 +158,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); /******** FFI ********/ +TVM_FFI_STATIC_INIT_BLOCK({ TensorInfoNode::RegisterReflection(); }); TVM_REGISTER_OBJECT_TYPE(ArgInfoNode); TVM_REGISTER_NODE_TYPE(TensorInfoNode); diff --git a/src/meta_schedule/database/json_database.cc b/src/meta_schedule/database/json_database.cc index c25853eb7004..afaf641126ca 100644 --- a/src/meta_schedule/database/json_database.cc +++ b/src/meta_schedule/database/json_database.cc @@ -16,13 +16,14 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include #include #include "../module_equality.h" #include "../utils.h" -#include namespace tvm { namespace meta_schedule { @@ -215,9 +216,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, return Database(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - JSONDatabaseNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ JSONDatabaseNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(JSONDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseJSONDatabase") .set_body_typed(Database::JSONDatabase); diff --git a/src/meta_schedule/database/memory_database.cc b/src/meta_schedule/database/memory_database.cc index 8a7a2da09b17..37cc9aa15a86 100644 --- a/src/meta_schedule/database/memory_database.cc +++ b/src/meta_schedule/database/memory_database.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../module_equality.h" #include "../utils.h" -#include namespace tvm { namespace meta_schedule { @@ -33,8 +34,8 @@ class MemoryDatabaseNode : public DatabaseNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("records", &MemoryDatabaseNode::records) - .def_ro("workloads", &MemoryDatabaseNode::workloads); + .def_ro("records", &MemoryDatabaseNode::records) + .def_ro("workloads", &MemoryDatabaseNode::workloads); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.MemoryDatabase"; @@ -103,9 +104,7 @@ TVM_REGISTER_NODE_TYPE(MemoryDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseMemoryDatabase") .set_body_typed(Database::MemoryDatabase); -TVM_FFI_STATIC_INIT_BLOCK({ - MemoryDatabaseNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MemoryDatabaseNode::RegisterReflection(); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/ordered_union_database.cc b/src/meta_schedule/database/ordered_union_database.cc index 99aecf5d8632..6e07c4763deb 100644 --- a/src/meta_schedule/database/ordered_union_database.cc +++ b/src/meta_schedule/database/ordered_union_database.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -28,8 +29,8 @@ class OrderedUnionDatabaseNode : public DatabaseNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("databases", &OrderedUnionDatabaseNode::databases); + refl::ObjectDef().def_ro("databases", + &OrderedUnionDatabaseNode::databases); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.OrderedUnionDatabase"; @@ -87,9 +88,7 @@ TVM_REGISTER_NODE_TYPE(OrderedUnionDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseOrderedUnionDatabase") .set_body_typed(Database::OrderedUnionDatabase); -TVM_FFI_STATIC_INIT_BLOCK({ - OrderedUnionDatabaseNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ OrderedUnionDatabaseNode::RegisterReflection(); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/schedule_fn_database.cc b/src/meta_schedule/database/schedule_fn_database.cc index 2235d7768209..a037cd861c62 100644 --- a/src/meta_schedule/database/schedule_fn_database.cc +++ b/src/meta_schedule/database/schedule_fn_database.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -30,8 +31,8 @@ class ScheduleFnDatabaseNode : public DatabaseNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("schedule_fn", &ScheduleFnDatabaseNode::schedule_fn); + refl::ObjectDef().def_ro("schedule_fn", + &ScheduleFnDatabaseNode::schedule_fn); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ScheduleFnDatabase"; @@ -105,9 +106,7 @@ TVM_REGISTER_NODE_TYPE(ScheduleFnDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseScheduleFnDatabase") .set_body_typed(Database::ScheduleFnDatabase); -TVM_FFI_STATIC_INIT_BLOCK({ - ScheduleFnDatabaseNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnDatabaseNode::RegisterReflection(); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/database/union_database.cc b/src/meta_schedule/database/union_database.cc index 4b843aab2b8e..bab6a4ca6f3a 100644 --- a/src/meta_schedule/database/union_database.cc +++ b/src/meta_schedule/database/union_database.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -28,8 +29,7 @@ class UnionDatabaseNode : public DatabaseNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("databases", &UnionDatabaseNode::databases); + refl::ObjectDef().def_ro("databases", &UnionDatabaseNode::databases); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.UnionDatabase"; @@ -90,9 +90,7 @@ TVM_REGISTER_NODE_TYPE(UnionDatabaseNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.DatabaseUnionDatabase") .set_body_typed(Database::UnionDatabase); -TVM_FFI_STATIC_INIT_BLOCK({ - UnionDatabaseNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ UnionDatabaseNode::RegisterReflection(); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index acfb29b8de30..da8a61eb8603 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -38,9 +38,7 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } -TVM_FFI_STATIC_INIT_BLOCK({ - ExtractedTaskNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ExtractedTaskNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ExtractedTask") diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc b/src/meta_schedule/feature_extractor/per_store_feature.cc index 9fdb9c9adc16..1a18d1c70039 100644 --- a/src/meta_schedule/feature_extractor/per_store_feature.cc +++ b/src/meta_schedule/feature_extractor/per_store_feature.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include +#include #include #include @@ -1372,7 +1372,8 @@ class PerStoreFeatureNode : public FeatureExtractorNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("buffers_per_store", &PerStoreFeatureNode::buffers_per_store) - .def_ro("arith_intensity_curve_num_samples", &PerStoreFeatureNode::arith_intensity_curve_num_samples) + .def_ro("arith_intensity_curve_num_samples", + &PerStoreFeatureNode::arith_intensity_curve_num_samples) .def_ro("cache_line_bytes", &PerStoreFeatureNode::cache_line_bytes) .def_ro("extract_workload", &PerStoreFeatureNode::extract_workload) .def_ro("feature_vector_length", &PerStoreFeatureNode::feature_vector_length); @@ -1446,9 +1447,7 @@ FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store, return FeatureExtractor(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - PerStoreFeatureNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ PerStoreFeatureNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(PerStoreFeatureNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.FeatureExtractorPerStoreFeature") diff --git a/src/meta_schedule/mutator/mutate_compute_location.cc b/src/meta_schedule/mutator/mutate_compute_location.cc index 959b5a52850a..276fd89605d2 100644 --- a/src/meta_schedule/mutator/mutate_compute_location.cc +++ b/src/meta_schedule/mutator/mutate_compute_location.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -133,9 +134,7 @@ Mutator Mutator::MutateComputeLocation() { return Mutator(make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ - MutateComputeLocationNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MutateComputeLocationNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateComputeLocationNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateComputeLocation") diff --git a/src/meta_schedule/mutator/mutate_parallel.cc b/src/meta_schedule/mutator/mutate_parallel.cc index 8d7e8884661c..afad83634f8c 100644 --- a/src/meta_schedule/mutator/mutate_parallel.cc +++ b/src/meta_schedule/mutator/mutate_parallel.cc @@ -16,11 +16,12 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include #include "../utils.h" -#include namespace tvm { namespace tir { @@ -172,8 +173,8 @@ class MutateParallelNode : public MutatorNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("max_jobs_per_core", &MutateParallelNode::max_jobs_per_core); + refl::ObjectDef().def_ro("max_jobs_per_core", + &MutateParallelNode::max_jobs_per_core); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.MutateParallel"; @@ -312,9 +313,7 @@ Mutator Mutator::MutateParallel(int64_t max_jobs_per_core) { return Mutator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - MutateParallelNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MutateParallelNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateParallelNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateParallel") .set_body_typed(Mutator::MutateParallel); diff --git a/src/meta_schedule/mutator/mutate_thread_binding.cc b/src/meta_schedule/mutator/mutate_thread_binding.cc index f62658ff79fd..ade84fd3363a 100644 --- a/src/meta_schedule/mutator/mutate_thread_binding.cc +++ b/src/meta_schedule/mutator/mutate_thread_binding.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -171,9 +172,7 @@ Optional MutateThreadBindingNode::Apply(const Trace& trace, TRandState* r Mutator Mutator::MutateThreadBinding() { return Mutator(make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ - MutateThreadBindingNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MutateThreadBindingNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateThreadBindingNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutateThreadBinding") diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 50cfe89a24df..728a081e28cc 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -16,11 +16,12 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include #include #include "../utils.h" -#include namespace tvm { namespace meta_schedule { @@ -275,9 +276,7 @@ Optional MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s Mutator Mutator::MutateTileSize() { return Mutator(make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ - MutateTileSizeNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MutateTileSizeNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateTileSize") diff --git a/src/meta_schedule/mutator/mutate_unroll.cc b/src/meta_schedule/mutator/mutate_unroll.cc index 812af8c447f8..1948493a0e42 100644 --- a/src/meta_schedule/mutator/mutate_unroll.cc +++ b/src/meta_schedule/mutator/mutate_unroll.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace tir { @@ -144,9 +145,7 @@ Optional MutateUnrollNode::Apply(const Trace& trace, TRandState* rand_sta Mutator Mutator::MutateUnroll() { return Mutator(make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ - MutateUnrollNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MutateUnrollNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(MutateUnrollNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.MutatorMutateUnroll").set_body_typed(Mutator::MutateUnroll); diff --git a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc index e06b983a6a5f..46a478a1aa5f 100644 --- a/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace tir { @@ -231,9 +232,7 @@ Postproc Postproc::RewriteCooperativeFetch() { return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - RewriteCooperativeFetchNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RewriteCooperativeFetchNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteCooperativeFetchNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteCooperativeFetch") diff --git a/src/meta_schedule/postproc/rewrite_reduction_block.cc b/src/meta_schedule/postproc/rewrite_reduction_block.cc index 42570a595f80..f182d710f807 100644 --- a/src/meta_schedule/postproc/rewrite_reduction_block.cc +++ b/src/meta_schedule/postproc/rewrite_reduction_block.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace tir { @@ -180,9 +181,7 @@ TVM_REGISTER_NODE_TYPE(RewriteReductionBlockNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteReductionBlock") .set_body_typed(Postproc::RewriteReductionBlock); -TVM_FFI_STATIC_INIT_BLOCK({ - RewriteReductionBlockNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RewriteReductionBlockNode::RegisterReflection(); }); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_tensorize.cc b/src/meta_schedule/postproc/rewrite_tensorize.cc index 69f5fa2ff8dd..cc1aa1e02cb7 100644 --- a/src/meta_schedule/postproc/rewrite_tensorize.cc +++ b/src/meta_schedule/postproc/rewrite_tensorize.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include +#include #include @@ -111,9 +111,7 @@ Postproc Postproc::RewriteTensorize(bool vectorize_init_loop) { return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - RewriteTensorizeNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RewriteTensorizeNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteTensorizeNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteTensorize") .set_body_typed(Postproc::RewriteTensorize); diff --git a/src/meta_schedule/postproc/rewrite_unbound_block.cc b/src/meta_schedule/postproc/rewrite_unbound_block.cc index cdc9ff12db9a..846a74833d14 100644 --- a/src/meta_schedule/postproc/rewrite_unbound_block.cc +++ b/src/meta_schedule/postproc/rewrite_unbound_block.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include +#include #include "../utils.h" @@ -86,7 +86,6 @@ namespace meta_schedule { /*! \brief Add thread binding to unbound blocks */ class RewriteUnboundBlockNode : public PostprocNode { public: - // Inherited from PostprocNode void InitializeWithTuneContext(const TuneContext& context) final { CHECK(context->target.defined()) << "ValueError: target is not defined"; @@ -148,9 +147,7 @@ Postproc Postproc::RewriteUnboundBlock(int max_threadblocks) { return Postproc(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - RewriteUnboundBlockNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RewriteUnboundBlockNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RewriteUnboundBlockNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.PostprocRewriteUnboundBlock") diff --git a/src/meta_schedule/profiler.cc b/src/meta_schedule/profiler.cc index ca01e1003f76..d92991fcbc34 100644 --- a/src/meta_schedule/profiler.cc +++ b/src/meta_schedule/profiler.cc @@ -120,9 +120,7 @@ Optional Profiler::Current() { } } -TVM_FFI_STATIC_INIT_BLOCK({ - ProfilerNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ProfilerNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ProfilerNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.Profiler").set_body_typed([]() -> Profiler { diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 2fb2a9c90d71..78f8289d280c 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -121,9 +122,7 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: return res; } -TVM_FFI_STATIC_INIT_BLOCK({ - AddRFactorNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ AddRFactorNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AddRFactorNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAddRFactor") .set_body_typed(ScheduleRule::AddRFactor); diff --git a/src/meta_schedule/schedule_rule/apply_custom_rule.cc b/src/meta_schedule/schedule_rule/apply_custom_rule.cc index 7cc70dfe4733..28fa488b0ebc 100644 --- a/src/meta_schedule/schedule_rule/apply_custom_rule.cc +++ b/src/meta_schedule/schedule_rule/apply_custom_rule.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -74,8 +75,7 @@ class ApplyCustomRuleNode : public ScheduleRuleNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("target_", &ApplyCustomRuleNode::target_); + refl::ObjectDef().def_ro("target_", &ApplyCustomRuleNode::target_); } static constexpr bool _type_has_method_visit_attrs = false; static constexpr const char* _type_key = "meta_schedule.ApplyCustomRule"; @@ -91,9 +91,7 @@ bool ScheduleRule::IsApplyCustomRule(const ScheduleRule& rule) { return rule->IsInstance(); } -TVM_FFI_STATIC_INIT_BLOCK({ - ApplyCustomRuleNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ApplyCustomRuleNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ApplyCustomRuleNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleApplyCustomRule") .set_body_typed(ScheduleRule::ApplyCustomRule); diff --git a/src/meta_schedule/schedule_rule/auto_bind.cc b/src/meta_schedule/schedule_rule/auto_bind.cc index ddb92273da74..7e264da4a981 100644 --- a/src/meta_schedule/schedule_rule/auto_bind.cc +++ b/src/meta_schedule/schedule_rule/auto_bind.cc @@ -16,8 +16,8 @@ * specific language governing permissions and limitations * under the License. */ -#include #include +#include #include #include @@ -81,9 +81,7 @@ ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array thread_ return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - AutoBindNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ AutoBindNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AutoBindNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind") diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc b/src/meta_schedule/schedule_rule/auto_inline.cc index b30a82eb06b0..e0e9386ca344 100644 --- a/src/meta_schedule/schedule_rule/auto_inline.cc +++ b/src/meta_schedule/schedule_rule/auto_inline.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -193,9 +194,7 @@ ScheduleRule ScheduleRule::AutoInline(bool into_producer, // return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - AutoInlineNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ AutoInlineNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(AutoInlineNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline") .set_body_typed(ScheduleRule::AutoInline); @@ -243,9 +242,7 @@ ScheduleRule ScheduleRule::InlineConstantScalars() { return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - InlineConstantScalarsNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ InlineConstantScalarsNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars") .set_body_typed(ScheduleRule::InlineConstantScalars); diff --git a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc index f418cef8346c..571a1375a546 100644 --- a/src/meta_schedule/schedule_rule/cross_thread_reduction.cc +++ b/src/meta_schedule/schedule_rule/cross_thread_reduction.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -293,9 +294,7 @@ ScheduleRule ScheduleRule::CrossThreadReduction(Array thread_extents) { return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - CrossThreadReductionNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ CrossThreadReductionNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(CrossThreadReductionNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleCrossThreadReduction") .set_body_typed(ScheduleRule::CrossThreadReduction); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index b702d04f45bc..0d17477e2b94 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -54,9 +54,7 @@ using tir::IterVarType; using tir::LoopRV; using tir::Schedule; -TVM_FFI_STATIC_INIT_BLOCK({ - MultiLevelTilingNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ MultiLevelTilingNode::RegisterReflection(); }); TVM_REGISTER_OBJECT_TYPE(StateNode); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 025a24148a37..a85e88aad6f8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -19,9 +19,9 @@ #ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ #define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_H_ +#include #include #include -#include #include #include diff --git a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc index 44c217ea4969..d123e33e71f9 100644 --- a/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc +++ b/src/meta_schedule/schedule_rule/parallel_vectorize_unroll.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace tir { @@ -136,9 +137,7 @@ ScheduleRule ScheduleRule::ParallelizeVectorizeUnroll(int max_jobs_per_core, return ScheduleRule(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - ParallelizeVectorizeUnrollNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ParallelizeVectorizeUnrollNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ParallelizeVectorizeUnrollNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleParallelizeVectorizeUnroll") .set_body_typed(ScheduleRule::ParallelizeVectorizeUnroll); diff --git a/src/meta_schedule/schedule_rule/random_compute_location.cc b/src/meta_schedule/schedule_rule/random_compute_location.cc index ce6fce57e816..27ab8f4ad026 100644 --- a/src/meta_schedule/schedule_rule/random_compute_location.cc +++ b/src/meta_schedule/schedule_rule/random_compute_location.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -125,9 +126,7 @@ ScheduleRule ScheduleRule::RandomComputeLocation() { return ScheduleRule(make_object()); } -TVM_FFI_STATIC_INIT_BLOCK({ - RandomComputeLocationNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RandomComputeLocationNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RandomComputeLocationNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.ScheduleRuleRandomComputeLocation") diff --git a/src/meta_schedule/search_strategy/evolutionary_search.cc b/src/meta_schedule/search_strategy/evolutionary_search.cc index 5f0b9405431a..f942d28e52a1 100644 --- a/src/meta_schedule/search_strategy/evolutionary_search.cc +++ b/src/meta_schedule/search_strategy/evolutionary_search.cc @@ -17,9 +17,10 @@ * under the License. */ +#include + #include "../module_equality.h" #include "../utils.h" -#include #define TVM_META_SCHEDULE_CHECK_PROB_RANGE(p, name) \ CHECK(0.0 <= (p) && (p) <= 1.0) << "ValueError: name should be within [0, 1], " \ @@ -383,7 +384,8 @@ class EvolutionarySearchNode : public SearchStrategyNode { namespace refl = tvm::ffi::reflection; refl::ObjectDef() .def_ro("population_size", &EvolutionarySearchNode::population_size) - .def_ro("num_empty_iters_before_early_stop", &EvolutionarySearchNode::num_empty_iters_before_early_stop) + .def_ro("num_empty_iters_before_early_stop", + &EvolutionarySearchNode::num_empty_iters_before_early_stop) .def_ro("init_measured_ratio", &EvolutionarySearchNode::init_measured_ratio) .def_ro("init_min_unmeasured", &EvolutionarySearchNode::init_min_unmeasured) .def_ro("max_fail_count", &EvolutionarySearchNode::max_fail_count) @@ -798,9 +800,7 @@ Array EvolutionarySearchEvolveWithCostModel(EvolutionarySearch self, return result; } -TVM_FFI_STATIC_INIT_BLOCK({ - EvolutionarySearchNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ EvolutionarySearchNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch") diff --git a/src/meta_schedule/search_strategy/replay_func.cc b/src/meta_schedule/search_strategy/replay_func.cc index 5716816be7e1..e5dbc27f6411 100644 --- a/src/meta_schedule/search_strategy/replay_func.cc +++ b/src/meta_schedule/search_strategy/replay_func.cc @@ -60,7 +60,6 @@ class ReplayFuncNode : public SearchStrategyNode { /*! \brief The state of the search strategy. */ std::unique_ptr state_ = nullptr; - static void RegisterReflection() { // No fields to register } @@ -161,9 +160,7 @@ SearchStrategy SearchStrategy::ReplayFunc() { return SearchStrategy(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - ReplayFuncNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ReplayFuncNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ReplayFuncNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayFunc") diff --git a/src/meta_schedule/search_strategy/replay_trace.cc b/src/meta_schedule/search_strategy/replay_trace.cc index 1eaee10aec19..f6ad9e3770d0 100644 --- a/src/meta_schedule/search_strategy/replay_trace.cc +++ b/src/meta_schedule/search_strategy/replay_trace.cc @@ -77,8 +77,7 @@ class ReplayTraceNode : public SearchStrategyNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("max_fail_count", &ReplayTraceNode::max_fail_count); + refl::ObjectDef().def_ro("max_fail_count", &ReplayTraceNode::max_fail_count); } static constexpr const bool _type_has_method_visit_attrs = false; @@ -189,9 +188,7 @@ SearchStrategy SearchStrategy::ReplayTrace(int max_fail_count) { return SearchStrategy(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - ReplayTraceNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ReplayTraceNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ReplayTraceNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SearchStrategyReplayTrace") diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 0a30ff09ac50..d716e4f3e488 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -116,9 +117,7 @@ SpaceGenerator SpaceGenerator::PostOrderApply(ffi::Function f_block_filter, return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - PostOrderApplyNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ PostOrderApplyNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(PostOrderApplyNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorPostOrderApply") diff --git a/src/meta_schedule/space_generator/schedule_fn.cc b/src/meta_schedule/space_generator/schedule_fn.cc index f533cc815913..06695cbb3f2c 100644 --- a/src/meta_schedule/space_generator/schedule_fn.cc +++ b/src/meta_schedule/space_generator/schedule_fn.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -98,9 +99,7 @@ SpaceGenerator SpaceGenerator::ScheduleFn(ffi::Function schedule_fn, return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - ScheduleFnNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ ScheduleFnNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(ScheduleFnNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorScheduleFn") diff --git a/src/meta_schedule/space_generator/space_generator_union.cc b/src/meta_schedule/space_generator/space_generator_union.cc index 5355010d1cd4..24464ad31e31 100644 --- a/src/meta_schedule/space_generator/space_generator_union.cc +++ b/src/meta_schedule/space_generator/space_generator_union.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -30,8 +31,8 @@ class SpaceGeneratorUnionNode : public SpaceGeneratorNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("space_generators", &SpaceGeneratorUnionNode::space_generators); + refl::ObjectDef().def_ro("space_generators", + &SpaceGeneratorUnionNode::space_generators); } static constexpr const bool _type_has_method_visit_attrs = false; @@ -85,9 +86,7 @@ SpaceGenerator SpaceGenerator::SpaceGeneratorUnion(Array space_g return SpaceGenerator(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - SpaceGeneratorUnionNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ SpaceGeneratorUnionNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(SpaceGeneratorUnionNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.SpaceGeneratorSpaceGeneratorUnion") diff --git a/src/meta_schedule/task_scheduler/gradient_based.cc b/src/meta_schedule/task_scheduler/gradient_based.cc index 667166ec3845..207e8b6616fa 100644 --- a/src/meta_schedule/task_scheduler/gradient_based.cc +++ b/src/meta_schedule/task_scheduler/gradient_based.cc @@ -16,9 +16,10 @@ * specific language governing permissions and limitations * under the License. */ -#include "../utils.h" #include +#include "../utils.h" + namespace tvm { namespace meta_schedule { @@ -145,9 +146,7 @@ TaskScheduler TaskScheduler::GradientBased(ffi::Function logger, double alpha, i return TaskScheduler(n); } -TVM_FFI_STATIC_INIT_BLOCK({ - GradientBasedNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ GradientBasedNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(GradientBasedNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerGradientBased") diff --git a/src/meta_schedule/task_scheduler/round_robin.cc b/src/meta_schedule/task_scheduler/round_robin.cc index 4b48e8c8a582..35685bc7f229 100644 --- a/src/meta_schedule/task_scheduler/round_robin.cc +++ b/src/meta_schedule/task_scheduler/round_robin.cc @@ -17,6 +17,7 @@ * under the License. */ #include + #include "../utils.h" namespace tvm { @@ -30,8 +31,7 @@ class RoundRobinNode final : public TaskSchedulerNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; - refl::ObjectDef() - .def_ro("task_id", &RoundRobinNode::task_id); + refl::ObjectDef().def_ro("task_id", &RoundRobinNode::task_id); } static constexpr bool _type_has_method_visit_attrs = false; @@ -66,10 +66,7 @@ TaskScheduler TaskScheduler::RoundRobin(ffi::Function logger) { return TaskScheduler(n); } - -TVM_FFI_STATIC_INIT_BLOCK({ - RoundRobinNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ RoundRobinNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(RoundRobinNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TaskSchedulerRoundRobin") diff --git a/src/meta_schedule/tune_context.cc b/src/meta_schedule/tune_context.cc index 877e348f2148..179a7ac1d6ff 100644 --- a/src/meta_schedule/tune_context.cc +++ b/src/meta_schedule/tune_context.cc @@ -62,9 +62,7 @@ void TuneContextNode::Initialize() { } } -TVM_FFI_STATIC_INIT_BLOCK({ - TuneContextNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ TuneContextNode::RegisterReflection(); }); TVM_REGISTER_NODE_TYPE(TuneContextNode); TVM_FFI_REGISTER_GLOBAL("meta_schedule.TuneContext") diff --git a/src/node/serialization.cc b/src/node/serialization.cc index 026398b99a75..843ffcecbe29 100644 --- a/src/node/serialization.cc +++ b/src/node/serialization.cc @@ -78,9 +78,7 @@ class NodeIndexer : private AttrVisitor { void Visit(const char* key, void** value) final {} void Visit(const char* key, DataType* value) final {} - void Visit(const char* key, runtime::NDArray* value) final { - MakeIndex(Any(*value)); - } + void Visit(const char* key, runtime::NDArray* value) final { MakeIndex(Any(*value)); } void Visit(const char* key, Optional* value) final {} void Visit(const char* key, Optional* value) final {} @@ -565,7 +563,8 @@ class JSONAttrSetter : private AttrVisitor { } void Visit(const char* key, ObjectRef* value) final { Optional index; - ParseOptionalValue(key, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); + ParseOptionalValue(key, &index, + [this](const char* key, int64_t* value) { ParseValue(key, value); }); if (index.has_value()) { *value = node_list_->at(*index).cast(); } @@ -684,11 +683,12 @@ class JSONAttrSetter : private AttrVisitor { } default: { Optional index; - ParseOptionalValue(field_info->name.data, &index, [this](const char* key, int64_t* value) { ParseValue(key, value); }); + ParseOptionalValue(field_info->name.data, &index, + [this](const char* key, int64_t* value) { ParseValue(key, value); }); if (index.has_value()) { Any value = node_list_->at(*index).cast(); setter(obj, value); - } else{ + } else { setter(obj, Any()); } } diff --git a/src/script/ir_builder/ir/frame.cc b/src/script/ir_builder/ir/frame.cc index 7006aa25f36f..3b4abb955a35 100644 --- a/src/script/ir_builder/ir/frame.cc +++ b/src/script/ir_builder/ir/frame.cc @@ -25,9 +25,7 @@ namespace script { namespace ir_builder { namespace ir { -TVM_FFI_STATIC_INIT_BLOCK({ - IRModuleFrameNode::RegisterReflection(); -}); +TVM_FFI_STATIC_INIT_BLOCK({ IRModuleFrameNode::RegisterReflection(); }); void IRModuleFrameNode::ExitWithScope() { Map func_map; diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index b4bd64ada6f5..6b58d1e03a2a 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -108,9 +108,7 @@ void IRDocsifierNode::SetCommonPrefix(const ObjectRef& root, ffi::TypedFunction is_var) { class Visitor : private AttrVisitor { public: - void operator()(ObjectRef obj) { - this->Visit("", &obj); - } + void operator()(ObjectRef obj) { this->Visit("", &obj); } private: void RecursiveVisitAny(ffi::Any* value) {