From 949b11c90892bdcd133243740be6728057786bf7 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 30 Jun 2025 08:46:21 -0400 Subject: [PATCH] [REFACTOR] Formalize namespace for all objects This PR formalizes the namespace for all object registered so we do not have object that sits on root namespace Also fixes the Visitor style in TensorMapNode --- docs/arch/runtime.rst | 2 +- ffi/include/tvm/ffi/container/array.h | 2 +- ffi/include/tvm/ffi/container/map.h | 2 +- ffi/include/tvm/ffi/error.h | 2 +- ffi/include/tvm/ffi/function.h | 2 +- ffi/include/tvm/ffi/object.h | 14 ++++-- include/tvm/arith/int_set.h | 2 +- include/tvm/ir/attrs.h | 6 +-- include/tvm/ir/env_func.h | 2 +- include/tvm/ir/expr.h | 16 +++---- include/tvm/ir/function.h | 2 +- include/tvm/ir/global_info.h | 6 +-- include/tvm/ir/global_var_supply.h | 2 +- include/tvm/ir/module.h | 2 +- include/tvm/ir/name_supply.h | 2 +- include/tvm/ir/op.h | 2 +- include/tvm/ir/source_map.h | 10 ++-- include/tvm/ir/type.h | 19 +++++--- include/tvm/node/object_path.h | 16 +++---- include/tvm/node/script_printer.h | 2 +- include/tvm/node/structural_equal.h | 2 +- include/tvm/relax/expr.h | 2 +- include/tvm/runtime/profiling.h | 4 +- include/tvm/target/tag.h | 2 +- include/tvm/target/target.h | 2 +- include/tvm/target/target_info.h | 2 +- include/tvm/target/target_kind.h | 2 +- include/tvm/target/virtual_device.h | 2 +- include/tvm/te/operation.h | 12 ++--- python/tvm/arith/int_set.py | 1 + python/tvm/ffi/container.py | 4 +- python/tvm/ffi/ndarray.py | 2 +- python/tvm/ir/attrs.py | 6 +-- python/tvm/ir/base.py | 10 ++-- python/tvm/ir/expr.py | 46 ++----------------- python/tvm/ir/function.py | 1 + python/tvm/ir/global_info.py | 3 ++ python/tvm/ir/module.py | 4 +- python/tvm/ir/op.py | 2 +- python/tvm/ir/supply.py | 4 +- python/tvm/ir/type.py | 11 +++-- python/tvm/relax/dpl/pattern.py | 2 +- python/tvm/relax/expr.py | 1 + python/tvm/runtime/ndarray.py | 2 +- python/tvm/runtime/object_path.py | 18 ++++---- python/tvm/runtime/script_printer.py | 2 +- python/tvm/script/ir_builder/relax/ir.py | 14 +++--- python/tvm/target/target.py | 4 +- python/tvm/target/virtual_device.py | 2 +- python/tvm/te/tensor.py | 11 +++-- python/tvm/tir/expr.py | 4 +- python/tvm/tir/function.py | 2 +- src/ir/type.cc | 1 + src/runtime/cuda/cuda_device_api.cc | 2 +- src/runtime/hexagon/hexagon_common.cc | 2 +- src/runtime/metal/metal_device_api.mm | 2 +- src/runtime/opencl/opencl_common.h | 2 +- src/runtime/profiling.cc | 4 +- src/runtime/rocm/rocm_device_api.cc | 2 +- tests/python/ir/test_ir_attrs.py | 8 ++-- tests/python/ir/test_ir_container.py | 4 +- tests/python/ir/test_node_reflection.py | 4 +- tests/python/te/test_te_create_primfunc.py | 2 +- ...nsform_lower_device_storage_access_info.py | 4 +- .../test_tir_transform_storage_rewrite.py | 6 ++- .../test_tvmscript_ir_builder_tir.py | 2 +- web/src/runtime.ts | 2 +- 67 files changed, 167 insertions(+), 179 deletions(-) diff --git a/docs/arch/runtime.rst b/docs/arch/runtime.rst index 004429148bb4..613c7d86e19e 100644 --- a/docs/arch/runtime.rst +++ b/docs/arch/runtime.rst @@ -233,7 +233,7 @@ Each ``Object`` subclass will override this to register its members. Here is an hash_reduce(value); } - static constexpr const char* _type_key = "IntImm"; + static constexpr const char* _type_key = "ir.IntImm"; TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; // in cc file diff --git a/ffi/include/tvm/ffi/container/array.h b/ffi/include/tvm/ffi/container/array.h index 97fe5916822d..7483fd79b7b7 100644 --- a/ffi/include/tvm/ffi/container/array.h +++ b/ffi/include/tvm/ffi/container/array.h @@ -156,7 +156,7 @@ class ArrayObj : public Object, public details::InplaceArrayBaseVisit("description", &description); } - static constexpr const char* _type_key = "AttrFieldInfo"; + static constexpr const char* _type_key = "ir.AttrFieldInfo"; static constexpr bool _type_has_method_sequal_reduce = false; static constexpr bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(AttrFieldInfoNode, Object); @@ -164,7 +164,7 @@ class BaseAttrsNode : public Object { static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; - static constexpr const char* _type_key = "Attrs"; + static constexpr const char* _type_key = "ir.Attrs"; TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object); }; @@ -201,7 +201,7 @@ class DictAttrsNode : public BaseAttrsNode { Array ListFieldInfo() const final; // type info - static constexpr const char* _type_key = "DictAttrs"; + static constexpr const char* _type_key = "ir.DictAttrs"; TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode); }; diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 062a5212de2c..aac4595e0628 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -66,7 +66,7 @@ class EnvFuncNode : public Object { hash_reduce(name); } - static constexpr const char* _type_key = "EnvFunc"; + static constexpr const char* _type_key = "ir.EnvFunc"; static constexpr bool _type_has_method_sequal_reduce = true; static constexpr bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(EnvFuncNode, Object); diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index bcdbea38e41f..07a198cfc33b 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -61,7 +61,7 @@ class BaseExprNode : public Object { refl::ObjectDef().def_ro("span", &BaseExprNode::span, refl::DefaultValue(Span())); } - static constexpr const char* _type_key = "BaseExpr"; + static constexpr const char* _type_key = "ir.BaseExpr"; static constexpr const bool _type_has_method_visit_attrs = true; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; @@ -117,7 +117,7 @@ class PrimExprNode : public BaseExprNode { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const char* _type_key = "PrimExpr"; + static constexpr const char* _type_key = "ir.PrimExpr"; static constexpr const uint32_t _type_child_slots = 40; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode); }; @@ -161,7 +161,7 @@ class PrimExprConvertibleNode : public Object { virtual ~PrimExprConvertibleNode() {} virtual PrimExpr ToPrimExpr() const = 0; - static constexpr const char* _type_key = "PrimExprConvertible"; + static constexpr const char* _type_key = "ir.PrimExprConvertible"; TVM_DECLARE_BASE_OBJECT_INFO(PrimExprConvertibleNode, Object); }; @@ -433,7 +433,7 @@ class RelaxExprNode : public BaseExprNode { refl::ObjectDef().def_ro("struct_info_", &RelaxExprNode::struct_info_); } - static constexpr const char* _type_key = "RelaxExpr"; + static constexpr const char* _type_key = "ir.RelaxExpr"; static constexpr const uint32_t _type_child_slots = 22; TVM_DECLARE_BASE_OBJECT_INFO(RelaxExprNode, BaseExprNode); }; @@ -478,7 +478,7 @@ class GlobalVarNode : public RelaxExprNode { hash_reduce.FreeVarHashImpl(this); } - static constexpr const char* _type_key = "GlobalVar"; + static constexpr const char* _type_key = "ir.GlobalVar"; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarNode, RelaxExprNode); }; @@ -517,7 +517,7 @@ class IntImmNode : public PrimExprNode { hash_reduce(value); } - static constexpr const char* _type_key = "IntImm"; + static constexpr const char* _type_key = "ir.IntImm"; TVM_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode); }; @@ -565,7 +565,7 @@ class FloatImmNode : public PrimExprNode { hash_reduce(value); } - static constexpr const char* _type_key = "FloatImm"; + static constexpr const char* _type_key = "ir.FloatImm"; TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode); }; @@ -718,7 +718,7 @@ class RangeNode : public Object { hash_reduce(extent); } - static constexpr const char* _type_key = "Range"; + static constexpr const char* _type_key = "ir.Range"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(RangeNode, Object); diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index 3bb43a1594de..53f19ed3f17c 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -222,7 +222,7 @@ class BaseFuncNode : public RelaxExprNode { refl::ObjectDef().def_ro("attrs", &BaseFuncNode::attrs); } - static constexpr const char* _type_key = "BaseFunc"; + static constexpr const char* _type_key = "ir.BaseFunc"; static constexpr const uint32_t _type_child_slots = 2; TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelaxExprNode); }; diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h index 4fbfeefa2399..4583b858c7a9 100644 --- a/include/tvm/ir/global_info.h +++ b/include/tvm/ir/global_info.h @@ -42,7 +42,7 @@ using MemoryScope = String; */ class GlobalInfoNode : public Object { public: - static constexpr const char* _type_key = "GlobalInfo"; + static constexpr const char* _type_key = "ir.GlobalInfo"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_BASE_OBJECT_INFO(GlobalInfoNode, Object); @@ -90,7 +90,7 @@ class VDeviceNode : public GlobalInfoNode { hash_reduce(vdevice_id); hash_reduce(memory_scope); } - static constexpr const char* _type_key = "VDevice"; + static constexpr const char* _type_key = "ir.VDevice"; TVM_DECLARE_FINAL_OBJECT_INFO(VDeviceNode, GlobalInfoNode); }; @@ -116,7 +116,7 @@ class DummyGlobalInfoNode : public GlobalInfoNode { static constexpr bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "DummyGlobalInfo"; + static constexpr const char* _type_key = "ir.DummyGlobalInfo"; TVM_DLL bool SEqualReduce(const DummyGlobalInfoNode* other, SEqualReducer equal) const { return true; diff --git a/include/tvm/ir/global_var_supply.h b/include/tvm/ir/global_var_supply.h index 827b643a9b64..29be4482c82f 100644 --- a/include/tvm/ir/global_var_supply.h +++ b/include/tvm/ir/global_var_supply.h @@ -86,7 +86,7 @@ class GlobalVarSupplyNode : public Object { /*! \brief The NameSupply used to generate unique name hints to GlobalVars. */ NameSupply name_supply_; - static constexpr const char* _type_key = "GlobalVarSupply"; + static constexpr const char* _type_key = "ir.GlobalVarSupply"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(GlobalVarSupplyNode, Object); diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 41c8cffbc21f..fa4086327e69 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -238,7 +238,7 @@ class IRModuleNode : public Object { TVM_OBJECT_ENABLE_SCRIPT_PRINTER(); - static constexpr const char* _type_key = "IRModule"; + static constexpr const char* _type_key = "ir.IRModule"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; TVM_DECLARE_FINAL_OBJECT_INFO(IRModuleNode, Object); diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 2fbf42fd9c1a..ad95c3171ed5 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -86,7 +86,7 @@ class NameSupplyNode : public Object { // Prefix for all GlobalVar names. It can be empty. std::string prefix_; - static constexpr const char* _type_key = "NameSupply"; + static constexpr const char* _type_key = "ir.NameSupply"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; TVM_DECLARE_FINAL_OBJECT_INFO(NameSupplyNode, Object); diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 3e864d2d4bc2..eaf639a5a478 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -115,7 +115,7 @@ class OpNode : public RelaxExprNode { hash_reduce(name); } - static constexpr const char* _type_key = "Op"; + static constexpr const char* _type_key = "ir.Op"; TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelaxExprNode); private: diff --git a/include/tvm/ir/source_map.h b/include/tvm/ir/source_map.h index 27ef33a035ae..f888b6762336 100644 --- a/include/tvm/ir/source_map.h +++ b/include/tvm/ir/source_map.h @@ -61,7 +61,7 @@ class SourceNameNode : public Object { return equal(name, other->name); } - static constexpr const char* _type_key = "SourceName"; + static constexpr const char* _type_key = "ir.SourceName"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNameNode, Object); }; @@ -122,7 +122,7 @@ class SpanNode : public Object { equal(end_column, other->end_column); } - static constexpr const char* _type_key = "Span"; + static constexpr const char* _type_key = "ir.Span"; TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object); }; @@ -151,7 +151,7 @@ class SequentialSpanNode : public SpanNode { static constexpr bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "SequentialSpan"; + static constexpr const char* _type_key = "ir.SequentialSpan"; TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode); bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const { @@ -208,7 +208,7 @@ class SourceNode : public Object { static constexpr bool _type_has_method_visit_attrs = false; - static constexpr const char* _type_key = "Source"; + static constexpr const char* _type_key = "ir.Source"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceNode, Object); }; @@ -243,7 +243,7 @@ class SourceMapObj : public Object { return equal(source_map, other->source_map); } - static constexpr const char* _type_key = "SourceMap"; + static constexpr const char* _type_key = "ir.SourceMap"; TVM_DECLARE_FINAL_OBJECT_INFO(SourceMapObj, Object); }; diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index d864766d7ff5..a07cdd5b9ebe 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -80,7 +80,7 @@ class TypeNode : public Object { */ mutable Span span; - static constexpr const char* _type_key = "Type"; + static constexpr const char* _type_key = "ir.Type"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 14; @@ -124,7 +124,7 @@ class PrimTypeNode : public TypeNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); } - static constexpr const char* _type_key = "PrimType"; + static constexpr const char* _type_key = "ir.PrimType"; TVM_DECLARE_FINAL_OBJECT_INFO(PrimTypeNode, TypeNode); }; @@ -187,7 +187,7 @@ class PointerTypeNode : public TypeNode { hash_reduce(storage_scope.empty() ? "global" : storage_scope); } - static constexpr const char* _type_key = "PointerType"; + static constexpr const char* _type_key = "ir.PointerType"; TVM_DECLARE_FINAL_OBJECT_INFO(PointerTypeNode, TypeNode); }; @@ -233,7 +233,7 @@ class TupleTypeNode : public TypeNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); } - static constexpr const char* _type_key = "TupleType"; + static constexpr const char* _type_key = "ir.TupleType"; TVM_DECLARE_FINAL_OBJECT_INFO(TupleTypeNode, TypeNode); }; @@ -308,7 +308,7 @@ class FuncTypeNode : public TypeNode { hash_reduce(ret_type); } - static constexpr const char* _type_key = "FuncType"; + static constexpr const char* _type_key = "ir.FuncType"; TVM_DECLARE_FINAL_OBJECT_INFO(FuncTypeNode, TypeNode); }; @@ -336,7 +336,12 @@ class FuncType : public Type { */ class TensorMapTypeNode : public TypeNode { public: - void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); } + static void RegisterReflection() { + namespace refl = tvm::ffi::reflection; + refl::ObjectDef().def_ro("span", &TensorMapTypeNode::span); + } + + static constexpr bool _type_has_method_visit_attrs = false; bool SEqualReduce(const TensorMapTypeNode* other, SEqualReducer equal) const { return equal(span, other->span); @@ -344,7 +349,7 @@ class TensorMapTypeNode : public TypeNode { void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(span); } - static constexpr const char* _type_key = "TensorMapType"; + static constexpr const char* _type_key = "ir.TensorMapType"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorMapTypeNode, TypeNode); }; diff --git a/include/tvm/node/object_path.h b/include/tvm/node/object_path.h index 9c17487a1d64..0445c3d3baa2 100644 --- a/include/tvm/node/object_path.h +++ b/include/tvm/node/object_path.h @@ -98,7 +98,7 @@ class ObjectPathNode : public Object { /*! \brief Extend this path with access to a missing map entry. */ ObjectPath MissingMapEntry() const; - static constexpr const char* _type_key = "ObjectPath"; + static constexpr const char* _type_key = "node.ObjectPath"; TVM_DECLARE_BASE_OBJECT_INFO(ObjectPathNode, Object); protected: @@ -139,7 +139,7 @@ class RootPathNode final : public ObjectPathNode { explicit RootPathNode(Optional name = std::nullopt); - static constexpr const char* _type_key = "RootPath"; + static constexpr const char* _type_key = "node.RootPath"; TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode); protected: @@ -161,7 +161,7 @@ class AttributeAccessPathNode final : public ObjectPathNode { explicit AttributeAccessPathNode(const ObjectPathNode* parent, String attr_key); - static constexpr const char* _type_key = "AttributeAccessPath"; + static constexpr const char* _type_key = "node.AttributeAccessPath"; TVM_DECLARE_FINAL_OBJECT_INFO(AttributeAccessPathNode, ObjectPathNode); protected: @@ -181,7 +181,7 @@ class UnknownAttributeAccessPathNode final : public ObjectPathNode { public: explicit UnknownAttributeAccessPathNode(const ObjectPathNode* parent); - static constexpr const char* _type_key = "UnknownAttributeAccessPath"; + static constexpr const char* _type_key = "node.UnknownAttributeAccessPath"; TVM_DECLARE_FINAL_OBJECT_INFO(UnknownAttributeAccessPathNode, ObjectPathNode); protected: @@ -204,7 +204,7 @@ class ArrayIndexPathNode : public ObjectPathNode { explicit ArrayIndexPathNode(const ObjectPathNode* parent, int32_t index); - static constexpr const char* _type_key = "ArrayIndexPath"; + static constexpr const char* _type_key = "node.ArrayIndexPath"; TVM_DECLARE_FINAL_OBJECT_INFO(ArrayIndexPathNode, ObjectPathNode); protected: @@ -226,7 +226,7 @@ class MissingArrayElementPathNode : public ObjectPathNode { explicit MissingArrayElementPathNode(const ObjectPathNode* parent, int32_t index); - static constexpr const char* _type_key = "MissingArrayElementPath"; + static constexpr const char* _type_key = "node.MissingArrayElementPath"; TVM_DECLARE_FINAL_OBJECT_INFO(MissingArrayElementPathNode, ObjectPathNode); protected: @@ -249,7 +249,7 @@ class MapValuePathNode : public ObjectPathNode { explicit MapValuePathNode(const ObjectPathNode* parent, ffi::Any key); - static constexpr const char* _type_key = "MapValuePath"; + static constexpr const char* _type_key = "node.MapValuePath"; TVM_DECLARE_FINAL_OBJECT_INFO(MapValuePathNode, ObjectPathNode); protected: @@ -268,7 +268,7 @@ class MissingMapEntryPathNode : public ObjectPathNode { public: explicit MissingMapEntryPathNode(const ObjectPathNode* parent); - static constexpr const char* _type_key = "MissingMapEntryPath"; + static constexpr const char* _type_key = "node.MissingMapEntryPath"; TVM_DECLARE_FINAL_OBJECT_INFO(MissingMapEntryPathNode, ObjectPathNode); protected: diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index a47db66be553..5fa54c5c8136 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -151,7 +151,7 @@ class PrinterConfigNode : public Object { Array GetBuiltinKeywords(); - static constexpr const char* _type_key = "node.PrinterConfig"; + static constexpr const char* _type_key = "script.PrinterConfig"; TVM_DECLARE_FINAL_OBJECT_INFO(PrinterConfigNode, Object); }; diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 46087f0bda40..249c2dabb64e 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -84,7 +84,7 @@ class ObjectPathPairNode : public Object { ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path); - static constexpr const char* _type_key = "ObjectPathPair"; + static constexpr const char* _type_key = "node.ObjectPathPair"; TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object); }; diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index e4049f23873c..808fbed3cfc7 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -122,7 +122,7 @@ class StructInfoNode : public Object { */ mutable Span span; - static constexpr const char* _type_key = "StructInfo"; + static constexpr const char* _type_key = "ir.StructInfo"; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; static constexpr const uint32_t _type_child_slots = 7; diff --git a/include/tvm/runtime/profiling.h b/include/tvm/runtime/profiling.h index 2a6ecc0e4d43..c43543cc3863 100644 --- a/include/tvm/runtime/profiling.h +++ b/include/tvm/runtime/profiling.h @@ -75,7 +75,7 @@ class TimerNode : public Object { virtual ~TimerNode() {} - static constexpr const char* _type_key = "TimerNode"; + static constexpr const char* _type_key = "runtime.TimerNode"; TVM_DECLARE_BASE_OBJECT_INFO(TimerNode, Object); }; @@ -125,7 +125,7 @@ class Timer : public ObjectRef { * virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } * virtual ~CPUTimerNode() {} * - * static constexpr const char* _type_key = "CPUTimerNode"; + * static constexpr const char* _type_key = "runtime.CPUTimerNode"; * TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); * * private: diff --git a/include/tvm/target/tag.h b/include/tvm/target/tag.h index 9af2c8e49732..00542d43dce1 100644 --- a/include/tvm/target/tag.h +++ b/include/tvm/target/tag.h @@ -48,7 +48,7 @@ class TargetTagNode : public Object { .def_ro("config", &TargetTagNode::config); } - static constexpr const char* _type_key = "TargetTag"; + static constexpr const char* _type_key = "target.TargetTag"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TargetTagNode, Object); diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 2d6b1834e228..cc334b785428 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -176,7 +176,7 @@ class TargetNode : public Object { bool SEqualReduce(const TargetNode* other, SEqualReducer equal) const; void SHashReduce(SHashReducer hash_reduce) const; - static constexpr const char* _type_key = "Target"; + static constexpr const char* _type_key = "target.Target"; static constexpr const bool _type_has_method_visit_attrs = false; static constexpr const bool _type_has_method_sequal_reduce = true; static constexpr const bool _type_has_method_shash_reduce = true; diff --git a/include/tvm/target/target_info.h b/include/tvm/target/target_info.h index 0c1c4abf0158..552152dbde87 100644 --- a/include/tvm/target/target_info.h +++ b/include/tvm/target/target_info.h @@ -58,7 +58,7 @@ class MemoryInfoNode : public Object { .def_ro("head_address", &MemoryInfoNode::head_address); } - static constexpr const char* _type_key = "MemoryInfo"; + static constexpr const char* _type_key = "target.MemoryInfo"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(MemoryInfoNode, Object); }; diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h index 9875ceef3367..3a451832499f 100644 --- a/include/tvm/target/target_kind.h +++ b/include/tvm/target/target_kind.h @@ -85,7 +85,7 @@ class TargetKindNode : public Object { .def_ro("default_keys", &TargetKindNode::default_keys); } - static constexpr const char* _type_key = "TargetKind"; + static constexpr const char* _type_key = "target.TargetKind"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(TargetKindNode, Object); diff --git a/include/tvm/target/virtual_device.h b/include/tvm/target/virtual_device.h index 660f2bf7fcb0..aabd3a2ecaf2 100644 --- a/include/tvm/target/virtual_device.h +++ b/include/tvm/target/virtual_device.h @@ -258,7 +258,7 @@ class VirtualDeviceNode : public AttrsNodeReflAdapter { refl::DefaultValue("")); } - static constexpr const char* _type_key = "VirtualDevice"; + static constexpr const char* _type_key = "target.VirtualDevice"; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(VirtualDeviceNode, BaseAttrsNode); friend class VirtualDevice; diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index e92409df53a5..abf52a2528a1 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -91,7 +91,7 @@ class TVM_DLL OperationNode : public Object { .def_ro("attrs", &OperationNode::attrs); } - static constexpr const char* _type_key = "Operation"; + static constexpr const char* _type_key = "te.Operation"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object); @@ -119,7 +119,7 @@ class PlaceholderOpNode : public OperationNode { .def_ro("dtype", &PlaceholderOpNode::dtype); } - static constexpr const char* _type_key = "PlaceholderOp"; + static constexpr const char* _type_key = "te.PlaceholderOp"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode); }; @@ -155,7 +155,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { .def_ro("reduce_axis", &BaseComputeOpNode::reduce_axis); } - static constexpr const char* _type_key = "BaseComputeOp"; + static constexpr const char* _type_key = "te.BaseComputeOp"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode); }; @@ -179,7 +179,7 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode { refl::ObjectDef().def_ro("body", &ComputeOpNode::body); } - static constexpr const char* _type_key = "ComputeOp"; + static constexpr const char* _type_key = "te.ComputeOp"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode); }; @@ -244,7 +244,7 @@ class ScanOpNode : public OperationNode { .def_ro("spatial_axis_", &ScanOpNode::spatial_axis_); } - static constexpr const char* _type_key = "ScanOp"; + static constexpr const char* _type_key = "te.ScanOp"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode); }; @@ -293,7 +293,7 @@ class ExternOpNode : public OperationNode { .def_ro("body", &ExternOpNode::body); } - static constexpr const char* _type_key = "ExternOp"; + static constexpr const char* _type_key = "te.ExternOp"; static constexpr const bool _type_has_method_visit_attrs = false; TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode); }; diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py index f779df5d4c92..7a0aae5fdaea 100644 --- a/python/tvm/arith/int_set.py +++ b/python/tvm/arith/int_set.py @@ -20,6 +20,7 @@ from . import _ffi_api +@tvm.ffi.register_object("ir.IntSet") class IntSet(Object): """Represent a set of integer in one dimension.""" diff --git a/python/tvm/ffi/container.py b/python/tvm/ffi/container.py index 66038976f5d2..157840ba9d46 100644 --- a/python/tvm/ffi/container.py +++ b/python/tvm/ffi/container.py @@ -64,7 +64,7 @@ def getitem_helper(obj, elem_getter, length, idx): return elem_getter(obj, idx) -@register_object("object.Array") +@register_object("ffi.Array") class Array(core.Object, collections.abc.Sequence): """Array container""" @@ -148,7 +148,7 @@ def __iter__(self): break -@register_object("object.Map") +@register_object("ffi.Map") class Map(core.Object, collections.abc.Mapping): """Map container.""" diff --git a/python/tvm/ffi/ndarray.py b/python/tvm/ffi/ndarray.py index 6d901fb14b8f..05856bdae7a2 100644 --- a/python/tvm/ffi/ndarray.py +++ b/python/tvm/ffi/ndarray.py @@ -23,7 +23,7 @@ from . import _ffi_api -@registry.register_object("object.Shape") +@registry.register_object("ffi.Shape") class Shape(tuple, core.PyNativeObject): """Shape object that is possibly returned by FFI call.""" diff --git a/python/tvm/ir/attrs.py b/python/tvm/ir/attrs.py index 6565a8de37b4..d8d6188e155c 100644 --- a/python/tvm/ir/attrs.py +++ b/python/tvm/ir/attrs.py @@ -22,7 +22,7 @@ from . import _ffi_api -@tvm.ffi.register_object +@tvm.ffi.register_object("ir.Attrs") class Attrs(Object): """Attribute node, which is mainly use for defining attributes of operators. @@ -93,7 +93,7 @@ def __getitem__(self, item): return self.__getattr__(item) -@tvm.ffi.register_object +@tvm.ffi.register_object("ir.DictAttrs") class DictAttrs(Attrs): """Dictionary attributes.""" @@ -157,7 +157,7 @@ def make_node(type_key, **kwargs): .. code-block:: python - x = tvm.ir.make_node("IntImm", dtype="int32", value=10, span=None) + x = tvm.ir.make_node("ir.IntImm", dtype="int32", value=10, span=None) assert isinstance(x, tvm.tir.IntImm) assert x.value == 10 """ diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index a31be4c40ccb..d34137101119 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -27,13 +27,13 @@ class Node(Object): """Base class of all IR Nodes.""" -@register_object("SourceMap") +@register_object("ir.SourceMap") class SourceMap(Object): def add(self, name, content): return get_global_func("SourceMapAdd")(self, name, content) -@register_object("SourceName") +@register_object("ir.SourceName") class SourceName(Object): """A identifier for a source location. @@ -47,7 +47,7 @@ def __init__(self, name): self.__init_handle_by_constructor__(_ffi_api.SourceName, name) # type: ignore # pylint: disable=no-member -@register_object("Span") +@register_object("ir.Span") class Span(Object): """Specifies a location in a source program. @@ -69,7 +69,7 @@ def __init__(self, source_name, line, end_line, column, end_column): ) -@register_object("SequentialSpan") +@register_object("ir.SequentialSpan") class SequentialSpan(Object): """A sequence of source spans @@ -86,7 +86,7 @@ def __init__(self, spans): self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans) -@register_object +@register_object("ir.EnvFunc") class EnvFunc(Object): """Environment function. diff --git a/python/tvm/ir/expr.py b/python/tvm/ir/expr.py index 1d5389827f8e..008924c227b5 100644 --- a/python/tvm/ir/expr.py +++ b/python/tvm/ir/expr.py @@ -25,12 +25,14 @@ from .base import Node, Span +@tvm.ffi.register_object("ir.BaseExpr") class BaseExpr(Node): """Base class of all the expressions.""" span: Optional[Span] +@tvm.ffi.register_object("ir.PrimExpr") class PrimExpr(BaseExpr): """Base class of all primitive expressions. @@ -41,6 +43,7 @@ class PrimExpr(BaseExpr): dtype: str +@tvm.ffi.register_object("ir.RelaxExpr") class RelaxExpr(BaseExpr): """Base class of all non-primitive expressions.""" @@ -56,7 +59,7 @@ def struct_info(self) -> Optional["tvm.relax.StructInfo"]: return _ffi_api.ExprStructInfo(self) -@tvm.ffi.register_object("GlobalVar") +@tvm.ffi.register_object("ir.GlobalVar") class GlobalVar(RelaxExpr): """A global variable in the IR. @@ -102,7 +105,7 @@ def __call__(self, *args: RelaxExpr) -> BaseExpr: raise RuntimeError(f"Do not know how to handle GlobalVar.__call__ for types {arg_types}") -@tvm.ffi.register_object +@tvm.ffi.register_object("ir.Range") class Range(Node, Scriptable): """Represent a range in TVM. @@ -167,42 +170,3 @@ def __eq__(self, other: Object) -> bool: def __ne__(self, other: Object) -> bool: return not self.__eq__(other) - - -# TODO(@relax-team): remove when we have a RelaxExpr base class -def is_relax_expr(expr: RelaxExpr) -> bool: - """check if a RelaxExpr is a Relax expresssion. - - Parameters - ---------- - expr : RelaxExpr - The expression to check. - - Returns - ------- - res : bool - If the expression is Relax expression, return True; otherwise return False. - """ - from tvm import relax # pylint: disable=import-outside-toplevel - - if isinstance( - expr, - ( - relax.Call, - relax.Constant, - relax.Tuple, - relax.TupleGetItem, - relax.If, - relax.Var, - relax.DataflowVar, - relax.ShapeExpr, - relax.SeqExpr, - relax.Function, - relax.ExternFunc, - relax.PrimValue, - relax.StringImm, - relax.DataTypeImm, - ), - ): - return True - return False diff --git a/python/tvm/ir/function.py b/python/tvm/ir/function.py index 8527ce66f0cf..f6fc42ccbc07 100644 --- a/python/tvm/ir/function.py +++ b/python/tvm/ir/function.py @@ -34,6 +34,7 @@ class CallingConv(IntEnum): DEVICE_KERNEL_LAUNCH = 2 +@tvm.ffi.register_object("ir.BaseFunc") class BaseFunc(RelaxExpr): """Base class of all functions.""" diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py index 458a16717bea..d4b4fdca1654 100644 --- a/python/tvm/ir/global_info.py +++ b/python/tvm/ir/global_info.py @@ -20,6 +20,7 @@ from . import _ffi_api +@tvm.ffi.register_object("ir.GlobalInfo") class GlobalInfo(Object): """Base node for all global info that can appear in the IR""" @@ -35,6 +36,7 @@ def same_as(self, other): return super().__eq__(other) +@tvm.ffi.register_object("ir.DummyGlobalInfo") class DummyGlobalInfo(GlobalInfo): def __init__(self) -> None: self.__init_handle_by_constructor__( @@ -42,6 +44,7 @@ def __init__(self) -> None: ) +@tvm.ffi.register_object("ir.VDevice") class VDevice(GlobalInfo): def __init__( self, diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 6033dc6f8066..3b99db85986e 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -30,7 +30,7 @@ from .base import Node -@tvm.ffi.register_object("IRModule") +@tvm.ffi.register_object("ir.IRModule") class IRModule(Node, Scriptable): """IRModule that holds functions and type definitions. @@ -57,7 +57,7 @@ def __init__(self, functions=None, attrs=None, global_infos=None): attrs = None if not attrs else attrs if attrs is not None: - attrs = tvm.ir.make_node("DictAttrs", **attrs) + attrs = tvm.ir.make_node("ir.DictAttrs", **attrs) if global_infos is None: global_infos = {} self.__init_handle_by_constructor__( diff --git a/python/tvm/ir/op.py b/python/tvm/ir/op.py index 41105c4549dd..e5111ccc8220 100644 --- a/python/tvm/ir/op.py +++ b/python/tvm/ir/op.py @@ -22,7 +22,7 @@ from .expr import RelaxExpr -@tvm.ffi.register_object("Op") +@tvm.ffi.register_object("ir.Op") class Op(RelaxExpr): """Primitive operator in the IR.""" diff --git a/python/tvm/ir/supply.py b/python/tvm/ir/supply.py index 046432edfd99..2038df4b3104 100644 --- a/python/tvm/ir/supply.py +++ b/python/tvm/ir/supply.py @@ -20,7 +20,7 @@ from . import _ffi_api -@tvm.ffi.register_object("NameSupply") +@tvm.ffi.register_object("ir.NameSupply") class NameSupply(Object): """NameSupply that can be used to generate unique names. @@ -77,7 +77,7 @@ def contains_name(self, name, add_prefix=True): return _ffi_api.NameSupply_ContainsName(self, name, add_prefix) -@tvm.ffi.register_object("GlobalVarSupply") +@tvm.ffi.register_object("ir.GlobalVarSupply") class GlobalVarSupply(Object): """GlobalVarSupply that holds a mapping between names and GlobalVars. diff --git a/python/tvm/ir/type.py b/python/tvm/ir/type.py index d0bf7014e27b..0f287be96146 100644 --- a/python/tvm/ir/type.py +++ b/python/tvm/ir/type.py @@ -23,6 +23,7 @@ from .base import Node +@tvm.ffi.register_object("ir.Type") class Type(Node, Scriptable): """The base class of all types.""" @@ -38,7 +39,7 @@ def same_as(self, other): return super().__eq__(other) -@tvm.ffi.register_object("PrimType") +@tvm.ffi.register_object("ir.PrimType") class PrimType(Type): """Primitive data type in the low level IR @@ -52,7 +53,7 @@ def __init__(self, dtype): self.__init_handle_by_constructor__(_ffi_api.PrimType, dtype) -@tvm.ffi.register_object("PointerType") +@tvm.ffi.register_object("ir.PointerType") class PointerType(Type): """PointerType used in the low-level TIR. @@ -69,7 +70,7 @@ def __init__(self, element_type, storage_scope=""): self.__init_handle_by_constructor__(_ffi_api.PointerType, element_type, storage_scope) -@tvm.ffi.register_object("TupleType") +@tvm.ffi.register_object("ir.TupleType") class TupleType(Type): """The type of tuple values. @@ -83,7 +84,7 @@ def __init__(self, fields): self.__init_handle_by_constructor__(_ffi_api.TupleType, fields) -@tvm.ffi.register_object("FuncType") +@tvm.ffi.register_object("ir.FuncType") class FuncType(Type): """Function type. @@ -109,7 +110,7 @@ def __init__(self, arg_types, ret_type): ) -@tvm.ffi.register_object("TensorMapType") +@tvm.ffi.register_object("ir.TensorMapType") class TensorMapType(Type): """TensorMapType used in the low-level TIR. diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py index 34aab9c99e77..633c2c6790da 100644 --- a/python/tvm/relax/dpl/pattern.py +++ b/python/tvm/relax/dpl/pattern.py @@ -119,7 +119,7 @@ def has_attr(self, attrs: Dict[str, Object]) -> "AttrPattern": result: AttrPattern The resulting AttrPattern """ - attrs = make_node("DictAttrs", **attrs) + attrs = make_node("ir.DictAttrs", **attrs) return AttrPattern(self, attrs) def has_struct_info(self, struct_info: "StructInfo") -> "StructInfoPattern": diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 0fa8c4df88f8..9ddaf52e722c 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -56,6 +56,7 @@ def __init__(self): # NOTE: place base struct info in expr to avoid cyclic dep # from expr to struct info. +@tvm.ffi.register_object("ir.StructInfo") class StructInfo(Node, Scriptable): """The base class of all StructInfo. diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 538fa15c8a49..1d960d5dda4a 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -70,7 +70,7 @@ def from_dlpack(ext_tensor): ) -@tvm.ffi.register_object("object.NDArray") +@tvm.ffi.register_object("ffi.NDArray") class NDArray(tvm.ffi.core.NDArray): """Lightweight NDArray class of TVM runtime. diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index 45e4925a3e28..957db558a45b 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -40,7 +40,7 @@ ) -@tvm.ffi.register_object("ObjectPath") +@tvm.ffi.register_object("node.ObjectPath") class ObjectPath(Object): """ Path to an object from some root object. @@ -94,42 +94,42 @@ def missing_map_entry(self) -> "ObjectPath": __hash__ = Object.__hash__ -@tvm.ffi.register_object("RootPath") +@tvm.ffi.register_object("node.RootPath") class RootPath(ObjectPath): pass -@tvm.ffi.register_object("AttributeAccessPath") +@tvm.ffi.register_object("node.AttributeAccessPath") class AttributeAccessPath(ObjectPath): pass -@tvm.ffi.register_object("UnknownAttributeAccessPath") +@tvm.ffi.register_object("node.UnknownAttributeAccessPath") class UnknownAttributeAccessPath(ObjectPath): pass -@tvm.ffi.register_object("ArrayIndexPath") +@tvm.ffi.register_object("node.ArrayIndexPath") class ArrayIndexPath(ObjectPath): pass -@tvm.ffi.register_object("MissingArrayElementPath") +@tvm.ffi.register_object("node.MissingArrayElementPath") class MissingArrayElementPath(ObjectPath): pass -@tvm.ffi.register_object("MapValuePath") +@tvm.ffi.register_object("node.MapValuePath") class MapValuePath(ObjectPath): pass -@tvm.ffi.register_object("MissingMapEntryPath") +@tvm.ffi.register_object("node.MissingMapEntryPath") class MissingMapEntryPath(ObjectPath): pass -@tvm.ffi.register_object("ObjectPathPair") +@tvm.ffi.register_object("node.ObjectPathPair") class ObjectPathPair(Object): """ Pair of ObjectPaths, one for each object being tested for structural equality. diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index ade34e1e9b85..2820f7b97bc9 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -25,7 +25,7 @@ from .object_path import ObjectPath -@register_object("node.PrinterConfig") +@register_object("script.PrinterConfig") class PrinterConfig(Object): """Configuration of TVMScript printer""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1e48e9ea1ad7..c3a91594e7c3 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -426,11 +426,13 @@ def call_packed( sinfo_args = [sinfo_args] sinfo_args = [ - sinfo() - if callable(sinfo) - else sinfo.asobject() - if isinstance(sinfo, ObjectGeneric) - else sinfo + ( + sinfo() + if callable(sinfo) + else sinfo.asobject() + if isinstance(sinfo, ObjectGeneric) + else sinfo + ) for sinfo in sinfo_args ] @@ -439,7 +441,7 @@ def call_packed( attrs_type_key = kwargs["attrs_type_key"] kwargs.pop("attrs_type_key") else: - attrs_type_key = "DictAttrs" + attrs_type_key = "ir.DictAttrs" is_default = True attrs = None if kwargs or not is_default: diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index eb8bf1f9b807..6c83ef6e5bb2 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -30,7 +30,7 @@ from . import _ffi_api -@tvm.ffi.register_object +@tvm.ffi.register_object("target.TargetKind") class TargetKind(Object): """Kind of a compilation target""" @@ -53,7 +53,7 @@ def __getattr__(self, name: str): return _ffi_api.TargetGetFeature(self.target, name) -@tvm.ffi.register_object +@tvm.ffi.register_object("target.Target") class Target(Object): """Target device information, use through TVM API. diff --git a/python/tvm/target/virtual_device.py b/python/tvm/target/virtual_device.py index 3d923a4623d2..b062feb27aeb 100644 --- a/python/tvm/target/virtual_device.py +++ b/python/tvm/target/virtual_device.py @@ -22,7 +22,7 @@ from . import _ffi_api -@tvm.ffi.register_object +@tvm.ffi.register_object("target.VirtualDevice") class VirtualDevice(Object): """A compile time representation for where data is to be stored at runtime, and how to compile code to compute it.""" diff --git a/python/tvm/te/tensor.py b/python/tvm/te/tensor.py index aad18a8b016c..489ec38ba506 100644 --- a/python/tvm/te/tensor.py +++ b/python/tvm/te/tensor.py @@ -112,6 +112,7 @@ def name(self): return f"{op.name}.v{self.value_index}" +@tvm.ffi.register_object("te.Operation") class Operation(Object): """Represent an operation that generates a tensor""" @@ -141,12 +142,12 @@ def input_tensors(self): return _ffi_api.OpInputTensors(self) -@tvm.ffi.register_object +@tvm.ffi.register_object("te.PlaceholderOp") class PlaceholderOp(Operation): """Placeholder operation.""" -@tvm.ffi.register_object +@tvm.ffi.register_object("te.BaseComputeOp") class BaseComputeOp(Operation): """Compute operation.""" @@ -161,12 +162,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") -@tvm.ffi.register_object +@tvm.ffi.register_object("te.ComputeOp") class ComputeOp(BaseComputeOp): """Scalar operation.""" -@tvm.ffi.register_object +@tvm.ffi.register_object("te.ScanOp") class ScanOp(Operation): """Scan operation.""" @@ -176,6 +177,6 @@ def scan_axis(self): return self.__getattr__("scan_axis") -@tvm.ffi.register_object +@tvm.ffi.register_object("te.ExternOp") class ExternOp(Operation): """External operation.""" diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index e57c01f23afc..2e07cef9a3d3 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -558,7 +558,7 @@ def __init__( ) -@tvm.ffi.register_object +@tvm.ffi.register_object("ir.FloatImm") class FloatImm(ConstExpr): """Float constant. @@ -585,7 +585,7 @@ def __float__(self) -> float: return self.value -@tvm.ffi.register_object +@tvm.ffi.register_object("ir.IntImm") class IntImm(ConstExpr): """Int constant. diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index 55bae37809f0..b85fb3952249 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -81,7 +81,7 @@ def __init__( raise TypeError("params can only contain Var or Buffer") if attrs is None: - attrs = tvm.ir.make_node("DictAttrs") + attrs = tvm.ir.make_node("ir.DictAttrs") self.__init_handle_by_constructor__( _ffi_api.PrimFunc, diff --git a/src/ir/type.cc b/src/ir/type.cc index cd7b6a523cc2..83cbd962404a 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -30,6 +30,7 @@ TVM_FFI_STATIC_INIT_BLOCK({ PointerTypeNode::RegisterReflection(); TupleTypeNode::RegisterReflection(); FuncTypeNode::RegisterReflection(); + TensorMapTypeNode::RegisterReflection(); }); PrimType::PrimType(runtime::DataType dtype, Span span) { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index 98a83f4ed7e8..2af0bf159529 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -320,7 +320,7 @@ class CUDATimerNode : public TimerNode { CUDA_CALL(cudaEventCreate(&stop_)); } - static constexpr const char* _type_key = "CUDATimerNode"; + static constexpr const char* _type_key = "runtime.cuda.CUDATimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(CUDATimerNode, TimerNode); private: diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon_common.cc index 4c95d68b2dc3..27bf4ffc0eab 100644 --- a/src/runtime/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon_common.cc @@ -47,7 +47,7 @@ class HexagonTimerNode : public TimerNode { virtual int64_t SyncAndGetElapsedNanos() { return (end - start) * 1e3; } virtual ~HexagonTimerNode() {} - static constexpr const char* _type_key = "HexagonTimerNode"; + static constexpr const char* _type_key = "runtime.hexagon.HexagonTimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(HexagonTimerNode, TimerNode); private: diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index 46824b1600ee..8722dcfeb60a 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -388,7 +388,7 @@ virtual void Stop() { } virtual int64_t SyncAndGetElapsedNanos() { return stop_gpu_time_ - start_gpu_time_; } - static constexpr const char* _type_key = "MetalTimerNode"; + static constexpr const char* _type_key = "runtime.metal.MetalTimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(MetalTimerNode, TimerNode); private: diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h index dbef2f518f5a..3fefae597f21 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/runtime/opencl/opencl_common.h @@ -588,7 +588,7 @@ class OpenCLTimerNode : public TimerNode { OpenCLTimerNode() {} explicit OpenCLTimerNode(Device dev) : dev_(dev) {} - static constexpr const char* _type_key = "OpenCLTimerNode"; + static constexpr const char* _type_key = "runtime.opencl.OpenCLTimerNode"; static size_t count_timer_execs; static std::vector event_start_idxs; TVM_DECLARE_FINAL_OBJECT_INFO(OpenCLTimerNode, TimerNode); diff --git a/src/runtime/profiling.cc b/src/runtime/profiling.cc index bab1d50db6a9..2e1bfba0263a 100644 --- a/src/runtime/profiling.cc +++ b/src/runtime/profiling.cc @@ -54,7 +54,7 @@ class DefaultTimerNode : public TimerNode { virtual ~DefaultTimerNode() {} explicit DefaultTimerNode(Device dev) : device_(dev) {} - static constexpr const char* _type_key = "DefaultTimerNode"; + static constexpr const char* _type_key = "runtime.DefaultTimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(DefaultTimerNode, TimerNode); private: @@ -75,7 +75,7 @@ class CPUTimerNode : public TimerNode { virtual int64_t SyncAndGetElapsedNanos() { return duration_.count(); } virtual ~CPUTimerNode() {} - static constexpr const char* _type_key = "CPUTimerNode"; + static constexpr const char* _type_key = "runtime.CPUTimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(CPUTimerNode, TimerNode); private: diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index a5bc3b1a0da5..d0da510389f8 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -283,7 +283,7 @@ class ROCMTimerNode : public TimerNode { ROCM_CALL(hipEventCreate(&stop_)); } - static constexpr const char* _type_key = "ROCMTimerNode"; + static constexpr const char* _type_key = "runtime.rocm.ROCMTimerNode"; TVM_DECLARE_FINAL_OBJECT_INFO(ROCMTimerNode, TimerNode); private: diff --git a/tests/python/ir/test_ir_attrs.py b/tests/python/ir/test_ir_attrs.py index d61538ac2512..48c38c1556ef 100644 --- a/tests/python/ir/test_ir_attrs.py +++ b/tests/python/ir/test_ir_attrs.py @@ -38,7 +38,7 @@ def test_make_attrs(): def test_dict_attrs(): - dattr = tvm.ir.make_node("DictAttrs", x=1, y=10, name="xyz", padding=(0, 0)) + dattr = tvm.ir.make_node("ir.DictAttrs", x=1, y=10, name="xyz", padding=(0, 0)) assert dattr.x == 1 datrr = tvm.ir.load_json(tvm.ir.save_json(dattr)) assert dattr.name == "xyz" @@ -51,9 +51,9 @@ def test_dict_attrs(): def test_attrs_equal(): - dattr0 = tvm.ir.make_node("DictAttrs", x=1, y=[10, 20]) - dattr1 = tvm.ir.make_node("DictAttrs", y=[10, 20], x=1) - dattr2 = tvm.ir.make_node("DictAttrs", x=1, y=None) + dattr0 = tvm.ir.make_node("ir.DictAttrs", x=1, y=[10, 20]) + dattr1 = tvm.ir.make_node("ir.DictAttrs", y=[10, 20], x=1) + dattr2 = tvm.ir.make_node("ir.DictAttrs", x=1, y=None) tvm.ir.assert_structural_equal(dattr0, dattr1) assert not tvm.ir.structural_equal(dattr0, dattr2) assert not tvm.ir.structural_equal({"x": 1}, tvm.runtime.convert(1)) diff --git a/tests/python/ir/test_ir_container.py b/tests/python/ir/test_ir_container.py index c90df0a223d6..1004bad702f6 100644 --- a/tests/python/ir/test_ir_container.py +++ b/tests/python/ir/test_ir_container.py @@ -121,8 +121,8 @@ def test_return_variant_type(): def test_pass_variant_type(): func = tvm.get_global_func("testing.AcceptsVariant") - assert func("string arg") == "object.String" - assert func(17) == "IntImm" + assert func("string arg") == "ffi.String" + assert func(17) == "ir.IntImm" def test_pass_incorrect_variant_type(): diff --git a/tests/python/ir/test_node_reflection.py b/tests/python/ir/test_node_reflection.py index ab8d72841b7a..741e61b2eb48 100644 --- a/tests/python/ir/test_node_reflection.py +++ b/tests/python/ir/test_node_reflection.py @@ -70,7 +70,7 @@ def test_make_smap(): def test_make_node(): - x = tvm.ir.make_node("IntImm", dtype="int32", value=10, span=None) + x = tvm.ir.make_node("ir.IntImm", dtype="int32", value=10, span=None) assert isinstance(x, tvm.tir.IntImm) assert x.value == 10 A = te.placeholder((10,), name="A") @@ -80,7 +80,7 @@ def test_make_node(): assert AA.op == A.op assert AA.value_index == A.value_index - y = tvm.ir.make_node("IntImm", dtype=tvm.runtime.String("int32"), value=10, span=None) + y = tvm.ir.make_node("ir.IntImm", dtype=tvm.runtime.String("int32"), value=10, span=None) def test_make_sum(): diff --git a/tests/python/te/test_te_create_primfunc.py b/tests/python/te/test_te_create_primfunc.py index b0850a89b5c5..b070371b8ac4 100644 --- a/tests/python/te/test_te_create_primfunc.py +++ b/tests/python/te/test_te_create_primfunc.py @@ -335,7 +335,7 @@ def test_error_reporting(): assert False except TypeError as e: error_message = str(e) - assert error_message.find("Unsupported Operation: ScanOp.") != -1 + assert error_message.find("Unsupported Operation: te.ScanOp.") != -1 return assert False diff --git a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py index be55ed4777fb..5006efba50b2 100644 --- a/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py +++ b/tests/python/tir-transform/test_tir_transform_lower_device_storage_access_info.py @@ -22,7 +22,7 @@ @tvm.register_func("tvm.info.mem.global.test_with_head_address") def mem_info_with_head_address(): return tvm.ir.make_node( - "MemoryInfo", + "target.MemoryInfo", unit_bits=8, max_simd_bits=32, max_num_bits=128, @@ -33,7 +33,7 @@ def mem_info_with_head_address(): @tvm.register_func("tvm.info.mem.global.test_without_head_address") def mem_info_without_head_address(): return tvm.ir.make_node( - "MemoryInfo", + "target.MemoryInfo", unit_bits=8, max_simd_bits=32, max_num_bits=128, diff --git a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py index 548b199a94ce..e8d21a8dc4f9 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_rewrite.py +++ b/tests/python/tir-transform/test_tir_transform_storage_rewrite.py @@ -29,7 +29,11 @@ def register_mem(scope_tb, max_bits): @tvm.register_func("tvm.info.mem.%s" % scope_tb) def mem_info_inp_buffer(): return tvm.ir.make_node( - "MemoryInfo", unit_bits=16, max_simd_bits=32, max_num_bits=max_bits, head_address=None + "target.MemoryInfo", + unit_bits=16, + max_simd_bits=32, + max_num_bits=max_bits, + head_address=None, ) diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index 31ba6fb164d4..1dece07ed9dd 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -81,7 +81,7 @@ def test_ir_builder_tir_primfunc_complete(): body=tir.Evaluate(0), ret_type=tvm.ir.PrimType("int64"), buffer_map={c_handle: c_buffer, d_handle: d_buffer, e_handle: e_buffer}, - attrs=tvm.ir.make_node("DictAttrs", key="value"), + attrs=tvm.ir.make_node("ir.DictAttrs", key="value"), ) # Check if the generated ir is expected diff --git a/web/src/runtime.ts b/web/src/runtime.ts index 75b54463a745..e0898c95bf41 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1896,7 +1896,7 @@ export class Instance implements Disposable { /** Register all object factory */ private registerObjectFactoryFuncs(): void { - this.registerObjectConstructor("object.Array", + this.registerObjectConstructor("ffi.Array", (handle: number, lib: FFILibrary, ctx: RuntimeContext) => { return new TVMArray(handle, lib, ctx); });